summaryrefslogtreecommitdiff
path: root/ATRI/plugins/manage/data_source.py
diff options
context:
space:
mode:
Diffstat (limited to 'ATRI/plugins/manage/data_source.py')
-rw-r--r--ATRI/plugins/manage/data_source.py371
1 files changed, 161 insertions, 210 deletions
diff --git a/ATRI/plugins/manage/data_source.py b/ATRI/plugins/manage/data_source.py
index ff344e7..37edefc 100644
--- a/ATRI/plugins/manage/data_source.py
+++ b/ATRI/plugins/manage/data_source.py
@@ -1,273 +1,224 @@
-import json
+from typing import Dict
from pathlib import Path
from datetime import datetime
+from nonebot import get_bot
+from nonebot.adapters import Bot
+from nonebot.adapters.onebot.v11 import MessageEvent, GroupMessageEvent
+
+from ATRI.utils import FileDealer
from ATRI.service import ServiceTools
from ATRI.message import MessageBuilder
from ATRI.exceptions import load_error
+from .models import RequestInfo
+
-MANAGE_DIR = Path(".") / "data" / "plugins" / "manege"
-ESSENTIAL_DIR = Path(".") / "data" / "plugins" / "essential"
+MANAGE_DIR = Path(".") / "data" / "plugins" / "manage"
MANAGE_DIR.mkdir(parents=True, exist_ok=True)
-ESSENTIAL_DIR.mkdir(parents=True, exist_ok=True)
-_TRACK_BACK_FORMAT = (
- MessageBuilder("Track ID: {track_id}")
- .text("Prompt: {prompt}")
- .text("Time: {time}")
+_TRACEBACK_FORMAT = (
+ MessageBuilder("追踪ID:{trace_id}")
+ .text("关键词:{prompt}")
+ .text("时间:{time}")
.text("{content}")
.done()
)
-class Manage:
- @staticmethod
- def _load_block_user_list() -> dict:
- """
- 文件结构:
- {
- "Block user ID": {
- "time": "Block time"
- }
- }
- """
- file_name = "block_user.json"
+class BotManager:
+ async def __load_data(self, file_name: str) -> dict:
path = MANAGE_DIR / file_name
+ dealer = FileDealer(path)
if not path.is_file():
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
- return dict()
+ await dealer.write_json(dict())
+
try:
- data = json.loads(path.read_bytes())
+ data = dealer.json()
except Exception:
data = dict()
return data
- @staticmethod
- def _save_block_user_list(data: dict) -> None:
- file_name = "block_user.json"
+ async def __store_data(self, file_name: str, data: dict) -> None:
path = MANAGE_DIR / file_name
+ dealer = FileDealer(path)
if not path.is_file():
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
-
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps(data, indent=4))
-
- @staticmethod
- def _load_block_group_list() -> dict:
- """
- 文件结构:
- {
- "Block group ID": {
- "time": "Block time"
- }
- }
- """
- file_name = "block_group.json"
- path = MANAGE_DIR / file_name
- if not path.is_file():
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
- return dict()
+ await dealer.write_json(dict())
- try:
- data = json.loads(path.read_bytes())
- except Exception:
- data = dict()
- return data
+ await dealer.write_json(data)
- @staticmethod
- def _save_block_group_list(data: dict) -> None:
- file_name = "block_group.json"
- path = MANAGE_DIR / file_name
- if not path.is_file():
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
+ async def __load_block_group(self) -> dict:
+ return await self.__load_data("block_group.json")
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps(data, indent=4))
+ async def __store_block_group(self, data: dict) -> None:
+ await self.__store_data("block_group.json", data)
- @classmethod
- def block_user(cls, user_id: str) -> bool:
- data = cls._load_block_user_list()
- now_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- data[user_id] = {"time": now_time}
- try:
- cls._save_block_user_list(data)
- return True
- except Exception:
- return False
+ async def __load_block_user(self) -> dict:
+ return await self.__load_data("block_user.json")
- @classmethod
- def unblock_user(cls, user_id: str) -> bool:
- data: dict = cls._load_block_user_list()
- if user_id not in data:
- return False
+ async def __store_block_user(self, data: dict) -> None:
+ await self.__store_data("block_user.json", data)
- try:
- data.pop(user_id)
- cls._save_block_user_list(data)
- return True
- except Exception:
- return False
+ async def load_friend_req(self) -> Dict[str, RequestInfo]:
+ return await self.__load_data("friend_add.json")
+
+ async def store_friend_req(self, data: dict) -> None:
+ await self.__store_data("friend_add.json", data)
+
+ async def load_group_req(self) -> Dict[str, RequestInfo]:
+ return await self.__load_data("group_invite.json")
+
+ async def store_group_req(self, data: dict) -> None:
+ await self.__store_data("group_invite.json", data)
- @classmethod
- def block_group(cls, group_id: str) -> bool:
- data = cls._load_block_group_list()
- now_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- data[group_id] = {"time": now_time}
+ async def block_group(self, group_id: str) -> None:
+ data = await self.__load_block_group()
+ data[group_id] = {"time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
try:
- cls._save_block_group_list(data)
- return True
+ await self.__store_block_group(data)
except Exception:
- return False
+ raise Exception("写入文件时失败")
- @classmethod
- def unblock_group(cls, group_id: str) -> bool:
- data: dict = cls._load_block_group_list()
+ async def unblock_group(self, group_id: str) -> None:
+ data = await self.__load_block_group()
if group_id not in data:
- return False
+ raise Exception("群不存在于封禁名单")
try:
data.pop(group_id)
- cls._save_block_group_list(data)
- return True
+ await self.__store_block_group(data)
except Exception:
- return False
+ raise Exception("写入文件时失败")
- @staticmethod
- def control_global_service(service: str, is_enabled: bool) -> bool:
- """
- Only SUPERUSER.
- """
+ async def block_user(self, user_id: str) -> None:
+ data = await self.__load_block_user()
+ data[user_id] = {"time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
try:
- data = ServiceTools(service).load_service()
+ await self.__store_block_user(data)
except Exception:
- return False
- data.enabled = is_enabled
- ServiceTools(service).save_service(data.dict())
- return True
-
- @staticmethod
- def control_user_service(service: str, user_id: str, is_enabled: bool) -> bool:
- """
- Only SUPERUSER.
- """
- try:
- data = ServiceTools(service).load_service()
- except Exception:
- return False
- temp_list: list = data.disable_user
-
- if is_enabled:
- try:
- temp_list.remove(user_id)
- except Exception:
- return False
- else:
- if user_id in temp_list:
- return True
-
- temp_list.append(user_id)
+ raise Exception("写入文件时失败")
- data.disable_user = temp_list
- ServiceTools(service).save_service(data.dict())
- return True
+ async def unblock_user(self, user_id: str) -> None:
+ data = await self.__load_block_user()
+ if user_id not in data:
+ raise Exception("用户不存在于封禁名单")
- @staticmethod
- def control_group_service(service: str, group_id: str, is_enabled: bool) -> bool:
- """
- SUPERUSER and GROUPADMIN or GROUPOWNER.
- Only current group.
- """
try:
- data = ServiceTools(service).load_service()
+ data.pop(user_id)
+ await self.__store_block_user(data)
except Exception:
- return False
- temp_list: list = data.disable_group
+ raise Exception("写入文件时失败")
- if is_enabled:
+ def toggle_global_service(self, service: str) -> bool:
+ serv = ServiceTools(service)
+ try:
+ data = serv.load_service()
+ except Exception as e:
+ error_msg = str(e)
+ raise Exception(error_msg)
+
+ data.enabled = not data.enabled
+ serv.save_service(data)
+ return data.enabled
+
+ def toggle_group_service(self, service: str, event) -> bool:
+ if isinstance(event, GroupMessageEvent):
+ group_id = str(event.group_id)
+ serv = ServiceTools(service)
try:
- temp_list.remove(group_id)
- except Exception:
- return False
+ data = serv.load_service()
+ except Exception as e:
+ error_msg = str(e)
+ raise Exception(error_msg)
+
+ if group_id in data.disable_group:
+ data.disable_group.remove(group_id)
+ result = True
+ else:
+ data.disable_group.append(group_id)
+ result = False
+ serv.save_service(data)
+ return result
+ raise Exception("该功能只能在群聊中使用")
+
+ def toggle_user_service(self, service: str, event: MessageEvent) -> bool:
+ user_id = event.get_user_id()
+ serv = ServiceTools(service)
+ try:
+ data = serv.load_service()
+ except Exception as e:
+ error_msg = str(e)
+ raise Exception(error_msg)
+
+ if user_id in data.disable_user:
+ data.disable_user.remove(user_id)
+ result = True
else:
- if group_id in temp_list:
- return True
-
- temp_list.append(group_id)
-
- data.disable_group = temp_list
- ServiceTools(service).save_service(data.dict())
- return True
-
- @staticmethod
- def load_friend_apply_list() -> dict:
- file_name = "friend_add.json"
- path = ESSENTIAL_DIR / file_name
- if not path.is_file():
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
- return dict()
+ data.disable_user.append(user_id)
+ result = False
+ serv.save_service(data)
+ return result
+ async def track_error(self, trace_id: str) -> str:
try:
- data = json.loads(path.read_bytes())
+ data = load_error(trace_id)
except Exception:
- data = dict()
- return data
-
- @staticmethod
- def save_friend_apply_list(data: dict) -> None:
- file_name = "friend_add.json"
- path = ESSENTIAL_DIR / file_name
- if not path.is_file():
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
-
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps(data, indent=4))
+ raise Exception("未找到对应ID的信息")
- @staticmethod
- def load_invite_apply_list() -> dict:
- file_name = "group_invite.json"
- path = ESSENTIAL_DIR / file_name
- if not path.is_file():
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
- return dict()
+ return _TRACEBACK_FORMAT.format(
+ trace_id=data.track_id,
+ prompt=data.prompt,
+ time=data.time,
+ content=data.content,
+ )
+ def __get_bot(self) -> Bot:
try:
- data = json.loads(path.read_bytes())
+ return get_bot()
except Exception:
- data = dict()
- return data
+ raise Exception("无法获取 bot 实例")
- @staticmethod
- def save_invite_apply_list(data: dict) -> None:
- file_name = "group_invite.json"
- path = ESSENTIAL_DIR / file_name
- if not path.is_file():
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
-
- with open(path, "w", encoding="utf-8") as w:
- w.write(json.dumps(data, indent=4))
+ async def apply_friend_req(self, code: str) -> None:
+ bot = self.__get_bot()
+ try:
+ await bot.call_api("set_friend_add_request", flag=code, approve=True)
+ except Exception:
+ raise Exception("同意失败,请尝试手动同意")
+ data = await self.load_friend_req()
+ data.pop(code)
+ await self.store_friend_req(data)
- @staticmethod
- async def track_error(track_id: str) -> str:
+ async def reject_friend_req(self, code: str) -> None:
+ bot = self.__get_bot()
try:
- data = load_error(track_id)
+ await bot.call_api("set_friend_add_request", flag=code, approve=False)
except Exception:
- return "请检查ID是否正确..."
+ raise Exception("拒绝失败,请尝试手动拒绝")
+ data = await self.load_friend_req()
+ data.pop(code)
+ await self.store_friend_req(data)
- prompt = data.get("prompt", "ignore")
- time = data.get("time", "ignore")
- content = data.get("content", "ignore")
+ async def apply_group_req(self, code: str) -> None:
+ bot = self.__get_bot()
+ try:
+ await bot.call_api(
+ "set_group_add_request", flag=code, sub_type="invite", approve=True
+ )
+ except Exception:
+ raise Exception("同意失败,请尝试手动同意")
+ data = await self.load_group_req()
+ data.pop(code)
+ await self.store_group_req(data)
- repo = _TRACK_BACK_FORMAT.format(
- track_id=track_id, prompt=prompt, time=time, content=content
- )
- return repo
+ async def reject_group_req(self, code: str) -> None:
+ bot = self.__get_bot()
+ try:
+ await bot.call_api(
+ "set_group_add_request", flag=code, sub_type="invite", approve=False
+ )
+ except Exception:
+ raise Exception("拒绝失败,请尝试手动拒绝")
+ data = await self.load_group_req()
+ data.pop(code)
+ await self.store_group_req(data)