summaryrefslogtreecommitdiff
path: root/ATRI/plugins/setu/modules/data_source.py
blob: bda7363a3f8acba1bfa067d74df9909214d311d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import json
import string
import aiosqlite
from aiosqlite.core import Connection
from pathlib import Path
from random import sample, choice
from aiohttp import ClientSession
from nonebot.adapters.cqhttp.message import MessageSegment, Message

from ATRI.log import logger as log
from ATRI.config import NsfwCheck
from ATRI.exceptions import RequestError, WriteError
from ATRI.utils.request import get_bytes
from ATRI.utils.img import compress_image


TEMP_DIR: Path = Path(".") / "ATRI" / "data" / "temp" / "setu"
SETU_DIR = Path(".") / "ATRI" / "data" / "database" / "setu"
os.makedirs(TEMP_DIR, exist_ok=True)
os.makedirs(SETU_DIR, exist_ok=True)
NSFW_URL = f"http://{NsfwCheck.host}:{NsfwCheck.port}/?url="
SIZE_REDUCE: bool = True


class Hso:
    @staticmethod
    async def nsfw_check(url: str) -> float:
        url = NSFW_URL + url
        try:
            data = json.loads(await get_bytes(url))
        except RequestError:
            raise RequestError("Request failed!")
        return round(data["score"], 4)

    @staticmethod
    async def _comp_setu(url: str) -> str:
        temp_id = "".join(sample(string.ascii_letters + string.digits, 8))
        file = TEMP_DIR / f"{temp_id}.png"

        try:
            async with ClientSession() as session:
                async with session.get(url) as r:
                    data = await r.read()
        except RequestError:
            raise RequestError("Request img failed!")

        try:
            with open(file, "wb") as r:
                r.write(data)
        except WriteError:
            raise WriteError("Writing img failed!")

        return compress_image(os.path.abspath(file))

    @classmethod
    async def setu(cls, data: dict) -> str:
        pid = data["pid"]
        title = data["title"]
        if SIZE_REDUCE:
            img = MessageSegment.image(
                "file:///" + await cls._comp_setu(data["url"]), proxy=False
            )
        else:
            img = MessageSegment.image(data["url"], proxy=False)

        msg = f"Pid: {pid}\n" f"Title: {title}\n" f"{img}"
        return msg

    @classmethod
    async def acc_setu(cls, d: list) -> str:
        data: dict = choice(d)

        for i in data["tags"]:
            if i["name"] == "R-18":
                return "太涩了不方便发w"

        pid = data["id"]
        title = data["title"]
        try:
            pic = data["meta_single_page"]["original_image_url"].replace(
                "pximg.net", "pixiv.cat"
            )
        except Exception:
            pic = choice(data["meta_pages"])["original"]["image_urls"].replace(
                "pximg.net", "pixiv.cat"
            )
        if SIZE_REDUCE:
            img = MessageSegment.image(
                "file:///" + await cls._comp_setu(pic), proxy=False
            )
        else:
            img = MessageSegment.image(pic, proxy=False)

        msg = f"Pid: {pid}\n" f"Title: {title}\n" f"{img}"
        return msg


class SetuData:
    SETU_DATA = SETU_DIR / "setu.db"

    @classmethod
    async def _check_database(cls) -> bool:
        if not cls.SETU_DATA.exists():
            log.warning(f"未发现数据库\n-> {cls.SETU_DATA}\n将开始创建")
            async with aiosqlite.connect(cls.SETU_DATA) as db:
                cur = await db.cursor()
                await cur.execute(
                    """
                    CREATE TABLE setu(
                        pid PID, title TITLE, tags TAGS,
                        user_id USER_ID, user_name USER_NAME,
                        user_account USER_ACCOUNT, url URL,
                        UNIQUE(
                            pid, title, tags, user_id,
                            user_name, user_account, url
                        )
                    );
                    """
                )
                await db.commit()
            log.warning(f"...创建数据库\n-> {cls.SETU_DATA}\n完成!")
            return True
        return True

    @classmethod
    async def add_data(cls, d: dict) -> None:
        data = (
            d["pid"],
            d["title"],
            d["tags"],
            d["user_id"],
            d["user_name"],
            d["user_account"],
            d["url"],
        )

        check = await cls._check_database()
        if check:
            async with aiosqlite.connect(cls.SETU_DATA) as db:
                await db.execute(
                    """
                    INSERT INTO setu(
                        pid, title, tags, user_id,
                        user_name, user_account, url
                    ) VALUES(
                        ?, ?, ?, ?, ?, ?, ?
                    );
                    """,
                    data,
                )
                await db.commit()

    @classmethod
    async def del_data(cls, pid: int) -> None:
        if not isinstance(pid, int):  # 防注入
            raise ValueError("Please provide int.")

        check = await cls._check_database()
        if check:
            async with aiosqlite.connect(cls.SETU_DATA) as db:
                await db.execute(f"DELETE FROM setu WHERE pid = {str(pid)};")
                await db.commit()

    @classmethod
    async def count(cls):
        check = await cls._check_database()
        if check:
            async with aiosqlite.connect(cls.SETU_DATA) as db:
                async with db.execute("SELECT * FROM setu") as cursor:
                    return len(await cursor.fetchall())  # type: ignore

    @classmethod
    async def get_setu(cls):
        check = await cls._check_database()
        if check:
            async with aiosqlite.connect(cls.SETU_DATA) as db:
                async with db.execute(
                    "SELECT * FROM setu ORDER BY RANDOM() limit 1;"
                ) as cursor:
                    return await cursor.fetchall()