Skip to content

Commit

Permalink
refactor inline completion backend
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Dec 30, 2023
1 parent b4358bc commit 606e510
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 241 deletions.
183 changes: 138 additions & 45 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,169 @@
import json
import time
import traceback
from asyncio import AbstractEventLoop
from typing import Any, AsyncIterator, Dict, Union

# necessary to prevent circular import
from typing import TYPE_CHECKING, AsyncIterator, Dict

import tornado
from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin
from jupyter_ai.completions.models import (
CompletionError,
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
ModelChangedNotification,
)
from jupyter_ai.config_manager import ConfigManager, Logger

if TYPE_CHECKING:
from jupyter_ai.handlers import InlineCompletionHandler
from jupyter_server.base.handlers import JupyterHandler
from langchain.pydantic_v1 import BaseModel, ValidationError


class BaseInlineCompletionHandler(LLMHandlerMixin):
"""Class implementing completion handling."""
class BaseInlineCompletionHandler(
LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler
):
"""A Tornado WebSocket handler that receives inline completion requests and
fulfills them accordingly. This class is instantiated once per WebSocket
connection."""

handler_kind = "completion"

def __init__(
self,
log: Logger,
config_manager: ConfigManager,
model_parameters: Dict[str, Dict],
ws_sessions: Dict[str, "InlineCompletionHandler"],
):
super().__init__(log, config_manager, model_parameters)
self.ws_sessions = ws_sessions

async def on_message(
self, message: InlineCompletionRequest
) -> InlineCompletionReply:
try:
return await self.process_message(message)
except Exception as e:
return await self._handle_exc(e, message)

