diff options
Diffstat (limited to 'ATRI/plugins/atri_chat_bot.py')
-rw-r--r-- | ATRI/plugins/atri_chat_bot.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/ATRI/plugins/atri_chat_bot.py b/ATRI/plugins/atri_chat_bot.py new file mode 100644 index 0000000..9000006 --- /dev/null +++ b/ATRI/plugins/atri_chat_bot.py @@ -0,0 +1,55 @@ +from ATRI.config import ChatterBot +from chatterbot import ChatBot +from chatterbot.trainers import ListTrainer +from chatterbot.trainers import ChatterBotCorpusTrainer +from ATRI.log import logger as log + +__doc__ = """ +可以不断学习的聊天(胡言乱语/复读)机器人 +https://chatterbot.readthedocs.io/ +""" + +MONGO_ADAPTER = "chatterbot.storage.MongoDatabaseAdapter" +SQLITE_ADAPTER = "chatterbot.storage.SQLStorageAdapter" + +class ATRIChatBot: + bot = ChatBot( + "ATRI", + storage_adapter=MONGO_ADAPTER if ChatterBot.mongo_database_uri else SQLITE_ADAPTER, + logic_adapters=[ + { + 'import_path': 'chatterbot.logic.BestMatch', + 'default_response': ChatterBot.default_response, + 'maximum_similarity_threshold': ChatterBot.maximum_similarity_threshold + } + ], + database_uri=ChatterBot.mongo_database_uri, + read_only=True # 只能通过 learn 函数学习 + ) + list_trainer = ListTrainer(bot) + session_text_dict = dict() + + @staticmethod + def learn_from_corpus(): + trainer = ChatterBotCorpusTrainer(ATRIChatBot.bot) + # 从 corpus 的中文语料库学习,yaml 包太新的话需要把 corpus.py 的 yaml.load() 改成 yaml.full_load() + # 可以尝试用 https://github.com/hbwzhsh/chinese_chatbot_corpus 里面的语料进行训练 + trainer.train("chatterbot.corpus.chinese") + + @staticmethod + def learn(session_id: str, text: str): + # 查找上一条消息并训练模型 + last_text = ATRIChatBot.session_text_dict.get(session_id) + if last_text: + ATRIChatBot.list_trainer.train([ + last_text, # 问(可多个) + text # 答 + ]) + # 更新最后一条消息 + ATRIChatBot.session_text_dict[session_id] = text + + @staticmethod + async def get_response(text: str) -> str: + response = ATRIChatBot.bot.get_response(text) + log.info(f"人工智障回复:{text} -> {response.text}") + return response.text
\ No newline at end of file |