Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document how to add custom model providers #420

Merged
merged 2 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,96 @@ responsible for all charges they incur when they make API requests. Review your
provider's pricing information before submitting requests via Jupyter AI.
:::

### Custom model providers

You can define new providers using the LangChain framework API. Custom providers
inherit from both `jupyter-ai`'s ``BaseProvider`` and `langchain`'s [``LLM``][LLM].
You can either import a pre-defined model from [LangChain LLM list][langchain_llms],
or define a [custom LLM][custom_llm].
In the example below, we define a provider with two models using
a dummy ``FakeListLLM`` model, which returns responses from the ``responses``
keyword argument.

```python
# my_package/my_provider.py
from jupyter_ai_magics import BaseProvider
from langchain.llms import FakeListLLM


class MyProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = [
"model_a",
"model_b"
]
def __init__(self, **kwargs):
model = kwargs.get("model_id")
kwargs["responses"] = (
["This is a response from model 'a'"]
if model == "model_a" else
["This is a response from model 'b'"]
)
super().__init__(**kwargs)
```


If the new provider inherits from [``BaseChatModel``][BaseChatModel], it will be available
both in the chat UI and with magic commands. Otherwise, users can only use the new provider
with magic commands.

To make the new provider available, you need to declare it as an [entry point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html):

```toml
# my_package/pyproject.toml
[project]
name = "my_package"
version = "0.0.1"

[project.entry-points."jupyter_ai.model_providers"]
my-provider = "my_provider:MyProvider"
```

To test that the above minimal provider package works, install it with:

```sh
# from `my_package` directory
pip install -e .
```

Then, restart JupyterLab. You should now see an info message in the log that mentions
your new provider's `id`:

```
[I 2023-10-29 13:56:16.915 AiExtension] Registered model provider `my_provider`.
```

[langchain_llms]: https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.llms
[custom_llm]: https://python.langchain.com/docs/modules/model_io/models/llms/custom_llm
[LLM]: https://api.python.langchain.com/en/latest/llms/langchain.llms.base.LLM.html#langchain.llms.base.LLM
[BaseChatModel]: https://api.python.langchain.com/en/latest/chat_models/langchain.chat_models.base.BaseChatModel.html


### Customizing prompt templates

To modify the prompt template for a given format, override the ``get_prompt_template`` method:

```python
from langchain.prompts import PromptTemplate


class MyProvider(BaseProvider, FakeListLLM):
# (... properties as above ...)
def get_prompt_template(self, format) -> PromptTemplate:
if format === "code":
return PromptTemplate.from_template(
"{prompt}\n\nProduce output as source code only, "
"with no text or explanation before or after it."
)
return super().get_prompt_template(format)
```

## The chat interface

The easiest way to get started with Jupyter AI is to use the chat interface.
Expand Down Expand Up @@ -689,7 +779,7 @@ Write a poem about C++.

You can also define a custom LangChain chain:

```
```python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
Expand Down
10 changes: 6 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def get_lm_providers(
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
except Exception as e:
log.error(
f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
f"Unable to load model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
)
continue
if not is_provider_allowed(provider.id, restrictions):
Expand All @@ -58,9 +59,10 @@ def get_em_providers(
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
except Exception as e:
log.error(
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
)
continue
if not is_provider_allowed(provider.id, restrictions):
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def broadcast_message(self, message: Message):
self.chat_history.append(message)

async def on_message(self, message):
self.log.debug("Message recieved: %s", message)
self.log.debug("Message received: %s", message)

try:
message = json.loads(message)
Expand Down