diff options
-rw-r--r-- | ATRI/config.py | 8 | ||||
-rw-r--r-- | ATRI/plugins/atri_chat_bot.py | 55 | ||||
-rw-r--r-- | ATRI/plugins/chat/__init__.py | 7 | ||||
-rw-r--r-- | ATRI/plugins/chatbot/__init__.py | 31 | ||||
-rw-r--r-- | config.yml | 12 | ||||
-rw-r--r-- | requirements.txt | 3 |
6 files changed, 113 insertions, 3 deletions
diff --git a/ATRI/config.py b/ATRI/config.py index d1a9bfd..3350423 100644 --- a/ATRI/config.py +++ b/ATRI/config.py @@ -36,6 +36,14 @@ class SauceNAO: key: str = config.get("key", "") +class ChatterBot: + config: dict = config["ChatterBot"] + + mongo_database_uri: str = config.get("mongo_database_uri", None) + maximum_similarity_threshold: float = float(config.get("maximum_similarity_threshold", 0.05)) + default_response: set = set(config.get("default_response", ["咱听不明白(o_ _)ノ"])) + group_random_response_rate: float = float(config.get("group_random_response_rate", 0.1)) + RUNTIME_CONFIG = { "host": BotSelfConfig.host, "port": BotSelfConfig.port, diff --git a/ATRI/plugins/atri_chat_bot.py b/ATRI/plugins/atri_chat_bot.py new file mode 100644 index 0000000..9000006 --- /dev/null +++ b/ATRI/plugins/atri_chat_bot.py @@ -0,0 +1,55 @@ +from ATRI.config import ChatterBot +from chatterbot import ChatBot +from chatterbot.trainers import ListTrainer +from chatterbot.trainers import ChatterBotCorpusTrainer +from ATRI.log import logger as log + +__doc__ = """ +可以不断学习的聊天(胡言乱语/复读)机器人 +https://chatterbot.readthedocs.io/ +""" + +MONGO_ADAPTER = "chatterbot.storage.MongoDatabaseAdapter" +SQLITE_ADAPTER = "chatterbot.storage.SQLStorageAdapter" + +class ATRIChatBot: + bot = ChatBot( + "ATRI", + storage_adapter=MONGO_ADAPTER if ChatterBot.mongo_database_uri else SQLITE_ADAPTER, + logic_adapters=[ + { + 'import_path': 'chatterbot.logic.BestMatch', + 'default_response': ChatterBot.default_response, + 'maximum_similarity_threshold': ChatterBot.maximum_similarity_threshold + } + ], + database_uri=ChatterBot.mongo_database_uri, + read_only=True # 只能通过 learn 函数学习 + ) + list_trainer = ListTrainer(bot) + session_text_dict = dict() + + @staticmethod + def learn_from_corpus(): + trainer = ChatterBotCorpusTrainer(ATRIChatBot.bot) + # 从 corpus 的中文语料库学习,yaml 包太新的话需要把 corpus.py 的 yaml.load() 改成 yaml.full_load() + # 可以尝试用 https://github.com/hbwzhsh/chinese_chatbot_corpus 里面的语料进行训练 + trainer.train("chatterbot.corpus.chinese") + + @staticmethod + def learn(session_id: str, text: str): + # 查找上一条消息并训练模型 + last_text = ATRIChatBot.session_text_dict.get(session_id) + if last_text: + ATRIChatBot.list_trainer.train([ + last_text, # 问(可多个) + text # 答 + ]) + # 更新最后一条消息 + ATRIChatBot.session_text_dict[session_id] = text + + @staticmethod + async def get_response(text: str) -> str: + response = ATRIChatBot.bot.get_response(text) + log.info(f"人工智障回复:{text} -> {response.text}") + return response.text
\ No newline at end of file diff --git a/ATRI/plugins/chat/__init__.py b/ATRI/plugins/chat/__init__.py index ab22db9..7436645 100644 --- a/ATRI/plugins/chat/__init__.py +++ b/ATRI/plugins/chat/__init__.py @@ -7,7 +7,7 @@ from ATRI.utils import CoolqCodeChecker from ATRI.utils.limit import FreqLimiter from ATRI.utils.apscheduler import scheduler from .data_source import Chat - +from ATRI.plugins.atri_chat_bot import ATRIChatBot _chat_flmt = FreqLimiter(3) _chat_flmt_notice = choice(["慢...慢一..点❤", "冷静1下", "歇会歇会~~", "我开始为你以后的伴侣担心了..."]) @@ -27,7 +27,10 @@ async def _chat(bot: Bot, event: MessageEvent): repo = await Chat().deal(msg, user_id) _chat_flmt.start_cd(user_id) try: - await chat.finish(repo) + if repo: + await chat.finish(repo) + else: # 实在没话说就尝试 chatterbot + await chat.finish(await ATRIChatBot.get_response(msg)) except Exception: return diff --git a/ATRI/plugins/chatbot/__init__.py b/ATRI/plugins/chatbot/__init__.py new file mode 100644 index 0000000..53b3772 --- /dev/null +++ b/ATRI/plugins/chatbot/__init__.py @@ -0,0 +1,31 @@ +import random +from ATRI.config import ChatterBot +from ATRI.plugins.atri_chat_bot import ATRIChatBot +from nonebot import on_message +from nonebot import on_command +from nonebot.adapters.cqhttp import ( + Bot, + GroupMessageEvent, + MessageEvent, +) +from nonebot.permission import SUPERUSER + +chatbot = on_message(priority=114514) + +async def _learn_from_group(bot: Bot, event: MessageEvent): + text = event.get_plaintext().strip() + if not text: + return + if isinstance(event, GroupMessageEvent): # 从群友那学习说话 + ATRIChatBot.learn(event.get_session_id(), text) + if random.random() <= ChatterBot.group_random_response_rate: # 随机回话 + await chatbot.finish(await ATRIChatBot.get_response(text)) + + +chatbot_learn = on_command("/learn_corpus", permission=SUPERUSER) + +@chatbot_learn.handle() +async def _learn_from_corpus(bot: Bot, event: MessageEvent): + ATRIChatBot.learn_from_corpus() + await chatbot.finish("咱从corpus那学习完了!")
\ No newline at end of file @@ -11,3 +11,15 @@ BotSelfConfig: SauceNAO: key: "" + +ChatterBot: + # mongodb存储数据库地址(数据库将自动创建) + # 示例 mongodb://localhost:27017/atri-chatterbot + # 如果为空将用SQLite代替(Python3.8后不再支持time.clock将导致初始化SQLite报错) + mongo_database_uri: "mongodb://localhost:27017/atri-chatterbot" + # 置信率阈值 + maximum_similarity_threshold: 0.05 + # 生成的回复低于置信率阈值时将使用的默认回复 + default_response: ["咱听不明白(o_ _)ノ", "不懂欸", "?_?"] + # 群聊天随机回复概率 + group_random_response_rate: 0.05
\ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bef0a64..2adedb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ pathlib>=1.0.1 pytz>=2020.1 pyyaml>=5.4 scikit-image>=0.18.3 -tensorflow>=2.6.0
\ No newline at end of file +tensorflow>=2.6.0 +chatterbot>=1.0.0
\ No newline at end of file |