blob: 900000604e98adfaf6ae23240e00efe407504920 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
|
from ATRI.config import ChatterBot
from chatterbot import ChatBot
from chatterbot.trainers import ListTrainer
from chatterbot.trainers import ChatterBotCorpusTrainer
from ATRI.log import logger as log
__doc__ = """
可以不断学习的聊天(胡言乱语/复读)机器人
https://chatterbot.readthedocs.io/
"""
MONGO_ADAPTER = "chatterbot.storage.MongoDatabaseAdapter"
SQLITE_ADAPTER = "chatterbot.storage.SQLStorageAdapter"
class ATRIChatBot:
bot = ChatBot(
"ATRI",
storage_adapter=MONGO_ADAPTER if ChatterBot.mongo_database_uri else SQLITE_ADAPTER,
logic_adapters=[
{
'import_path': 'chatterbot.logic.BestMatch',
'default_response': ChatterBot.default_response,
'maximum_similarity_threshold': ChatterBot.maximum_similarity_threshold
}
],
database_uri=ChatterBot.mongo_database_uri,
read_only=True # 只能通过 learn 函数学习
)
list_trainer = ListTrainer(bot)
session_text_dict = dict()
@staticmethod
def learn_from_corpus():
trainer = ChatterBotCorpusTrainer(ATRIChatBot.bot)
# 从 corpus 的中文语料库学习,yaml 包太新的话需要把 corpus.py 的 yaml.load() 改成 yaml.full_load()
# 可以尝试用 https://github.com/hbwzhsh/chinese_chatbot_corpus 里面的语料进行训练
trainer.train("chatterbot.corpus.chinese")
@staticmethod
def learn(session_id: str, text: str):
# 查找上一条消息并训练模型
last_text = ATRIChatBot.session_text_dict.get(session_id)
if last_text:
ATRIChatBot.list_trainer.train([
last_text, # 问(可多个)
text # 答
])
# 更新最后一条消息
ATRIChatBot.session_text_dict[session_id] = text
@staticmethod
async def get_response(text: str) -> str:
response = ATRIChatBot.bot.get_response(text)
log.info(f"人工智障回复:{text} -> {response.text}")
return response.text
|