From 60be1eb2bd16fc89827a50ad18377642aae176be Mon Sep 17 00:00:00 2001
From: Kyomotoi <0w0@imki.moe>
Date: Thu, 6 Apr 2023 16:07:46 +0800
Subject: =?UTF-8?q?=F0=9F=8E=A8=20=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE?=
 =?UTF-8?q?=E5=BA=93=E4=BB=A3=E7=A0=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 ATRI/plugins/bilibili_dynamic/data_source.py | 18 +++++------
 ATRI/plugins/bilibili_dynamic/db.py          | 26 ---------------
 ATRI/plugins/rss/rss_mikanan/data_source.py  | 18 +++++------
 ATRI/plugins/rss/rss_mikanan/db.py           | 26 ---------------
 ATRI/plugins/rss/rss_rsshub/data_source.py   | 19 +++++------
 ATRI/plugins/rss/rss_rsshub/db.py            | 26 ---------------
 ATRI/plugins/thesaurus/data_source.py        | 46 ++++++++++++---------------
 ATRI/plugins/thesaurus/db.py                 | 47 ----------------------------
 8 files changed, 42 insertions(+), 184 deletions(-)
 delete mode 100644 ATRI/plugins/bilibili_dynamic/db.py
 delete mode 100644 ATRI/plugins/rss/rss_mikanan/db.py
 delete mode 100644 ATRI/plugins/rss/rss_rsshub/db.py
 delete mode 100644 ATRI/plugins/thesaurus/db.py

(limited to 'ATRI/plugins')

diff --git a/ATRI/plugins/bilibili_dynamic/data_source.py b/ATRI/plugins/bilibili_dynamic/data_source.py
index d629454..15babe5 100644
--- a/ATRI/plugins/bilibili_dynamic/data_source.py
+++ b/ATRI/plugins/bilibili_dynamic/data_source.py
@@ -5,8 +5,8 @@ from operator import itemgetter
 from ATRI.message import MessageBuilder
 from ATRI.utils import TimeDealer
 from ATRI.exceptions import BilibiliDynamicError
+from ATRI.database import DatabaseWrapper, BilibiliSubscription
 
-from .db import DB
 from .api import API
 
 
@@ -17,27 +17,25 @@ _OUTPUT_FORMAT = (
     .text("链接: {up_dy_link}")
     .done()
 )
+DB = DatabaseWrapper(BilibiliSubscription)
 
 
 class BilibiliDynamicSubscriptor:
     async def __add_sub(self, uid: int, group_id: int):
         try:
-            async with DB() as db:
-                await db.add_sub(uid, group_id)
+            await DB.add_sub(uid=uid, group_id=group_id)
         except Exception:
             raise BilibiliDynamicError("添加订阅失败")
 
     async def update_sub(self, uid: int, group_id: int, update_map: dict):
         try:
-            async with DB() as db:
-                await db.update_sub(uid, group_id, update_map)
+            await DB.update_sub(update_map=update_map, uid=uid, group_id=group_id)
         except Exception:
             raise BilibiliDynamicError("更新订阅失败")
 
     async def __del_sub(self, uid: int, group_id: int):
         try:
-            async with DB() as db:
-                await db.del_sub({"uid": uid, "group_id": group_id})
+            await DB.del_sub({"uid": uid, "group_id": group_id})
         except Exception:
             raise BilibiliDynamicError("删除订阅失败")
 
@@ -48,15 +46,13 @@ class BilibiliDynamicSubscriptor:
             query_map = {"uid": uid, "group_id": group_id}
 
         try:
-            async with DB() as db:
-                return await db.get_sub_list(query_map)
+            return await DB.get_sub_list(query_map)
         except Exception:
             raise BilibiliDynamicError("获取订阅列表失败")
 
     async def get_all_subs(self) -> list:
         try:
-            async with DB() as db:
-                return await db.get_all_subs()
+            return await DB.get_all_subs()
         except Exception:
             raise BilibiliDynamicError("获取全部订阅列表失败")
 
