From 6738a22ea9835f517a723b99cb3c616bca7a6662 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Fri, 18 Oct 2024 09:50:12 -0700 Subject: [PATCH] fix typing issues raised by mypy --- packages/jupyter-ai/jupyter_ai/chat_handlers/base.py | 9 ++++++--- packages/jupyter-ai/jupyter_ai/handlers.py | 8 ++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index af7ca559a..5bbd91add 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -9,6 +9,7 @@ Awaitable, ClassVar, Dict, + get_args as get_type_args, List, Literal, Optional, @@ -265,8 +266,8 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): def broadcast_message(self, message: Message): """ Broadcasts a message to all WebSocket connections. If there are no - WebSocket connections, this method directly appends to - `self.chat_history`. + WebSocket connections and the message is a chat message, this method + directly appends to `self.chat_history`. """ broadcast = False for websocket in self._root_chat_handlers.values(): @@ -278,7 +279,9 @@ def broadcast_message(self, message: Message): break if not broadcast: - self._chat_history.append(message) + if isinstance(message, get_type_args(ChatMessage)): + cast(ChatMessage, message) + self._chat_history.append(message) def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None): """ diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 06175961b..e22d2ae3b 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -121,6 +121,10 @@ def loop(self) -> AbstractEventLoop: def pending_messages(self) -> List[PendingMessage]: return self.settings["pending_messages"] + @pending_messages.setter + def pending_messages(self, new_pending_messages): + self.settings["pending_messages"] = new_pending_messages + @property def cleared_message_ids(self) -> Set[str]: """Set of `HumanChatMessage.id` that were cleared via `ClearRequest`.""" @@ -128,10 +132,6 @@ def cleared_message_ids(self) -> Set[str]: self.settings["cleared_message_ids"] = set() return self.settings["cleared_message_ids"] - @pending_messages.setter - def pending_messages(self, new_pending_messages): - self.settings["pending_messages"] = new_pending_messages - def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path)