import io import re import skimage import skimage.io import numpy as np from PIL import Image from pathlib import Path from sys import getsizeof import tensorflow as tf from ATRI.log import logger as log from ATRI.utils import request from ATRI.exceptions import RequestError, WriteFileError SETU_PATH = Path(".") / "data" / "database" / "setu" TEMP_PATH = Path(".") / "data" / "temp" SETU_PATH.mkdir(parents=True, exist_ok=True) TEMP_PATH.mkdir(parents=True, exist_ok=True) MODULE_URL = ( "https://fastly.jsdelivr.net/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite" ) VGG_MEAN = [104, 117, 123] def prepare_image(img): H, W, _ = img.shape h, w = (224, 224) h_off = max((H - h) // 2, 0) w_off = max((W - w) // 2, 0) image = img[h_off : h_off + h, w_off : w_off + w, :] image = image[:, :, ::-1] image = image.astype(np.float32, copy=False) image = image * 255.0 image = image - np.array(VGG_MEAN, dtype=np.float32) image = np.expand_dims(image, axis=0) return image async def detect_image(url: str, file_size: int) -> list: try: req = await request.get(url) except Exception: raise RequestError("Get info from download image failed!") img_byte = getsizeof(req.read()) // 1024 if img_byte < file_size: return [0, 0] try: pattern = r"-(.*?)\/" file_name = re.findall(pattern, url)[0] path = TEMP_PATH / f"{file_name}.jpg" with open(path, "wb") as f: f.write(req.read()) except Exception: raise WriteFileError("Writing file failed!") await init_module() model_path = str((SETU_PATH / "nsfw.tflite").absolute()) try: interpreter = tf.Interpreter(model_path=model_path) # type: ignore except Exception: interpreter = tf.lite.Interpreter(model_path=model_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() im = Image.open(path) if im.mode != "RGB": im = im.convert("RGB") imr = im.resize((256, 256), resample=Image.BILINEAR) fh_im = io.BytesIO() imr.save(fh_im, format="JPEG") fh_im.seek(0) image = skimage.img_as_float32(skimage.io.imread(fh_im)) final = prepare_image(image) interpreter.set_tensor(input_details[0]["index"], final) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]["index"]) result = np.squeeze(output_data).tolist() return result async def init_module(): file_name = "nsfw.tflite" path = SETU_PATH / file_name if not path.is_file(): log.warning("缺少模型文件,装载中") data = await request.get(MODULE_URL) try: with open(path, "wb") as w: w.write(data.read()) log.info("模型装载完成") except WriteFileError: raise WriteFileError("NSFW TF module init failed!")