diff --git a/ATRI/plugins/bilibili_dynamic/db.py b/ATRI/plugins/bilibili_dynamic/db.py
deleted file mode 100644
index e6bb8bc..0000000
--- a/ATRI/plugins/bilibili_dynamic/db.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from ATRI.database import BilibiliSubscription
-
-
-class DB:
-    async def __aenter__(self):
-        return self
-
-    async def __aexit__(self, exc_type, exc_val, exc_tb):
-        pass
-
-    async def add_sub(self, uid: int, group_id: int):
-        await BilibiliSubscription.create(uid=uid, group_id=group_id)
-
-    async def update_sub(self, uid: int, group_id: int, update_map: dict):
-        await BilibiliSubscription.filter(uid=uid, group_id=group_id).update(
-            **update_map
-        )
-
-    async def del_sub(self, query_map: dict):
-        await BilibiliSubscription.filter(**query_map).delete()
-
-    async def get_sub_list(self, query_map: dict) -> list:
-        return await BilibiliSubscription.filter(**query_map)
-
-    async def get_all_subs(self) -> list:
-        return await BilibiliSubscription.all()
diff --git a/ATRI/plugins/rss/rss_mikanan/data_source.py b/ATRI/plugins/rss/rss_mikanan/data_source.py
index 6aec1b5..289d4f3 100644
--- a/ATRI/plugins/rss/rss_mikanan/data_source.py
+++ b/ATRI/plugins/rss/rss_mikanan/data_source.py
@@ -2,44 +2,40 @@ import xmltodict
 
 from ATRI.exceptions import RssError
 from ATRI.utils import request, gen_random_str
+from ATRI.database import DatabaseWrapper, RssMikananiSubcription
 
 
-from .db import DB
+DB = DatabaseWrapper(RssMikananiSubcription)
 
 
 class RssMikananSubscriptor:
     async def __add_sub(self, _id: str, group_id: int):
         try:
-            async with DB() as db:
-                await db.add_sub(_id, group_id)
+            await DB.add_sub(_id=_id, group_id=group_id)
         except Exception:
             raise RssError("rss.mikan: 添加订阅失败")
 
     async def update_sub(self, _id: str, group_id: int, update_map: dict):
         try:
-            async with DB() as db:
-                await db.update_sub(_id, group_id, update_map)
+            await DB.update_sub(update_map=update_map, _id=_id, group_id=group_id)
         except Exception:
             raise RssError("rss.mikan: 更新订阅失败")
 
     async def __del_sub(self, _id: str, group_id: int):
         try:
-            async with DB() as db:
-                await db.del_sub({"_id": _id, "group_id": group_id})
+            await DB.del_sub({"_id": _id, "group_id": group_id})
         except Exception:
             raise RssError("rss.mikan: 删除订阅失败")
 
     async def get_sub_list(self, query_map: dict) -> list:
         try:
-            async with DB() as db:
-                return await db.get_sub_list(query_map)
+            return await DB.get_sub_list(query_map)
         except Exception:
             raise RssError("rss.mikan: 获取订阅列表失败")
 
     async def get_all_subs(self) -> list:
         try:
-            async with DB() as db:
-                return await db.get_all_subs()
+            return await DB.get_all_subs()
         except Exception:
             raise RssError("rss.mikan: 获取所有订阅失败")
 
diff --git a/ATRI/plugins/rss/rss_mikanan/db.py b/ATRI/plugins/rss/rss_mikanan/db.py
deleted file mode 100644
index ac3385d..0000000
--- a/ATRI/plugins/rss/rss_mikanan/db.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from ATRI.database import RssMikananiSubcription
-
-
-class DB:
-    async def __aenter__(self):
-        return self
-
-    async def __aexit__(self, exc_type, exc_val, exc_tb):
-        pass
-
-    async def add_sub(self, _id: str, group_id: int):
-        await RssMikananiSubcription.create(_id=_id, group_id=group_id)
-
-    async def update_sub(self, _id: str, group_id: int, update_map: dict):
-        await RssMikananiSubcription.filter(_id=_id, group_id=group_id).update(
-            **update_map
-        )
-
-    async def del_sub(self, query_map: dict):
-        await RssMikananiSubcription.filter(**query_map).delete()
-
-    async def get_sub_list(self, query_map: dict) -> list:
-        return await RssMikananiSubcription.filter(**query_map)
-
-    async def get_all_subs(self) -> list:
-        return await RssMikananiSubcription.all()
diff --git a/ATRI/plugins/rss/rss_rsshub/data_source.py b/ATRI/plugins/rss/rss_rsshub/data_source.py
index 9ca1c04..d829f1a 100644
--- a/ATRI/plugins/rss/rss_rsshub/data_source.py
+++ b/ATRI/plugins/rss/rss_rsshub/data_source.py
@@ -2,43 +2,40 @@ import xmltodict
 
 from ATRI.exceptions import RssError
 from ATRI.utils import request, gen_random_str
+from ATRI.database import DatabaseWrapper, RssRsshubSubcription
 
-from .db import DB
+
+DB = DatabaseWrapper(RssRsshubSubcription)
 
 
 class RssHubSubscriptor:
     async def __add_sub(self, _id: str, group_id: int):
         try:
