summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKyomotoi <[email protected]>2021-10-24 16:47:29 +0800
committerKyomotoi <[email protected]>2021-10-24 16:47:29 +0800
commit16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a (patch)
tree77630910cac9cc38396cabe3dcc11e2d856a6c2f
parentd789b1ae77f4415dab062c4af516e303dc447ddc (diff)
downloadATRI-16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a.tar.gz
ATRI-16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a.tar.bz2
ATRI-16838e0d83c8dd3f1be1e6ec5fb1cecd9b319d8a.zip
🔖 更新版本:YHN-001-A04
新增: - nsfw检测(主动/被动) 又名 涩图嗅探 - 可选代理 修复: - plugin/chat 在 nb2-a14+ 版本 finish 内为空时会报错 其他: - 对定时任务进行中文命名
-rw-r--r--ATRI/config.py8
-rw-r--r--ATRI/plugins/chat/__init__.py7
-rw-r--r--ATRI/plugins/essential.py14
-rw-r--r--ATRI/plugins/setu/__init__.py89
-rw-r--r--ATRI/plugins/setu/data_source.py9
-rw-r--r--ATRI/plugins/setu/tf_dealer.py119
-rw-r--r--ATRI/plugins/status/__init__.py2
-rw-r--r--ATRI/utils/request.py10
-rw-r--r--README.md3
-rw-r--r--changelog.md4
-rw-r--r--config.yml1
-rw-r--r--requirements.txt6
12 files changed, 262 insertions, 10 deletions
diff --git a/ATRI/config.py b/ATRI/config.py
index 7d5127c..d1a9bfd 100644
--- a/ATRI/config.py
+++ b/ATRI/config.py
@@ -1,8 +1,13 @@
+import yaml
from pathlib import Path
from datetime import timedelta
from ipaddress import IPv4Address
-from .utils import *
+
+def load_yml(file: Path, encoding="utf-8") -> dict:
+ with open(file, "r", encoding=encoding) as f:
+ data = yaml.safe_load(f)
+ return data
CONFIG_PATH = Path(".") / "config.yml"
@@ -22,6 +27,7 @@ class BotSelfConfig:
session_expire_timeout: timedelta = timedelta(
seconds=config.get("session_expire_timeout", 60)
)
+ proxy: str = config.get("proxy", None)
class SauceNAO:
diff --git a/ATRI/plugins/chat/__init__.py b/ATRI/plugins/chat/__init__.py
index 9f7f26c..ab22db9 100644
--- a/ATRI/plugins/chat/__init__.py
+++ b/ATRI/plugins/chat/__init__.py
@@ -26,7 +26,10 @@ async def _chat(bot: Bot, event: MessageEvent):
msg = str(event.message)
repo = await Chat().deal(msg, user_id)
_chat_flmt.start_cd(user_id)
- await chat.finish(repo)
+ try:
+ await chat.finish(repo)
+ except Exception:
+ return
my_name_is = Chat().on_command("叫我", "更改闲聊(划掉 文爱)时的称呼", aliases={"我是"}, priority=1)
@@ -111,7 +114,7 @@ async def _deal_say(bot: Bot, event: MessageEvent, state: T_State):
await say.finish(msg)
[email protected]_job("interval", hours=3, misfire_grace_time=60)
[email protected]_job("interval", name="闲聊词库检查更新", hours=3, misfire_grace_time=60)
async def _check_kimo():
try:
await Chat().update_data()
diff --git a/ATRI/plugins/essential.py b/ATRI/plugins/essential.py
index 3bbf66e..4cc0ac2 100644
--- a/ATRI/plugins/essential.py
+++ b/ATRI/plugins/essential.py
@@ -1,5 +1,6 @@
import os
import json
+import shutil
import asyncio
from datetime import datetime
from pydantic.main import BaseModel
@@ -30,14 +31,19 @@ from ATRI.service import Service
from ATRI.log import logger as log
from ATRI.config import BotSelfConfig
from ATRI.utils import CoolqCodeChecker
+from ATRI.utils.apscheduler import scheduler
+
driver = ATRI.driver()
bots = nonebot.get_bots()
+
ESSENTIAL_DIR = Path(".") / "data" / "database" / "essential"
MANEGE_DIR = Path(".") / "data" / "database" / "manege"
+TEMP_PATH = Path(".") / "data" / "temp"
os.makedirs(ESSENTIAL_DIR, exist_ok=True)
os.makedirs(MANEGE_DIR, exist_ok=True)
+os.makedirs(TEMP_PATH, exist_ok=True)
@driver.on_startup
@@ -298,3 +304,11 @@ async def _recall_private_event(bot: Bot, event: FriendRecallNoticeEvent):
msg = "主人,咱拿到了一条撤回信息!\n" f"{user}@[私聊]" "撤回了\n" f"{repo}"
for superuser in BotSelfConfig.superusers:
await bot.send_private_msg(user_id=int(superuser), message=msg)
+
+
[email protected]_job("interval", name="清除缓存", minutes=30, misfire_grace_time=5)
+async def _clear_cache():
+ try:
+ shutil.rmtree(TEMP_PATH)
+ except Exception:
+ log.warning("清除缓存失败,请手动清除:data/temp")
diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py
index eec5281..3e0e452 100644
--- a/ATRI/plugins/setu/__init__.py
+++ b/ATRI/plugins/setu/__init__.py
@@ -2,7 +2,10 @@ import re
import asyncio
from random import choice
from nonebot.adapters.cqhttp import Bot, MessageEvent, Message
+from nonebot.adapters.cqhttp.message import MessageSegment
+from nonebot.typing import T_State
+from ATRI.config import BotSelfConfig
from ATRI.utils.limit import FreqLimiter, DailyLimiter
from ATRI.utils.apscheduler import scheduler
from .data_source import Setu
@@ -74,7 +77,91 @@ async def _tag_setu(bot: Bot, event: MessageEvent):
await bot.delete_msg(message_id=event_id)
[email protected]_job("interval", hours=1, misfire_grace_time=60, args=[Bot])
+setu_catcher = Setu().on_message("涩图嗅探")
+
+
+@setu_catcher.handle()
+async def _setu_catcher(bot: Bot, event: MessageEvent):
+ msg = str(event.message)
+ pattern = r"url=(.*?)]"
+ args = re.findall(pattern, msg)
+ if not args:
+ return
+ else:
+ hso = list()
+ for i in args:
+ try:
+ data = await Setu().detecter(i)
+ except Exception:
+ return
+ if data[1] > 0.7:
+ hso.append(data[1])
+
+ hso.sort(reverse=True)
+
+ if not hso:
+ return
+ elif len(hso) == 1:
+ u_repo = f"hso! 涩值:{'{:.2%}'.format(hso[0])}\n不行我要发给别人看"
+ s_repo = (
+ f"涩图来咧!\n{MessageSegment.image(args[0])}\n涩值:{'{:.2%}'.format(hso[0])}"
+ )
+
+ else:
+ u_repo = f"hso! 最涩的达到:{'{:.2%}'.format(hso[0])}\n不行我一定要发给别人看"
+
+ ss = list()
+ for s in args:
+ ss.append(MessageSegment.image(s))
+ ss = "\n".join(ss)
+ s_repo = f"多张涩图来咧!\n{ss}\n最涩的达到:{'{:.2%}'.format(hso[0])}"
+
+ await bot.send(event, u_repo)
+ for superuser in BotSelfConfig.superusers:
+ await bot.send_private_msg(user_id=superuser, message=s_repo)
+
+
+nsfw_checker = Setu().on_command("/nsfw", "涩值检测")
+
+
+@nsfw_checker.handle()
+async def _nsfw_checker(bot: Bot, event: MessageEvent, state: T_State):
+ msg = str(event.message).strip()
+ if msg:
+ state["nsfw_img"] = msg
+
+
+@nsfw_checker.got("nsfw_img", "图呢?")
+async def _deal_check(bot: Bot, event: MessageEvent, state: T_State):
+ msg = state["nsfw_img"]
+ pattern = r"url=(.*?)]"
+ args = re.findall(pattern, msg)
+ if not args:
+ await nsfw_checker.reject("请发送图片而不是其他东西!!")
+
+ data = await Setu().detecter(args[0])
+ hso = data[1]
+ if not hso:
+ await nsfw_checker.finish("图太小了!不测!")
+
+ resu = f"涩值:{'{:.2%}'.format(hso)}\n"
+ if hso >= 0.75:
+ resu += "hso!不行我要发给别人看"
+ repo = f"涩图来咧!\n{MessageSegment.image(args[0])}\n涩值:{'{:.2%}'.format(hso[0])}"
+ for superuser in BotSelfConfig.superusers:
+ await bot.send_private_msg(user_id=superuser, message=repo)
+
+ elif 0.75 > hso >= 0.5:
+ resu += "嗯。可冲"
+ else:
+ resu += "还行8"
+
+ await nsfw_checker.finish(resu)
+
+
+ "interval", name="涩批诱捕器", hours=1, misfire_grace_time=60, args=[Bot]
+)
async def _scheduler_setu(bot):
try:
group_list = await bot.get_group_list()
diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py
index 2c665ae..39f3815 100644
--- a/ATRI/plugins/setu/data_source.py
+++ b/ATRI/plugins/setu/data_source.py
@@ -7,6 +7,7 @@ from nonebot.adapters.cqhttp import MessageSegment
from ATRI.service import Service
from ATRI.rule import is_in_service
from ATRI.utils import request
+from .tf_dealer import detect_image
LOLICON_URL = "https://api.lolicon.app/setu/v2"
@@ -61,6 +62,14 @@ class Setu(Service):
return repo, setu
@staticmethod
+ async def detecter(url: str) -> list:
+ """
+ 涩值检测.
+ """
+ data = await detect_image(url)
+ return data
+
+ @staticmethod
async def scheduler() -> str:
"""
每隔指定时间随机抽取一个群发送涩图.
diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py
new file mode 100644
index 0000000..f41de49
--- /dev/null
+++ b/ATRI/plugins/setu/tf_dealer.py
@@ -0,0 +1,119 @@
+import io
+import os
+import re
+import string
+import asyncio
+import skimage
+import skimage.io
+import numpy as np
+from PIL import Image
+from pathlib import Path
+from sys import getsizeof
+from random import sample
+
+try:
+ import tflite_runtime.interpreter as tf # type: ignore
+except Exception:
+ import tensorflow as tf
+
+from ATRI.log import logger as log
+from ATRI.utils import request
+from ATRI.exceptions import RequestError, WriteError
+
+
+SETU_PATH = Path(".") / "data" / "database" / "setu"
+TEMP_PATH = Path(".") / "data" / "temp"
+os.makedirs(SETU_PATH, exist_ok=True)
+os.makedirs(TEMP_PATH, exist_ok=True)
+
+
+MODULE_URL = "https://cdn.jsdelivr.net/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite"
+VGG_MEAN = [104, 117, 123]
+
+
+def prepare_image(img):
+ H, W, _ = img.shape
+ h, w = (224, 224)
+
+ h_off = max((H - h) // 2, 0)
+ w_off = max((W - w) // 2, 0)
+ image = img[h_off : h_off + h, w_off : w_off + w, :]
+
+ image = image[:, :, ::-1]
+
+ image = image.astype(np.float32, copy=False)
+ image = image * 255.0
+ image = image - np.array(VGG_MEAN, dtype=np.float32)
+
+ image = np.expand_dims(image, axis=0)
+ return image
+
+
+async def detect_image(url) -> list:
+ try:
+ req = await request.get(url)
+ except RequestError:
+ raise RequestError("Get info from download image failed!")
+
+ img_byte = getsizeof(req.read()) // 1024
+ if img_byte < 256:
+ return [0, 0]
+
+ try:
+ pattern = r"-(.*?)\/"
+ file_name = re.findall(pattern, url)[0]
+ path = TEMP_PATH / f"{file_name}.jpg"
+ with open(path, "wb") as f:
+ f.write(req.read())
+ except WriteError:
+ raise WriteError("Writing file failed!")
+
+ await init_module()
+ model_path = str((SETU_PATH / "nsfw.tflite").absolute())
+
+ try:
+ interpreter = tf.Interpreter(model_path=model_path) # type: ignore
+ except Exception:
+ interpreter = tf.lite.Interpreter(model_path=model_path)
+
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ im = Image.open(path)
+
+ if im.mode != "RGB":
+ im = im.convert("RGB")
+ imr = im.resize((256, 256), resample=Image.BILINEAR)
+ fh_im = io.BytesIO()
+ imr.save(fh_im, format="JPEG")
+ fh_im.seek(0)
+
+ image = skimage.img_as_float32(skimage.io.imread(fh_im))
+
+ final = prepare_image(image)
+ interpreter.set_tensor(input_details[0]["index"], final)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]["index"])
+
+ result = np.squeeze(output_data).tolist()
+ return result
+
+
+async def init_module():
+ file_name = "nsfw.tflite"
+ path = SETU_PATH / file_name
+ if not path.is_file():
+ log.warning("缺少模型文件,装载中")
+ data = await request.get(MODULE_URL)
+ try:
+ with open(path, "wb") as w:
+ w.write(data.read())
+ log.info("模型装载完成")
+ except WriteError:
+ raise WriteError("NSFW TF module init failed!")
+
+
+asyncio.get_event_loop().run_until_complete(init_module())
diff --git a/ATRI/plugins/status/__init__.py b/ATRI/plugins/status/__init__.py
index 2746953..359344b 100644
--- a/ATRI/plugins/status/__init__.py
+++ b/ATRI/plugins/status/__init__.py
@@ -24,7 +24,7 @@ async def _status(bot: Bot, event: MessageEvent):
info_msg = "アトリは高性能ですから!"
[email protected]_job("interval", minutes=10, misfire_grace_time=15)
[email protected]_job("interval", name="状态检查", minutes=10, misfire_grace_time=15)
async def _status_checking():
global info_msg
msg, stat = IsSurvive().get_status()
diff --git a/ATRI/utils/request.py b/ATRI/utils/request.py
index 8dec9c7..2cfce85 100644
--- a/ATRI/utils/request.py
+++ b/ATRI/utils/request.py
@@ -1,11 +1,17 @@
import httpx
+from ATRI.config import BotSelfConfig
+
+
+proxy = BotSelfConfig.proxy
+if not proxy:
+ proxy = dict()
async def get(url: str, **kwargs):
- async with httpx.AsyncClient() as client:
+ async with httpx.AsyncClient(proxies=proxy) as client:
return await client.get(url, **kwargs)
async def post(url: str, **kwargs):
- async with httpx.AsyncClient() as client:
+ async with httpx.AsyncClient(proxies=proxy) as client:
return await client.post(url, **kwargs)
diff --git a/README.md b/README.md
index a94c270..89a405f 100644
--- a/README.md
+++ b/README.md
@@ -35,6 +35,7 @@
涩批:
- 文爱
- 涩图
+- 涩图嗅探
- 涩批翻译机
实用:
@@ -57,7 +58,7 @@
- 状态查看
**TODO**:
- - [x] 网页控制台 (目前仅支持部分数据可视化,请于启动后点击,uvicorn给予的url以访问)
+ - [ ] 网页控制台
- [ ] RSS订阅
- [ ] B站动态订阅
- [ ] 冷重启
diff --git a/changelog.md b/changelog.md
index 6355903..47e5578 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,5 +1,9 @@
> 此处仅为记录新功能更新,修复 BUG/以及其它 请关注[`GitHub commits`](https://github.com/Kyomotoi/ATRI/commits/main)
+## Oct 24, 2021
+- 新增:
+ - nsfw检测(主动/被动)
+
## Jul 31, 2021
- 新增:
- 前端(单主页)
diff --git a/config.yml b/config.yml
index ad29648..afe9fde 100644
--- a/config.yml
+++ b/config.yml
@@ -7,6 +7,7 @@ BotSelfConfig:
command_start: ["", "/"]
command_sep: ["."]
session_expire_timeout: 60
+ proxy: ""
SauceNAO:
key: ""
diff --git a/requirements.txt b/requirements.txt
index b02b614..bef0a64 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,10 +2,12 @@ aiohttp>=3.6.2
aiofiles>=0.6.0
APScheduler>=3.7.0
Pillow>=8.1.1
-nonebot-adapter-cqhttp>=2.0.0a12
+nonebot-adapter-cqhttp>=2.0.0a14
nonebot-plugin-test>=0.2.0
nonebot2>=2.0.0a13.post1
psutil>=5.7.2
pathlib>=1.0.1
pytz>=2020.1
-pyyaml>=5.4 \ No newline at end of file
+pyyaml>=5.4
+scikit-image>=0.18.3
+tensorflow>=2.6.0 \ No newline at end of file