summaryrefslogtreecommitdiff
path: root/ATRI/plugins/setu/tf_dealer.py
blob: 58b43371b1f0d9a567d8da7a6dbb4909daf48149 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import io
import os
import re
import asyncio
import skimage
import skimage.io
import numpy as np
from PIL import Image
from pathlib import Path
from sys import getsizeof

import tensorflow as tf

from ATRI.log import logger as log
from ATRI.utils import request
from ATRI.exceptions import RequestError, WriteFileError


SETU_PATH = Path(".") / "data" / "database" / "setu"
TEMP_PATH = Path(".") / "data" / "temp"
os.makedirs(SETU_PATH, exist_ok=True)
os.makedirs(TEMP_PATH, exist_ok=True)


MODULE_URL = "https://cdn.jsdelivr.net/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite"
VGG_MEAN = [104, 117, 123]


def prepare_image(img):
    H, W, _ = img.shape
    h, w = (224, 224)

    h_off = max((H - h) // 2, 0)
    w_off = max((W - w) // 2, 0)
    image = img[h_off : h_off + h, w_off : w_off + w, :]

    image = image[:, :, ::-1]

    image = image.astype(np.float32, copy=False)
    image = image * 255.0
    image = image - np.array(VGG_MEAN, dtype=np.float32)

    image = np.expand_dims(image, axis=0)
    return image


async def detect_image(url: str, file_size: int) -> list:
    try:
        req = await request.get(url)
    except RequestError:
        raise RequestError("Get info from download image failed!")

    img_byte = getsizeof(req.read()) // 1024
    if img_byte < file_size:
        return [0, 0]

    try:
        pattern = r"-(.*?)\/"
        file_name = re.findall(pattern, url)[0]
        path = TEMP_PATH / f"{file_name}.jpg"
        with open(path, "wb") as f:
            f.write(req.read())
    except WriteFileError:
        raise WriteFileError("Writing file failed!")

    await init_module()
    model_path = str((SETU_PATH / "nsfw.tflite").absolute())

    try:
        interpreter = tf.Interpreter(model_path=model_path)  # type: ignore
    except Exception:
        interpreter = tf.lite.Interpreter(model_path=model_path)

    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    im = Image.open(path)

    if im.mode != "RGB":
        im = im.convert("RGB")
    imr = im.resize((256, 256), resample=Image.BILINEAR)
    fh_im = io.BytesIO()
    imr.save(fh_im, format="JPEG")
    fh_im.seek(0)

    image = skimage.img_as_float32(skimage.io.imread(fh_im))

    final = prepare_image(image)
    interpreter.set_tensor(input_details[0]["index"], final)

    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]["index"])

    result = np.squeeze(output_data).tolist()
    return result


async def init_module():
    file_name = "nsfw.tflite"
    path = SETU_PATH / file_name
    if not path.is_file():
        log.warning("缺少模型文件,装载中")
        data = await request.get(MODULE_URL)
        try:
            with open(path, "wb") as w:
                w.write(data.read())
            log.info("模型装载完成")
        except WriteFileError:
            raise WriteFileError("NSFW TF module init failed!")


loop = asyncio.get_event_loop()
loop.create_task(init_module())