diff options
Diffstat (limited to 'ATRI/service.py')
-rw-r--r-- | ATRI/service.py | 54 |
1 files changed, 27 insertions, 27 deletions
diff --git a/ATRI/service.py b/ATRI/service.py index 9f680c8..c9509df 100644 --- a/ATRI/service.py +++ b/ATRI/service.py @@ -4,7 +4,7 @@ import json from pathlib import Path from types import ModuleType from pydantic import BaseModel -from typing import List, Set, Tuple, Type, Union, Optional, TYPE_CHECKING +from typing import List, Set, Tuple, Type, Union, Optional from nonebot.matcher import Matcher from nonebot.permission import Permission @@ -19,9 +19,6 @@ from nonebot.rule import Rule, command, keyword, regex from ATRI.exceptions import ReadFileError, WriteError -if TYPE_CHECKING: - from nonebot.adapters import Bot, Event - SERVICE_DIR = Path(".") / "data" / "service" SERVICES_DIR = SERVICE_DIR / "services" @@ -69,7 +66,7 @@ class Service: def __init__( self, service: str, - docs: str = None, + docs: str, only_admin: bool = False, rule: Optional[Union[Rule, T_RuleChecker]] = None, permission: Optional[Permission] = None, @@ -77,6 +74,7 @@ class Service: temp: bool = False, priority: int = 1, state: Optional[T_State] = None, + main_cmd: str = str(), ): self.service = service self.docs = docs @@ -87,13 +85,9 @@ class Service: self.temp = temp self.priority = priority self.state = state + self.main_cmd = (main_cmd,) - def _generate_service_config(self, service: str = None, docs: str = None) -> None: - if not service: - service = self.service - if not docs: - docs = self.docs or str() - + def _generate_service_config(self, service: str, docs: str = str()) -> None: path = SERVICES_DIR / f"{service}.json" data = ServiceInfo( service=service, @@ -110,31 +104,28 @@ class Service: except WriteError: raise WriteError("Write service info failed!") - def save_service(self, service_data: dict, service: str = None) -> None: + def save_service(self, service_data: dict, service: str) -> None: if not service: service = self.service path = SERVICES_DIR / f"{service}.json" if not path.is_file(): - self._generate_service_config() + self._generate_service_config(service, self.docs) with open(path, "w", encoding="utf-8") as w: w.write(json.dumps(service_data, indent=4)) - def load_service(self, service: str = None) -> dict: - if not service: - service = self.service - + def load_service(self, service: str) -> dict: path = SERVICES_DIR / f"{service}.json" if not path.is_file(): - self._generate_service_config() + self._generate_service_config(service, self.docs) try: data = json.loads(path.read_bytes()) except ReadFileError: with open(path, "w", encoding="utf-8") as w: w.write(json.dumps({})) - self._generate_service_config() + self._generate_service_config(service, self.docs) data = json.loads(path.read_bytes()) return data @@ -142,25 +133,25 @@ class Service: data = self.load_service(self.service) temp_data: dict = data["cmd_list"] temp_data.update(cmds) - self.save_service(data) + self.save_service(data, self.service) def _load_cmds(self) -> dict: path = SERVICES_DIR / f"{self.service}.json" if not path.is_file(): - self._generate_service_config() + self._generate_service_config(self.service, self.docs) data = json.loads(path.read_bytes()) return data["cmd_list"] def on_message( self, - name: str = None, + name: str = str(), docs: str = str(), rule: Optional[Union[Rule, T_RuleChecker]] = None, permission: Optional[Union[Permission, T_PermissionChecker]] = None, handlers: Optional[List[Union[T_Handler, Dependent]]] = None, block: bool = True, - priority: int = None, + priority: int = 1, state: Optional[T_State] = None, ) -> Type[Matcher]: if not rule: @@ -169,8 +160,6 @@ class Service: permission = self.permission if not handlers: handlers = self.handlers - if not priority: - priority = self.priority if not state: state = self.state @@ -253,11 +242,13 @@ class Service: if not aliases: aliases = set() + if isinstance(cmd, tuple): + cmd = ".".join(map(str, cmd)) + cmd_list[cmd] = CommandInfo( type="command", docs=docs, aliases=list(aliases) ).dict() self._save_cmds(cmd_list) - commands = set([cmd]) | (aliases or set()) return self.on_message(rule=command(*commands) & rule, block=True, **kwargs) @@ -297,6 +288,15 @@ class Service: return self.on_message(rule=regex(pattern, flags) & rule, **kwargs) + def cmd_as_group(self, cmd: str, docs: str, **kwargs) -> Type[Matcher]: + sub_cmd = (cmd,) if isinstance(cmd, str) else cmd + _cmd = self.main_cmd + sub_cmd + + if "aliases" in kwargs: + del kwargs["aliases"] + + return self.on_command(_cmd, docs, **kwargs) + class ServiceTools(object): @staticmethod @@ -327,7 +327,7 @@ class ServiceTools(object): return data @classmethod - def auth_service(cls, service, user_id: str = None, group_id: str = None) -> bool: + def auth_service(cls, service, user_id: str = str(), group_id: str = str()) -> bool: data = cls.load_service(service) auth_global = data.get("enabled", True) |