diff options
| -rw-r--r-- | ATRI/plugins/setu/data_source.py | 7 | ||||
| -rw-r--r-- | ATRI/plugins/setu/tf_dealer.py | 9 | 
2 files changed, 8 insertions, 8 deletions
| diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py index e7d769d..61b1cf8 100644 --- a/ATRI/plugins/setu/data_source.py +++ b/ATRI/plugins/setu/data_source.py @@ -5,7 +5,7 @@ from ATRI.service import Service  from ATRI.rule import is_in_service  from ATRI.utils import request  from ATRI.config import Setu as ST -from .tf_dealer import detect_image +from .tf_dealer import detect_image, init_module  LOLICON_URL = "https://api.lolicon.app/setu/v2" @@ -81,3 +81,8 @@ class Setu(Service):      async def async_recall(bot: Bot, event_id):          await asyncio.sleep(30)          await bot.delete_msg(message_id=event_id) + + +from ATRI import driver + +driver().on_startup(init_module) diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py index 4e1eafa..466ab1f 100644 --- a/ATRI/plugins/setu/tf_dealer.py +++ b/ATRI/plugins/setu/tf_dealer.py @@ -1,6 +1,5 @@  import io  import re -import asyncio  import skimage  import skimage.io  import numpy as np @@ -48,7 +47,7 @@ def prepare_image(img):  async def detect_image(url: str, file_size: int) -> list:      try:          req = await request.get(url) -    except RequestError: +    except Exception:          raise RequestError("Get info from download image failed!")      img_byte = getsizeof(req.read()) // 1024 @@ -61,7 +60,7 @@ async def detect_image(url: str, file_size: int) -> list:          path = TEMP_PATH / f"{file_name}.jpg"          with open(path, "wb") as f:              f.write(req.read()) -    except WriteFileError: +    except Exception:          raise WriteFileError("Writing file failed!")      await init_module() @@ -110,7 +109,3 @@ async def init_module():              log.info("模型装载完成")          except WriteFileError:              raise WriteFileError("NSFW TF module init failed!") - - -loop = asyncio.get_event_loop() -loop.create_task(init_module()) | 
