From 920173c8faaab5a65459ce176d36812bac6feb08 Mon Sep 17 00:00:00 2001 From: Yuki-Asuuna <10174503104@stu.ecnu.edu.cn> Date: Fri, 25 Feb 2022 21:22:48 +0800 Subject: =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0b=E7=AB=99=E5=8A=A8=E6=80=81?= =?UTF-8?q?=E8=AE=A2=E9=98=85=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I8b74e3a286901379b8337e33d1b581524cb80d97 --- ATRI/database/__init__.py | 1 + ATRI/database/db.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++ ATRI/database/models.py | 22 ++++++++++++++ 3 files changed, 100 insertions(+) create mode 100644 ATRI/database/__init__.py create mode 100644 ATRI/database/db.py create mode 100644 ATRI/database/models.py (limited to 'ATRI/database') diff --git a/ATRI/database/__init__.py b/ATRI/database/__init__.py new file mode 100644 index 0000000..6840881 --- /dev/null +++ b/ATRI/database/__init__.py @@ -0,0 +1 @@ +from .db import DB diff --git a/ATRI/database/db.py b/ATRI/database/db.py new file mode 100644 index 0000000..c2aa015 --- /dev/null +++ b/ATRI/database/db.py @@ -0,0 +1,77 @@ +from tortoise import Tortoise + +from ATRI.database import models +from nonebot import get_driver + + +# 关于数据库的操作类,只实现与数据库有关的CRUD +# 请不要把业务逻辑写进去 +class DB: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + async def init(self): + from ATRI.database import models + + await Tortoise.init( + db_url="sqlite://ATRI/database/db.sqlite3", + modules={"models": [locals()["models"]]}, + ) + # Generate the schema + await Tortoise.generate_schemas() + + async def add_subscription(self, uid: int, groupid: int) -> bool: + try: + _ = await models.Subscription.create(uid=uid, groupid=groupid) + return True + except: + return False + + async def get_all_subscriptions_by_gid(self, groupid: int) -> list: + try: + subs = await self.get_subscriptions(query_map={"groupid": groupid}) + return subs + except: + return [] + + async def remove_subscription(self, query_map: dict) -> bool: + try: + ret = await models.Subscription.filter(**query_map).delete() + return True + except: + return False + + async def get_subscriptions(self, query_map: dict) -> list: + try: + ret = await models.Subscription.filter(**query_map) + return ret + except: + return [] + + async def get_all_subscriptions(self) -> list: + try: + ret = await models.Subscription.all() + return ret + except: + return [] + + async def update_subscriptions_by_uid(self, uid: int, update_map: dict) -> bool: + try: + # why use ** ? + # Reference: https://stackoverflow.com/questions/5710391/converting-python-dict-to-kwargs + _ = await models.Subscription.filter(uid=uid).update(**update_map) + return True + except: + return False + + +async def init(): + async with DB() as db: + await db.init() + + +driver = get_driver() +driver.on_startup(init) diff --git a/ATRI/database/models.py b/ATRI/database/models.py new file mode 100644 index 0000000..b3953df --- /dev/null +++ b/ATRI/database/models.py @@ -0,0 +1,22 @@ +""" + 定义SQLITE数据库的关系模式(表) + 数据库采用了tortoise orm,可以很好地支持异步 +""" + +from tortoise.models import Model +from tortoise import fields +from datetime import datetime + +# b站订阅表 +class Subscription(Model): + uid = fields.IntField(pk=True) # up的uid + groupid = fields.IntField() # 群号 + nickname = fields.TextField(null=True) # 订阅up的名称 + last_update = fields.DatetimeField( + default=datetime.fromordinal(1) + ) # 上一条动态更新时间 默认0001-01-01 00:00:00 + + def __str__(self): + return "[{nickname}|{uid}|{groupid}]".format( + nickname=self.nickname, uid=self.uid, groupid=self.groupid + ) -- cgit v1.2.3