summaryrefslogtreecommitdiff
path: root/ATRI/plugins
diff options
context:
space:
mode:
Diffstat (limited to 'ATRI/plugins')
-rw-r--r--ATRI/plugins/atri_chat_bot.py20
-rw-r--r--ATRI/plugins/chatbot/__init__.py6
2 files changed, 14 insertions, 12 deletions
diff --git a/ATRI/plugins/atri_chat_bot.py b/ATRI/plugins/atri_chat_bot.py
index 9000006..3226b2f 100644
--- a/ATRI/plugins/atri_chat_bot.py
+++ b/ATRI/plugins/atri_chat_bot.py
@@ -12,19 +12,22 @@ https://chatterbot.readthedocs.io/
MONGO_ADAPTER = "chatterbot.storage.MongoDatabaseAdapter"
SQLITE_ADAPTER = "chatterbot.storage.SQLStorageAdapter"
+
class ATRIChatBot:
bot = ChatBot(
"ATRI",
- storage_adapter=MONGO_ADAPTER if ChatterBot.mongo_database_uri else SQLITE_ADAPTER,
+ storage_adapter=MONGO_ADAPTER
+ if ChatterBot.mongo_database_uri
+ else SQLITE_ADAPTER,
logic_adapters=[
{
- 'import_path': 'chatterbot.logic.BestMatch',
- 'default_response': ChatterBot.default_response,
- 'maximum_similarity_threshold': ChatterBot.maximum_similarity_threshold
+ "import_path": "chatterbot.logic.BestMatch",
+ "default_response": ChatterBot.default_response,
+ "maximum_similarity_threshold": ChatterBot.maximum_similarity_threshold,
}
],
database_uri=ChatterBot.mongo_database_uri,
- read_only=True # 只能通过 learn 函数学习
+ read_only=True, # 只能通过 learn 函数学习
)
list_trainer = ListTrainer(bot)
session_text_dict = dict()
@@ -41,10 +44,7 @@ class ATRIChatBot:
# 查找上一条消息并训练模型
last_text = ATRIChatBot.session_text_dict.get(session_id)
if last_text:
- ATRIChatBot.list_trainer.train([
- last_text, # 问(可多个)
- text # 答
- ])
+ ATRIChatBot.list_trainer.train([last_text, text]) # 问(可多个) # 答
# 更新最后一条消息
ATRIChatBot.session_text_dict[session_id] = text
@@ -52,4 +52,4 @@ class ATRIChatBot:
async def get_response(text: str) -> str:
response = ATRIChatBot.bot.get_response(text)
log.info(f"人工智障回复:{text} -> {response.text}")
- return response.text \ No newline at end of file
+ return response.text
diff --git a/ATRI/plugins/chatbot/__init__.py b/ATRI/plugins/chatbot/__init__.py
index 53b3772..e9061cc 100644
--- a/ATRI/plugins/chatbot/__init__.py
+++ b/ATRI/plugins/chatbot/__init__.py
@@ -12,6 +12,7 @@ from nonebot.permission import SUPERUSER
chatbot = on_message(priority=114514)
+
@chatbot.handle()
async def _learn_from_group(bot: Bot, event: MessageEvent):
text = event.get_plaintext().strip()
@@ -19,13 +20,14 @@ async def _learn_from_group(bot: Bot, event: MessageEvent):
return
if isinstance(event, GroupMessageEvent): # 从群友那学习说话
ATRIChatBot.learn(event.get_session_id(), text)
- if random.random() <= ChatterBot.group_random_response_rate: # 随机回话
+ if random.random() <= ChatterBot.group_random_response_rate: # 随机回话
await chatbot.finish(await ATRIChatBot.get_response(text))
chatbot_learn = on_command("/learn_corpus", permission=SUPERUSER)
+
@chatbot_learn.handle()
async def _learn_from_corpus(bot: Bot, event: MessageEvent):
ATRIChatBot.learn_from_corpus()
- await chatbot.finish("咱从corpus那学习完了!") \ No newline at end of file
+ await chatbot.finish("咱从corpus那学习完了!")