summaryrefslogtreecommitdiff
path: root/ATRI/plugins/atri_chat_bot.py
blob: 900000604e98adfaf6ae23240e00efe407504920 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from ATRI.config import ChatterBot
from chatterbot import ChatBot
from chatterbot.trainers import ListTrainer
from chatterbot.trainers import ChatterBotCorpusTrainer
from ATRI.log import logger as log

__doc__ = """
可以不断学习的聊天(胡言乱语/复读)机器人
https://chatterbot.readthedocs.io/
"""

MONGO_ADAPTER = "chatterbot.storage.MongoDatabaseAdapter"
SQLITE_ADAPTER = "chatterbot.storage.SQLStorageAdapter"

class ATRIChatBot:
    bot = ChatBot(
        "ATRI",
        storage_adapter=MONGO_ADAPTER if ChatterBot.mongo_database_uri else SQLITE_ADAPTER,
        logic_adapters=[
            {
                'import_path': 'chatterbot.logic.BestMatch',
                'default_response': ChatterBot.default_response,
                'maximum_similarity_threshold': ChatterBot.maximum_similarity_threshold
            }
        ],
        database_uri=ChatterBot.mongo_database_uri,
        read_only=True  # 只能通过 learn 函数学习
    )
    list_trainer = ListTrainer(bot)
    session_text_dict = dict()

    @staticmethod
    def learn_from_corpus():
        trainer = ChatterBotCorpusTrainer(ATRIChatBot.bot)
        # 从 corpus 的中文语料库学习,yaml 包太新的话需要把 corpus.py 的 yaml.load() 改成 yaml.full_load()
        # 可以尝试用 https://github.com/hbwzhsh/chinese_chatbot_corpus 里面的语料进行训练
        trainer.train("chatterbot.corpus.chinese")

    @staticmethod
    def learn(session_id: str, text: str):
        # 查找上一条消息并训练模型
        last_text = ATRIChatBot.session_text_dict.get(session_id)
        if last_text:
            ATRIChatBot.list_trainer.train([
                last_text,  # 问(可多个)
                text  # 答
            ])
        # 更新最后一条消息
        ATRIChatBot.session_text_dict[session_id] = text

    @staticmethod
    async def get_response(text: str) -> str:
        response = ATRIChatBot.bot.get_response(text)
        log.info(f"人工智障回复:{text}  ->  {response.text}")
        return response.text