summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKyomotoi <[email protected]>2022-06-18 00:24:53 +0800
committerKyomotoi <[email protected]>2022-06-18 00:24:53 +0800
commit6c02470858bf8e544d8655c039e741161c36ca54 (patch)
treedee0a9c8237c90241bc0cf686b93cea808f4980d
parentc1576c50f5acf1dd7f5d6730eae9d3d251fed954 (diff)
downloadATRI-6c02470858bf8e544d8655c039e741161c36ca54.tar.gz
ATRI-6c02470858bf8e544d8655c039e741161c36ca54.tar.bz2
ATRI-6c02470858bf8e544d8655c039e741161c36ca54.zip
🐛 解决无法正常终止程序
-rw-r--r--ATRI/plugins/setu/data_source.py7
-rw-r--r--ATRI/plugins/setu/tf_dealer.py9
2 files changed, 8 insertions, 8 deletions
diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py
index e7d769d..61b1cf8 100644
--- a/ATRI/plugins/setu/data_source.py
+++ b/ATRI/plugins/setu/data_source.py
@@ -5,7 +5,7 @@ from ATRI.service import Service
from ATRI.rule import is_in_service
from ATRI.utils import request
from ATRI.config import Setu as ST
-from .tf_dealer import detect_image
+from .tf_dealer import detect_image, init_module
LOLICON_URL = "https://api.lolicon.app/setu/v2"
@@ -81,3 +81,8 @@ class Setu(Service):
async def async_recall(bot: Bot, event_id):
await asyncio.sleep(30)
await bot.delete_msg(message_id=event_id)
+
+
+from ATRI import driver
+
+driver().on_startup(init_module)
diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py
index 4e1eafa..466ab1f 100644
--- a/ATRI/plugins/setu/tf_dealer.py
+++ b/ATRI/plugins/setu/tf_dealer.py
@@ -1,6 +1,5 @@
import io
import re
-import asyncio
import skimage
import skimage.io
import numpy as np
@@ -48,7 +47,7 @@ def prepare_image(img):
async def detect_image(url: str, file_size: int) -> list:
try:
req = await request.get(url)
- except RequestError:
+ except Exception:
raise RequestError("Get info from download image failed!")
img_byte = getsizeof(req.read()) // 1024
@@ -61,7 +60,7 @@ async def detect_image(url: str, file_size: int) -> list:
path = TEMP_PATH / f"{file_name}.jpg"
with open(path, "wb") as f:
f.write(req.read())
- except WriteFileError:
+ except Exception:
raise WriteFileError("Writing file failed!")
await init_module()
@@ -110,7 +109,3 @@ async def init_module():
log.info("模型装载完成")
except WriteFileError:
raise WriteFileError("NSFW TF module init failed!")
-
-
-loop = asyncio.get_event_loop()
-loop.create_task(init_module())