Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document how to add custom model providers #420

Merged
merged 2 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,94 @@ responsible for all charges they incur when they make API requests. Review your
provider's pricing information before submitting requests via Jupyter AI.
:::

### Custom model providers

You can define a new provider building upon LangChain framework API. The provider
inherit from both `jupyter-ai`'s ``BaseProvider`` and `langchain`'s [``LLM``][LLM].
You can either import a pre-defined model from [LangChain LLM list][langchain_llms],
or define a [custom LLM][custom_llm].
In the example below, we demonstrate defining a provider with two models using
a dummy ``FakeListLLM`` model, which returns responses from the ``responses``
keyword argument.
krassowski marked this conversation as resolved.
Show resolved Hide resolved

```python
# my_package/my_provider.py
from jupyter_ai_magics import BaseProvider
from langchain.llms import FakeListLLM


class MyProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = [
"model_a",
"model_b"
]
def __init__(self, **kwargs):
model = kwargs.get("model_id")
kwargs["responses"] = (
["This is a response from model 'a'"]
if model == "model_a" else
["This is a response from model 'b'"]
)
super().__init__(**kwargs)
```


The provider will be available for both chat and magic usage if it inherits from
[``BaseChatModel``][BaseChatModel] or otherwise only in the magic.

To plug the new provider you will need declare it via an [entry point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html):
krassowski marked this conversation as resolved.
Show resolved Hide resolved

```toml
# my_package/pyproject.toml
[project]
name = "my_package"
version = "0.0.1"

[project.entry-points."jupyter_ai.model_providers"]
my-provider = "my_provider:MyProvider"
```

To test that the above minimal provider package works, install it with:

```sh
# from `my_package` directory
pip install -e .
```

and restart JupyterLab which now should include a log with:
krassowski marked this conversation as resolved.
Show resolved Hide resolved

```
[I 2023-10-29 13:56:16.915 AiExtension] Registered model provider `ai21`.
krassowski marked this conversation as resolved.
Show resolved Hide resolved
```

[langchain_llms]: https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.llms
[custom_llm]: https://python.langchain.com/docs/modules/model_io/models/llms/custom_llm
[LLM]: https://api.python.langchain.com/en/latest/llms/langchain.llms.base.LLM.html#langchain.llms.base.LLM
[BaseChatModel]: https://api.python.langchain.com/en/latest/chat_models/langchain.chat_models.base.BaseChatModel.html


### Customising prompt templates
krassowski marked this conversation as resolved.
Show resolved Hide resolved

To modify the prompt template for a given format, override the implementation of ``get_prompt_template`` method:
krassowski marked this conversation as resolved.
Show resolved Hide resolved

```python
from langchain.prompts import PromptTemplate


class MyProvider(BaseProvider, FakeListLLM):
# (... properties as above ...)
def get_prompt_template(self, format) -> PromptTemplate:
if format === "code":
return PromptTemplate.from_template(
"{prompt}\n\nProduce output as source code only, "
"with no text or explanation before or after it."
)
return super().get_prompt_template(format)
```

## The chat interface

The easiest way to get started with Jupyter AI is to use the chat interface.
Expand Down Expand Up @@ -689,7 +777,7 @@ Write a poem about C++.

You can also define a custom LangChain chain:

```
```python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
Expand Down
10 changes: 6 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def get_lm_providers(
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
except Exception as e:
log.error(
f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
f"Unable to load model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
)
continue
if not is_provider_allowed(provider.id, restrictions):
Expand All @@ -58,9 +59,10 @@ def get_em_providers(
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
except Exception as e:
log.error(
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
)
continue
if not is_provider_allowed(provider.id, restrictions):
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def broadcast_message(self, message: Message):
self.chat_history.append(message)

async def on_message(self, message):
self.log.debug("Message recieved: %s", message)
self.log.debug("Message received: %s", message)

try:
message = json.loads(message)
Expand Down