summaryrefslogtreecommitdiff
path: root/ATRI/plugins/atri_chat_bot.py
blob: 3226b2f9ecb915830b71f7555ef1bc153d91ff97 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from ATRI.config import ChatterBot
from chatterbot import ChatBot
from chatterbot.trainers import ListTrainer
from chatterbot.trainers import ChatterBotCorpusTrainer
from ATRI.log import logger as log

__doc__ = """
可以不断学习的聊天(胡言乱语/复读)机器人
https://chatterbot.readthedocs.io/
"""

MONGO_ADAPTER = "chatterbot.storage.MongoDatabaseAdapter"
SQLITE_ADAPTER = "chatterbot.storage.SQLStorageAdapter"


class ATRIChatBot:
    bot = ChatBot(
        "ATRI",
        storage_adapter=MONGO_ADAPTER
        if ChatterBot.mongo_database_uri
        else SQLITE_ADAPTER,
        logic_adapters=[
            {
                "import_path": "chatterbot.logic.BestMatch",
                "default_response": ChatterBot.default_response,
                "maximum_similarity_threshold": ChatterBot.maximum_similarity_threshold,
            }
        ],
        database_uri=ChatterBot.mongo_database_uri,
        read_only=True,  # 只能通过 learn 函数学习
    )
    list_trainer = ListTrainer(bot)
    session_text_dict = dict()

    @staticmethod
    def learn_from_corpus():
        trainer = ChatterBotCorpusTrainer(ATRIChatBot.bot)
        # 从 corpus 的中文语料库学习,yaml 包太新的话需要把 corpus.py 的 yaml.load() 改成 yaml.full_load()
        # 可以尝试用 https://github.com/hbwzhsh/chinese_chatbot_corpus 里面的语料进行训练
        trainer.train("chatterbot.corpus.chinese")

    @staticmethod
    def learn(session_id: str, text: str):
        # 查找上一条消息并训练模型
        last_text = ATRIChatBot.session_text_dict.get(session_id)
        if last_text:
            ATRIChatBot.list_trainer.train([last_text, text])  # 问(可多个)  # 答
        # 更新最后一条消息
        ATRIChatBot.session_text_dict[session_id] = text

    @staticmethod
    async def get_response(text: str) -> str:
        response = ATRIChatBot.bot.get_response(text)
        log.info(f"人工智障回复:{text}  ->  {response.text}")
        return response.text