Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ 为 MessageFactory 和 MessageSegmentFactory 添加类型提示 #127

Closed
wants to merge 8 commits into from
255 changes: 219 additions & 36 deletions nonebot_plugin_saa/abstract_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from inspect import signature
from typing_extensions import Self
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Type,
Expand All @@ -17,6 +19,7 @@
Optional,
Awaitable,
cast,
overload,
)

from nonebot.adapters import Bot, Event, Message, MessageSegment
Expand All @@ -32,7 +35,11 @@
extract_adapter_type,
)

if TYPE_CHECKING:
from .types import Text

TMSF = TypeVar("TMSF", bound="MessageSegmentFactory")
TMSFO = TypeVar("TMSFO", bound="MessageSegmentFactory")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

作用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

表明是两个可能不同的子类

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有两个都用到的地方吗

TMF = TypeVar("TMF", bound="MessageFactory")
BuildFunc = Union[
Callable[[TMSF], Union[MessageSegment, Awaitable[MessageSegment]]],
Expand Down Expand Up @@ -105,7 +112,7 @@
]
]

data: dict
data: Dict[str, Any]
_custom_builders: Dict[SupportedAdapters, CustomBuildFunc]

def _register_custom_builder(
Expand All @@ -131,8 +138,21 @@
cls._builders = {}
return super().__init_subclass__()

def __eq__(self, other: Self) -> bool:
return self.data == other.data
def __eq__(self, other: object) -> bool:
if isinstance(other, MessageSegmentFactory):
return self.data == other.data
elif isinstance(other, str):
return self.data["text"] == other
else:
return False

def __str__(self) -> str:
kvstr = ",".join([f"{k}={v!r}" for k, v in self.data.items()])
return f"[SAA:{self.__class__.__name__}|{kvstr}]"

def __repr__(self) -> str:
attrs = ", ".join([f"{k}={v!r}" for k, v in self.data.items()])
return f"{self.__class__.__name__}({attrs})"

def overwrite(
self,
Expand All @@ -151,11 +171,67 @@
return await do_build(self, builder, bot)
raise AdapterNotInstalled(adapter_name)

def __add__(self: TMSF, other: Union[str, TMSF, Iterable[TMSF]]):
return MessageFactory(self) + other

def __radd__(self: TMSF, other: Union[str, TMSF, Iterable[TMSF]]):
return MessageFactory(other) + self
@overload
def __add__(
self: Self, other: Union[str, Iterable[str]]
) -> "MessageFactory[Union[Self, Text]]":
... # pragma: no cover

@overload
def __add__(
self: Self, other: Union[TMSFO, Iterable[TMSFO]]
) -> "MessageFactory[Union[Self, TMSFO]]":
... # pragma: no cover

@overload
def __add__(
self: Self, other: Iterable[Union[str, TMSFO]]
) -> "MessageFactory[Union[Self, Text, TMSFO]]":
... # pragma: no cover

def __add__(
self: Self, other: Union[str, TMSFO, Iterable[Union[str, TMSFO]]]
) -> "MessageFactory":
if isinstance(other, str):
text = MessageFactory.get_text_factory()(other)
return MessageFactory([self, text])
elif isinstance(other, MessageSegmentFactory):
return MessageFactory([self, other])
elif isinstance(other, Iterable):
return MessageFactory([self, *other])
else:
raise TypeError(f"unsupported type {type(other)}")

@overload
def __radd__(
self: Self, other: Union[str, Iterable[str]]
) -> "MessageFactory[Union[Self, Text]]":
... # pragma: no cover

@overload
def __radd__(
self: Self, other: Union[TMSFO, Iterable[TMSFO]]
) -> "MessageFactory[Union[Self, TMSFO]]":
... # pragma: no cover

@overload
def __radd__(
self: Self, other: Iterable[Union[str, TMSFO]]
) -> "MessageFactory[Union[Self, Text, TMSFO]]":
... # pragma: no cover

def __radd__(
self: Self, other: Union[str, TMSFO, Iterable[Union[str, TMSFO]]]
) -> "MessageFactory":
if isinstance(other, str):
text = MessageFactory.get_text_factory()(other)
return MessageFactory([text, self])
elif isinstance(other, MessageSegmentFactory):
return MessageFactory([other, self])
elif isinstance(other, Iterable):
return MessageFactory([*other, self])
else:
raise TypeError(f"unsupported type {type(other)}")

async def send(self, *, at_sender=False, reply=False):
"回复消息,仅能用在事件响应器中"
Expand Down Expand Up @@ -241,52 +317,159 @@
return message_type(ms)
raise AdapterNotInstalled(adapter_name)

def __init__(self, message: Union[str, Iterable[TMSF], TMSF]):
super().__init__()
@overload
def __init__(self: "MessageFactory[Text]", ms: Union[str, Iterable[str]]) -> None:
... # pragma: no cover

if message is None:
return
@overload
def __init__(
self: "MessageFactory[TMSFO]", ms: Union[TMSFO, Iterable[TMSFO]]
) -> None:
... # pragma: no cover

if isinstance(message, str):
self.append(self.get_text_factory()(message))
elif isinstance(message, MessageSegmentFactory):
self.append(message)
elif isinstance(message, Iterable):
self.extend(message)
@overload
def __init__(
self: "MessageFactory[Text | TMSFO]",
ms: Iterable[Union[str, TMSFO]],
) -> None:
... # pragma: no cover

def __add__(self: TMF, other: Union[str, TMSF, Iterable[TMSF]]) -> TMF:
result = self.copy()
result += other
return result
@overload
def __init__(self: "MessageFactory") -> None:
... # pragma: no cover

def __radd__(self: TMF, other: Union[str, TMSF, Iterable[TMSF]]) -> TMF:
result = self.__class__(other)
return result + self
def __init__(self, ms: Union[str, TMSFO, Iterable[Union[str, TMSFO]], None] = None):
super().__init__()
if ms is None:
return

def __iadd__(self: TMF, other: Union[str, TMSF, Iterable[TMSF]]) -> TMF:
if isinstance(ms, (str, MessageSegmentFactory)):
self.__iadd__(ms)
elif isinstance(ms, Iterable):
for i in ms:
self.__iadd__(i)

@overload
def __add__(
self: "MessageFactory[TMSF]", other: Union[str, Iterable[str]]
) -> "MessageFactory[Union[TMSF, Text]]":
... # pragma: no cover

@overload
def __add__(
self: "MessageFactory[TMSF]", other: Union[TMSFO, Iterable[TMSFO]]
) -> "MessageFactory[Union[TMSF, TMSFO]]":
... # pragma: no cover

@overload
def __add__(
self: "MessageFactory[TMSF]",
other: Iterable[Union[str, TMSFO]],
) -> "MessageFactory[Union[TMSF, Text, TMSFO]]":
... # pragma: no cover

def __add__(
self: "MessageFactory[TMSF]",
other: Union[str, TMSFO, Iterable[Union[str, TMSFO]]],
) -> "MessageFactory":
copied = self.copy()
if isinstance(other, str):
copied.append(self.get_text_factory()(other))
return copied
elif isinstance(other, MessageSegmentFactory):
copied.append(other)
return copied
elif isinstance(other, Iterable):
for i in other:
copied += i
return copied
else:
raise TypeError(

Check warning on line 387 in nonebot_plugin_saa/abstract_factories.py

View check run for this annotation

Codecov / codecov/patch

nonebot_plugin_saa/abstract_factories.py#L387

Added line #L387 was not covered by tests
f"unsupported operand type(s) for +: '{self.__class__.__name__}' and '{type(other)}'" # noqa: E501
)

@overload
def __radd__(
self: "MessageFactory[TMSF]", other: Union[str, Iterable[str]]
) -> "MessageFactory[Union[TMSF, Text]]":
... # pragma: no cover

@overload
def __radd__(
self: "MessageFactory[TMSF]", other: Union[TMSFO, Iterable[TMSFO]]
) -> "MessageFactory[Union[TMSF, TMSFO]]":
... # pragma: no cover

@overload
def __radd__(
self: "MessageFactory[TMSF]",
other: Iterable[Union[str, TMSFO]],
) -> "MessageFactory[Union[TMSF, Text, TMSFO]]":
... # pragma: no cover

def __radd__(
self: "MessageFactory[TMSF]",
other: Union[str, TMSFO, Iterable[Union[str, TMSFO]]],
) -> "MessageFactory":
if isinstance(other, (str, MessageSegmentFactory)):
return MessageFactory(other) + self
elif isinstance(other, Iterable):
return MessageFactory(other) + self # type: ignore
else:
raise TypeError(

Check warning on line 419 in nonebot_plugin_saa/abstract_factories.py

View check run for this annotation

Codecov / codecov/patch

nonebot_plugin_saa/abstract_factories.py#L419

Added line #L419 was not covered by tests
f"unsupported operand type(s) for +: '{type(other)}' and '{self.__class__.__name__}'" # noqa: E501
)

@overload
def __iadd__(
self: "MessageFactory[TMSF]", other: Union[str, Iterable[str]]
) -> "MessageFactory[Union[TMSF, Text]]":
... # pragma: no cover

@overload
def __iadd__(
self: "MessageFactory[TMSF]", other: Union[TMSFO, Iterable[TMSFO]]
) -> "MessageFactory[Union[TMSF, TMSFO]]":
... # pragma: no cover

@overload
def __iadd__(
self: "MessageFactory[TMSF]",
other: Iterable[Union[str, TMSFO]],
) -> "MessageFactory[Union[TMSF, Text, TMSFO]]":
... # pragma: no cover

def __iadd__(
self: "MessageFactory[TMSF]",
other: Union[str, TMSFO, Iterable[Union[str, TMSFO]]],
) -> "MessageFactory":
if isinstance(other, str):
self.append(self.get_text_factory()(other))
return self
elif isinstance(other, MessageSegmentFactory):
self.append(other)
return self
elif isinstance(other, Iterable):
self.extend(other)
return self
else:
raise TypeError(
f"unsupported operand type(s) for +=: '{self.__class__.__name__}' and '{type(other)}'" # noqa: E501
)

return self

def append(self: TMF, obj: Union[str, TMSF]) -> TMF:
if isinstance(obj, MessageSegmentFactory):
super().append(obj)
elif isinstance(obj, str):
def append(self, obj: Union[str, TMSFO]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可能会break

if isinstance(obj, str):
super().append(self.get_text_factory()(obj))

elif isinstance(obj, MessageSegmentFactory):
super().append(obj) # type: ignore
else:
raise TypeError(f"unsupported type {type(obj)}")
return self

def extend(self: TMF, obj: Union[TMF, Iterable[TMSF]]) -> TMF:
def extend(self: TMF, obj: Union[TMF, Iterable[Union[str, TMSFO]]]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样

for message_segment_factory in obj:
self.append(message_segment_factory)

return self

def copy(self: TMF) -> TMF:
return deepcopy(self)

Expand Down
8 changes: 8 additions & 0 deletions nonebot_plugin_saa/types/common_message_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,11 @@ def __init__(self, message_id: MessageId):

super().__init__()
self.data = message_id

def __str__(self) -> str:
kvstr = ",".join([f"{k}={v!r}" for k, v in self.data.dict().items()])
return f"[SAA:{self.__class__.__name__}|{kvstr}]"

def __repr__(self) -> str:
kvrepr = ", ".join([f"{k}={v!r}" for k, v in self.data.dict().items()])
return f"{self.__class__.__name__}({kvrepr})"
6 changes: 5 additions & 1 deletion tests/test_feishu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import httpx
from nonebug import App
from nonebot import get_driver
from nonebot.adapters.feishu.bot import BotInfo
from nonebot.adapters.feishu import Bot, Message
from nonebot.adapters.feishu.models import BotInfo
from nonebot.adapters.feishu.config import BotConfig

from nonebot_plugin_saa.utils import SupportedAdapters
Expand Down Expand Up @@ -86,6 +86,8 @@ def mock_feishu_message_event(message: Message, group=False):
message=PrivateEventMessage(chat_type="p2p", **event_message_dict),
),
reply=None,
original_message=Message("original_message"),
_message=Message("message"),
)
else:
return GroupMessageEvent(
Expand All @@ -96,6 +98,8 @@ def mock_feishu_message_event(message: Message, group=False):
message=GroupEventMessage(chat_type="group", **event_message_dict),
),
reply=None,
original_message=Message("original_message"),
_message=Message("message"),
)


Expand Down
Loading