-            async with DB() as db:
-                await db.add_sub(_id, group_id)
+            await DB.add_sub(_id=_id, group_id=group_id)
         except Exception:
             raise RssError("rss.rsshub: 添加订阅失败")
 
     async def update_sub(self, _id: str, group_id: int, update_map: dict):
         try:
-            async with DB() as db:
-                await db.update_sub(_id, group_id, update_map)
+            await DB.update_sub(update_map=update_map, _id=_id, group_id=group_id)
         except Exception:
             raise RssError("rss.rsshub: 更新订阅失败")
 
     async def __del_sub(self, _id: str, group_id: int):
         try:
-            async with DB() as db:
-                await db.del_sub({"_id": _id, "group_id": group_id})
+            await DB.del_sub({"_id": _id, "group_id": group_id})
         except Exception:
             raise RssError("rss.rsshub: 删除订阅失败")
 
     async def get_sub_list(self, query_map: dict) -> list:
         try:
-            async with DB() as db:
-                return await db.get_sub_list(query_map)
+            return await DB.get_sub_list(query_map)
         except Exception:
             raise RssError("rss.rsshub: 获取订阅列表失败")
 
     async def get_all_subs(self) -> list:
         try:
-            async with DB() as db:
-                return await db.get_all_subs()
+            return await DB.get_all_subs()
         except Exception:
             raise RssError("rss.rsshub: 获取所有订阅失败")
 
diff --git a/ATRI/plugins/rss/rss_rsshub/db.py b/ATRI/plugins/rss/rss_rsshub/db.py
deleted file mode 100644
index 3f614a3..0000000
--- a/ATRI/plugins/rss/rss_rsshub/db.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from ATRI.database import RssRsshubSubcription
-
-
-class DB:
-    async def __aenter__(self):
-        return self
-
-    async def __aexit__(self, exc_type, exc_val, exc_tb):
-        pass
-
-    async def add_sub(self, _id: str, group_id: int):
-        await RssRsshubSubcription.create(_id=_id, group_id=group_id)
-
-    async def update_sub(self, _id, group_id, update_map: dict):
-        await RssRsshubSubcription.filter(_id=_id, group_id=group_id).update(
-            **update_map
-        )
-
-    async def del_sub(self, query_map: dict):
-        await RssRsshubSubcription.filter(**query_map).delete()
-
-    async def get_sub_list(self, query_map: dict) -> list:
-        return await RssRsshubSubcription.filter(**query_map)
-
-    async def get_all_subs(self) -> list:
-        return await RssRsshubSubcription.all()
diff --git a/ATRI/plugins/thesaurus/data_source.py b/ATRI/plugins/thesaurus/data_source.py
index 4f99b2c..281a777 100644
--- a/ATRI/plugins/thesaurus/data_source.py
+++ b/ATRI/plugins/thesaurus/data_source.py
@@ -2,23 +2,23 @@ from datetime import datetime, timedelta, timezone as tz
 
 from ATRI.message import MessageBuilder
 from ATRI.exceptions import ThesaurusError
+from ATRI.database import DatabaseWrapper, ThesaurusStoragor, ThesaurusAuditList
 
-from .db import DBForTS, DBForTAL
-from .db import ThesaurusStoragor
+
+DBForTS = DatabaseWrapper(ThesaurusStoragor)
+DBForTAL = DatabaseWrapper(ThesaurusAuditList)
 
 
 class ThesaurusManager:
     async def __add_item(self, _id: str, group_id: int, is_main: bool = False):
         if is_main:
             try:
-                async with DBForTS() as db:
-                    await db.add_item(_id, group_id)
+                await DBForTS.add_sub(_id=_id, group_id=group_id)
             except Exception:
                 raise ThesaurusError(f"添加词库(ts)数据失败 目标词id: {_id}")
         else:
             try:
-                async with DBForTAL() as db:
-                    await db.add_item(_id, group_id)
+                await DBForTAL.add_sub(_id=_id, group_id=group_id)
             except Exception:
                 raise ThesaurusError(f"添加词库(tal)数据失败 目标词id: {_id}")
 
@@ -27,56 +27,52 @@ class ThesaurusManager:
     ):
         if is_main:
             try:
-                async with DBForTS() as db:
-                    await db.update_item(_id, group_id, update_map)
+                await DBForTS.update_sub(
+                    update_map=update_map, _id=_id, group_id=group_id
+                )
             except Exception:
                 raise ThesaurusError(f"更新词库(ts)数据失败 目标词id: {_id}")
         else:
             try:
-                async with DBForTAL() as db:
-                    await db.update_item(_id, group_id, update_map)
+                await DBForTAL.update_sub(
+                    update_map=update_map, _id=_id, group_id=group_id
+                )
             except Exception:
                 raise ThesaurusError(f"更新词库(tal)数据失败 目标词id: {_id}")
 
     async def __del_item(self, _id: str, group_id: int, is_main: bool = False):
         if is_main:
             try:
