diff options
Diffstat (limited to 'ATRI/plugins')
-rw-r--r-- | ATRI/plugins/console/__init__.py | 171 | ||||
-rw-r--r-- | ATRI/plugins/console/data_source.py | 37 | ||||
-rw-r--r-- | ATRI/plugins/console/driver/__init__.py | 11 | ||||
-rw-r--r-- | ATRI/plugins/console/models.py | 2 | ||||
-rw-r--r-- | ATRI/plugins/kimo/data_source.py | 5 | ||||
-rw-r--r-- | ATRI/plugins/rss/__init__.py | 39 | ||||
-rw-r--r-- | ATRI/plugins/rss/rss_mikanan/__init__.py | 0 | ||||
-rw-r--r-- | ATRI/plugins/rss/rss_mikanan/data_source.py | 0 | ||||
-rw-r--r-- | ATRI/plugins/rss/rss_mikanan/db.py | 0 | ||||
-rw-r--r-- | ATRI/plugins/rss/rss_rsshub/__init__.py | 162 | ||||
-rw-r--r-- | ATRI/plugins/rss/rss_rsshub/data_source.py | 115 | ||||
-rw-r--r-- | ATRI/plugins/rss/rss_rsshub/db.py | 26 | ||||
-rw-r--r-- | ATRI/plugins/setu/__init__.py | 7 | ||||
-rw-r--r-- | ATRI/plugins/setu/data_source.py | 6 | ||||
-rw-r--r-- | ATRI/plugins/setu/nsfw_checker.py (renamed from ATRI/plugins/setu/tf_dealer.py) | 47 | ||||
-rw-r--r-- | ATRI/plugins/twitter/__init__.py | 7 |
16 files changed, 441 insertions, 194 deletions
diff --git a/ATRI/plugins/console/__init__.py b/ATRI/plugins/console/__init__.py index e4fec29..2bd14c8 100644 --- a/ATRI/plugins/console/__init__.py +++ b/ATRI/plugins/console/__init__.py @@ -4,7 +4,8 @@ from nonebot.params import ArgPlainText from nonebot.adapters.onebot.v11 import PrivateMessageEvent, GroupMessageEvent from ATRI.config import BotSelfConfig -from ATRI.exceptions import WriteFileError, ReadFileError +from ATRI.exceptions import WriteFileError + from .data_source import Console, CONSOLE_DIR from .models import AuthData @@ -14,52 +15,29 @@ gen_console_key = Console().cmd_as_group("auth", "获取进入网页后台的凭 @gen_console_key.got("is_pub_n", "咱的运行环境是否有公网(y/n)") async def _(event: PrivateMessageEvent, is_pub_n: str = ArgPlainText("is_pub_n")): + data_path = CONSOLE_DIR / "data.json" + if not data_path.is_file: + with open(data_path, "w", encoding="utf-8") as w: + w.write(json.dumps(dict())) + if is_pub_n != "y": - ip = str(await Console().get_host_ip(False)) + host = str(await Console().get_host_ip(False)) await gen_console_key.send("没有公网吗...嗯知道了") else: - ip = str(await Console().get_host_ip(True)) - - p = BotSelfConfig.port - rs = Console().get_random_str(20) - - df = CONSOLE_DIR / "data.json" - try: - if not df.is_file(): - with open(df, "w", encoding="utf-8") as w: - w.write(json.dumps({})) + host = str(await Console().get_host_ip(True)) - d = json.loads(df.read_bytes()) + port = BotSelfConfig.port + token = Console().get_random_str(20) - ca = d.get("data", None) - if ca: - # 此处原本想用 matcher.finish 但这是在 try 里啊! - await gen_console_key.send("咱已经告诉你了嗷!啊!忘了.../con.load 获取吧") - return + data = json.loads(data_path.read_bytes()) + data["data"] = AuthData(token=token).dict() + with open(data_path, "w", encoding="utf-8") as w: + w.write(json.dumps(data)) - d["data"] = AuthData(ip=ip, port=str(p), token=rs).dict() - - with open(df, "w", encoding="utf-8") as w: - w.write(json.dumps(d)) - except Exception: - msg = f""" - 哦吼!写入文件失败了...还请自行记下哦... - IP: {ip} - PORT: {p} - TOKEN: {rs} - 一定要保管好哦!切勿告诉他人哦! - 入口: atri-console.imki.moe - """.strip() - await gen_console_key.send(msg) - - raise WriteFileError("Writing file: " + str(df) + " failed!") - - msg = f""" - 该信息已保存!可通过 /gauth 获取~ - IP: {ip} - PORT: {p} - TOKEN: {rs} - 一定要保管好哦!切勿告诉他人哦! + msg = f"""控制台信息已生成! + 请访问: {host}:{port} + Token: {token} + 该 token 有效时间为 15min """.strip() await gen_console_key.finish(msg) @@ -69,37 +47,6 @@ async def _(event: GroupMessageEvent): await gen_console_key.finish("请私戳咱获取(") -load_console_key = Console().cmd_as_group("load", "获取已生成的后台凭证") - - -@load_console_key.handle() -async def _(event: PrivateMessageEvent): - df = CONSOLE_DIR / "data.json" - if not df.is_file(): - await load_console_key.finish("你还没有问咱索要奥!/con.auth 以获取") - - try: - d = json.loads(df.read_bytes()) - except Exception: - await load_console_key.send("获取数据失败了...请自行打开文件查看吧:\n" + str(df)) - raise ReadFileError("Reading file: " + str(df) + " failed!") - - data = d["data"] - msg = f""" - 诶嘿嘿嘿——凭证信息来咯! - IP: {data['ip']} - PORT: {data['port']} - TOKEN: {data['token']} - 切记!不要告诉他人!! - """.strip() - await load_console_key.finish(msg) - - -@load_console_key.handle() -async def _(event: GroupMessageEvent): - await load_console_key.finish("请私戳咱获取(") - - del_console_key = Console().cmd_as_group("del", "销毁进入网页后台的凭证") @@ -126,82 +73,10 @@ async def _(is_sure: str = ArgPlainText("is_sure_d")): await del_console_key.finish("销毁成功!如需再次获取: /con.auth") -res_console_key = Console().cmd_as_group("reauth", "重置进入网页后台的凭证") - - -@res_console_key.got("is_sure_r", "...你确定吗(y/n)") -async def _(is_sure: str = ArgPlainText("is_sure_r")): - if is_sure != "y": - await res_console_key.finish("反悔了呢...") - - df = CONSOLE_DIR / "data.json" - if not df.is_file(): - await del_console_key.finish("你还没向咱索取凭证呢.../con.auth 以获取") - - try: - data: dict = json.loads(df.read_bytes()) - - del data["data"] - - with open(df, "w", encoding="utf-8") as w: - w.write(json.dumps(data)) - except Exception: - await del_console_key.send("销毁失败了...请至此处自行删除文件:\n" + str(df)) - raise WriteFileError("Writing / Reading file: " + str(df) + " failed!") - - -@res_console_key.got("is_pub_r_n", "咱的运行环境是否有公网(y/n)") -async def _(event: PrivateMessageEvent, is_pub_n: str = ArgPlainText("is_pub_n")): - if is_pub_n != "y": - ip = str(await Console().get_host_ip(False)) - await res_console_key.send("没有公网吗...嗯知道了") - else: - ip = str(await Console().get_host_ip(True)) - - p = BotSelfConfig.port - rs = Console().get_random_str(20) - - df = CONSOLE_DIR / "data.json" - try: - if not df.is_file(): - with open(df, "w", encoding="utf-8") as w: - w.write(json.dumps({})) - - d = json.loads(df.read_bytes()) - - ca = d.get("data", None) - if ca: - await res_console_key.send("咱已经告诉你了嗷!啊!忘了.../con.load 获取吧") - return - - d["data"] = AuthData(ip=ip, port=str(p), token=rs).dict() - - with open(df, "w", encoding="utf-8") as w: - w.write(json.dumps(d)) - except Exception: - msg = f""" - 哦吼!写入文件失败了...还请自行记下哦... - IP: {ip} - PORT: {p} - TOKEN: {rs} - 一定要保管好哦!切勿告诉他人哦! - """.strip() - await res_console_key.send(msg) - - raise WriteFileError("Writing file: " + str(df) + " failed!") - - msg = f""" - 该信息已保存!可通过 /con.load 获取~ - IP: {ip} - PORT: {p} - TOKEN: {rs} - 一定要保管好哦!切勿告诉他人哦! - """.strip() - await res_console_key.finish(msg) - - from ATRI import driver as dr -from .driver import init +from .data_source import init_resource +from .driver import init_driver -dr().on_startup(init) +dr().on_startup(init_resource) +dr().on_startup(init_driver) diff --git a/ATRI/plugins/console/data_source.py b/ATRI/plugins/console/data_source.py index eee862c..db0b6b1 100644 --- a/ATRI/plugins/console/data_source.py +++ b/ATRI/plugins/console/data_source.py @@ -1,6 +1,7 @@ import json import socket import string +import zipfile from random import sample from pathlib import Path @@ -8,7 +9,9 @@ from nonebot.permission import SUPERUSER from ATRI.service import Service from ATRI.utils import request +from ATRI.rule import is_in_service from ATRI.exceptions import WriteFileError +from ATRI.log import logger as log CONSOLE_DIR = Path(".") / "data" / "plugins" / "console" @@ -18,7 +21,13 @@ CONSOLE_DIR.mkdir(parents=True, exist_ok=True) class Console(Service): def __init__(self): Service.__init__( - self, "控制台", "前端管理页面", True, main_cmd="/con", permission=SUPERUSER + self, + "控制台", + "前端管理页面", + True, + is_in_service("控制台"), + main_cmd="/con", + permission=SUPERUSER, ) @staticmethod @@ -56,3 +65,29 @@ class Console(Service): if not data: return {"data": None} return data + + +FRONTEND_DIR = CONSOLE_DIR / "frontend" +FRONTEND_DIR.mkdir(parents=True, exist_ok=True) +CONSOLE_RESOURCE_URL = ( + "https://jsd.imki.moe/gh/kyomotoi/Project-ATRI-Console@main/archive/dist.zip" +) + + +async def init_resource(): + log.info("控制台初始化中...") + + try: + resp = await request.get(CONSOLE_RESOURCE_URL) + except Exception: + log.error("控制台资源装载失败, 将无法访问管理界面") + return + print(len(resp.read())) + file_path = CONSOLE_DIR / "dist.zip" + with open(file_path, "wb") as w: + w.write(resp.read()) + + with zipfile.ZipFile(file_path, "r") as zr: + zr.extractall(FRONTEND_DIR) + + log.success("控制台初始化完成") diff --git a/ATRI/plugins/console/driver/__init__.py b/ATRI/plugins/console/driver/__init__.py index 22afd4a..04c714d 100644 --- a/ATRI/plugins/console/driver/__init__.py +++ b/ATRI/plugins/console/driver/__init__.py @@ -1,7 +1,9 @@ from nonebot.drivers.fastapi import Driver +from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware +from ATRI.plugins.console.data_source import FRONTEND_DIR from .view import ( handle_auther, handle_base_uri, @@ -49,8 +51,15 @@ def register_routes(driver: Driver): app.get(CONSOLE_API_BLOCK_LIST_URI)(handle_get_block_list) app.get(CONSOLE_API_BLOCK_EDIT_URI)(handle_edit_block) + static_path = str(FRONTEND_DIR) + app.mount( + "/", + StaticFiles(directory=static_path, html=True), + name="atri-console", + ) + -def init(): +def init_driver(): from ATRI import driver register_routes(driver()) # type: ignore diff --git a/ATRI/plugins/console/models.py b/ATRI/plugins/console/models.py index 973d05d..acbf916 100644 --- a/ATRI/plugins/console/models.py +++ b/ATRI/plugins/console/models.py @@ -2,8 +2,6 @@ from pydantic import BaseModel class AuthData(BaseModel): - ip: str - port: str token: str diff --git a/ATRI/plugins/kimo/data_source.py b/ATRI/plugins/kimo/data_source.py index f2c56ee..2135a01 100644 --- a/ATRI/plugins/kimo/data_source.py +++ b/ATRI/plugins/kimo/data_source.py @@ -32,15 +32,16 @@ class Kimo(Service): file_name = "kimo.json" path = CHAT_PATH / file_name if not path.is_file(): - log.warning("未检测到词库,生成中") + log.warning("插件 kimo 缺少资源, 装载中...") data = await cls._request(KIMO_URL) try: with open(path, "w", encoding="utf-8") as w: w.write(json.dumps(data, indent=4)) - log.info("生成完成") except Exception: raise WriteFileError("Writing kimo words failed!") + log.success("插件 kimo 资源装载完成") + @classmethod async def _load_data(cls) -> dict: file_name = "kimo.json" diff --git a/ATRI/plugins/rss/__init__.py b/ATRI/plugins/rss/__init__.py new file mode 100644 index 0000000..a29f385 --- /dev/null +++ b/ATRI/plugins/rss/__init__.py @@ -0,0 +1,39 @@ +from pathlib import Path + +from nonebot.adapters.onebot.v11 import MessageEvent +from nonebot.permission import SUPERUSER +from nonebot.adapters.onebot.v11 import GROUP_OWNER, GROUP_ADMIN + +from ATRI.service import Service + + +RSS_PLUGIN_DIR = Path(".") / "ATRI" / "plugins" / "rss" + + +class RssHelper(Service): + def __init__(self): + Service.__init__( + self, + "rss", + "Rss系插件助手", + True, + permission=SUPERUSER | GROUP_OWNER | GROUP_ADMIN, + main_cmd="/rss", + ) + + +rss_menu = RssHelper().on_command("/rss", "Rss帮助菜单") + + +@rss_menu.handle() +async def _rss_menu(event: MessageEvent): + raw_rss_list = RSS_PLUGIN_DIR.glob("rss_*") + rss_list = [str(i).split("/")[-1] for i in raw_rss_list] + if not rss_list: + rss_list = [str(i).split("\\")[-1] for i in raw_rss_list] + + result = f"""Rss Helper: + 可用订阅源: {"、".join(map(str, rss_list)).replace("rss_", str())} + 命令: /rss.(订阅源名称) + """.strip() + await rss_menu.finish(result) diff --git a/ATRI/plugins/rss/rss_mikanan/__init__.py b/ATRI/plugins/rss/rss_mikanan/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/ATRI/plugins/rss/rss_mikanan/__init__.py diff --git a/ATRI/plugins/rss/rss_mikanan/data_source.py b/ATRI/plugins/rss/rss_mikanan/data_source.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/ATRI/plugins/rss/rss_mikanan/data_source.py diff --git a/ATRI/plugins/rss/rss_mikanan/db.py b/ATRI/plugins/rss/rss_mikanan/db.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/ATRI/plugins/rss/rss_mikanan/db.py diff --git a/ATRI/plugins/rss/rss_rsshub/__init__.py b/ATRI/plugins/rss/rss_rsshub/__init__.py new file mode 100644 index 0000000..2d23f16 --- /dev/null +++ b/ATRI/plugins/rss/rss_rsshub/__init__.py @@ -0,0 +1,162 @@ +import pytz +import asyncio +from tabulate import tabulate +from datetime import timedelta, datetime + +from apscheduler.triggers.base import BaseTrigger +from apscheduler.triggers.combining import AndTrigger +from apscheduler.triggers.interval import IntervalTrigger + +from nonebot import get_bot +from nonebot.matcher import Matcher +from nonebot.params import CommandArg, ArgPlainText +from nonebot.permission import Permission +from nonebot.adapters.onebot.v11 import Message, GroupMessageEvent + +from ATRI.log import logger as log +from ATRI.utils import timestamp2datetime +from ATRI.utils.apscheduler import scheduler +from ATRI.database import RssRsshubSubcription + +from .data_source import RssHubSubscriptor + + +add_sub = RssHubSubscriptor().cmd_as_group("add", "为本群添加 RSSHub 订阅") + + +@add_sub.handle() +async def _(matcher: Matcher, args: Message = CommandArg()): + msg = args.extract_plain_text() + if msg: + matcher.set_arg("rrh_add_url", args) + + +@add_sub.got("rrh_add_url", "RSSHub 链接呢?速速") +async def _(event: GroupMessageEvent, _url: str = ArgPlainText("rrh_add_url")): + group_id = event.group_id + sub = RssHubSubscriptor() + + result = await sub.add_sub(_url, group_id) + await add_sub.finish(result) + + +del_sub = RssHubSubscriptor().cmd_as_group("del", "删除本群 RSSHub 订阅") + + +@del_sub.handle() +async def _(event: GroupMessageEvent): + group_id = event.group_id + sub = RssHubSubscriptor() + + query_result = await sub.get_sub_list({"group_id": group_id}) + if not query_result: + await del_sub.finish("本群还没有任何订阅呢...") + + subs = list() + for i in query_result: + subs.append([i._id, i.title]) + + output = "本群的 RSSHub 订阅列表如下~\n" + tabulate( + subs, headers=["ID", "title"], tablefmt="plain" + ) + await del_sub.send(output) + + +@del_sub.got("rrh_del_sub_id", "要取消的ID呢? 速速\n(键入 q 以取消)") +async def _(event: GroupMessageEvent, _id: str = ArgPlainText("rrh_del_sub_id")): + if _id == "q": + await del_sub.finish("已取消操作~") + + group_id = event.group_id + sub = RssHubSubscriptor() + + result = await sub.del_sub(_id, group_id) + await del_sub.finish(result) + + +get_sub_list = RssHubSubscriptor().cmd_as_group( + "list", "获取本群 RSSHub 订阅列表", permission=Permission() +) + + +@get_sub_list.handle() +async def _(event: GroupMessageEvent): + group_id = event.group_id + sub = RssHubSubscriptor() + + query_result = await sub.get_sub_list({"group_id": group_id}) + if not query_result: + await get_sub_list.finish("本群还没有任何订阅呢...") + + subs = list() + for i in query_result: + subs.append([i.update_time, i.title]) + + output = "本群的 RSSHub 订阅列表如下~\n" + tabulate( + subs, headers=["最后更新时间", "标题"], tablefmt="plain" + ) + await get_sub_list.send(output) + + +tq = asyncio.Queue() + + +class RssHubDynamicChecker(BaseTrigger): + def get_next_fire_time(self, previous_fire_time, now): + conf = RssHubSubscriptor().load_service("rss.rsshub") + if conf["enabled"]: + return now + + [email protected]_job( + AndTrigger([IntervalTrigger(seconds=120), RssHubDynamicChecker()]), + name="RssHub 订阅检查", + max_instances=3, # type: ignore + misfire_grace_time=60, # type: ignore +) +async def _(): + sub = RssHubSubscriptor() + try: + all_dy = await sub.get_all_subs() + except Exception: + log.debug("RssHub 订阅列表为空 跳过") + return + + if tq.empty(): + for i in all_dy: + await tq.put(i) + else: + m: RssRsshubSubcription = tq.get_nowait() + log.info(f"准备查询 RssHub: {m.rss_link} 的动态, 队列剩余 {tq.qsize()}") + + raw_ts = m.update_time.replace( + tzinfo=pytz.timezone("Asia/Shanghai") + ) + timedelta(hours=8) + ts = raw_ts.timestamp() + + info: dict = await sub.get_rsshub_info(m.rss_link) + if not info: + log.warning(f"无法获取 RssHub: {m.rss_link} 的动态") + return + + t_time = info["item"][0]["pubDate"] + time_patt = "%a, %d %b %Y %H:%M:%S GMT" + + raw_t = datetime.strptime(t_time, time_patt) + timedelta(hours=8) + ts_t = raw_t.timestamp() + + if ts < ts_t: + item = info["item"][0] + title = item["title"] + link = item["link"] + + repo = f"""本群订阅的 RssHub 更新啦! + {title} + {link} + """ + + bot = get_bot() + await bot.send_group_msg(group_id=m.group_id, message=repo) + await sub.update_sub( + m._id, m.group_id, {"update_time": timestamp2datetime(ts_t)} + ) diff --git a/ATRI/plugins/rss/rss_rsshub/data_source.py b/ATRI/plugins/rss/rss_rsshub/data_source.py new file mode 100644 index 0000000..0dc0ebd --- /dev/null +++ b/ATRI/plugins/rss/rss_rsshub/data_source.py @@ -0,0 +1,115 @@ +import xmltodict + +from nonebot.permission import SUPERUSER +from nonebot.adapters.onebot.v11 import GROUP_OWNER, GROUP_ADMIN + +from ATRI.service import Service +from ATRI.rule import is_in_service +from ATRI.exceptions import RssError +from ATRI.utils import request, gen_random_str + +from .db import DB + + +class RssHubSubscriptor(Service): + def __init__(self): + Service.__init__( + self, + "rss.rsshub", + "Rss的Rsshub支持", + rule=is_in_service("rss.rsshub"), + permission=SUPERUSER | GROUP_OWNER | GROUP_ADMIN, + main_cmd="/rss.rsshub", + ) + + async def __add_sub(self, _id: str, group_id: int): + try: + async with DB() as db: + await db.add_sub(_id, group_id) + except Exception: + raise RssError("rss.rsshub: 添加订阅失败") + + async def update_sub(self, _id: str, group_id: int, update_map: dict): + try: + async with DB() as db: + await db.update_sub(_id, group_id, update_map) + except Exception: + raise RssError("rss.rsshub: 更新订阅失败") + + async def __del_sub(self, _id: str, group_id: int): + try: + async with DB() as db: + await db.del_sub({"_id": _id, "group_id": group_id}) + except Exception: + raise RssError("rss.rsshub: 删除订阅失败") + + async def get_sub_list(self, query_map: dict) -> list: + try: + async with DB() as db: + return await db.get_sub_list(query_map) + except Exception: + raise RssError("rss.rsshub: 获取订阅列表失败") + + async def get_all_subs(self) -> list: + try: + async with DB() as db: + return await db.get_all_subs() + except Exception: + raise RssError("rss.rsshub: 获取所有订阅失败") + + async def add_sub(self, url: str, group_id: int) -> str: + try: + resp = await request.get(url) + except Exception: + raise RssError("rss.rsshub: 请求链接失败") + + if "RSSHub" not in resp.text: + return "该链接不含RSSHub内容" + + xml_data = resp.read() + data = xmltodict.parse(xml_data) + check_url = data["rss"]["channel"]["link"] + + query_result = await self.get_sub_list( + {"raw_link": check_url, "group_id": group_id} + ) + if query_result: + _id = query_result[0]._id + return f"该链接已经订阅过啦! ID: {_id}" + + _id = gen_random_str(6) + title = data["rss"]["channel"]["title"] + disc = data["rss"]["channel"]["description"] + + await self.__add_sub(_id, group_id) + await self.update_sub( + _id, + group_id, + { + "title": title, + "rss_link": url, + "discription": disc, + }, + ) + return f"订阅成功! ID: {_id}" + + async def del_sub(self, _id: str, group_id: int) -> str: + query_result = await self.get_sub_list({"_id": _id, "group_id": group_id}) + if not query_result: + return "没有找到该订阅..." + + await self.__del_sub(_id, group_id) + return f"成功取消ID为 {_id} 的订阅" + + async def get_rsshub_info(self, url: str) -> dict: + try: + resp = await request.get(url) + except Exception: + raise RssError("rss.rsshub: 请求链接失败") + + if "RSSHub" not in resp.text: + return dict() + + xml_data = resp.read() + data = xmltodict.parse(xml_data) + return data["rss"]["channel"] diff --git a/ATRI/plugins/rss/rss_rsshub/db.py b/ATRI/plugins/rss/rss_rsshub/db.py new file mode 100644 index 0000000..3f614a3 --- /dev/null +++ b/ATRI/plugins/rss/rss_rsshub/db.py @@ -0,0 +1,26 @@ +from ATRI.database import RssRsshubSubcription + + +class DB: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def add_sub(self, _id: str, group_id: int): + await RssRsshubSubcription.create(_id=_id, group_id=group_id) + + async def update_sub(self, _id, group_id, update_map: dict): + await RssRsshubSubcription.filter(_id=_id, group_id=group_id).update( + **update_map + ) + + async def del_sub(self, query_map: dict): + await RssRsshubSubcription.filter(**query_map).delete() + + async def get_sub_list(self, query_map: dict) -> list: + return await RssRsshubSubcription.filter(**query_map) + + async def get_all_subs(self) -> list: + return await RssRsshubSubcription.all() diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py index 71243eb..7962af8 100644 --- a/ATRI/plugins/setu/__init__.py +++ b/ATRI/plugins/setu/__init__.py @@ -99,8 +99,8 @@ async def _setu_catcher(bot: Bot, event: MessageEvent): data = await Setu().detecter(i, _catcher_max_file_size) except Exception: return - if data[1] > 0.7: - hso.append(data[1]) + if data > 0.7: + hso.append(data) hso.sort(reverse=True) @@ -135,8 +135,7 @@ async def _deal_check(bot: Bot, event: MessageEvent): if not args: await nsfw_checker.reject("请发送图片而不是其他东西!!") - data = await Setu().detecter(args[0], _catcher_max_file_size) - hso = data[1] + hso = await Setu().detecter(args[0], _catcher_max_file_size) if not hso: await nsfw_checker.finish("图太小了!不测!") diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py index 61b1cf8..71649fe 100644 --- a/ATRI/plugins/setu/data_source.py +++ b/ATRI/plugins/setu/data_source.py @@ -5,7 +5,7 @@ from ATRI.service import Service from ATRI.rule import is_in_service from ATRI.utils import request from ATRI.config import Setu as ST -from .tf_dealer import detect_image, init_module +from .nsfw_checker import detect_image, init_model LOLICON_URL = "https://api.lolicon.app/setu/v2" @@ -70,7 +70,7 @@ class Setu(Service): return repo, setu @staticmethod - async def detecter(url: str, file_size: int) -> list: + async def detecter(url: str, file_size: int) -> float: """ 涩值检测. """ @@ -85,4 +85,4 @@ class Setu(Service): from ATRI import driver -driver().on_startup(init_module) +driver().on_startup(init_model) diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/nsfw_checker.py index 4f4b374..2f045c1 100644 --- a/ATRI/plugins/setu/tf_dealer.py +++ b/ATRI/plugins/setu/nsfw_checker.py @@ -1,14 +1,12 @@ import io import re import skimage -import skimage.io +import onnxruntime import numpy as np from PIL import Image from pathlib import Path from sys import getsizeof -import tensorflow as tf - from ATRI.log import logger as log from ATRI.utils import request from ATRI.exceptions import RequestError, WriteFileError @@ -20,7 +18,7 @@ SETU_PATH.mkdir(parents=True, exist_ok=True) TEMP_PATH.mkdir(parents=True, exist_ok=True) -MODULE_URL = "https://jsd.imki.moe/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite" +MODEL_URL = "https://res.imki.moe/nsfw.onnx" VGG_MEAN = [104, 117, 123] @@ -42,7 +40,7 @@ def prepare_image(img): return image -async def detect_image(url: str, file_size: int) -> list: +async def detect_image(url: str, file_size: int) -> float: try: req = await request.get(url) except Exception: @@ -50,7 +48,7 @@ async def detect_image(url: str, file_size: int) -> list: img_byte = getsizeof(req.read()) // 1024 if img_byte < file_size: - return [0, 0] + return 0 try: pattern = r"-(.*?)\/" @@ -61,18 +59,8 @@ async def detect_image(url: str, file_size: int) -> list: except Exception: raise WriteFileError("Writing file failed!") - await init_module() - model_path = str((SETU_PATH / "nsfw.tflite").absolute()) - - try: - interpreter = tf.Interpreter(model_path=model_path) # type: ignore - except Exception: - interpreter = tf.lite.Interpreter(model_path=model_path) - - interpreter.allocate_tensors() - - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + model_path = str(SETU_PATH / "nsfw.onnx") + session = onnxruntime.InferenceSession(model_path) im = Image.open(path) @@ -84,26 +72,25 @@ async def detect_image(url: str, file_size: int) -> list: fh_im.seek(0) image = skimage.img_as_float32(skimage.io.imread(fh_im)) - final = prepare_image(image) - interpreter.set_tensor(input_details[0]["index"], final) - interpreter.invoke() - output_data = interpreter.get_tensor(output_details[0]["index"]) + input_feed = {session.get_inputs()[0].name: final} + outputs = [output.name for output in session.get_outputs()] + result = session.run(outputs, input_feed) - result = np.squeeze(output_data).tolist() - return result + return result[0][0][1] -async def init_module(): - file_name = "nsfw.tflite" +async def init_model(): + file_name = "nsfw.onnx" path = SETU_PATH / file_name if not path.is_file(): - log.warning("缺少模型文件,装载中") + log.warning("插件 setu 缺少资源, 装载中...") try: - data = await request.get(MODULE_URL) + data = await request.get(MODEL_URL) with open(path, "wb") as w: w.write(data.read()) - log.info("模型装载完成") except Exception: - log.error("装载模型失败") + log.error("插件 setu 装载资源失败, 命令 '/nsfw' 将失效") + + log.success("插件 setu 装载资源完成") diff --git a/ATRI/plugins/twitter/__init__.py b/ATRI/plugins/twitter/__init__.py index ab5d76d..ebf673d 100644 --- a/ATRI/plugins/twitter/__init__.py +++ b/ATRI/plugins/twitter/__init__.py @@ -66,7 +66,7 @@ async def _td_del_sub(event: GroupMessageEvent): await del_sub.send(output) -@del_sub.got("td_del_sub_tid", "要取消的tid呢?速速\n(键入 1 以取消)") +@del_sub.got("td_del_sub_tid", "要取消的tid呢?速速\n(键入 q 以取消)") async def _td_deal_del_sub( event: GroupMessageEvent, _tid: str = ArgPlainText("td_del_sub_tid") ): @@ -74,7 +74,7 @@ async def _td_deal_del_sub( if not re.match(patt, _tid): await del_sub.reject("这似乎不是tid呢,请重新输入:") - if _tid == "1": + if _tid == "q": await del_sub.finish("已取消操作~") group_id = event.group_id @@ -175,6 +175,7 @@ async def _check_td(): tzinfo=pytz.timezone("Asia/Shanghai") ) + timedelta(hours=8, minutes=8) ts = raw_ts.timestamp() + info: dict = await sub.get_twitter_user_info(m.screen_name) if not info.get("status", list()): log.warning(f"无法获取推主 {m.name}@{m.screen_name} 的动态") @@ -185,8 +186,8 @@ async def _check_td(): raw_t = datetime.strptime(t_time, time_patt) + timedelta(hours=8) ts_t = raw_t.timestamp() - if ts < ts_t: + if ts < ts_t: raw_media = info["status"]["entities"].get("media", dict()) _pic = raw_media[0]["media_url"] if raw_media else str() |