async def process_message(
##
# Interface for subclasses
##
async def handle_request(
self, message: InlineCompletionRequest
) -> InlineCompletionReply:
"""
Processes an inline completion request. Completion handlers
(subclasses) must implement this method.
Handles an inline completion request, without streaming. Subclasses
must define this method and write a reply via `self.write_message()`.
The method definition does not need to be wrapped in a try/except block.
"""
raise NotImplementedError("Should be implemented by subclasses.")
raise NotImplementedError(
"The required method `self.handle_request()` is not defined by this subclass."
)

async def stream(
async def handle_stream_request(
self, message: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
""" "
Stream the inline completion as it is generated. Completion handlers
(subclasses) can implement this method.
"""
raise NotImplementedError()
Handles an inline completion request, **with streaming**.
Implementations may optionally define this method. Implementations that
do so should stream replies via successive calls to
`self.write_message()`.
The method definition does not need to be wrapped in a try/except block.
"""
raise NotImplementedError(
"The optional method `self.handle_stream_request()` is not defined by this subclass."
)

##
# Definition of base class
##
handler_kind = "completion"

@property
def loop(self) -> AbstractEventLoop:
return self.settings["jai_event_loop"]

def write_message(self, message: Union[bytes, str, Dict[str, Any], BaseModel]):
"""
Write a bytes, string, dict, or Pydantic model object to the WebSocket
connection. The base definition of this method is provided by Tornado.
"""
if isinstance(message, BaseModel):
message = message.dict()

super().write_message(message)

def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)

def pre_get(self):
"""Handles authentication/authorization."""
# authenticate the request before opening the websocket
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise tornado.web.HTTPError(403)

async def _handle_exc(self, e: Exception, message: InlineCompletionRequest):
# authorize the user.
if not self.authorizer.is_authorized(self, user, "execute", "events"):
raise tornado.web.HTTPError(403)

async def get(self, *args, **kwargs):
"""Get an event socket."""
self.pre_get()
res = super().get(*args, **kwargs)
await res

async def on_message(self, message):
"""Public Tornado method that is called when the client sends a message
over this connection. This should **not** be overriden by subclasses."""

# first, verify that the message is an `InlineCompletionRequest`.
self.log.debug("Message received: %s", message)
try:
message = json.loads(message)
request = InlineCompletionRequest(**message)
except ValidationError as e:
self.log.error(e)
return

# next, dispatch the request to the correct handler and create the
# `handle_request` coroutine object
handle_request = None
if request.stream:
try:
handle_request = self._handle_stream_request(request)
except NotImplementedError:
self.log.error(
"Unable to handle stream request. The current `InlineCompletionHandler` does not implement the `handle_stream_request()` method."
)
return

else:
handle_request = self._handle_request(request)

# finally, wrap `handle_request` in an exception handler, and start the
# task on the event loop.
async def handle_request_and_catch():
try:
await handle_request
except Exception as e:
await self.handle_exc(e, request)

self.loop.create_task(handle_request_and_catch())

async def handle_exc(self, e: Exception, request: InlineCompletionRequest):
"""
Handles an exception raised in either `handle_request()` or
`handle_stream_request()`. This base class provides a default
implementation, which may be overriden by subclasses.
"""
error = CompletionError(
type=e.__class__.__name__,
title=e.args[0] if e.args else "Exception",
traceback=traceback.format_exc(),
)
return InlineCompletionReply(
list=InlineCompletionList(items=[]), error=error, reply_to=message.number
self.write_message(
InlineCompletionReply(
list=InlineCompletionList(items=[]),
error=error,
reply_to=request.number,
)
)

def broadcast(self, message: ModelChangedNotification):
for session in self.ws_sessions.values():
session.write_message(message.dict())
async def _handle_request(self, request: InlineCompletionRequest):
"""Private wrapper around `self.handle_request()`."""
start = time.time()
await self.handle_request(request)
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion handler resolved in {latency_ms} ms.")

async def _handle_stream_request(self, request: InlineCompletionRequest):
"""Private wrapper around `self.handle_stream_request()`."""
start = time.time()
await self._handle_stream_request(request)
async for chunk in self.stream(request):
self.write_message(chunk.dict())
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Inline completion streaming completed in {latency_ms} ms.")
90 changes: 49 additions & 41 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
ModelChangedNotification,
)
from .base import BaseInlineCompletionHandler

Expand Down Expand Up @@ -55,15 +54,6 @@ def __init__(self, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
next_lm_id = (
f'{lm_provider.id}:{lm_provider_params["model_id"]}'
if lm_provider
else None
)
self.broadcast(ModelChangedNotification(model=next_lm_id))

model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

Expand All @@ -83,39 +73,52 @@ def create_llm_chain(
self.llm = llm
self.llm_chain = prompt_template | llm | StrOutputParser()

async def process_message(
async def handle_request(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
if request.stream:
token = self._token_from_request(request, 0)
return InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)
else:
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
suggestion = await self.llm_chain.ainvoke(input=model_arguments)
suggestion = self._post_process_suggestion(suggestion, request)
return InlineCompletionReply(
"""Handles an inline completion request without streaming."""
self.get_llm_chain()
model_arguments = self._template_inputs_from_request(request)
suggestion = await self.llm_chain.ainvoke(input=model_arguments)
suggestion = self._post_process_suggestion(suggestion, request)
self.write_message(
InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)
)

async def stream(self, request: InlineCompletionRequest):
def _write_incomplete_reply(self, request: InlineCompletionRequest):
"""Writes an incomplete `InlineCompletionReply`, indicating to the
client that LLM output is about to streamed across this connection.
Should be called first in `self.handle_stream_request()`."""

token = self._token_from_request(request, 0)
reply = InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)
self.write_message(reply)

async def handle_stream_request(self, request: InlineCompletionRequest):
# first, send empty initial reply.
self._write_incomplete_reply()

# then, generate and stream LLM output over this connection.
self.get_llm_chain()
token = self._token_from_request(request, 0)
model_arguments = self._template_inputs_from_request(request)
suggestion = ""

async for fragment in self.llm_chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
Expand All @@ -124,18 +127,23 @@ async def stream(self, request: InlineCompletionRequest):
continue
else:
suggestion = self._post_process_suggestion(suggestion, request)
yield InlineCompletionStreamChunk(
self.write_message(
InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)
)

# finally, send a message confirming that we are done
self.write_message(
InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
done=True,
)
# at the end send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)

def _token_from_request(self, request: InlineCompletionRequest, suggestion: int):
Expand Down
28 changes: 12 additions & 16 deletions packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Type
from typing import Any, Dict, Type

from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.config_manager import ConfigManager
from jupyter_ai_magics.providers import BaseProvider


Expand All @@ -12,23 +12,20 @@ class LLMHandlerMixin:

handler_kind: str

def __init__(
self,
log: Logger,
config_manager: ConfigManager,
model_parameters: Dict[str, Dict],
):
self.log = log
self.config_manager = config_manager
self.model_parameters = model_parameters
@property
def config_manager(self) -> ConfigManager:
return self.settings["jai_config_manager"]

@property
def model_parameters(self) -> Dict[str, Dict[str, Any]]:
return self.settings["model_parameters"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.llm = None
self.llm_params = None
self.llm_chain = None

def model_changed_callback(self):
"""Method which can be overridden in sub-classes to listen to model change."""
pass

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
Expand All @@ -50,7 +47,6 @@ def get_llm_chain(self):
f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}."
)
self.create_llm_chain(lm_provider, lm_provider_params)
self.model_changed_callback()
elif self.llm_params != lm_provider_params:
self.log.info(
f"{self.handler_kind} model params changed, updating the llm chain."
Expand Down
Loading

0 comments on commit 606e510

Please sign in to comment.