forked from jupyterlab/jupyter-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
202 additions
and
241 deletions.
There are no files selected for viewing
183 changes: 138 additions & 45 deletions
183
packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,76 +1,169 @@ | ||
import json | ||
import time | ||
import traceback | ||
from asyncio import AbstractEventLoop | ||
from typing import Any, AsyncIterator, Dict, Union | ||
|
||
# necessary to prevent circular import | ||
from typing import TYPE_CHECKING, AsyncIterator, Dict | ||
|
||
import tornado | ||
from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin | ||
from jupyter_ai.completions.models import ( | ||
CompletionError, | ||
InlineCompletionList, | ||
InlineCompletionReply, | ||
InlineCompletionRequest, | ||
InlineCompletionStreamChunk, | ||
ModelChangedNotification, | ||
) | ||
from jupyter_ai.config_manager import ConfigManager, Logger | ||
|
||
if TYPE_CHECKING: | ||
from jupyter_ai.handlers import InlineCompletionHandler | ||
from jupyter_server.base.handlers import JupyterHandler | ||
from langchain.pydantic_v1 import BaseModel, ValidationError | ||
|
||
|
||
class BaseInlineCompletionHandler(LLMHandlerMixin): | ||
"""Class implementing completion handling.""" | ||
class BaseInlineCompletionHandler( | ||
LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler | ||
): | ||
"""A Tornado WebSocket handler that receives inline completion requests and | ||
fulfills them accordingly. This class is instantiated once per WebSocket | ||
connection.""" | ||
|
||
handler_kind = "completion" | ||
|
||
def __init__( | ||
self, | ||
log: Logger, | ||
config_manager: ConfigManager, | ||
model_parameters: Dict[str, Dict], | ||
ws_sessions: Dict[str, "InlineCompletionHandler"], | ||
): | ||
super().__init__(log, config_manager, model_parameters) | ||
self.ws_sessions = ws_sessions | ||
|
||
async def on_message( | ||
self, message: InlineCompletionRequest | ||
) -> InlineCompletionReply: | ||
try: | ||
return await self.process_message(message) | ||
except Exception as e: | ||
return await self._handle_exc(e, message) | ||
|
||
async def process_message( | ||
## | ||
# Interface for subclasses | ||
## | ||
async def handle_request( | ||
self, message: InlineCompletionRequest | ||
) -> InlineCompletionReply: | ||
""" | ||
Processes an inline completion request. Completion handlers | ||
(subclasses) must implement this method. | ||
Handles an inline completion request, without streaming. Subclasses | ||
must define this method and write a reply via `self.write_message()`. | ||
The method definition does not need to be wrapped in a try/except block. | ||
""" | ||
raise NotImplementedError("Should be implemented by subclasses.") | ||
raise NotImplementedError( | ||
"The required method `self.handle_request()` is not defined by this subclass." | ||
) | ||
|
||
async def stream( | ||
async def handle_stream_request( | ||
self, message: InlineCompletionRequest | ||
) -> AsyncIterator[InlineCompletionStreamChunk]: | ||
""" " | ||
Stream the inline completion as it is generated. Completion handlers | ||
(subclasses) can implement this method. | ||
""" | ||
raise NotImplementedError() | ||
Handles an inline completion request, **with streaming**. | ||
Implementations may optionally define this method. Implementations that | ||
do so should stream replies via successive calls to | ||
`self.write_message()`. | ||
The method definition does not need to be wrapped in a try/except block. | ||
""" | ||
raise NotImplementedError( | ||
"The optional method `self.handle_stream_request()` is not defined by this subclass." | ||
) | ||
|
||
## | ||
# Definition of base class | ||
## | ||
handler_kind = "completion" | ||
|
||
@property | ||
def loop(self) -> AbstractEventLoop: | ||
return self.settings["jai_event_loop"] | ||
|
||
def write_message(self, message: Union[bytes, str, Dict[str, Any], BaseModel]): | ||
""" | ||
Write a bytes, string, dict, or Pydantic model object to the WebSocket | ||
connection. The base definition of this method is provided by Tornado. | ||
""" | ||
if isinstance(message, BaseModel): | ||
message = message.dict() | ||
|
||
super().write_message(message) | ||
|
||
def initialize(self): | ||
self.log.debug("Initializing websocket connection %s", self.request.path) | ||
|
||
def pre_get(self): | ||
"""Handles authentication/authorization.""" | ||
# authenticate the request before opening the websocket | ||
user = self.current_user | ||
if user is None: | ||
self.log.warning("Couldn't authenticate WebSocket connection") | ||
raise tornado.web.HTTPError(403) | ||
|
||
async def _handle_exc(self, e: Exception, message: InlineCompletionRequest): | ||
# authorize the user. | ||
if not self.authorizer.is_authorized(self, user, "execute", "events"): | ||
raise tornado.web.HTTPError(403) | ||
|
||
async def get(self, *args, **kwargs): | ||
"""Get an event socket.""" | ||
self.pre_get() | ||
res = super().get(*args, **kwargs) | ||
await res | ||
|
||
async def on_message(self, message): | ||
"""Public Tornado method that is called when the client sends a message | ||
over this connection. This should **not** be overriden by subclasses.""" | ||
|
||
# first, verify that the message is an `InlineCompletionRequest`. | ||
self.log.debug("Message received: %s", message) | ||
try: | ||
message = json.loads(message) | ||
request = InlineCompletionRequest(**message) | ||
except ValidationError as e: | ||
self.log.error(e) | ||
return | ||
|
||
# next, dispatch the request to the correct handler and create the | ||
# `handle_request` coroutine object | ||
handle_request = None | ||
if request.stream: | ||
try: | ||
handle_request = self._handle_stream_request(request) | ||
except NotImplementedError: | ||
self.log.error( | ||
"Unable to handle stream request. The current `InlineCompletionHandler` does not implement the `handle_stream_request()` method." | ||
) | ||
return | ||
|
||
else: | ||
handle_request = self._handle_request(request) | ||
|
||
# finally, wrap `handle_request` in an exception handler, and start the | ||
# task on the event loop. | ||
async def handle_request_and_catch(): | ||
try: | ||
await handle_request | ||
except Exception as e: | ||
await self.handle_exc(e, request) | ||
|
||
self.loop.create_task(handle_request_and_catch()) | ||
|
||
async def handle_exc(self, e: Exception, request: InlineCompletionRequest): | ||
""" | ||
Handles an exception raised in either `handle_request()` or | ||
`handle_stream_request()`. This base class provides a default | ||
implementation, which may be overriden by subclasses. | ||
""" | ||
error = CompletionError( | ||
type=e.__class__.__name__, | ||
title=e.args[0] if e.args else "Exception", | ||
traceback=traceback.format_exc(), | ||
) | ||
return InlineCompletionReply( | ||
list=InlineCompletionList(items=[]), error=error, reply_to=message.number | ||
self.write_message( | ||
InlineCompletionReply( | ||
list=InlineCompletionList(items=[]), | ||
error=error, | ||
reply_to=request.number, | ||
) | ||
) | ||
|
||
def broadcast(self, message: ModelChangedNotification): | ||
for session in self.ws_sessions.values(): | ||
session.write_message(message.dict()) | ||
async def _handle_request(self, request: InlineCompletionRequest): | ||
"""Private wrapper around `self.handle_request()`.""" | ||
start = time.time() | ||
await self.handle_request(request) | ||
latency_ms = round((time.time() - start) * 1000) | ||
self.log.info(f"Inline completion handler resolved in {latency_ms} ms.") | ||
|
||
async def _handle_stream_request(self, request: InlineCompletionRequest): | ||
"""Private wrapper around `self.handle_stream_request()`.""" | ||
start = time.time() | ||
await self._handle_stream_request(request) | ||
async for chunk in self.stream(request): | ||
self.write_message(chunk.dict()) | ||
latency_ms = round((time.time() - start) * 1000) | ||
self.log.info(f"Inline completion streaming completed in {latency_ms} ms.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.