diff options
author | Kyomotoi <[email protected]> | 2021-10-24 16:47:29 +0800 |
---|---|---|
committer | Kyomotoi <[email protected]> | 2021-10-24 16:47:29 +0800 |
commit | 16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a (patch) | |
tree | 77630910cac9cc38396cabe3dcc11e2d856a6c2f | |
parent | d789b1ae77f4415dab062c4af516e303dc447ddc (diff) | |
download | ATRI-16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a.tar.gz ATRI-16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a.tar.bz2 ATRI-16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a.zip |
🔖 更新版本:YHN-001-A04
新增:
- nsfw检测(主动/被动) 又名 涩图嗅探
- 可选代理
修复:
- plugin/chat 在 nb2-a14+ 版本 finish 内为空时会报错
其他:
- 对定时任务进行中文命名
-rw-r--r-- | ATRI/config.py | 8 | ||||
-rw-r--r-- | ATRI/plugins/chat/__init__.py | 7 | ||||
-rw-r--r-- | ATRI/plugins/essential.py | 14 | ||||
-rw-r--r-- | ATRI/plugins/setu/__init__.py | 89 | ||||
-rw-r--r-- | ATRI/plugins/setu/data_source.py | 9 | ||||
-rw-r--r-- | ATRI/plugins/setu/tf_dealer.py | 119 | ||||
-rw-r--r-- | ATRI/plugins/status/__init__.py | 2 | ||||
-rw-r--r-- | ATRI/utils/request.py | 10 | ||||
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | changelog.md | 4 | ||||
-rw-r--r-- | config.yml | 1 | ||||
-rw-r--r-- | requirements.txt | 6 |
12 files changed, 262 insertions, 10 deletions
diff --git a/ATRI/config.py b/ATRI/config.py index 7d5127c..d1a9bfd 100644 --- a/ATRI/config.py +++ b/ATRI/config.py @@ -1,8 +1,13 @@ +import yaml from pathlib import Path from datetime import timedelta from ipaddress import IPv4Address -from .utils import * + +def load_yml(file: Path, encoding="utf-8") -> dict: + with open(file, "r", encoding=encoding) as f: + data = yaml.safe_load(f) + return data CONFIG_PATH = Path(".") / "config.yml" @@ -22,6 +27,7 @@ class BotSelfConfig: session_expire_timeout: timedelta = timedelta( seconds=config.get("session_expire_timeout", 60) ) + proxy: str = config.get("proxy", None) class SauceNAO: diff --git a/ATRI/plugins/chat/__init__.py b/ATRI/plugins/chat/__init__.py index 9f7f26c..ab22db9 100644 --- a/ATRI/plugins/chat/__init__.py +++ b/ATRI/plugins/chat/__init__.py @@ -26,7 +26,10 @@ async def _chat(bot: Bot, event: MessageEvent): msg = str(event.message) repo = await Chat().deal(msg, user_id) _chat_flmt.start_cd(user_id) - await chat.finish(repo) + try: + await chat.finish(repo) + except Exception: + return my_name_is = Chat().on_command("叫我", "更改闲聊(划掉 文爱)时的称呼", aliases={"我是"}, priority=1) @@ -111,7 +114,7 @@ async def _deal_say(bot: Bot, event: MessageEvent, state: T_State): await say.finish(msg) [email protected]_job("interval", hours=3, misfire_grace_time=60) [email protected]_job("interval", name="闲聊词库检查更新", hours=3, misfire_grace_time=60) async def _check_kimo(): try: await Chat().update_data() diff --git a/ATRI/plugins/essential.py b/ATRI/plugins/essential.py index 3bbf66e..4cc0ac2 100644 --- a/ATRI/plugins/essential.py +++ b/ATRI/plugins/essential.py @@ -1,5 +1,6 @@ import os import json +import shutil import asyncio from datetime import datetime from pydantic.main import BaseModel @@ -30,14 +31,19 @@ from ATRI.service import Service from ATRI.log import logger as log from ATRI.config import BotSelfConfig from ATRI.utils import CoolqCodeChecker +from ATRI.utils.apscheduler import scheduler + driver = ATRI.driver() bots = nonebot.get_bots() + ESSENTIAL_DIR = Path(".") / "data" / "database" / "essential" MANEGE_DIR = Path(".") / "data" / "database" / "manege" +TEMP_PATH = Path(".") / "data" / "temp" os.makedirs(ESSENTIAL_DIR, exist_ok=True) os.makedirs(MANEGE_DIR, exist_ok=True) +os.makedirs(TEMP_PATH, exist_ok=True) @driver.on_startup @@ -298,3 +304,11 @@ async def _recall_private_event(bot: Bot, event: FriendRecallNoticeEvent): msg = "主人,咱拿到了一条撤回信息!\n" f"{user}@[私聊]" "撤回了\n" f"{repo}" for superuser in BotSelfConfig.superusers: await bot.send_private_msg(user_id=int(superuser), message=msg) + + [email protected]_job("interval", name="清除缓存", minutes=30, misfire_grace_time=5) +async def _clear_cache(): + try: + shutil.rmtree(TEMP_PATH) + except Exception: + log.warning("清除缓存失败,请手动清除:data/temp") diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py index eec5281..3e0e452 100644 --- a/ATRI/plugins/setu/__init__.py +++ b/ATRI/plugins/setu/__init__.py @@ -2,7 +2,10 @@ import re import asyncio from random import choice from nonebot.adapters.cqhttp import Bot, MessageEvent, Message +from nonebot.adapters.cqhttp.message import MessageSegment +from nonebot.typing import T_State +from ATRI.config import BotSelfConfig from ATRI.utils.limit import FreqLimiter, DailyLimiter from ATRI.utils.apscheduler import scheduler from .data_source import Setu @@ -74,7 +77,91 @@ async def _tag_setu(bot: Bot, event: MessageEvent): await bot.delete_msg(message_id=event_id) [email protected]_job("interval", hours=1, misfire_grace_time=60, args=[Bot]) +setu_catcher = Setu().on_message("涩图嗅探") + + +@setu_catcher.handle() +async def _setu_catcher(bot: Bot, event: MessageEvent): + msg = str(event.message) + pattern = r"url=(.*?)]" + args = re.findall(pattern, msg) + if not args: + return + else: + hso = list() + for i in args: + try: + data = await Setu().detecter(i) + except Exception: + return + if data[1] > 0.7: + hso.append(data[1]) + + hso.sort(reverse=True) + + if not hso: + return + elif len(hso) == 1: + u_repo = f"hso! 涩值:{'{:.2%}'.format(hso[0])}\n不行我要发给别人看" + s_repo = ( + f"涩图来咧!\n{MessageSegment.image(args[0])}\n涩值:{'{:.2%}'.format(hso[0])}" + ) + + else: + u_repo = f"hso! 最涩的达到:{'{:.2%}'.format(hso[0])}\n不行我一定要发给别人看" + + ss = list() + for s in args: + ss.append(MessageSegment.image(s)) + ss = "\n".join(ss) + s_repo = f"多张涩图来咧!\n{ss}\n最涩的达到:{'{:.2%}'.format(hso[0])}" + + await bot.send(event, u_repo) + for superuser in BotSelfConfig.superusers: + await bot.send_private_msg(user_id=superuser, message=s_repo) + + +nsfw_checker = Setu().on_command("/nsfw", "涩值检测") + + +@nsfw_checker.handle() +async def _nsfw_checker(bot: Bot, event: MessageEvent, state: T_State): + msg = str(event.message).strip() + if msg: + state["nsfw_img"] = msg + + +@nsfw_checker.got("nsfw_img", "图呢?") +async def _deal_check(bot: Bot, event: MessageEvent, state: T_State): + msg = state["nsfw_img"] + pattern = r"url=(.*?)]" + args = re.findall(pattern, msg) + if not args: + await nsfw_checker.reject("请发送图片而不是其他东西!!") + + data = await Setu().detecter(args[0]) + hso = data[1] + if not hso: + await nsfw_checker.finish("图太小了!不测!") + + resu = f"涩值:{'{:.2%}'.format(hso)}\n" + if hso >= 0.75: + resu += "hso!不行我要发给别人看" + repo = f"涩图来咧!\n{MessageSegment.image(args[0])}\n涩值:{'{:.2%}'.format(hso[0])}" + for superuser in BotSelfConfig.superusers: + await bot.send_private_msg(user_id=superuser, message=repo) + + elif 0.75 > hso >= 0.5: + resu += "嗯。可冲" + else: + resu += "还行8" + + await nsfw_checker.finish(resu) + + [email protected]_job( + "interval", name="涩批诱捕器", hours=1, misfire_grace_time=60, args=[Bot] +) async def _scheduler_setu(bot): try: group_list = await bot.get_group_list() diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py index 2c665ae..39f3815 100644 --- a/ATRI/plugins/setu/data_source.py +++ b/ATRI/plugins/setu/data_source.py @@ -7,6 +7,7 @@ from nonebot.adapters.cqhttp import MessageSegment from ATRI.service import Service from ATRI.rule import is_in_service from ATRI.utils import request +from .tf_dealer import detect_image LOLICON_URL = "https://api.lolicon.app/setu/v2" @@ -61,6 +62,14 @@ class Setu(Service): return repo, setu @staticmethod + async def detecter(url: str) -> list: + """ + 涩值检测. + """ + data = await detect_image(url) + return data + + @staticmethod async def scheduler() -> str: """ 每隔指定时间随机抽取一个群发送涩图. diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py new file mode 100644 index 0000000..f41de49 --- /dev/null +++ b/ATRI/plugins/setu/tf_dealer.py @@ -0,0 +1,119 @@ +import io +import os +import re +import string +import asyncio +import skimage +import skimage.io +import numpy as np +from PIL import Image +from pathlib import Path +from sys import getsizeof +from random import sample + +try: + import tflite_runtime.interpreter as tf # type: ignore +except Exception: + import tensorflow as tf + +from ATRI.log import logger as log +from ATRI.utils import request +from ATRI.exceptions import RequestError, WriteError + + +SETU_PATH = Path(".") / "data" / "database" / "setu" +TEMP_PATH = Path(".") / "data" / "temp" +os.makedirs(SETU_PATH, exist_ok=True) +os.makedirs(TEMP_PATH, exist_ok=True) + + +MODULE_URL = "https://cdn.jsdelivr.net/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite" +VGG_MEAN = [104, 117, 123] + + +def prepare_image(img): + H, W, _ = img.shape + h, w = (224, 224) + + h_off = max((H - h) // 2, 0) + w_off = max((W - w) // 2, 0) + image = img[h_off : h_off + h, w_off : w_off + w, :] + + image = image[:, :, ::-1] + + image = image.astype(np.float32, copy=False) + image = image * 255.0 + image = image - np.array(VGG_MEAN, dtype=np.float32) + + image = np.expand_dims(image, axis=0) + return image + + +async def detect_image(url) -> list: + try: + req = await request.get(url) + except RequestError: + raise RequestError("Get info from download image failed!") + + img_byte = getsizeof(req.read()) // 1024 + if img_byte < 256: + return [0, 0] + + try: + pattern = r"-(.*?)\/" + file_name = re.findall(pattern, url)[0] + path = TEMP_PATH / f"{file_name}.jpg" + with open(path, "wb") as f: + f.write(req.read()) + except WriteError: + raise WriteError("Writing file failed!") + + await init_module() + model_path = str((SETU_PATH / "nsfw.tflite").absolute()) + + try: + interpreter = tf.Interpreter(model_path=model_path) # type: ignore + except Exception: + interpreter = tf.lite.Interpreter(model_path=model_path) + + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + im = Image.open(path) + + if im.mode != "RGB": + im = im.convert("RGB") + imr = im.resize((256, 256), resample=Image.BILINEAR) + fh_im = io.BytesIO() + imr.save(fh_im, format="JPEG") + fh_im.seek(0) + + image = skimage.img_as_float32(skimage.io.imread(fh_im)) + + final = prepare_image(image) + interpreter.set_tensor(input_details[0]["index"], final) + + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]["index"]) + + result = np.squeeze(output_data).tolist() + return result + + +async def init_module(): + file_name = "nsfw.tflite" + path = SETU_PATH / file_name + if not path.is_file(): + log.warning("缺少模型文件,装载中") + data = await request.get(MODULE_URL) + try: + with open(path, "wb") as w: + w.write(data.read()) + log.info("模型装载完成") + except WriteError: + raise WriteError("NSFW TF module init failed!") + + +asyncio.get_event_loop().run_until_complete(init_module()) diff --git a/ATRI/plugins/status/__init__.py b/ATRI/plugins/status/__init__.py index 2746953..359344b 100644 --- a/ATRI/plugins/status/__init__.py +++ b/ATRI/plugins/status/__init__.py @@ -24,7 +24,7 @@ async def _status(bot: Bot, event: MessageEvent): info_msg = "アトリは高性能ですから!" [email protected]_job("interval", minutes=10, misfire_grace_time=15) [email protected]_job("interval", name="状态检查", minutes=10, misfire_grace_time=15) async def _status_checking(): global info_msg msg, stat = IsSurvive().get_status() diff --git a/ATRI/utils/request.py b/ATRI/utils/request.py index 8dec9c7..2cfce85 100644 --- a/ATRI/utils/request.py +++ b/ATRI/utils/request.py @@ -1,11 +1,17 @@ import httpx +from ATRI.config import BotSelfConfig + + +proxy = BotSelfConfig.proxy +if not proxy: + proxy = dict() async def get(url: str, **kwargs): - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(proxies=proxy) as client: return await client.get(url, **kwargs) async def post(url: str, **kwargs): - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(proxies=proxy) as client: return await client.post(url, **kwargs) @@ -35,6 +35,7 @@ 涩批: - 文爱 - 涩图 +- 涩图嗅探 - 涩批翻译机 实用: @@ -57,7 +58,7 @@ - 状态查看 **TODO**: - - [x] 网页控制台 (目前仅支持部分数据可视化,请于启动后点击,uvicorn给予的url以访问) + - [ ] 网页控制台 - [ ] RSS订阅 - [ ] B站动态订阅 - [ ] 冷重启 diff --git a/changelog.md b/changelog.md index 6355903..47e5578 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,9 @@ > 此处仅为记录新功能更新,修复 BUG/以及其它 请关注[`GitHub commits`](https://github.com/Kyomotoi/ATRI/commits/main) +## Oct 24, 2021 +- 新增: + - nsfw检测(主动/被动) + ## Jul 31, 2021 - 新增: - 前端(单主页) @@ -7,6 +7,7 @@ BotSelfConfig: command_start: ["", "/"] command_sep: ["."] session_expire_timeout: 60 + proxy: "" SauceNAO: key: "" diff --git a/requirements.txt b/requirements.txt index b02b614..bef0a64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,12 @@ aiohttp>=3.6.2 aiofiles>=0.6.0 APScheduler>=3.7.0 Pillow>=8.1.1 -nonebot-adapter-cqhttp>=2.0.0a12 +nonebot-adapter-cqhttp>=2.0.0a14 nonebot-plugin-test>=0.2.0 nonebot2>=2.0.0a13.post1 psutil>=5.7.2 pathlib>=1.0.1 pytz>=2020.1 -pyyaml>=5.4
\ No newline at end of file +pyyaml>=5.4 +scikit-image>=0.18.3 +tensorflow>=2.6.0
\ No newline at end of file |