summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ATRI/config.py8
-rw-r--r--ATRI/plugins/atri_chat_bot.py55
-rw-r--r--ATRI/plugins/chat/__init__.py7
-rw-r--r--ATRI/plugins/chatbot/__init__.py31
-rw-r--r--config.yml12
-rw-r--r--requirements.txt3
6 files changed, 113 insertions, 3 deletions
diff --git a/ATRI/config.py b/ATRI/config.py
index d1a9bfd..3350423 100644
--- a/ATRI/config.py
+++ b/ATRI/config.py
@@ -36,6 +36,14 @@ class SauceNAO:
key: str = config.get("key", "")
+class ChatterBot:
+ config: dict = config["ChatterBot"]
+
+ mongo_database_uri: str = config.get("mongo_database_uri", None)
+ maximum_similarity_threshold: float = float(config.get("maximum_similarity_threshold", 0.05))
+ default_response: set = set(config.get("default_response", ["咱听不明白(o_ _)ノ"]))
+ group_random_response_rate: float = float(config.get("group_random_response_rate", 0.1))
+
RUNTIME_CONFIG = {
"host": BotSelfConfig.host,
"port": BotSelfConfig.port,
diff --git a/ATRI/plugins/atri_chat_bot.py b/ATRI/plugins/atri_chat_bot.py
new file mode 100644
index 0000000..9000006
--- /dev/null
+++ b/ATRI/plugins/atri_chat_bot.py
@@ -0,0 +1,55 @@
+from ATRI.config import ChatterBot
+from chatterbot import ChatBot
+from chatterbot.trainers import ListTrainer
+from chatterbot.trainers import ChatterBotCorpusTrainer
+from ATRI.log import logger as log
+
+__doc__ = """
+可以不断学习的聊天(胡言乱语/复读)机器人
+https://chatterbot.readthedocs.io/
+"""
+
+MONGO_ADAPTER = "chatterbot.storage.MongoDatabaseAdapter"
+SQLITE_ADAPTER = "chatterbot.storage.SQLStorageAdapter"
+
+class ATRIChatBot:
+ bot = ChatBot(
+ "ATRI",
+ storage_adapter=MONGO_ADAPTER if ChatterBot.mongo_database_uri else SQLITE_ADAPTER,
+ logic_adapters=[
+ {
+ 'import_path': 'chatterbot.logic.BestMatch',
+ 'default_response': ChatterBot.default_response,
+ 'maximum_similarity_threshold': ChatterBot.maximum_similarity_threshold
+ }
+ ],
+ database_uri=ChatterBot.mongo_database_uri,
+ read_only=True # 只能通过 learn 函数学习
+ )
+ list_trainer = ListTrainer(bot)
+ session_text_dict = dict()
+
+ @staticmethod
+ def learn_from_corpus():
+ trainer = ChatterBotCorpusTrainer(ATRIChatBot.bot)
+ # 从 corpus 的中文语料库学习,yaml 包太新的话需要把 corpus.py 的 yaml.load() 改成 yaml.full_load()
+ # 可以尝试用 https://github.com/hbwzhsh/chinese_chatbot_corpus 里面的语料进行训练
+ trainer.train("chatterbot.corpus.chinese")
+
+ @staticmethod
+ def learn(session_id: str, text: str):
+ # 查找上一条消息并训练模型
+ last_text = ATRIChatBot.session_text_dict.get(session_id)
+ if last_text:
+ ATRIChatBot.list_trainer.train([
+ last_text, # 问(可多个)
+ text # 答
+ ])
+ # 更新最后一条消息
+ ATRIChatBot.session_text_dict[session_id] = text
+
+ @staticmethod
+ async def get_response(text: str) -> str:
+ response = ATRIChatBot.bot.get_response(text)
+ log.info(f"人工智障回复:{text} -> {response.text}")
+ return response.text \ No newline at end of file
diff --git a/ATRI/plugins/chat/__init__.py b/ATRI/plugins/chat/__init__.py
index ab22db9..7436645 100644
--- a/ATRI/plugins/chat/__init__.py
+++ b/ATRI/plugins/chat/__init__.py
@@ -7,7 +7,7 @@ from ATRI.utils import CoolqCodeChecker
from ATRI.utils.limit import FreqLimiter
from ATRI.utils.apscheduler import scheduler
from .data_source import Chat
-
+from ATRI.plugins.atri_chat_bot import ATRIChatBot
_chat_flmt = FreqLimiter(3)
_chat_flmt_notice = choice(["慢...慢一..点❤", "冷静1下", "歇会歇会~~", "我开始为你以后的伴侣担心了..."])
@@ -27,7 +27,10 @@ async def _chat(bot: Bot, event: MessageEvent):
repo = await Chat().deal(msg, user_id)
_chat_flmt.start_cd(user_id)
try:
- await chat.finish(repo)
+ if repo:
+ await chat.finish(repo)
+ else: # 实在没话说就尝试 chatterbot
+ await chat.finish(await ATRIChatBot.get_response(msg))
except Exception:
return
diff --git a/ATRI/plugins/chatbot/__init__.py b/ATRI/plugins/chatbot/__init__.py
new file mode 100644
index 0000000..53b3772
--- /dev/null
+++ b/ATRI/plugins/chatbot/__init__.py
@@ -0,0 +1,31 @@
+import random
+from ATRI.config import ChatterBot
+from ATRI.plugins.atri_chat_bot import ATRIChatBot
+from nonebot import on_message
+from nonebot import on_command
+from nonebot.adapters.cqhttp import (
+ Bot,
+ GroupMessageEvent,
+ MessageEvent,
+)
+from nonebot.permission import SUPERUSER
+
+chatbot = on_message(priority=114514)
+
+async def _learn_from_group(bot: Bot, event: MessageEvent):
+ text = event.get_plaintext().strip()
+ if not text:
+ return
+ if isinstance(event, GroupMessageEvent): # 从群友那学习说话
+ ATRIChatBot.learn(event.get_session_id(), text)
+ if random.random() <= ChatterBot.group_random_response_rate: # 随机回话
+ await chatbot.finish(await ATRIChatBot.get_response(text))
+
+
+chatbot_learn = on_command("/learn_corpus", permission=SUPERUSER)
+
+@chatbot_learn.handle()
+async def _learn_from_corpus(bot: Bot, event: MessageEvent):
+ ATRIChatBot.learn_from_corpus()
+ await chatbot.finish("咱从corpus那学习完了!") \ No newline at end of file
diff --git a/config.yml b/config.yml
index afe9fde..84eee7e 100644
--- a/config.yml
+++ b/config.yml
@@ -11,3 +11,15 @@ BotSelfConfig:
SauceNAO:
key: ""
+
+ChatterBot:
+ # mongodb存储数据库地址(数据库将自动创建)
+ # 示例 mongodb://localhost:27017/atri-chatterbot
+ # 如果为空将用SQLite代替(Python3.8后不再支持time.clock将导致初始化SQLite报错)
+ mongo_database_uri: "mongodb://localhost:27017/atri-chatterbot"
+ # 置信率阈值
+ maximum_similarity_threshold: 0.05
+ # 生成的回复低于置信率阈值时将使用的默认回复
+ default_response: ["咱听不明白(o_ _)ノ", "不懂欸", "?_?"]
+ # 群聊天随机回复概率
+ group_random_response_rate: 0.05 \ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index bef0a64..2adedb0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,4 +10,5 @@ pathlib>=1.0.1
pytz>=2020.1
pyyaml>=5.4
scikit-image>=0.18.3
-tensorflow>=2.6.0 \ No newline at end of file
+tensorflow>=2.6.0
+chatterbot>=1.0.0 \ No newline at end of file