diff options
Diffstat (limited to 'ATRI/plugins')
| -rw-r--r-- | ATRI/plugins/chat/__init__.py | 7 | ||||
| -rw-r--r-- | ATRI/plugins/essential.py | 14 | ||||
| -rw-r--r-- | ATRI/plugins/setu/__init__.py | 89 | ||||
| -rw-r--r-- | ATRI/plugins/setu/data_source.py | 9 | ||||
| -rw-r--r-- | ATRI/plugins/setu/tf_dealer.py | 119 | ||||
| -rw-r--r-- | ATRI/plugins/status/__init__.py | 2 | 
6 files changed, 236 insertions, 4 deletions
diff --git a/ATRI/plugins/chat/__init__.py b/ATRI/plugins/chat/__init__.py index 9f7f26c..ab22db9 100644 --- a/ATRI/plugins/chat/__init__.py +++ b/ATRI/plugins/chat/__init__.py @@ -26,7 +26,10 @@ async def _chat(bot: Bot, event: MessageEvent):      msg = str(event.message)      repo = await Chat().deal(msg, user_id)      _chat_flmt.start_cd(user_id) -    await chat.finish(repo) +    try: +        await chat.finish(repo) +    except Exception: +        return  my_name_is = Chat().on_command("叫我", "更改闲聊(划掉 文爱)时的称呼", aliases={"我是"}, priority=1) @@ -111,7 +114,7 @@ async def _deal_say(bot: Bot, event: MessageEvent, state: T_State):      await say.finish(msg) -@scheduler.scheduled_job("interval", hours=3, misfire_grace_time=60) +@scheduler.scheduled_job("interval", name="闲聊词库检查更新", hours=3, misfire_grace_time=60)  async def _check_kimo():      try:          await Chat().update_data() diff --git a/ATRI/plugins/essential.py b/ATRI/plugins/essential.py index 3bbf66e..4cc0ac2 100644 --- a/ATRI/plugins/essential.py +++ b/ATRI/plugins/essential.py @@ -1,5 +1,6 @@  import os  import json +import shutil  import asyncio  from datetime import datetime  from pydantic.main import BaseModel @@ -30,14 +31,19 @@ from ATRI.service import Service  from ATRI.log import logger as log  from ATRI.config import BotSelfConfig  from ATRI.utils import CoolqCodeChecker +from ATRI.utils.apscheduler import scheduler +  driver = ATRI.driver()  bots = nonebot.get_bots() +  ESSENTIAL_DIR = Path(".") / "data" / "database" / "essential"  MANEGE_DIR = Path(".") / "data" / "database" / "manege" +TEMP_PATH = Path(".") / "data" / "temp"  os.makedirs(ESSENTIAL_DIR, exist_ok=True)  os.makedirs(MANEGE_DIR, exist_ok=True) +os.makedirs(TEMP_PATH, exist_ok=True)  @driver.on_startup @@ -298,3 +304,11 @@ async def _recall_private_event(bot: Bot, event: FriendRecallNoticeEvent):      msg = "主人,咱拿到了一条撤回信息!\n" f"{user}@[私聊]" "撤回了\n" f"{repo}"      for superuser in BotSelfConfig.superusers:          await bot.send_private_msg(user_id=int(superuser), message=msg) + + +@scheduler.scheduled_job("interval", name="清除缓存", minutes=30, misfire_grace_time=5) +async def _clear_cache(): +    try: +        shutil.rmtree(TEMP_PATH) +    except Exception: +        log.warning("清除缓存失败,请手动清除:data/temp") diff --git a/ATRI/plugins/setu/__init__.py b/ATRI/plugins/setu/__init__.py index eec5281..3e0e452 100644 --- a/ATRI/plugins/setu/__init__.py +++ b/ATRI/plugins/setu/__init__.py @@ -2,7 +2,10 @@ import re  import asyncio  from random import choice  from nonebot.adapters.cqhttp import Bot, MessageEvent, Message +from nonebot.adapters.cqhttp.message import MessageSegment +from nonebot.typing import T_State +from ATRI.config import BotSelfConfig  from ATRI.utils.limit import FreqLimiter, DailyLimiter  from ATRI.utils.apscheduler import scheduler  from .data_source import Setu @@ -74,7 +77,91 @@ async def _tag_setu(bot: Bot, event: MessageEvent):      await bot.delete_msg(message_id=event_id) -@scheduler.scheduled_job("interval", hours=1, misfire_grace_time=60, args=[Bot]) +setu_catcher = Setu().on_message("涩图嗅探") + + +@setu_catcher.handle() +async def _setu_catcher(bot: Bot, event: MessageEvent): +    msg = str(event.message) +    pattern = r"url=(.*?)]" +    args = re.findall(pattern, msg) +    if not args: +        return +    else: +        hso = list() +        for i in args: +            try: +                data = await Setu().detecter(i) +            except Exception: +                return +            if data[1] > 0.7: +                hso.append(data[1]) + +        hso.sort(reverse=True) + +        if not hso: +            return +        elif len(hso) == 1: +            u_repo = f"hso! 涩值:{'{:.2%}'.format(hso[0])}\n不行我要发给别人看" +            s_repo = ( +                f"涩图来咧!\n{MessageSegment.image(args[0])}\n涩值:{'{:.2%}'.format(hso[0])}" +            ) + +        else: +            u_repo = f"hso! 最涩的达到:{'{:.2%}'.format(hso[0])}\n不行我一定要发给别人看" + +            ss = list() +            for s in args: +                ss.append(MessageSegment.image(s)) +            ss = "\n".join(ss) +            s_repo = f"多张涩图来咧!\n{ss}\n最涩的达到:{'{:.2%}'.format(hso[0])}" + +        await bot.send(event, u_repo) +        for superuser in BotSelfConfig.superusers: +            await bot.send_private_msg(user_id=superuser, message=s_repo) + + +nsfw_checker = Setu().on_command("/nsfw", "涩值检测") + + +@nsfw_checker.handle() +async def _nsfw_checker(bot: Bot, event: MessageEvent, state: T_State): +    msg = str(event.message).strip() +    if msg: +        state["nsfw_img"] = msg + + +@nsfw_checker.got("nsfw_img", "图呢?") +async def _deal_check(bot: Bot, event: MessageEvent, state: T_State): +    msg = state["nsfw_img"] +    pattern = r"url=(.*?)]" +    args = re.findall(pattern, msg) +    if not args: +        await nsfw_checker.reject("请发送图片而不是其他东西!!") + +    data = await Setu().detecter(args[0]) +    hso = data[1] +    if not hso: +        await nsfw_checker.finish("图太小了!不测!") + +    resu = f"涩值:{'{:.2%}'.format(hso)}\n" +    if hso >= 0.75: +        resu += "hso!不行我要发给别人看" +        repo = f"涩图来咧!\n{MessageSegment.image(args[0])}\n涩值:{'{:.2%}'.format(hso[0])}" +        for superuser in BotSelfConfig.superusers: +            await bot.send_private_msg(user_id=superuser, message=repo) + +    elif 0.75 > hso >= 0.5: +        resu += "嗯。可冲" +    else: +        resu += "还行8" + +    await nsfw_checker.finish(resu) + + +@scheduler.scheduled_job( +    "interval", name="涩批诱捕器", hours=1, misfire_grace_time=60, args=[Bot] +)  async def _scheduler_setu(bot):      try:          group_list = await bot.get_group_list() diff --git a/ATRI/plugins/setu/data_source.py b/ATRI/plugins/setu/data_source.py index 2c665ae..39f3815 100644 --- a/ATRI/plugins/setu/data_source.py +++ b/ATRI/plugins/setu/data_source.py @@ -7,6 +7,7 @@ from nonebot.adapters.cqhttp import MessageSegment  from ATRI.service import Service  from ATRI.rule import is_in_service  from ATRI.utils import request +from .tf_dealer import detect_image  LOLICON_URL = "https://api.lolicon.app/setu/v2" @@ -61,6 +62,14 @@ class Setu(Service):          return repo, setu      @staticmethod +    async def detecter(url: str) -> list: +        """ +        涩值检测. +        """ +        data = await detect_image(url) +        return data + +    @staticmethod      async def scheduler() -> str:          """          每隔指定时间随机抽取一个群发送涩图. diff --git a/ATRI/plugins/setu/tf_dealer.py b/ATRI/plugins/setu/tf_dealer.py new file mode 100644 index 0000000..f41de49 --- /dev/null +++ b/ATRI/plugins/setu/tf_dealer.py @@ -0,0 +1,119 @@ +import io +import os +import re +import string +import asyncio +import skimage +import skimage.io +import numpy as np +from PIL import Image +from pathlib import Path +from sys import getsizeof +from random import sample + +try: +    import tflite_runtime.interpreter as tf  # type: ignore +except Exception: +    import tensorflow as tf + +from ATRI.log import logger as log +from ATRI.utils import request +from ATRI.exceptions import RequestError, WriteError + + +SETU_PATH = Path(".") / "data" / "database" / "setu" +TEMP_PATH = Path(".") / "data" / "temp" +os.makedirs(SETU_PATH, exist_ok=True) +os.makedirs(TEMP_PATH, exist_ok=True) + + +MODULE_URL = "https://cdn.jsdelivr.net/gh/Kyomotoi/CDN@master/project/ATRI/nsfw.tflite" +VGG_MEAN = [104, 117, 123] + + +def prepare_image(img): +    H, W, _ = img.shape +    h, w = (224, 224) + +    h_off = max((H - h) // 2, 0) +    w_off = max((W - w) // 2, 0) +    image = img[h_off : h_off + h, w_off : w_off + w, :] + +    image = image[:, :, ::-1] + +    image = image.astype(np.float32, copy=False) +    image = image * 255.0 +    image = image - np.array(VGG_MEAN, dtype=np.float32) + +    image = np.expand_dims(image, axis=0) +    return image + + +async def detect_image(url) -> list: +    try: +        req = await request.get(url) +    except RequestError: +        raise RequestError("Get info from download image failed!") + +    img_byte = getsizeof(req.read()) // 1024 +    if img_byte < 256: +        return [0, 0] + +    try: +        pattern = r"-(.*?)\/" +        file_name = re.findall(pattern, url)[0] +        path = TEMP_PATH / f"{file_name}.jpg" +        with open(path, "wb") as f: +            f.write(req.read()) +    except WriteError: +        raise WriteError("Writing file failed!") + +    await init_module() +    model_path = str((SETU_PATH / "nsfw.tflite").absolute()) + +    try: +        interpreter = tf.Interpreter(model_path=model_path)  # type: ignore +    except Exception: +        interpreter = tf.lite.Interpreter(model_path=model_path) + +    interpreter.allocate_tensors() + +    input_details = interpreter.get_input_details() +    output_details = interpreter.get_output_details() + +    im = Image.open(path) + +    if im.mode != "RGB": +        im = im.convert("RGB") +    imr = im.resize((256, 256), resample=Image.BILINEAR) +    fh_im = io.BytesIO() +    imr.save(fh_im, format="JPEG") +    fh_im.seek(0) + +    image = skimage.img_as_float32(skimage.io.imread(fh_im)) + +    final = prepare_image(image) +    interpreter.set_tensor(input_details[0]["index"], final) + +    interpreter.invoke() +    output_data = interpreter.get_tensor(output_details[0]["index"]) + +    result = np.squeeze(output_data).tolist() +    return result + + +async def init_module(): +    file_name = "nsfw.tflite" +    path = SETU_PATH / file_name +    if not path.is_file(): +        log.warning("缺少模型文件,装载中") +        data = await request.get(MODULE_URL) +        try: +            with open(path, "wb") as w: +                w.write(data.read()) +            log.info("模型装载完成") +        except WriteError: +            raise WriteError("NSFW TF module init failed!") + + +asyncio.get_event_loop().run_until_complete(init_module()) diff --git a/ATRI/plugins/status/__init__.py b/ATRI/plugins/status/__init__.py index 2746953..359344b 100644 --- a/ATRI/plugins/status/__init__.py +++ b/ATRI/plugins/status/__init__.py @@ -24,7 +24,7 @@ async def _status(bot: Bot, event: MessageEvent):  info_msg = "アトリは高性能ですから!" -@scheduler.scheduled_job("interval", minutes=10, misfire_grace_time=15) +@scheduler.scheduled_job("interval", name="状态检查", minutes=10, misfire_grace_time=15)  async def _status_checking():      global info_msg      msg, stat = IsSurvive().get_status()  | 
