1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
import io
import re
import skimage
import onnxruntime
import numpy as np
from PIL import Image
from pathlib import Path
from sys import getsizeof
from ATRI.log import log
from ATRI.utils import request
from ATRI.exceptions import RequestError, WriteFileError
SETU_PATH = Path(".") / "data" / "plugins" / "setu"
TEMP_PATH = Path(".") / "data" / "temp"
SETU_PATH.mkdir(parents=True, exist_ok=True)
TEMP_PATH.mkdir(parents=True, exist_ok=True)
MODEL_URL = "https://res.imki.moe/nsfw.onnx"
VGG_MEAN = [104, 117, 123]
def prepare_image(img):
H, W, _ = img.shape
h, w = (224, 224)
h_off = max((H - h) // 2, 0)
w_off = max((W - w) // 2, 0)
image = img[h_off : h_off + h, w_off : w_off + w, :]
image = image[:, :, ::-1]
image = image.astype(np.float32, copy=False)
image = image * 255.0
image = image - np.array(VGG_MEAN, dtype=np.float32)
image = np.expand_dims(image, axis=0)
return image
async def detect_image(url: str, max_size: int, disab_gif: bool) -> float:
try:
req = await request.get(url)
if itype := req.headers.get("Content-Type"):
if disab_gif and itype == "image/gif":
return 0
except Exception:
raise RequestError("Get info from download image failed!")
img_byte = getsizeof(req.read()) // 1024
if img_byte < max_size:
return 0
try:
pattern = r"-(.*?)\/"
file_name = re.findall(pattern, url)[0]
path = TEMP_PATH / f"{file_name}.jpg"
with open(path, "wb") as f:
f.write(req.read())
except Exception:
raise WriteFileError("Writing file failed!")
model_path = str(SETU_PATH / "nsfw.onnx")
session = onnxruntime.InferenceSession(model_path)
im = Image.open(path)
if im.mode != "RGB":
im = im.convert("RGB")
imr = im.resize((256, 256), resample=Image.BILINEAR)
fh_im = io.BytesIO()
imr.save(fh_im, format="JPEG")
fh_im.seek(0)
image = skimage.img_as_float32(skimage.io.imread(fh_im))
final = prepare_image(image)
input_feed = {session.get_inputs()[0].name: final}
outputs = [output.name for output in session.get_outputs()]
result = session.run(outputs, input_feed)
return result[0][0][1]
async def init_model():
file_name = "nsfw.onnx"
path = SETU_PATH / file_name
if not path.is_file():
log.warning("插件 setu 缺少资源, 装载中...")
try:
data = await request.get(MODEL_URL)
with open(path, "wb") as w:
w.write(data.read())
except Exception:
log.error("插件 setu 装载资源失败, 命令 '/nsfw' 将失效")
log.success("插件 setu 装载资源完成")
|