diff options
author | Kyomotoi <0w0@imki.moe> | 2022-09-18 15:25:50 +0800 |
---|---|---|
committer | Kyomotoi <0w0@imki.moe> | 2022-09-18 15:25:50 +0800 |
commit | 4006dce2581e7d6597958180021974a6de01294e (patch) | |
tree | dbb1ad82008dd456b072c4cda93e9cf044490145 /ATRI/plugins/setu/tf_dealer.py | |
parent | fdb70b4d9ec99e1d2f372a2d9c11e9da344d1b6d (diff) | |
download | ATRI-4006dce2581e7d6597958180021974a6de01294e.tar.gz ATRI-4006dce2581e7d6597958180021974a6de01294e.tar.bz2 ATRI-4006dce2581e7d6597958180021974a6de01294e.zip |
⚡️ 更换涩图检测模型及其 runtime
Diffstat (limited to 'ATRI/plugins/setu/tf_dealer.py')
-rw-r--r-- | ATRI/plugins/setu/tf_dealer.py | 109 |
1 files changed, 0 insertions, 109 deletions
diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py deleted file mode 100644 index 4f4b374..0000000 --- a/ATRI/plugins/setu/tf_dealer.py +++ /dev/null @@ -1,109 +0,0 @@ -import io -import re -import skimage -import skimage.io -import numpy as np -from PIL import Image -from pathlib import Path -from sys import getsizeof - -import tensorflow as tf - -from ATRI.log import logger as log -from ATRI.utils import request -from ATRI.exceptions import RequestError, WriteFileError - - -SETU_PATH = Path(".") / "data" / "plugins" / "setu" -TEMP_PATH = Path(".") / "data" / "temp" -SETU_PATH.mkdir(parents=True, exist_ok=True) -TEMP_PATH.mkdir(parents=True, exist_ok=True) - - -MODULE_URL = "https://jsd.imki.moe/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite" -VGG_MEAN = [104, 117, 123] - - -def prepare_image(img): - H, W, _ = img.shape - h, w = (224, 224) - - h_off = max((H - h) // 2, 0) - w_off = max((W - w) // 2, 0) - image = img[h_off : h_off + h, w_off : w_off + w, :] - - image = image[:, :, ::-1] - - image = image.astype(np.float32, copy=False) - image = image * 255.0 - image = image - np.array(VGG_MEAN, dtype=np.float32) - - image = np.expand_dims(image, axis=0) - return image - - -async def detect_image(url: str, file_size: int) -> list: - try: - req = await request.get(url) - except Exception: - raise RequestError("Get info from download image failed!") - - img_byte = getsizeof(req.read()) // 1024 - if img_byte < file_size: - return [0, 0] - - try: - pattern = r"-(.*?)\/" - file_name = re.findall(pattern, url)[0] - path = TEMP_PATH / f"{file_name}.jpg" - with open(path, "wb") as f: - f.write(req.read()) - except Exception: - raise WriteFileError("Writing file failed!") - - await init_module() - model_path = str((SETU_PATH / "nsfw.tflite").absolute()) - - try: - interpreter = tf.Interpreter(model_path=model_path) # type: ignore - except Exception: - interpreter = tf.lite.Interpreter(model_path=model_path) - - interpreter.allocate_tensors() - - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - im = Image.open(path) - - if im.mode != "RGB": - im = im.convert("RGB") - imr = im.resize((256, 256), resample=Image.BILINEAR) - fh_im = io.BytesIO() - imr.save(fh_im, format="JPEG") - fh_im.seek(0) - - image = skimage.img_as_float32(skimage.io.imread(fh_im)) - - final = prepare_image(image) - interpreter.set_tensor(input_details[0]["index"], final) - - interpreter.invoke() - output_data = interpreter.get_tensor(output_details[0]["index"]) - - result = np.squeeze(output_data).tolist() - return result - - -async def init_module(): - file_name = "nsfw.tflite" - path = SETU_PATH / file_name - if not path.is_file(): - log.warning("缺少模型文件,装载中") - try: - data = await request.get(MODULE_URL) - with open(path, "wb") as w: - w.write(data.read()) - log.info("模型装载完成") - except Exception: - log.error("装载模型失败") |