summaryrefslogtreecommitdiff
path: root/ATRI/plugins/setu/nsfw_checker.py
blob: d7ad6e1fdc6ce54f962c27e98d3133810f99e147 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import io
import re
import skimage
import onnxruntime
import numpy as np
from PIL import Image
from pathlib import Path
from sys import getsizeof

from ATRI.log import log
from ATRI.utils import request
from ATRI.exceptions import RequestError, WriteFileError


SETU_PATH = Path(".") / "data" / "plugins" / "setu"
TEMP_PATH = Path(".") / "data" / "temp"
SETU_PATH.mkdir(parents=True, exist_ok=True)
TEMP_PATH.mkdir(parents=True, exist_ok=True)


MODEL_URL = "https://res.imki.moe/nsfw.onnx"
VGG_MEAN = [104, 117, 123]


def prepare_image(img):
    H, W, _ = img.shape
    h, w = (224, 224)

    h_off = max((H - h) // 2, 0)
    w_off = max((W - w) // 2, 0)
    image = img[h_off : h_off + h, w_off : w_off + w, :]

    image = image[:, :, ::-1]

    image = image.astype(np.float32, copy=False)
    image = image * 255.0
    image = image - np.array(VGG_MEAN, dtype=np.float32)

    image = np.expand_dims(image, axis=0)
    return image


async def detect_image(url: str, max_size: int, disab_gif: bool) -> float:
    try:
        req = await request.get(url)
        if itype := req.headers.get("Content-Type"):
            if disab_gif and itype == "image/gif":
                return 0

    except Exception:
        raise RequestError("Get info from download image failed!")

    img_byte = getsizeof(req.read()) // 1024
    if img_byte < max_size:
        return 0

    try:
        pattern = r"-(.*?)\/"
        file_name = re.findall(pattern, url)[0]
        path = TEMP_PATH / f"{file_name}.jpg"
        with open(path, "wb") as f:
            f.write(req.read())
    except Exception:
        raise WriteFileError("Writing file failed!")

    model_path = str(SETU_PATH / "nsfw.onnx")
    session = onnxruntime.InferenceSession(model_path)

    im = Image.open(path)

    if im.mode != "RGB":
        im = im.convert("RGB")
    imr = im.resize((256, 256), resample=Image.BILINEAR)
    fh_im = io.BytesIO()
    imr.save(fh_im, format="JPEG")
    fh_im.seek(0)

    image = skimage.img_as_float32(skimage.io.imread(fh_im))
    final = prepare_image(image)

    input_feed = {session.get_inputs()[0].name: final}
    outputs = [output.name for output in session.get_outputs()]
    result = session.run(outputs, input_feed)

    return result[0][0][1]


async def init_model():
    file_name = "nsfw.onnx"
    path = SETU_PATH / file_name
    if not path.is_file():
        log.warning("插件 setu 缺少资源, 装载中...")
        try:
            data = await request.get(MODEL_URL)
            with open(path, "wb") as w:
                w.write(data.read())
        except Exception:
            log.error("插件 setu 装载资源失败, 命令 '/nsfw' 将失效")

    log.success("插件 setu 装载资源完成")