From 4006dce2581e7d6597958180021974a6de01294e Mon Sep 17 00:00:00 2001 From: Kyomotoi <0w0@imki.moe> Date: Sun, 18 Sep 2022 15:25:50 +0800 Subject: =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20=E6=9B=B4=E6=8D=A2=E6=B6=A9?= =?UTF-8?q?=E5=9B=BE=E6=A3=80=E6=B5=8B=E6=A8=A1=E5=9E=8B=E5=8F=8A=E5=85=B6?= =?UTF-8?q?=20runtime?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ATRI/plugins/setu/__init__.py | 7 ++- ATRI/plugins/setu/data_source.py | 6 +-- ATRI/plugins/setu/nsfw_checker.py | 96 +++++++++++++++++++++++++++++++++ ATRI/plugins/setu/tf_dealer.py | 109 -------------------------------------- 4 files changed, 102 insertions(+), 116 deletions(-) create mode 100644 ATRI/plugins/setu/nsfw_checker.py delete mode 100644 ATRI/plugins/setu/tf_dealer.py (limited to 'ATRI') diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py index 71243eb..7962af8 100644 --- a/ATRI/plugins/setu/__init__.py +++ b/ATRI/plugins/setu/__init__.py @@ -99,8 +99,8 @@ async def _setu_catcher(bot: Bot, event: MessageEvent): data = await Setu().detecter(i, _catcher_max_file_size) except Exception: return - if data[1] > 0.7: - hso.append(data[1]) + if data > 0.7: + hso.append(data) hso.sort(reverse=True) @@ -135,8 +135,7 @@ async def _deal_check(bot: Bot, event: MessageEvent): if not args: await nsfw_checker.reject("请发送图片而不是其他东西!!") - data = await Setu().detecter(args[0], _catcher_max_file_size) - hso = data[1] + hso = await Setu().detecter(args[0], _catcher_max_file_size) if not hso: await nsfw_checker.finish("图太小了!不测!") diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py index 61b1cf8..71649fe 100644 --- a/ATRI/plugins/setu/data_source.py +++ b/ATRI/plugins/setu/data_source.py @@ -5,7 +5,7 @@ from ATRI.service import Service from ATRI.rule import is_in_service from ATRI.utils import request from ATRI.config import Setu as ST -from .tf_dealer import detect_image, init_module +from .nsfw_checker import detect_image, init_model LOLICON_URL = "https://api.lolicon.app/setu/v2" @@ -70,7 +70,7 @@ class Setu(Service): return repo, setu @staticmethod - async def detecter(url: str, file_size: int) -> list: + async def detecter(url: str, file_size: int) -> float: """ 涩值检测. """ @@ -85,4 +85,4 @@ class Setu(Service): from ATRI import driver -driver().on_startup(init_module) +driver().on_startup(init_model) diff --git a/ATRI/plugins/setu/nsfw_checker.py b/ATRI/plugins/setu/nsfw_checker.py new file mode 100644 index 0000000..2f045c1 --- /dev/null +++ b/ATRI/plugins/setu/nsfw_checker.py @@ -0,0 +1,96 @@ +import io +import re +import skimage +import onnxruntime +import numpy as np +from PIL import Image +from pathlib import Path +from sys import getsizeof + +from ATRI.log import logger as log +from ATRI.utils import request +from ATRI.exceptions import RequestError, WriteFileError + + +SETU_PATH = Path(".") / "data" / "plugins" / "setu" +TEMP_PATH = Path(".") / "data" / "temp" +SETU_PATH.mkdir(parents=True, exist_ok=True) +TEMP_PATH.mkdir(parents=True, exist_ok=True) + + +MODEL_URL = "https://res.imki.moe/nsfw.onnx" +VGG_MEAN = [104, 117, 123] + + +def prepare_image(img): + H, W, _ = img.shape + h, w = (224, 224) + + h_off = max((H - h) // 2, 0) + w_off = max((W - w) // 2, 0) + image = img[h_off : h_off + h, w_off : w_off + w, :] + + image = image[:, :, ::-1] + + image = image.astype(np.float32, copy=False) + image = image * 255.0 + image = image - np.array(VGG_MEAN, dtype=np.float32) + + image = np.expand_dims(image, axis=0) + return image + + +async def detect_image(url: str, file_size: int) -> float: + try: + req = await request.get(url) + except Exception: + raise RequestError("Get info from download image failed!") + + img_byte = getsizeof(req.read()) // 1024 + if img_byte < file_size: + return 0 + + try: + pattern = r"-(.*?)\/" + file_name = re.findall(pattern, url)[0] + path = TEMP_PATH / f"{file_name}.jpg" + with open(path, "wb") as f: + f.write(req.read()) + except Exception: + raise WriteFileError("Writing file failed!") + + model_path = str(SETU_PATH / "nsfw.onnx") + session = onnxruntime.InferenceSession(model_path) + + im = Image.open(path) + + if im.mode != "RGB": + im = im.convert("RGB") + imr = im.resize((256, 256), resample=Image.BILINEAR) + fh_im = io.BytesIO() + imr.save(fh_im, format="JPEG") + fh_im.seek(0) + + image = skimage.img_as_float32(skimage.io.imread(fh_im)) + final = prepare_image(image) + + input_feed = {session.get_inputs()[0].name: final} + outputs = [output.name for output in session.get_outputs()] + result = session.run(outputs, input_feed) + + return result[0][0][1] + + +async def init_model(): + file_name = "nsfw.onnx" + path = SETU_PATH / file_name + if not path.is_file(): + log.warning("插件 setu 缺少资源, 装载中...") + try: + data = await request.get(MODEL_URL) + with open(path, "wb") as w: + w.write(data.read()) + except Exception: + log.error("插件 setu 装载资源失败, 命令 '/nsfw' 将失效") + + log.success("插件 setu 装载资源完成") diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py deleted file mode 100644 index 4f4b374..0000000 --- a/ATRI/plugins/setu/tf_dealer.py +++ /dev/null @@ -1,109 +0,0 @@ -import io -import re -import skimage -import skimage.io -import numpy as np -from PIL import Image -from pathlib import Path -from sys import getsizeof - -import tensorflow as tf - -from ATRI.log import logger as log -from ATRI.utils import request -from ATRI.exceptions import RequestError, WriteFileError - - -SETU_PATH = Path(".") / "data" / "plugins" / "setu" -TEMP_PATH = Path(".") / "data" / "temp" -SETU_PATH.mkdir(parents=True, exist_ok=True) -TEMP_PATH.mkdir(parents=True, exist_ok=True) - - -MODULE_URL = "https://jsd.imki.moe/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite" -VGG_MEAN = [104, 117, 123] - - -def prepare_image(img): - H, W, _ = img.shape - h, w = (224, 224) - - h_off = max((H - h) // 2, 0) - w_off = max((W - w) // 2, 0) - image = img[h_off : h_off + h, w_off : w_off + w, :] - - image = image[:, :, ::-1] - - image = image.astype(np.float32, copy=False) - image = image * 255.0 - image = image - np.array(VGG_MEAN, dtype=np.float32) - - image = np.expand_dims(image, axis=0) - return image - - -async def detect_image(url: str, file_size: int) -> list: - try: - req = await request.get(url) - except Exception: - raise RequestError("Get info from download image failed!") - - img_byte = getsizeof(req.read()) // 1024 - if img_byte < file_size: - return [0, 0] - - try: - pattern = r"-(.*?)\/" - file_name = re.findall(pattern, url)[0] - path = TEMP_PATH / f"{file_name}.jpg" - with open(path, "wb") as f: - f.write(req.read()) - except Exception: - raise WriteFileError("Writing file failed!") - - await init_module() - model_path = str((SETU_PATH / "nsfw.tflite").absolute()) - - try: - interpreter = tf.Interpreter(model_path=model_path) # type: ignore - except Exception: - interpreter = tf.lite.Interpreter(model_path=model_path) - - interpreter.allocate_tensors() - - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - im = Image.open(path) - - if im.mode != "RGB": - im = im.convert("RGB") - imr = im.resize((256, 256), resample=Image.BILINEAR) - fh_im = io.BytesIO() - imr.save(fh_im, format="JPEG") - fh_im.seek(0) - - image = skimage.img_as_float32(skimage.io.imread(fh_im)) - - final = prepare_image(image) - interpreter.set_tensor(input_details[0]["index"], final) - - interpreter.invoke() - output_data = interpreter.get_tensor(output_details[0]["index"]) - - result = np.squeeze(output_data).tolist() - return result - - -async def init_module(): - file_name = "nsfw.tflite" - path = SETU_PATH / file_name - if not path.is_file(): - log.warning("缺少模型文件,装载中") - try: - data = await request.get(MODULE_URL) - with open(path, "wb") as w: - w.write(data.read()) - log.info("模型装载完成") - except Exception: - log.error("装载模型失败") -- cgit v1.2.3