From 3e32ca3964ff8f40e0b491e87f153040f2348fd0 Mon Sep 17 00:00:00 2001 From: Kyomotoi Date: Thu, 3 Feb 2022 14:36:24 +0800 Subject: =?UTF-8?q?=F0=9F=94=96=20=E6=9B=B4=E6=96=B0=E7=89=88=E6=9C=AC:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更新记录请参考文档: atri.kyomotoi.moe/changelog/overview/ --- test/utils.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 test/utils.py (limited to 'test/utils.py') diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 0000000..70d1a27 --- /dev/null +++ b/test/utils.py @@ -0,0 +1,87 @@ +""" +Fork from: https://github.com/nonebot/nonebot2/blob/master/tests/utils.py +""" +from typing import TYPE_CHECKING, Type, Optional + +from pydantic import create_model + +if TYPE_CHECKING: + from nonebot.adapters import Event, Message + + +def make_fake_message() -> Type["Message"]: + from nonebot.adapters import Message, MessageSegment + + class FakeMessageSegment(MessageSegment): + @classmethod + def get_message_class(cls): + return FakeMessage + + def __str__(self) -> str: + return self.data["text"] if self.type == "text" else f"[fake:{self.type}]" + + @classmethod + def text(cls, text: str): + return cls("text", {"text": text}) + + @classmethod + def image(cls, url: str): + return cls("image", {"url": url}) + + def is_text(self) -> bool: + return self.type == "text" + + class FakeMessage(Message): + @classmethod + def get_segment_class(cls): + return FakeMessageSegment + + @staticmethod + def _construct(msg: str): + yield FakeMessageSegment.text(msg) + + return FakeMessage + + +def make_fake_event( + _type: str = "message", + _name: str = "test", + _description: str = "test", + _user_id: str = "test", + _session_id: str = "test", + _message: Optional["Message"] = None, + _to_me: bool = True, + **fields, +) -> Type["Event"]: + from nonebot.adapters import Event + + _Fake = create_model("_Fake", __base__=Event, **fields) + + class FakeEvent(_Fake): + def get_type(self) -> str: + return _type + + def get_event_name(self) -> str: + return _name + + def get_event_description(self) -> str: + return _description + + def get_user_id(self) -> str: + return _user_id + + def get_session_id(self) -> str: + return _session_id + + def get_message(self) -> "Message": + if _message is not None: + return _message + raise NotImplementedError + + def is_tome(self) -> bool: + return _to_me + + class Config: + extra = "forbid" + + return FakeEvent -- cgit v1.2.3