summaryrefslogtreecommitdiff
path: root/ATRI/plugins
diff options
context:
space:
mode:
Diffstat (limited to 'ATRI/plugins')
-rw-r--r--ATRI/plugins/console/__init__.py171
-rw-r--r--ATRI/plugins/console/data_source.py37
-rw-r--r--ATRI/plugins/console/driver/__init__.py11
-rw-r--r--ATRI/plugins/console/models.py2
-rw-r--r--ATRI/plugins/kimo/data_source.py5
-rw-r--r--ATRI/plugins/rss/__init__.py39
-rw-r--r--ATRI/plugins/rss/rss_mikanan/__init__.py0
-rw-r--r--ATRI/plugins/rss/rss_mikanan/data_source.py0
-rw-r--r--ATRI/plugins/rss/rss_mikanan/db.py0
-rw-r--r--ATRI/plugins/rss/rss_rsshub/__init__.py162
-rw-r--r--ATRI/plugins/rss/rss_rsshub/data_source.py115
-rw-r--r--ATRI/plugins/rss/rss_rsshub/db.py26
-rw-r--r--ATRI/plugins/setu/__init__.py7
-rw-r--r--ATRI/plugins/setu/data_source.py6
-rw-r--r--ATRI/plugins/setu/nsfw_checker.py (renamed from ATRI/plugins/setu/tf_dealer.py)47
-rw-r--r--ATRI/plugins/twitter/__init__.py7
16 files changed, 441 insertions, 194 deletions
diff --git a/ATRI/plugins/console/__init__.py b/ATRI/plugins/console/__init__.py
index e4fec29..2bd14c8 100644
--- a/ATRI/plugins/console/__init__.py
+++ b/ATRI/plugins/console/__init__.py
@@ -4,7 +4,8 @@ from nonebot.params import ArgPlainText
from nonebot.adapters.onebot.v11 import PrivateMessageEvent, GroupMessageEvent
from ATRI.config import BotSelfConfig
-from ATRI.exceptions import WriteFileError, ReadFileError
+from ATRI.exceptions import WriteFileError
+
from .data_source import Console, CONSOLE_DIR
from .models import AuthData
@@ -14,52 +15,29 @@ gen_console_key = Console().cmd_as_group("auth", "获取进入网页后台的凭
@gen_console_key.got("is_pub_n", "咱的运行环境是否有公网(y/n)")
async def _(event: PrivateMessageEvent, is_pub_n: str = ArgPlainText("is_pub_n")):
+ data_path = CONSOLE_DIR / "data.json"
+ if not data_path.is_file:
+ with open(data_path, "w", encoding="utf-8") as w:
+ w.write(json.dumps(dict()))
+
if is_pub_n != "y":
- ip = str(await Console().get_host_ip(False))
+ host = str(await Console().get_host_ip(False))
await gen_console_key.send("没有公网吗...嗯知道了")
else:
- ip = str(await Console().get_host_ip(True))
-
- p = BotSelfConfig.port
- rs = Console().get_random_str(20)
-
- df = CONSOLE_DIR / "data.json"
- try:
- if not df.is_file():
- with open(df, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
+ host = str(await Console().get_host_ip(True))
- d = json.loads(df.read_bytes())
+ port = BotSelfConfig.port
+ token = Console().get_random_str(20)
- ca = d.get("data", None)
- if ca:
- # 此处原本想用 matcher.finish 但这是在 try 里啊!
- await gen_console_key.send("咱已经告诉你了嗷!啊!忘了.../con.load 获取吧")
- return
+ data = json.loads(data_path.read_bytes())
+ data["data"] = AuthData(token=token).dict()
+ with open(data_path, "w", encoding="utf-8") as w:
+ w.write(json.dumps(data))
- d["data"] = AuthData(ip=ip, port=str(p), token=rs).dict()
-
- with open(df, "w", encoding="utf-8") as w:
- w.write(json.dumps(d))
- except Exception:
- msg = f"""
- 哦吼!写入文件失败了...还请自行记下哦...
- IP: {ip}
- PORT: {p}
- TOKEN: {rs}
- 一定要保管好哦!切勿告诉他人哦!
- 入口: atri-console.imki.moe
- """.strip()
- await gen_console_key.send(msg)
-
- raise WriteFileError("Writing file: " + str(df) + " failed!")
-
- msg = f"""
- 该信息已保存!可通过 /gauth 获取~
- IP: {ip}
- PORT: {p}
- TOKEN: {rs}
- 一定要保管好哦!切勿告诉他人哦!
+ msg = f"""控制台信息已生成!
+ 请访问: {host}:{port}
+ Token: {token}
+ 该 token 有效时间为 15min
""".strip()
await gen_console_key.finish(msg)
@@ -69,37 +47,6 @@ async def _(event: GroupMessageEvent):
await gen_console_key.finish("请私戳咱获取(")
-load_console_key = Console().cmd_as_group("load", "获取已生成的后台凭证")
-
-
-@load_console_key.handle()
-async def _(event: PrivateMessageEvent):
- df = CONSOLE_DIR / "data.json"
- if not df.is_file():
- await load_console_key.finish("你还没有问咱索要奥!/con.auth 以获取")
-
- try:
- d = json.loads(df.read_bytes())
- except Exception:
- await load_console_key.send("获取数据失败了...请自行打开文件查看吧:\n" + str(df))
- raise ReadFileError("Reading file: " + str(df) + " failed!")
-
- data = d["data"]
- msg = f"""
- 诶嘿嘿嘿——凭证信息来咯!
- IP: {data['ip']}
- PORT: {data['port']}
- TOKEN: {data['token']}
- 切记!不要告诉他人!!
- """.strip()
- await load_console_key.finish(msg)
-
-
-@load_console_key.handle()
-async def _(event: GroupMessageEvent):
- await load_console_key.finish("请私戳咱获取(")
-
-
del_console_key = Console().cmd_as_group("del", "销毁进入网页后台的凭证")
@@ -126,82 +73,10 @@ async def _(is_sure: str = ArgPlainText("is_sure_d")):
await del_console_key.finish("销毁成功!如需再次获取: /con.auth")
-res_console_key = Console().cmd_as_group("reauth", "重置进入网页后台的凭证")
-
-
-@res_console_key.got("is_sure_r", "...你确定吗(y/n)")
-async def _(is_sure: str = ArgPlainText("is_sure_r")):
- if is_sure != "y":
- await res_console_key.finish("反悔了呢...")
-
- df = CONSOLE_DIR / "data.json"
- if not df.is_file():
- await del_console_key.finish("你还没向咱索取凭证呢.../con.auth 以获取")
-
- try:
- data: dict = json.loads(df.read_bytes())
-
- del data["data"]
-
- with open(df, "w", encoding="utf-8") as w:
- w.write(json.dumps(data))
- except Exception:
- await del_console_key.send("销毁失败了...请至此处自行删除文件:\n" + str(df))
- raise WriteFileError("Writing / Reading file: " + str(df) + " failed!")
-
-
-@res_console_key.got("is_pub_r_n", "咱的运行环境是否有公网(y/n)")
-async def _(event: PrivateMessageEvent, is_pub_n: str = ArgPlainText("is_pub_n")):
- if is_pub_n != "y":
- ip = str(await Console().get_host_ip(False))
- await res_console_key.send("没有公网吗...嗯知道了")
- else:
- ip = str(await Console().get_host_ip(True))
-
- p = BotSelfConfig.port
- rs = Console().get_random_str(20)
-
- df = CONSOLE_DIR / "data.json"
- try:
- if not df.is_file():
- with open(df, "w", encoding="utf-8") as w:
- w.write(json.dumps({}))
-
- d = json.loads(df.read_bytes())
-
- ca = d.get("data", None)
- if ca:
- await res_console_key.send("咱已经告诉你了嗷!啊!忘了.../con.load 获取吧")
- return
-
- d["data"] = AuthData(ip=ip, port=str(p), token=rs).dict()
-
- with open(df, "w", encoding="utf-8") as w:
- w.write(json.dumps(d))
- except Exception:
- msg = f"""
- 哦吼!写入文件失败了...还请自行记下哦...
- IP: {ip}
- PORT: {p}
- TOKEN: {rs}
- 一定要保管好哦!切勿告诉他人哦!
- """.strip()
- await res_console_key.send(msg)
-
- raise WriteFileError("Writing file: " + str(df) + " failed!")
-
- msg = f"""
- 该信息已保存!可通过 /con.load 获取~
- IP: {ip}
- PORT: {p}
- TOKEN: {rs}
- 一定要保管好哦!切勿告诉他人哦!
- """.strip()
- await res_console_key.finish(msg)
-
-
from ATRI import driver as dr
-from .driver import init
+from .data_source import init_resource
+from .driver import init_driver
-dr().on_startup(init)
+dr().on_startup(init_resource)
+dr().on_startup(init_driver)
diff --git a/ATRI/plugins/console/data_source.py b/ATRI/plugins/console/data_source.py
index eee862c..db0b6b1 100644
--- a/ATRI/plugins/console/data_source.py
+++ b/ATRI/plugins/console/data_source.py
@@ -1,6 +1,7 @@
import json
import socket
import string
+import zipfile
from random import sample
from pathlib import Path
@@ -8,7 +9,9 @@ from nonebot.permission import SUPERUSER
from ATRI.service import Service
from ATRI.utils import request
+from ATRI.rule import is_in_service
from ATRI.exceptions import WriteFileError
+from ATRI.log import logger as log
CONSOLE_DIR = Path(".") / "data" / "plugins" / "console"
@@ -18,7 +21,13 @@ CONSOLE_DIR.mkdir(parents=True, exist_ok=True)
class Console(Service):
def __init__(self):
Service.__init__(
- self, "控制台", "前端管理页面", True, main_cmd="/con", permission=SUPERUSER
+ self,
+ "控制台",
+ "前端管理页面",
+ True,
+ is_in_service("控制台"),
+ main_cmd="/con",
+ permission=SUPERUSER,
)
@staticmethod
@@ -56,3 +65,29 @@ class Console(Service):
if not data:
return {"data": None}
return data
+
+
+FRONTEND_DIR = CONSOLE_DIR / "frontend"
+FRONTEND_DIR.mkdir(parents=True, exist_ok=True)
+CONSOLE_RESOURCE_URL = (
+ "https://jsd.imki.moe/gh/kyomotoi/Project-ATRI-Console@main/archive/dist.zip"
+)
+
+
+async def init_resource():
+ log.info("控制台初始化中...")
+
+ try:
+ resp = await request.get(CONSOLE_RESOURCE_URL)
+ except Exception:
+ log.error("控制台资源装载失败, 将无法访问管理界面")
+ return
+ print(len(resp.read()))
+ file_path = CONSOLE_DIR / "dist.zip"
+ with open(file_path, "wb") as w:
+ w.write(resp.read())
+
+ with zipfile.ZipFile(file_path, "r") as zr:
+ zr.extractall(FRONTEND_DIR)
+
+ log.success("控制台初始化完成")
diff --git a/ATRI/plugins/console/driver/__init__.py b/ATRI/plugins/console/driver/__init__.py
index 22afd4a..04c714d 100644
--- a/ATRI/plugins/console/driver/__init__.py
+++ b/ATRI/plugins/console/driver/__init__.py
@@ -1,7 +1,9 @@
from nonebot.drivers.fastapi import Driver
+from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
+from ATRI.plugins.console.data_source import FRONTEND_DIR
from .view import (
handle_auther,
handle_base_uri,
@@ -49,8 +51,15 @@ def register_routes(driver: Driver):
app.get(CONSOLE_API_BLOCK_LIST_URI)(handle_get_block_list)
app.get(CONSOLE_API_BLOCK_EDIT_URI)(handle_edit_block)
+ static_path = str(FRONTEND_DIR)
+ app.mount(
+ "/",
+ StaticFiles(directory=static_path, html=True),
+ name="atri-console",
+ )
+
-def init():
+def init_driver():
from ATRI import driver
register_routes(driver()) # type: ignore
diff --git a/ATRI/plugins/console/models.py b/ATRI/plugins/console/models.py
index 973d05d..acbf916 100644
--- a/ATRI/plugins/console/models.py
+++ b/ATRI/plugins/console/models.py
@@ -2,8 +2,6 @@ from pydantic import BaseModel
class AuthData(BaseModel):
- ip: str
- port: str
token: str
diff --git a/ATRI/plugins/kimo/data_source.py b/ATRI/plugins/kimo/data_source.py
index f2c56ee..2135a01 100644
--- a/ATRI/plugins/kimo/data_source.py
+++ b/ATRI/plugins/kimo/data_source.py
@@ -32,15 +32,16 @@ class Kimo(Service):
file_name = "kimo.json"
path = CHAT_PATH / file_name
if not path.is_file():
- log.warning("未检测到词库,生成中")
+ log.warning("插件 kimo 缺少资源, 装载中...")
data = await cls._request(KIMO_URL)
try:
with open(path, "w", encoding="utf-8") as w:
w.write(json.dumps(data, indent=4))
- log.info("生成完成")
except Exception:
raise WriteFileError("Writing kimo words failed!")
+ log.success("插件 kimo 资源装载完成")
+
@classmethod
async def _load_data(cls) -> dict:
file_name = "kimo.json"
diff --git a/ATRI/plugins/rss/__init__.py b/ATRI/plugins/rss/__init__.py
new file mode 100644
index 0000000..a29f385
--- /dev/null
+++ b/ATRI/plugins/rss/__init__.py
@@ -0,0 +1,39 @@
+from pathlib import Path
+
+from nonebot.adapters.onebot.v11 import MessageEvent
+from nonebot.permission import SUPERUSER
+from nonebot.adapters.onebot.v11 import GROUP_OWNER, GROUP_ADMIN
+
+from ATRI.service import Service
+
+
+RSS_PLUGIN_DIR = Path(".") / "ATRI" / "plugins" / "rss"
+
+
+class RssHelper(Service):
+ def __init__(self):
+ Service.__init__(
+ self,
+ "rss",
+ "Rss系插件助手",
+ True,
+ permission=SUPERUSER | GROUP_OWNER | GROUP_ADMIN,
+ main_cmd="/rss",
+ )
+
+
+rss_menu = RssHelper().on_command("/rss", "Rss帮助菜单")
+
+
+@rss_menu.handle()
+async def _rss_menu(event: MessageEvent):
+ raw_rss_list = RSS_PLUGIN_DIR.glob("rss_*")
+ rss_list = [str(i).split("/")[-1] for i in raw_rss_list]
+ if not rss_list:
+ rss_list = [str(i).split("\\")[-1] for i in raw_rss_list]
+
+ result = f"""Rss Helper:
+ 可用订阅源: {"、".join(map(str, rss_list)).replace("rss_", str())}
+ 命令: /rss.(订阅源名称)
+ """.strip()
+ await rss_menu.finish(result)
diff --git a/ATRI/plugins/rss/rss_mikanan/__init__.py b/ATRI/plugins/rss/rss_mikanan/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/ATRI/plugins/rss/rss_mikanan/__init__.py
diff --git a/ATRI/plugins/rss/rss_mikanan/data_source.py b/ATRI/plugins/rss/rss_mikanan/data_source.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/ATRI/plugins/rss/rss_mikanan/data_source.py
diff --git a/ATRI/plugins/rss/rss_mikanan/db.py b/ATRI/plugins/rss/rss_mikanan/db.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/ATRI/plugins/rss/rss_mikanan/db.py
diff --git a/ATRI/plugins/rss/rss_rsshub/__init__.py b/ATRI/plugins/rss/rss_rsshub/__init__.py
new file mode 100644
index 0000000..2d23f16
--- /dev/null
+++ b/ATRI/plugins/rss/rss_rsshub/__init__.py
@@ -0,0 +1,162 @@
+import pytz
+import asyncio
+from tabulate import tabulate
+from datetime import timedelta, datetime
+
+from apscheduler.triggers.base import BaseTrigger
+from apscheduler.triggers.combining import AndTrigger
+from apscheduler.triggers.interval import IntervalTrigger
+
+from nonebot import get_bot
+from nonebot.matcher import Matcher
+from nonebot.params import CommandArg, ArgPlainText
+from nonebot.permission import Permission
+from nonebot.adapters.onebot.v11 import Message, GroupMessageEvent
+
+from ATRI.log import logger as log
+from ATRI.utils import timestamp2datetime
+from ATRI.utils.apscheduler import scheduler
+from ATRI.database import RssRsshubSubcription
+
+from .data_source import RssHubSubscriptor
+
+
+add_sub = RssHubSubscriptor().cmd_as_group("add", "为本群添加 RSSHub 订阅")
+
+
+@add_sub.handle()
+async def _(matcher: Matcher, args: Message = CommandArg()):
+ msg = args.extract_plain_text()
+ if msg:
+ matcher.set_arg("rrh_add_url", args)
+
+
+@add_sub.got("rrh_add_url", "RSSHub 链接呢?速速")
+async def _(event: GroupMessageEvent, _url: str = ArgPlainText("rrh_add_url")):
+ group_id = event.group_id
+ sub = RssHubSubscriptor()
+
+ result = await sub.add_sub(_url, group_id)
+ await add_sub.finish(result)
+
+
+del_sub = RssHubSubscriptor().cmd_as_group("del", "删除本群 RSSHub 订阅")
+
+
+@del_sub.handle()
+async def _(event: GroupMessageEvent):
+ group_id = event.group_id
+ sub = RssHubSubscriptor()
+
+ query_result = await sub.get_sub_list({"group_id": group_id})
+ if not query_result:
+ await del_sub.finish("本群还没有任何订阅呢...")
+
+ subs = list()
+ for i in query_result:
+ subs.append([i._id, i.title])
+
+ output = "本群的 RSSHub 订阅列表如下~\n" + tabulate(
+ subs, headers=["ID", "title"], tablefmt="plain"
+ )
+ await del_sub.send(output)
+
+
+@del_sub.got("rrh_del_sub_id", "要取消的ID呢? 速速\n(键入 q 以取消)")
+async def _(event: GroupMessageEvent, _id: str = ArgPlainText("rrh_del_sub_id")):
+ if _id == "q":
+ await del_sub.finish("已取消操作~")
+
+ group_id = event.group_id
+ sub = RssHubSubscriptor()
+
+ result = await sub.del_sub(_id, group_id)
+ await del_sub.finish(result)
+
+
+get_sub_list = RssHubSubscriptor().cmd_as_group(
+ "list", "获取本群 RSSHub 订阅列表", permission=Permission()
+)
+
+
+@get_sub_list.handle()
+async def _(event: GroupMessageEvent):
+ group_id = event.group_id
+ sub = RssHubSubscriptor()
+
+ query_result = await sub.get_sub_list({"group_id": group_id})
+ if not query_result:
+ await get_sub_list.finish("本群还没有任何订阅呢...")
+
+ subs = list()
+ for i in query_result:
+ subs.append([i.update_time, i.title])
+
+ output = "本群的 RSSHub 订阅列表如下~\n" + tabulate(
+ subs, headers=["最后更新时间", "标题"], tablefmt="plain"
+ )
+ await get_sub_list.send(output)
+
+
+tq = asyncio.Queue()
+
+
+class RssHubDynamicChecker(BaseTrigger):
+ def get_next_fire_time(self, previous_fire_time, now):
+ conf = RssHubSubscriptor().load_service("rss.rsshub")
+ if conf["enabled"]:
+ return now
+
+
+ AndTrigger([IntervalTrigger(seconds=120), RssHubDynamicChecker()]),
+ name="RssHub 订阅检查",
+ max_instances=3, # type: ignore
+ misfire_grace_time=60, # type: ignore
+)
+async def _():
+ sub = RssHubSubscriptor()
+ try:
+ all_dy = await sub.get_all_subs()
+ except Exception:
+ log.debug("RssHub 订阅列表为空 跳过")
+ return
+
+ if tq.empty():
+ for i in all_dy:
+ await tq.put(i)
+ else:
+ m: RssRsshubSubcription = tq.get_nowait()
+ log.info(f"准备查询 RssHub: {m.rss_link} 的动态, 队列剩余 {tq.qsize()}")
+
+ raw_ts = m.update_time.replace(
+ tzinfo=pytz.timezone("Asia/Shanghai")
+ ) + timedelta(hours=8)
+ ts = raw_ts.timestamp()
+
+ info: dict = await sub.get_rsshub_info(m.rss_link)
+ if not info:
+ log.warning(f"无法获取 RssHub: {m.rss_link} 的动态")
+ return
+
+ t_time = info["item"][0]["pubDate"]
+ time_patt = "%a, %d %b %Y %H:%M:%S GMT"
+
+ raw_t = datetime.strptime(t_time, time_patt) + timedelta(hours=8)
+ ts_t = raw_t.timestamp()
+
+ if ts < ts_t:
+ item = info["item"][0]
+ title = item["title"]
+ link = item["link"]
+
+ repo = f"""本群订阅的 RssHub 更新啦!
+ {title}
+ {link}
+ """
+
+ bot = get_bot()
+ await bot.send_group_msg(group_id=m.group_id, message=repo)
+ await sub.update_sub(
+ m._id, m.group_id, {"update_time": timestamp2datetime(ts_t)}
+ )
diff --git a/ATRI/plugins/rss/rss_rsshub/data_source.py b/ATRI/plugins/rss/rss_rsshub/data_source.py
new file mode 100644
index 0000000..0dc0ebd
--- /dev/null
+++ b/ATRI/plugins/rss/rss_rsshub/data_source.py
@@ -0,0 +1,115 @@
+import xmltodict
+
+from nonebot.permission import SUPERUSER
+from nonebot.adapters.onebot.v11 import GROUP_OWNER, GROUP_ADMIN
+
+from ATRI.service import Service
+from ATRI.rule import is_in_service
+from ATRI.exceptions import RssError
+from ATRI.utils import request, gen_random_str
+
+from .db import DB
+
+
+class RssHubSubscriptor(Service):
+ def __init__(self):
+ Service.__init__(
+ self,
+ "rss.rsshub",
+ "Rss的Rsshub支持",
+ rule=is_in_service("rss.rsshub"),
+ permission=SUPERUSER | GROUP_OWNER | GROUP_ADMIN,
+ main_cmd="/rss.rsshub",
+ )
+
+ async def __add_sub(self, _id: str, group_id: int):
+ try:
+ async with DB() as db:
+ await db.add_sub(_id, group_id)
+ except Exception:
+ raise RssError("rss.rsshub: 添加订阅失败")
+
+ async def update_sub(self, _id: str, group_id: int, update_map: dict):
+ try:
+ async with DB() as db:
+ await db.update_sub(_id, group_id, update_map)
+ except Exception:
+ raise RssError("rss.rsshub: 更新订阅失败")
+
+ async def __del_sub(self, _id: str, group_id: int):
+ try:
+ async with DB() as db:
+ await db.del_sub({"_id": _id, "group_id": group_id})
+ except Exception:
+ raise RssError("rss.rsshub: 删除订阅失败")
+
+ async def get_sub_list(self, query_map: dict) -> list:
+ try:
+ async with DB() as db:
+ return await db.get_sub_list(query_map)
+ except Exception:
+ raise RssError("rss.rsshub: 获取订阅列表失败")
+
+ async def get_all_subs(self) -> list:
+ try:
+ async with DB() as db:
+ return await db.get_all_subs()
+ except Exception:
+ raise RssError("rss.rsshub: 获取所有订阅失败")
+
+ async def add_sub(self, url: str, group_id: int) -> str:
+ try:
+ resp = await request.get(url)
+ except Exception:
+ raise RssError("rss.rsshub: 请求链接失败")
+
+ if "RSSHub" not in resp.text:
+ return "该链接不含RSSHub内容"
+
+ xml_data = resp.read()
+ data = xmltodict.parse(xml_data)
+ check_url = data["rss"]["channel"]["link"]
+
+ query_result = await self.get_sub_list(
+ {"raw_link": check_url, "group_id": group_id}
+ )
+ if query_result:
+ _id = query_result[0]._id
+ return f"该链接已经订阅过啦! ID: {_id}"
+
+ _id = gen_random_str(6)
+ title = data["rss"]["channel"]["title"]
+ disc = data["rss"]["channel"]["description"]
+
+ await self.__add_sub(_id, group_id)
+ await self.update_sub(
+ _id,
+ group_id,
+ {
+ "title": title,
+ "rss_link": url,
+ "discription": disc,
+ },
+ )
+ return f"订阅成功! ID: {_id}"
+
+ async def del_sub(self, _id: str, group_id: int) -> str:
+ query_result = await self.get_sub_list({"_id": _id, "group_id": group_id})
+ if not query_result:
+ return "没有找到该订阅..."
+
+ await self.__del_sub(_id, group_id)
+ return f"成功取消ID为 {_id} 的订阅"
+
+ async def get_rsshub_info(self, url: str) -> dict:
+ try:
+ resp = await request.get(url)
+ except Exception:
+ raise RssError("rss.rsshub: 请求链接失败")
+
+ if "RSSHub" not in resp.text:
+ return dict()
+
+ xml_data = resp.read()
+ data = xmltodict.parse(xml_data)
+ return data["rss"]["channel"]
diff --git a/ATRI/plugins/rss/rss_rsshub/db.py b/ATRI/plugins/rss/rss_rsshub/db.py
new file mode 100644
index 0000000..3f614a3
--- /dev/null
+++ b/ATRI/plugins/rss/rss_rsshub/db.py
@@ -0,0 +1,26 @@
+from ATRI.database import RssRsshubSubcription
+
+
+class DB:
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ async def add_sub(self, _id: str, group_id: int):
+ await RssRsshubSubcription.create(_id=_id, group_id=group_id)
+
+ async def update_sub(self, _id, group_id, update_map: dict):
+ await RssRsshubSubcription.filter(_id=_id, group_id=group_id).update(
+ **update_map
+ )
+
+ async def del_sub(self, query_map: dict):
+ await RssRsshubSubcription.filter(**query_map).delete()
+
+ async def get_sub_list(self, query_map: dict) -> list:
+ return await RssRsshubSubcription.filter(**query_map)
+
+ async def get_all_subs(self) -> list:
+ return await RssRsshubSubcription.all()
diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py
index 71243eb..7962af8 100644
--- a/ATRI/plugins/setu/__init__.py
+++ b/ATRI/plugins/setu/__init__.py
@@ -99,8 +99,8 @@ async def _setu_catcher(bot: Bot, event: MessageEvent):
data = await Setu().detecter(i, _catcher_max_file_size)
except Exception:
return
- if data[1] > 0.7:
- hso.append(data[1])
+ if data > 0.7:
+ hso.append(data)
hso.sort(reverse=True)
@@ -135,8 +135,7 @@ async def _deal_check(bot: Bot, event: MessageEvent):
if not args:
await nsfw_checker.reject("请发送图片而不是其他东西!!")
- data = await Setu().detecter(args[0], _catcher_max_file_size)
- hso = data[1]
+ hso = await Setu().detecter(args[0], _catcher_max_file_size)
if not hso:
await nsfw_checker.finish("图太小了!不测!")
diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py
index 61b1cf8..71649fe 100644
--- a/ATRI/plugins/setu/data_source.py
+++ b/ATRI/plugins/setu/data_source.py
@@ -5,7 +5,7 @@ from ATRI.service import Service
from ATRI.rule import is_in_service
from ATRI.utils import request
from ATRI.config import Setu as ST
-from .tf_dealer import detect_image, init_module
+from .nsfw_checker import detect_image, init_model
LOLICON_URL = "https://api.lolicon.app/setu/v2"
@@ -70,7 +70,7 @@ class Setu(Service):
return repo, setu
@staticmethod
- async def detecter(url: str, file_size: int) -> list:
+ async def detecter(url: str, file_size: int) -> float:
"""
涩值检测.
"""
@@ -85,4 +85,4 @@ class Setu(Service):
from ATRI import driver
-driver().on_startup(init_module)
+driver().on_startup(init_model)
diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/nsfw_checker.py
index 4f4b374..2f045c1 100644
--- a/ATRI/plugins/setu/tf_dealer.py
+++ b/ATRI/plugins/setu/nsfw_checker.py
@@ -1,14 +1,12 @@
import io
import re
import skimage
-import skimage.io
+import onnxruntime
import numpy as np
from PIL import Image
from pathlib import Path
from sys import getsizeof
-import tensorflow as tf
-
from ATRI.log import logger as log
from ATRI.utils import request
from ATRI.exceptions import RequestError, WriteFileError
@@ -20,7 +18,7 @@ SETU_PATH.mkdir(parents=True, exist_ok=True)
TEMP_PATH.mkdir(parents=True, exist_ok=True)
-MODULE_URL = "https://jsd.imki.moe/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite"
+MODEL_URL = "https://res.imki.moe/nsfw.onnx"
VGG_MEAN = [104, 117, 123]
@@ -42,7 +40,7 @@ def prepare_image(img):
return image
-async def detect_image(url: str, file_size: int) -> list:
+async def detect_image(url: str, file_size: int) -> float:
try:
req = await request.get(url)
except Exception:
@@ -50,7 +48,7 @@ async def detect_image(url: str, file_size: int) -> list:
img_byte = getsizeof(req.read()) // 1024
if img_byte < file_size:
- return [0, 0]
+ return 0
try:
pattern = r"-(.*?)\/"
@@ -61,18 +59,8 @@ async def detect_image(url: str, file_size: int) -> list:
except Exception:
raise WriteFileError("Writing file failed!")
- await init_module()
- model_path = str((SETU_PATH / "nsfw.tflite").absolute())
-
- try:
- interpreter = tf.Interpreter(model_path=model_path) # type: ignore
- except Exception:
- interpreter = tf.lite.Interpreter(model_path=model_path)
-
- interpreter.allocate_tensors()
-
- input_details = interpreter.get_input_details()
- output_details = interpreter.get_output_details()
+ model_path = str(SETU_PATH / "nsfw.onnx")
+ session = onnxruntime.InferenceSession(model_path)
im = Image.open(path)
@@ -84,26 +72,25 @@ async def detect_image(url: str, file_size: int) -> list:
fh_im.seek(0)
image = skimage.img_as_float32(skimage.io.imread(fh_im))
-
final = prepare_image(image)
- interpreter.set_tensor(input_details[0]["index"], final)
- interpreter.invoke()
- output_data = interpreter.get_tensor(output_details[0]["index"])
+ input_feed = {session.get_inputs()[0].name: final}
+ outputs = [output.name for output in session.get_outputs()]
+ result = session.run(outputs, input_feed)
- result = np.squeeze(output_data).tolist()
- return result
+ return result[0][0][1]
-async def init_module():
- file_name = "nsfw.tflite"
+async def init_model():
+ file_name = "nsfw.onnx"
path = SETU_PATH / file_name
if not path.is_file():
- log.warning("缺少模型文件,装载中")
+ log.warning("插件 setu 缺少资源, 装载中...")
try:
- data = await request.get(MODULE_URL)
+ data = await request.get(MODEL_URL)
with open(path, "wb") as w:
w.write(data.read())
- log.info("模型装载完成")
except Exception:
- log.error("装载模型失败")
+ log.error("插件 setu 装载资源失败, 命令 '/nsfw' 将失效")
+
+ log.success("插件 setu 装载资源完成")
diff --git a/ATRI/plugins/twitter/__init__.py b/ATRI/plugins/twitter/__init__.py
index ab5d76d..ebf673d 100644
--- a/ATRI/plugins/twitter/__init__.py
+++ b/ATRI/plugins/twitter/__init__.py
@@ -66,7 +66,7 @@ async def _td_del_sub(event: GroupMessageEvent):
await del_sub.send(output)
-@del_sub.got("td_del_sub_tid", "要取消的tid呢?速速\n(键入 1 以取消)")
+@del_sub.got("td_del_sub_tid", "要取消的tid呢?速速\n(键入 q 以取消)")
async def _td_deal_del_sub(
event: GroupMessageEvent, _tid: str = ArgPlainText("td_del_sub_tid")
):
@@ -74,7 +74,7 @@ async def _td_deal_del_sub(
if not re.match(patt, _tid):
await del_sub.reject("这似乎不是tid呢,请重新输入:")
- if _tid == "1":
+ if _tid == "q":
await del_sub.finish("已取消操作~")
group_id = event.group_id
@@ -175,6 +175,7 @@ async def _check_td():
tzinfo=pytz.timezone("Asia/Shanghai")
) + timedelta(hours=8, minutes=8)
ts = raw_ts.timestamp()
+
info: dict = await sub.get_twitter_user_info(m.screen_name)
if not info.get("status", list()):
log.warning(f"无法获取推主 {m.name}@{m.screen_name} 的动态")
@@ -185,8 +186,8 @@ async def _check_td():
raw_t = datetime.strptime(t_time, time_patt) + timedelta(hours=8)
ts_t = raw_t.timestamp()
- if ts < ts_t:
+ if ts < ts_t:
raw_media = info["status"]["entities"].get("media", dict())
_pic = raw_media[0]["media_url"] if raw_media else str()