-                async with DBForTS() as db:
-                    await db.del_item({"_id": _id, "group_id": group_id})
+                await DBForTS.del_sub({"_id": _id, "group_id": group_id})
             except Exception:
                 raise ThesaurusError(f"删除词库(ts)数据失败 目标词id: {_id}")
         else:
             try:
-                async with DBForTAL() as db:
-                    await db.del_item({"_id": _id, "group_id": group_id})
+                await DBForTAL.del_sub({"_id": _id, "group_id": group_id})
             except Exception:
                 raise ThesaurusError(f"删除词库(tal)数据失败 目标词id: {_id}")
 
     async def get_item_list(self, query_map: dict, is_main: bool = False) -> list:
         if is_main:
             try:
-                async with DBForTS() as db:
-                    return await db.get_item_list(query_map)
+                return await DBForTS.get_sub_list(query_map)
             except Exception:
                 raise ThesaurusError("获取词库(ts)列表数据失败")
         else:
             try:
-                async with DBForTAL() as db:
-                    return await db.get_item_list(query_map)
+                return await DBForTAL.get_sub_list(query_map)
             except Exception:
                 raise ThesaurusError("获取词库(tal)列表数据失败")
 
     async def get_all_items(self, is_main: bool = False) -> list:
         if is_main:
             try:
-                async with DBForTS() as db:
-                    return await db.get_all_items()
+                return await DBForTS.get_all_subs()
             except Exception:
                 raise ThesaurusError("获取全部词库(ts)列表数据失败")
         else:
             try:
-                async with DBForTAL() as db:
-                    return await db.get_all_items()
+                return await DBForTAL.get_all_subs()
             except Exception:
                 raise ThesaurusError("获取全部词库(tal)列表数据失败")
 
@@ -165,8 +161,7 @@ class ThesaurusManager:
 class ThesaurusListener:
     async def get_item_by_id(self, _id: str) -> ThesaurusStoragor:
         try:
-            async with DBForTS() as db:
-                data = await db.get_item_list({"_id": _id})
+            data = await DBForTS.get_sub_list({"_id": _id})
         except Exception:
             raise ThesaurusError(f"获取词库(ts)数据失败 词条ID: {_id}")
 
@@ -174,7 +169,6 @@ class ThesaurusListener:
 
     async def get_item_list(self, group_id: int):
         try:
-            async with DBForTS() as db:
-                return await db.get_item_list({"group_id": group_id})
+            return await DBForTS.get_sub_list({"group_id": group_id})
         except Exception:
             raise ThesaurusError(f"获取词库(ts)数据失败 目标群号: {group_id}")
diff --git a/ATRI/plugins/thesaurus/db.py b/ATRI/plugins/thesaurus/db.py
deleted file mode 100644
index b5394fc..0000000
--- a/ATRI/plugins/thesaurus/db.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from ATRI.database import ThesaurusStoragor, ThesaurusAuditList
-
-
-class DBForTS:
-    async def __aenter__(self):
-        return self
-
-    async def __aexit__(self, exc_type, exc_val, exc_tb):
-        pass
-
-    async def add_item(self, _id: str, group_id: int):
-        await ThesaurusStoragor.create(_id=_id, group_id=group_id)
-
-    async def update_item(self, _id: str, group_id: int, update_map: dict):
-        await ThesaurusStoragor.filter(_id=_id, group_id=group_id).update(**update_map)
-
-    async def del_item(self, query_map: dict):
-        await ThesaurusStoragor.filter(**query_map).delete()
-
-    async def get_item_list(self, query_map: dict) -> list:
-        return await ThesaurusStoragor.filter(**query_map)
-
-    async def get_all_items(self) -> list:
-        return await ThesaurusStoragor.all()
-
-
-class DBForTAL:
-    async def __aenter__(self):
-        return self
-
-    async def __aexit__(self, exc_type, exc_val, exc_tb):
-        pass
-
-    async def add_item(self, _id: str, group_id: int):
-        await ThesaurusAuditList.create(_id=_id, group_id=group_id)
-
-    async def update_item(self, _id: str, group_id: int, update_map: dict):
-        await ThesaurusAuditList.filter(_id=_id, group_id=group_id).update(**update_map)
-
-    async def del_item(self, query_map: dict):
-        await ThesaurusAuditList.filter(**query_map).delete()
-
-    async def get_item_list(self, query_map: dict) -> list:
-        return await ThesaurusAuditList.filter(**query_map)
-
-    async def get_all_items(self) -> list:
-        return await ThesaurusAuditList.all()
-- 
cgit v1.2.3