Skip to content

Commit

Permalink
hive-chat: Ensure ChatMessage.matrix is parsed if present
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed Dec 3, 2024
1 parent 710fb7d commit 72839bb
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
5 changes: 5 additions & 0 deletions libs/chat/hive/chat/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class ClientEvent:
def __init__(self, serialized: dict[str, Any]):
self._event = serialized

def __eq__(self, other):
if not isinstance(other, ClientEvent):
return False
return self._event == other._event

def json(self):
return self._event

Expand Down
5 changes: 4 additions & 1 deletion libs/chat/hive/chat/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ChatMessage:
default_factory=lambda: datetime.now(tz=timezone.utc))
uuid: str | UUID = field(default_factory=uuid4)
in_reply_to: Optional[str | UUID | ChatMessage] = None
matrix: Optional[MatrixEvent] = None
matrix: Optional[dict | MatrixEvent] = None
_unhandled: Optional[dict[str, Any]] = field(default=None, repr=False)

def __post_init__(self):
Expand All @@ -49,6 +49,9 @@ def __post_init__(self):
if self.in_reply_to == self.uuid:
raise ValueError

if not isinstance(self.matrix, (MatrixEvent, NoneType)):
self.matrix = MatrixEvent(self.matrix)

@classmethod
def json_keys(cls) -> list[str]:
names = (field.name for field in fields(cls))
Expand Down
3 changes: 3 additions & 0 deletions libs/chat/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_basic():
assert message.in_reply_to is None
assert message.matrix.json() == event
assert not message.has_unhandled_fields
assert ChatMessage.from_json(message.json()) == message


def test_html():
Expand All @@ -36,6 +37,7 @@ def test_html():
assert message.in_reply_to is None
assert message.matrix.json() == event
assert not message.has_unhandled_fields
assert ChatMessage.from_json(message.json()) == message


def test_image():
Expand All @@ -50,3 +52,4 @@ def test_image():
assert message.in_reply_to is None
assert message.matrix.json() == event
assert not message.has_unhandled_fields
assert ChatMessage.from_json(message.json()) == message

0 comments on commit 72839bb

Please sign in to comment.