summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKyomotoi <[email protected]>2022-06-18 00:19:13 +0800
committerKyomotoi <[email protected]>2022-06-18 00:19:13 +0800
commit48ccaabf66828594051de919c9b52098debf95a3 (patch)
treeb93a614b37799adc852976fd23d1902017b81b53
parente842f5cc40ad50fb32f671cba405ba23ad34004f (diff)
downloadATRI-48ccaabf66828594051de919c9b52098debf95a3.tar.gz
ATRI-48ccaabf66828594051de919c9b52098debf95a3.tar.bz2
ATRI-48ccaabf66828594051de919c9b52098debf95a3.zip
🚚💩 移动数据库相关函数, 优化代码
-rw-r--r--ATRI/__init__.py2
-rw-r--r--ATRI/database/__init__.py2
-rw-r--r--ATRI/database/db.py12
-rw-r--r--ATRI/plugins/essential.py17
4 files changed, 21 insertions, 12 deletions
diff --git a/ATRI/__init__.py b/ATRI/__init__.py
index fd166fe..7f17778 100644
--- a/ATRI/__init__.py
+++ b/ATRI/__init__.py
@@ -4,7 +4,6 @@ import nonebot
from nonebot.adapters.onebot.v11 import Adapter
from .config import RUNTIME_CONFIG, InlineGoCQHTTP
-from .database import init_database
__version__ = "YHN-001-A05.fix1"
@@ -23,7 +22,6 @@ def init():
nonebot.load_plugins("ATRI/plugins")
if InlineGoCQHTTP.enabled:
nonebot.load_plugin("nonebot_plugin_gocqhttp")
- init_database()
sleep(3)
diff --git a/ATRI/database/__init__.py b/ATRI/database/__init__.py
index 6c58d20..497bea4 100644
--- a/ATRI/database/__init__.py
+++ b/ATRI/database/__init__.py
@@ -1,2 +1,2 @@
-from .db import init_database
+from .db import init_database, close_database_connection
from .models import BilibiliSubscription, TwitterSubscription
diff --git a/ATRI/database/db.py b/ATRI/database/db.py
index 6e20ad3..c8782de 100644
--- a/ATRI/database/db.py
+++ b/ATRI/database/db.py
@@ -1,5 +1,5 @@
from pathlib import Path
-from tortoise import Tortoise, run_async
+from tortoise import Tortoise
from ATRI.log import logger as log
@@ -21,7 +21,13 @@ async def run():
await Tortoise.generate_schemas()
-def init_database():
+async def init_database():
log.info("正在初始化数据库...")
- run_async(run())
+ await run()
log.success("数据库初始化完成")
+
+
+async def close_database_connection():
+ log.info("正在关闭数据库连接...")
+ await Tortoise.close_connections()
+ log.info("数据库成功关闭")
diff --git a/ATRI/plugins/essential.py b/ATRI/plugins/essential.py
index 763b137..b0c1d02 100644
--- a/ATRI/plugins/essential.py
+++ b/ATRI/plugins/essential.py
@@ -31,6 +31,7 @@ import ATRI
from ATRI.service import Service
from ATRI.log import logger as log
from ATRI.config import BotSelfConfig
+from ATRI.database import init_database, close_database_connection
from ATRI.utils import MessageChecker
from ATRI.utils.apscheduler import scheduler
from ATRI.utils.check_update import CheckUpdate
@@ -77,6 +78,10 @@ async def shutdown():
log.info("Thanks for using.")
+driver.on_startup(init_database)
+driver.on_shutdown(close_database_connection)
+
+
@run_preprocessor
async def _check_block(event: MessageEvent):
user_file = "block_user.json"
@@ -87,7 +92,7 @@ async def _check_block(event: MessageEvent):
try:
data = json.loads(path.read_bytes())
- except BaseException:
+ except Exception:
data = dict()
user_id = event.get_user_id()
@@ -103,7 +108,7 @@ async def _check_block(event: MessageEvent):
try:
data = json.loads(path.read_bytes())
- except BaseException:
+ except Exception:
data = dict()
group_id = str(event.group_id)
@@ -290,7 +295,7 @@ async def _recall_group_event(bot: Bot, event: GroupRecallNoticeEvent):
try:
repo = await bot.get_msg(message_id=event.message_id)
- except BaseException:
+ except Exception:
return
log.debug(f"Recall raw msg:\n{repo}")
@@ -300,7 +305,7 @@ async def _recall_group_event(bot: Bot, event: GroupRecallNoticeEvent):
try:
m = recall_msg_dealer(repo)
- except:
+ except Exception:
check = MessageChecker(repo).check_cq_code
if not check:
m = repo
@@ -322,7 +327,7 @@ async def _recall_private_event(bot: Bot, event: FriendRecallNoticeEvent):
try:
repo = await bot.get_msg(message_id=event.message_id)
- except BaseException:
+ except Exception:
return
log.debug(f"Recall raw msg:\n{repo}")
@@ -331,7 +336,7 @@ async def _recall_private_event(bot: Bot, event: FriendRecallNoticeEvent):
try:
m = recall_msg_dealer(repo)
- except:
+ except Exception:
check = MessageChecker(repo).check_cq_code
if not check:
m = repo