diff options
-rw-r--r-- | ATRI/plugins/setu/__init__.py | 7 | ||||
-rw-r--r-- | ATRI/plugins/setu/data_source.py | 6 | ||||
-rw-r--r-- | ATRI/plugins/setu/nsfw_checker.py (renamed from ATRI/plugins/setu/tf_dealer.py) | 47 |
3 files changed, 23 insertions, 37 deletions
diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py index 71243eb..7962af8 100644 --- a/ATRI/plugins/setu/__init__.py +++ b/ATRI/plugins/setu/__init__.py @@ -99,8 +99,8 @@ async def _setu_catcher(bot: Bot, event: MessageEvent): data = await Setu().detecter(i, _catcher_max_file_size) except Exception: return - if data[1] > 0.7: - hso.append(data[1]) + if data > 0.7: + hso.append(data) hso.sort(reverse=True) @@ -135,8 +135,7 @@ async def _deal_check(bot: Bot, event: MessageEvent): if not args: await nsfw_checker.reject("请发送图片而不是其他东西!!") - data = await Setu().detecter(args[0], _catcher_max_file_size) - hso = data[1] + hso = await Setu().detecter(args[0], _catcher_max_file_size) if not hso: await nsfw_checker.finish("图太小了!不测!") diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py index 61b1cf8..71649fe 100644 --- a/ATRI/plugins/setu/data_source.py +++ b/ATRI/plugins/setu/data_source.py @@ -5,7 +5,7 @@ from ATRI.service import Service from ATRI.rule import is_in_service from ATRI.utils import request from ATRI.config import Setu as ST -from .tf_dealer import detect_image, init_module +from .nsfw_checker import detect_image, init_model LOLICON_URL = "https://api.lolicon.app/setu/v2" @@ -70,7 +70,7 @@ class Setu(Service): return repo, setu @staticmethod - async def detecter(url: str, file_size: int) -> list: + async def detecter(url: str, file_size: int) -> float: """ 涩值检测. """ @@ -85,4 +85,4 @@ class Setu(Service): from ATRI import driver -driver().on_startup(init_module) +driver().on_startup(init_model) diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/nsfw_checker.py index 4f4b374..2f045c1 100644 --- a/ATRI/plugins/setu/tf_dealer.py +++ b/ATRI/plugins/setu/nsfw_checker.py @@ -1,14 +1,12 @@ import io import re import skimage -import skimage.io +import onnxruntime import numpy as np from PIL import Image from pathlib import Path from sys import getsizeof -import tensorflow as tf - from ATRI.log import logger as log from ATRI.utils import request from ATRI.exceptions import RequestError, WriteFileError @@ -20,7 +18,7 @@ SETU_PATH.mkdir(parents=True, exist_ok=True) TEMP_PATH.mkdir(parents=True, exist_ok=True) -MODULE_URL = "https://jsd.imki.moe/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite" +MODEL_URL = "https://res.imki.moe/nsfw.onnx" VGG_MEAN = [104, 117, 123] @@ -42,7 +40,7 @@ def prepare_image(img): return image -async def detect_image(url: str, file_size: int) -> list: +async def detect_image(url: str, file_size: int) -> float: try: req = await request.get(url) except Exception: @@ -50,7 +48,7 @@ async def detect_image(url: str, file_size: int) -> list: img_byte = getsizeof(req.read()) // 1024 if img_byte < file_size: - return [0, 0] + return 0 try: pattern = r"-(.*?)\/" @@ -61,18 +59,8 @@ async def detect_image(url: str, file_size: int) -> list: except Exception: raise WriteFileError("Writing file failed!") - await init_module() - model_path = str((SETU_PATH / "nsfw.tflite").absolute()) - - try: - interpreter = tf.Interpreter(model_path=model_path) # type: ignore - except Exception: - interpreter = tf.lite.Interpreter(model_path=model_path) - - interpreter.allocate_tensors() - - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() + model_path = str(SETU_PATH / "nsfw.onnx") + session = onnxruntime.InferenceSession(model_path) im = Image.open(path) @@ -84,26 +72,25 @@ async def detect_image(url: str, file_size: int) -> list: fh_im.seek(0) image = skimage.img_as_float32(skimage.io.imread(fh_im)) - final = prepare_image(image) - interpreter.set_tensor(input_details[0]["index"], final) - interpreter.invoke() - output_data = interpreter.get_tensor(output_details[0]["index"]) + input_feed = {session.get_inputs()[0].name: final} + outputs = [output.name for output in session.get_outputs()] + result = session.run(outputs, input_feed) - result = np.squeeze(output_data).tolist() - return result + return result[0][0][1] -async def init_module(): - file_name = "nsfw.tflite" +async def init_model(): + file_name = "nsfw.onnx" path = SETU_PATH / file_name if not path.is_file(): - log.warning("缺少模型文件,装载中") + log.warning("插件 setu 缺少资源, 装载中...") try: - data = await request.get(MODULE_URL) + data = await request.get(MODEL_URL) with open(path, "wb") as w: w.write(data.read()) - log.info("模型装载完成") except Exception: - log.error("装载模型失败") + log.error("插件 setu 装载资源失败, 命令 '/nsfw' 将失效") + + log.success("插件 setu 装载资源完成") |