summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ATRI/plugins/setu/__init__.py7
-rw-r--r--ATRI/plugins/setu/data_source.py6
-rw-r--r--ATRI/plugins/setu/nsfw_checker.py (renamed from ATRI/plugins/setu/tf_dealer.py)47
3 files changed, 23 insertions, 37 deletions
diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py
index 71243eb..7962af8 100644
--- a/ATRI/plugins/setu/__init__.py
+++ b/ATRI/plugins/setu/__init__.py
@@ -99,8 +99,8 @@ async def _setu_catcher(bot: Bot, event: MessageEvent):
data = await Setu().detecter(i, _catcher_max_file_size)
except Exception:
return
- if data[1] > 0.7:
- hso.append(data[1])
+ if data > 0.7:
+ hso.append(data)
hso.sort(reverse=True)
@@ -135,8 +135,7 @@ async def _deal_check(bot: Bot, event: MessageEvent):
if not args:
await nsfw_checker.reject("请发送图片而不是其他东西!!")
- data = await Setu().detecter(args[0], _catcher_max_file_size)
- hso = data[1]
+ hso = await Setu().detecter(args[0], _catcher_max_file_size)
if not hso:
await nsfw_checker.finish("图太小了!不测!")
diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py
index 61b1cf8..71649fe 100644
--- a/ATRI/plugins/setu/data_source.py
+++ b/ATRI/plugins/setu/data_source.py
@@ -5,7 +5,7 @@ from ATRI.service import Service
from ATRI.rule import is_in_service
from ATRI.utils import request
from ATRI.config import Setu as ST
-from .tf_dealer import detect_image, init_module
+from .nsfw_checker import detect_image, init_model
LOLICON_URL = "https://api.lolicon.app/setu/v2"
@@ -70,7 +70,7 @@ class Setu(Service):
return repo, setu
@staticmethod
- async def detecter(url: str, file_size: int) -> list:
+ async def detecter(url: str, file_size: int) -> float:
"""
涩值检测.
"""
@@ -85,4 +85,4 @@ class Setu(Service):
from ATRI import driver
-driver().on_startup(init_module)
+driver().on_startup(init_model)
diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/nsfw_checker.py
index 4f4b374..2f045c1 100644
--- a/ATRI/plugins/setu/tf_dealer.py
+++ b/ATRI/plugins/setu/nsfw_checker.py
@@ -1,14 +1,12 @@
import io
import re
import skimage
-import skimage.io
+import onnxruntime
import numpy as np
from PIL import Image
from pathlib import Path
from sys import getsizeof
-import tensorflow as tf
-
from ATRI.log import logger as log
from ATRI.utils import request
from ATRI.exceptions import RequestError, WriteFileError
@@ -20,7 +18,7 @@ SETU_PATH.mkdir(parents=True, exist_ok=True)
TEMP_PATH.mkdir(parents=True, exist_ok=True)
-MODULE_URL = "https://jsd.imki.moe/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite"
+MODEL_URL = "https://res.imki.moe/nsfw.onnx"
VGG_MEAN = [104, 117, 123]
@@ -42,7 +40,7 @@ def prepare_image(img):
return image
-async def detect_image(url: str, file_size: int) -> list:
+async def detect_image(url: str, file_size: int) -> float:
try:
req = await request.get(url)
except Exception:
@@ -50,7 +48,7 @@ async def detect_image(url: str, file_size: int) -> list:
img_byte = getsizeof(req.read()) // 1024
if img_byte < file_size:
- return [0, 0]
+ return 0
try:
pattern = r"-(.*?)\/"
@@ -61,18 +59,8 @@ async def detect_image(url: str, file_size: int) -> list:
except Exception:
raise WriteFileError("Writing file failed!")
- await init_module()
- model_path = str((SETU_PATH / "nsfw.tflite").absolute())
-
- try:
- interpreter = tf.Interpreter(model_path=model_path) # type: ignore
- except Exception:
- interpreter = tf.lite.Interpreter(model_path=model_path)
-
- interpreter.allocate_tensors()
-
- input_details = interpreter.get_input_details()
- output_details = interpreter.get_output_details()
+ model_path = str(SETU_PATH / "nsfw.onnx")
+ session = onnxruntime.InferenceSession(model_path)
im = Image.open(path)
@@ -84,26 +72,25 @@ async def detect_image(url: str, file_size: int) -> list:
fh_im.seek(0)
image = skimage.img_as_float32(skimage.io.imread(fh_im))
-
final = prepare_image(image)
- interpreter.set_tensor(input_details[0]["index"], final)
- interpreter.invoke()
- output_data = interpreter.get_tensor(output_details[0]["index"])
+ input_feed = {session.get_inputs()[0].name: final}
+ outputs = [output.name for output in session.get_outputs()]
+ result = session.run(outputs, input_feed)
- result = np.squeeze(output_data).tolist()
- return result
+ return result[0][0][1]
-async def init_module():
- file_name = "nsfw.tflite"
+async def init_model():
+ file_name = "nsfw.onnx"
path = SETU_PATH / file_name
if not path.is_file():
- log.warning("缺少模型文件,装载中")
+ log.warning("插件 setu 缺少资源, 装载中...")
try:
- data = await request.get(MODULE_URL)
+ data = await request.get(MODEL_URL)
with open(path, "wb") as w:
w.write(data.read())
- log.info("模型装载完成")
except Exception:
- log.error("装载模型失败")
+ log.error("插件 setu 装载资源失败, 命令 '/nsfw' 将失效")
+
+ log.success("插件 setu 装载资源完成")