summaryrefslogtreecommitdiff
path: root/ATRI/plugins/setu/tf_dealer.py
diff options
context:
space:
mode:
Diffstat (limited to 'ATRI/plugins/setu/tf_dealer.py')
-rw-r--r--ATRI/plugins/setu/tf_dealer.py119
1 files changed, 119 insertions, 0 deletions
diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py
new file mode 100644
index 0000000..f41de49
--- /dev/null
+++ b/ATRI/plugins/setu/tf_dealer.py
@@ -0,0 +1,119 @@
+import io
+import os
+import re
+import string
+import asyncio
+import skimage
+import skimage.io
+import numpy as np
+from PIL import Image
+from pathlib import Path
+from sys import getsizeof
+from random import sample
+
+try:
+ import tflite_runtime.interpreter as tf # type: ignore
+except Exception:
+ import tensorflow as tf
+
+from ATRI.log import logger as log
+from ATRI.utils import request
+from ATRI.exceptions import RequestError, WriteError
+
+
+SETU_PATH = Path(".") / "data" / "database" / "setu"
+TEMP_PATH = Path(".") / "data" / "temp"
+os.makedirs(SETU_PATH, exist_ok=True)
+os.makedirs(TEMP_PATH, exist_ok=True)
+
+
+MODULE_URL = "https://cdn.jsdelivr.net/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite"
+VGG_MEAN = [104, 117, 123]
+
+
+def prepare_image(img):
+ H, W, _ = img.shape
+ h, w = (224, 224)
+
+ h_off = max((H - h) // 2, 0)
+ w_off = max((W - w) // 2, 0)
+ image = img[h_off : h_off + h, w_off : w_off + w, :]
+
+ image = image[:, :, ::-1]
+
+ image = image.astype(np.float32, copy=False)
+ image = image * 255.0
+ image = image - np.array(VGG_MEAN, dtype=np.float32)
+
+ image = np.expand_dims(image, axis=0)
+ return image
+
+
+async def detect_image(url) -> list:
+ try:
+ req = await request.get(url)
+ except RequestError:
+ raise RequestError("Get info from download image failed!")
+
+ img_byte = getsizeof(req.read()) // 1024
+ if img_byte < 256:
+ return [0, 0]
+
+ try:
+ pattern = r"-(.*?)\/"
+ file_name = re.findall(pattern, url)[0]
+ path = TEMP_PATH / f"{file_name}.jpg"
+ with open(path, "wb") as f:
+ f.write(req.read())
+ except WriteError:
+ raise WriteError("Writing file failed!")
+
+ await init_module()
+ model_path = str((SETU_PATH / "nsfw.tflite").absolute())
+
+ try:
+ interpreter = tf.Interpreter(model_path=model_path) # type: ignore
+ except Exception:
+ interpreter = tf.lite.Interpreter(model_path=model_path)
+
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ im = Image.open(path)
+
+ if im.mode != "RGB":
+ im = im.convert("RGB")
+ imr = im.resize((256, 256), resample=Image.BILINEAR)
+ fh_im = io.BytesIO()
+ imr.save(fh_im, format="JPEG")
+ fh_im.seek(0)
+
+ image = skimage.img_as_float32(skimage.io.imread(fh_im))
+
+ final = prepare_image(image)
+ interpreter.set_tensor(input_details[0]["index"], final)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]["index"])
+
+ result = np.squeeze(output_data).tolist()
+ return result
+
+
+async def init_module():
+ file_name = "nsfw.tflite"
+ path = SETU_PATH / file_name
+ if not path.is_file():
+ log.warning("缺少模型文件,装载中")
+ data = await request.get(MODULE_URL)
+ try:
+ with open(path, "wb") as w:
+ w.write(data.read())
+ log.info("模型装载完成")
+ except WriteError:
+ raise WriteError("NSFW TF module init failed!")
+
+
+asyncio.get_event_loop().run_until_complete(init_module())