From 6c02470858bf8e544d8655c039e741161c36ca54 Mon Sep 17 00:00:00 2001 From: Kyomotoi Date: Sat, 18 Jun 2022 00:24:53 +0800 Subject: =?UTF-8?q?=F0=9F=90=9B=20=E8=A7=A3=E5=86=B3=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E6=AD=A3=E5=B8=B8=E7=BB=88=E6=AD=A2=E7=A8=8B=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ATRI/plugins/setu/data_source.py | 7 ++++++- ATRI/plugins/setu/tf_dealer.py | 9 ++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py index e7d769d..61b1cf8 100644 --- a/ATRI/plugins/setu/data_source.py +++ b/ATRI/plugins/setu/data_source.py @@ -5,7 +5,7 @@ from ATRI.service import Service from ATRI.rule import is_in_service from ATRI.utils import request from ATRI.config import Setu as ST -from .tf_dealer import detect_image +from .tf_dealer import detect_image, init_module LOLICON_URL = "https://api.lolicon.app/setu/v2" @@ -81,3 +81,8 @@ class Setu(Service): async def async_recall(bot: Bot, event_id): await asyncio.sleep(30) await bot.delete_msg(message_id=event_id) + + +from ATRI import driver + +driver().on_startup(init_module) diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py index 4e1eafa..466ab1f 100644 --- a/ATRI/plugins/setu/tf_dealer.py +++ b/ATRI/plugins/setu/tf_dealer.py @@ -1,6 +1,5 @@ import io import re -import asyncio import skimage import skimage.io import numpy as np @@ -48,7 +47,7 @@ def prepare_image(img): async def detect_image(url: str, file_size: int) -> list: try: req = await request.get(url) - except RequestError: + except Exception: raise RequestError("Get info from download image failed!") img_byte = getsizeof(req.read()) // 1024 @@ -61,7 +60,7 @@ async def detect_image(url: str, file_size: int) -> list: path = TEMP_PATH / f"{file_name}.jpg" with open(path, "wb") as f: f.write(req.read()) - except WriteFileError: + except Exception: raise WriteFileError("Writing file failed!") await init_module() @@ -110,7 +109,3 @@ async def init_module(): log.info("模型装载完成") except WriteFileError: raise WriteFileError("NSFW TF module init failed!") - - -loop = asyncio.get_event_loop() -loop.create_task(init_module()) -- cgit v1.2.3