From 16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a Mon Sep 17 00:00:00 2001 From: Kyomotoi Date: Sun, 24 Oct 2021 16:47:29 +0800 Subject: =?UTF-8?q?=F0=9F=94=96=20=E6=9B=B4=E6=96=B0=E7=89=88=E6=9C=AC?= =?UTF-8?q?=EF=BC=9AYHN-001-A04?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增: - nsfw检测(主动/被动) 又名 涩图嗅探 - 可选代理 修复: - plugin/chat 在 nb2-a14+ 版本 finish 内为空时会报错 其他: - 对定时任务进行中文命名 --- ATRI/config.py | 8 ++- ATRI/plugins/chat/__init__.py | 7 ++- ATRI/plugins/essential.py | 14 +++++ ATRI/plugins/setu/__init__.py | 89 ++++++++++++++++++++++++++++- ATRI/plugins/setu/data_source.py | 9 +++ ATRI/plugins/setu/tf_dealer.py | 119 +++++++++++++++++++++++++++++++++++++++ ATRI/plugins/status/__init__.py | 2 +- ATRI/utils/request.py | 10 +++- README.md | 3 +- changelog.md | 4 ++ config.yml | 1 + requirements.txt | 6 +- 12 files changed, 262 insertions(+), 10 deletions(-) create mode 100644 ATRI/plugins/setu/tf_dealer.py diff --git a/ATRI/config.py b/ATRI/config.py index 7d5127c..d1a9bfd 100644 --- a/ATRI/config.py +++ b/ATRI/config.py @@ -1,8 +1,13 @@ +import yaml from pathlib import Path from datetime import timedelta from ipaddress import IPv4Address -from .utils import * + +def load_yml(file: Path, encoding="utf-8") -> dict: + with open(file, "r", encoding=encoding) as f: + data = yaml.safe_load(f) + return data CONFIG_PATH = Path(".") / "config.yml" @@ -22,6 +27,7 @@ class BotSelfConfig: session_expire_timeout: timedelta = timedelta( seconds=config.get("session_expire_timeout", 60) ) + proxy: str = config.get("proxy", None) class SauceNAO: diff --git a/ATRI/plugins/chat/__init__.py b/ATRI/plugins/chat/__init__.py index 9f7f26c..ab22db9 100644 --- a/ATRI/plugins/chat/__init__.py +++ b/ATRI/plugins/chat/__init__.py @@ -26,7 +26,10 @@ async def _chat(bot: Bot, event: MessageEvent): msg = str(event.message) repo = await Chat().deal(msg, user_id) _chat_flmt.start_cd(user_id) - await chat.finish(repo) + try: + await chat.finish(repo) + except Exception: + return my_name_is = Chat().on_command("叫我", "更改闲聊(划掉 文爱)时的称呼", aliases={"我是"}, priority=1) @@ -111,7 +114,7 @@ async def _deal_say(bot: Bot, event: MessageEvent, state: T_State): await say.finish(msg) -@scheduler.scheduled_job("interval", hours=3, misfire_grace_time=60) +@scheduler.scheduled_job("interval", name="闲聊词库检查更新", hours=3, misfire_grace_time=60) async def _check_kimo(): try: await Chat().update_data() diff --git a/ATRI/plugins/essential.py b/ATRI/plugins/essential.py index 3bbf66e..4cc0ac2 100644 --- a/ATRI/plugins/essential.py +++ b/ATRI/plugins/essential.py @@ -1,5 +1,6 @@ import os import json +import shutil import asyncio from datetime import datetime from pydantic.main import BaseModel @@ -30,14 +31,19 @@ from ATRI.service import Service from ATRI.log import logger as log from ATRI.config import BotSelfConfig from ATRI.utils import CoolqCodeChecker +from ATRI.utils.apscheduler import scheduler + driver = ATRI.driver() bots = nonebot.get_bots() + ESSENTIAL_DIR = Path(".") / "data" / "database" / "essential" MANEGE_DIR = Path(".") / "data" / "database" / "manege" +TEMP_PATH = Path(".") / "data" / "temp" os.makedirs(ESSENTIAL_DIR, exist_ok=True) os.makedirs(MANEGE_DIR, exist_ok=True) +os.makedirs(TEMP_PATH, exist_ok=True) @driver.on_startup @@ -298,3 +304,11 @@ async def _recall_private_event(bot: Bot, event: FriendRecallNoticeEvent): msg = "主人,咱拿到了一条撤回信息!\n" f"{user}@[私聊]" "撤回了\n" f"{repo}" for superuser in BotSelfConfig.superusers: await bot.send_private_msg(user_id=int(superuser), message=msg) + + +@scheduler.scheduled_job("interval", name="清除缓存", minutes=30, misfire_grace_time=5) +async def _clear_cache(): + try: + shutil.rmtree(TEMP_PATH) + except Exception: + log.warning("清除缓存失败,请手动清除:data/temp") diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py index eec5281..3e0e452 100644 --- a/ATRI/plugins/setu/__init__.py +++ b/ATRI/plugins/setu/__init__.py @@ -2,7 +2,10 @@ import re import asyncio from random import choice from nonebot.adapters.cqhttp import Bot, MessageEvent, Message +from nonebot.adapters.cqhttp.message import MessageSegment +from nonebot.typing import T_State +from ATRI.config import BotSelfConfig from ATRI.utils.limit import FreqLimiter, DailyLimiter from ATRI.utils.apscheduler import scheduler from .data_source import Setu @@ -74,7 +77,91 @@ async def _tag_setu(bot: Bot, event: MessageEvent): await bot.delete_msg(message_id=event_id) -@scheduler.scheduled_job("interval", hours=1, misfire_grace_time=60, args=[Bot]) +setu_catcher = Setu().on_message("涩图嗅探") + + +@setu_catcher.handle() +async def _setu_catcher(bot: Bot, event: MessageEvent): + msg = str(event.message) + pattern = r"url=(.*?)]" + args = re.findall(pattern, msg) + if not args: + return + else: + hso = list() + for i in args: + try: + data = await Setu().detecter(i) + except Exception: + return + if data[1] > 0.7: + hso.append(data[1]) + + hso.sort(reverse=True) + + if not hso: + return + elif len(hso) == 1: + u_repo = f"hso! 涩值:{'{:.2%}'.format(hso[0])}\n不行我要发给别人看" + s_repo = ( + f"涩图来咧!\n{MessageSegment.image(args[0])}\n涩值:{'{:.2%}'.format(hso[0])}" + ) + + else: + u_repo = f"hso! 最涩的达到:{'{:.2%}'.format(hso[0])}\n不行我一定要发给别人看" + + ss = list() + for s in args: + ss.append(MessageSegment.image(s)) + ss = "\n".join(ss) + s_repo = f"多张涩图来咧!\n{ss}\n最涩的达到:{'{:.2%}'.format(hso[0])}" + + await bot.send(event, u_repo) + for superuser in BotSelfConfig.superusers: + await bot.send_private_msg(user_id=superuser, message=s_repo) + + +nsfw_checker = Setu().on_command("/nsfw", "涩值检测") + + +@nsfw_checker.handle() +async def _nsfw_checker(bot: Bot, event: MessageEvent, state: T_State): + msg = str(event.message).strip() + if msg: + state["nsfw_img"] = msg + + +@nsfw_checker.got("nsfw_img", "图呢?") +async def _deal_check(bot: Bot, event: MessageEvent, state: T_State): + msg = state["nsfw_img"] + pattern = r"url=(.*?)]" + args = re.findall(pattern, msg) + if not args: + await nsfw_checker.reject("请发送图片而不是其他东西!!") + + data = await Setu().detecter(args[0]) + hso = data[1] + if not hso: + await nsfw_checker.finish("图太小了!不测!") + + resu = f"涩值:{'{:.2%}'.format(hso)}\n" + if hso >= 0.75: + resu += "hso!不行我要发给别人看" + repo = f"涩图来咧!\n{MessageSegment.image(args[0])}\n涩值:{'{:.2%}'.format(hso[0])}" + for superuser in BotSelfConfig.superusers: + await bot.send_private_msg(user_id=superuser, message=repo) + + elif 0.75 > hso >= 0.5: + resu += "嗯。可冲" + else: + resu += "还行8" + + await nsfw_checker.finish(resu) + + +@scheduler.scheduled_job( + "interval", name="涩批诱捕器", hours=1, misfire_grace_time=60, args=[Bot] +) async def _scheduler_setu(bot): try: group_list = await bot.get_group_list() diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py index 2c665ae..39f3815 100644 --- a/ATRI/plugins/setu/data_source.py +++ b/ATRI/plugins/setu/data_source.py @@ -7,6 +7,7 @@ from nonebot.adapters.cqhttp import MessageSegment from ATRI.service import Service from ATRI.rule import is_in_service from ATRI.utils import request +from .tf_dealer import detect_image LOLICON_URL = "https://api.lolicon.app/setu/v2" @@ -60,6 +61,14 @@ class Setu(Service): repo = f"Title: {title}\nPid: {p_id}" return repo, setu + @staticmethod + async def detecter(url: str) -> list: + """ + 涩值检测. + """ + data = await detect_image(url) + return data + @staticmethod async def scheduler() -> str: """ diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py new file mode 100644 index 0000000..f41de49 --- /dev/null +++ b/ATRI/plugins/setu/tf_dealer.py @@ -0,0 +1,119 @@ +import io +import os +import re +import string +import asyncio +import skimage +import skimage.io +import numpy as np +from PIL import Image +from pathlib import Path +from sys import getsizeof +from random import sample + +try: + import tflite_runtime.interpreter as tf # type: ignore +except Exception: + import tensorflow as tf + +from ATRI.log import logger as log +from ATRI.utils import request +from ATRI.exceptions import RequestError, WriteError + + +SETU_PATH = Path(".") / "data" / "database" / "setu" +TEMP_PATH = Path(".") / "data" / "temp" +os.makedirs(SETU_PATH, exist_ok=True) +os.makedirs(TEMP_PATH, exist_ok=True) + + +MODULE_URL = "https://cdn.jsdelivr.net/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite" +VGG_MEAN = [104, 117, 123] + + +def prepare_image(img): + H, W, _ = img.shape + h, w = (224, 224) + + h_off = max((H - h) // 2, 0) + w_off = max((W - w) // 2, 0) + image = img[h_off : h_off + h, w_off : w_off + w, :] + + image = image[:, :, ::-1] + + image = image.astype(np.float32, copy=False) + image = image * 255.0 + image = image - np.array(VGG_MEAN, dtype=np.float32) + + image = np.expand_dims(image, axis=0) + return image + + +async def detect_image(url) -> list: + try: + req = await request.get(url) + except RequestError: + raise RequestError("Get info from download image failed!") + + img_byte = getsizeof(req.read()) // 1024 + if img_byte < 256: + return [0, 0] + + try: + pattern = r"-(.*?)\/" + file_name = re.findall(pattern, url)[0] + path = TEMP_PATH / f"{file_name}.jpg" + with open(path, "wb") as f: + f.write(req.read()) + except WriteError: + raise WriteError("Writing file failed!") + + await init_module() + model_path = str((SETU_PATH / "nsfw.tflite").absolute()) + + try: + interpreter = tf.Interpreter(model_path=model_path) # type: ignore + except Exception: + interpreter = tf.lite.Interpreter(model_path=model_path) + + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + im = Image.open(path) + + if im.mode != "RGB": + im = im.convert("RGB") + imr = im.resize((256, 256), resample=Image.BILINEAR) + fh_im = io.BytesIO() + imr.save(fh_im, format="JPEG") + fh_im.seek(0) + + image = skimage.img_as_float32(skimage.io.imread(fh_im)) + + final = prepare_image(image) + interpreter.set_tensor(input_details[0]["index"], final) + + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]["index"]) + + result = np.squeeze(output_data).tolist() + return result + + +async def init_module(): + file_name = "nsfw.tflite" + path = SETU_PATH / file_name + if not path.is_file(): + log.warning("缺少模型文件,装载中") + data = await request.get(MODULE_URL) + try: + with open(path, "wb") as w: + w.write(data.read()) + log.info("模型装载完成") + except WriteError: + raise WriteError("NSFW TF module init failed!") + + +asyncio.get_event_loop().run_until_complete(init_module()) diff --git a/ATRI/plugins/status/__init__.py b/ATRI/plugins/status/__init__.py index 2746953..359344b 100644 --- a/ATRI/plugins/status/__init__.py +++ b/ATRI/plugins/status/__init__.py @@ -24,7 +24,7 @@ async def _status(bot: Bot, event: MessageEvent): info_msg = "アトリは高性能ですから!" -@scheduler.scheduled_job("interval", minutes=10, misfire_grace_time=15) +@scheduler.scheduled_job("interval", name="状态检查", minutes=10, misfire_grace_time=15) async def _status_checking(): global info_msg msg, stat = IsSurvive().get_status() diff --git a/ATRI/utils/request.py b/ATRI/utils/request.py index 8dec9c7..2cfce85 100644 --- a/ATRI/utils/request.py +++ b/ATRI/utils/request.py @@ -1,11 +1,17 @@ import httpx +from ATRI.config import BotSelfConfig + + +proxy = BotSelfConfig.proxy +if not proxy: + proxy = dict() async def get(url: str, **kwargs): - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(proxies=proxy) as client: return await client.get(url, **kwargs) async def post(url: str, **kwargs): - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(proxies=proxy) as client: return await client.post(url, **kwargs) diff --git a/README.md b/README.md index a94c270..89a405f 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ 涩批: - 文爱 - 涩图 +- 涩图嗅探 - 涩批翻译机 实用: @@ -57,7 +58,7 @@ - 状态查看 **TODO**: - - [x] 网页控制台 (目前仅支持部分数据可视化,请于启动后点击,uvicorn给予的url以访问) + - [ ] 网页控制台 - [ ] RSS订阅 - [ ] B站动态订阅 - [ ] 冷重启 diff --git a/changelog.md b/changelog.md index 6355903..47e5578 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,9 @@ > 此处仅为记录新功能更新,修复 BUG/以及其它 请关注[`GitHub commits`](https://github.com/Kyomotoi/ATRI/commits/main) +## Oct 24, 2021 +- 新增: + - nsfw检测(主动/被动) + ## Jul 31, 2021 - 新增: - 前端(单主页) diff --git a/config.yml b/config.yml index ad29648..afe9fde 100644 --- a/config.yml +++ b/config.yml @@ -7,6 +7,7 @@ BotSelfConfig: command_start: ["", "/"] command_sep: ["."] session_expire_timeout: 60 + proxy: "" SauceNAO: key: "" diff --git a/requirements.txt b/requirements.txt index b02b614..bef0a64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,12 @@ aiohttp>=3.6.2 aiofiles>=0.6.0 APScheduler>=3.7.0 Pillow>=8.1.1 -nonebot-adapter-cqhttp>=2.0.0a12 +nonebot-adapter-cqhttp>=2.0.0a14 nonebot-plugin-test>=0.2.0 nonebot2>=2.0.0a13.post1 psutil>=5.7.2 pathlib>=1.0.1 pytz>=2020.1 -pyyaml>=5.4 \ No newline at end of file +pyyaml>=5.4 +scikit-image>=0.18.3 +tensorflow>=2.6.0 \ No newline at end of file -- cgit v1.2.3