Skip to content

Commit

Permalink
[2.x] Upgrade to LangChain v0.3 and Pydantic v2 (#1199)
Browse files Browse the repository at this point in the history
* remove importliner from project

* initial upgrade to langchain~=0.3, pydantic~=2.0

* default to `None` for all `Optional` fields explicitly

* fix history impl for Pydantic v2, fixes chat

* prefer `.model_dump_json()` over `.json()`

Addresses a Pydantic v2 deprecation warning, as `BaseModel.json()` is
now deprecated in favor of `BaseModel.model_dump_json()`.

* replace `.dict()` with `.model_dump()`.

`BaseModel.dict()` is deprecated in favor of `BaseModel.model_dump()` in
Pydantic v2.

* fix BaseProvider.server_settings

* fix OpenRouterProvider

* fix remaining unit tests

* address all Pydantic v1 deprecation warnings

* pre-commit

* fix mypy
  • Loading branch information
dlqqq authored Jan 15, 2025
1 parent b2bb5b5 commit 26593dc
Show file tree
Hide file tree
Showing 27 changed files with 189 additions and 271 deletions.
34 changes: 0 additions & 34 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,3 @@ jobs:
run: jlpm
- name: Lint TypeScript source
run: jlpm lerna run lint:check

lint_py_imports:
name: Lint Python imports
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Echo environment details
run: |
which python
which pip
python --version
pip --version
# see #546 for context on why this is necessary
- name: Create venv
run: |
python -m venv lint_py_imports
- name: Install job dependencies
run: |
source ./lint_py_imports/bin/activate
pip install jupyterlab~=4.0
pip install import-linter~=1.12.1
- name: Install Jupyter AI packages from source
run: |
source ./lint_py_imports/bin/activate
jlpm install
jlpm install-from-src
- name: Lint Python imports
run: |
source ./lint_py_imports/bin/activate
lint-imports
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
Field,
MultiEnvAuthStrategy,
)
from langchain.pydantic_v1 import BaseModel, Extra
from langchain_community.embeddings import (
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
QianfanEmbeddingsEndpoint,
)
from pydantic import BaseModel, ConfigDict


class BaseEmbeddingsProvider(BaseModel):
"""Base class for embedding providers"""

class Config:
extra = Extra.allow
# pydantic v2 model config
# upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra
model_config = ConfigDict(extra="allow")

id: ClassVar[str] = ...
"""ID for this provider class."""
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def handle_error(self, args: ErrorArgs):

prompt = f"Explain the following error:\n\n{last_error}"
# Set CellArgs based on ErrorArgs
values = args.dict()
values = args.model_dump()
values["type"] = "root"
cell_args = CellArgs(**values)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Literal, Optional

from langchain.pydantic_v1 import BaseModel
from pydantic import BaseModel


class InlineCompletionRequest(BaseModel):
Expand All @@ -21,12 +21,12 @@ class InlineCompletionRequest(BaseModel):
# whether to stream the response (if supported by the model)
stream: bool
# path to the notebook of file for which the completions are generated
path: Optional[str]
path: Optional[str] = None
# language inferred from the document mime type (if possible)
language: Optional[str]
language: Optional[str] = None
# identifier of the cell for which the completions are generated if in a notebook
# previous cells and following cells can be used to learn the wider context
cell_id: Optional[str]
cell_id: Optional[str] = None


class InlineCompletionItem(BaseModel):
Expand All @@ -36,9 +36,9 @@ class InlineCompletionItem(BaseModel):
"""

insertText: str
filterText: Optional[str]
isIncomplete: Optional[bool]
token: Optional[str]
filterText: Optional[str] = None
isIncomplete: Optional[bool] = None
token: Optional[str] = None


class CompletionError(BaseModel):
Expand All @@ -59,7 +59,7 @@ class InlineCompletionReply(BaseModel):
list: InlineCompletionList
# number of request for which we are replying
reply_to: int
error: Optional[CompletionError]
error: Optional[CompletionError] = None


class InlineCompletionStreamChunk(BaseModel):
Expand All @@ -69,7 +69,7 @@ class InlineCompletionStreamChunk(BaseModel):
response: InlineCompletionItem
reply_to: int
done: bool
error: Optional[CompletionError]
error: Optional[CompletionError] = None


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.pydantic_v1 import BaseModel
from pydantic import BaseModel


class Persona(BaseModel):
Expand Down
20 changes: 10 additions & 10 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Literal, Optional, get_args

import click
from langchain.pydantic_v1 import BaseModel
from pydantic import BaseModel

FORMAT_CHOICES_TYPE = Literal[
"code", "html", "image", "json", "markdown", "math", "md", "text"
Expand Down Expand Up @@ -46,23 +46,23 @@ class CellArgs(BaseModel):
type: Literal["root"] = "root"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str]
model_parameters: Optional[str] = None
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
response_path: Optional[str]
region_name: Optional[str] = None
request_schema: Optional[str] = None
response_path: Optional[str] = None


# Should match CellArgs
class ErrorArgs(BaseModel):
type: Literal["error"] = "error"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str]
model_parameters: Optional[str] = None
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
response_path: Optional[str]
region_name: Optional[str] = None
request_schema: Optional[str] = None
response_path: Optional[str] = None


class HelpArgs(BaseModel):
Expand All @@ -75,7 +75,7 @@ class VersionArgs(BaseModel):

class ListArgs(BaseModel):
type: Literal["list"] = "list"
provider_id: Optional[str]
provider_id: Optional[str] = None


class RegisterArgs(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from jupyter_ai_magics import BaseProvider
from jupyter_ai_magics.providers import EnvAuthStrategy, TextField
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import get_from_dict_or_env
from langchain_openai import ChatOpenAI


Expand Down Expand Up @@ -31,7 +30,9 @@ class OpenRouterProvider(BaseProvider, ChatOpenRouter):
]

def __init__(self, **kwargs):
openrouter_api_key = kwargs.pop("openrouter_api_key", None)
openrouter_api_key = get_from_dict_or_env(
kwargs, key="openrouter_api_key", env_key="OPENROUTER_API_KEY", default=None
)
openrouter_api_base = kwargs.pop(
"openai_api_base", "https://openrouter.ai/api/v1"
)
Expand All @@ -42,14 +43,6 @@ def __init__(self, **kwargs):
**kwargs,
)

@root_validator(pre=False, skip_on_failure=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "openai_api_key", "OPENROUTER_API_KEY")
)
return super().validate_environment(values)

@classmethod
def is_api_key_exc(cls, e: Exception):
import openai
Expand Down
82 changes: 21 additions & 61 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,14 @@
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.pydantic_v1 import BaseModel, Extra
from langchain.schema import LLMResult
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.llms import AI21, GPT4All, HuggingFaceEndpoint, Together
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

# this is necessary because `langchain.pydantic_v1.main` does not include
# `ModelMetaclass`, as it is not listed in `__all__` by the `pydantic.main`
# subpackage.
try:
from pydantic.v1.main import ModelMetaclass
except:
from pydantic.main import ModelMetaclass
from pydantic import BaseModel, ConfigDict

from . import completion_utils as completion
from .models.completion import (
Expand Down Expand Up @@ -122,7 +114,7 @@ class EnvAuthStrategy(BaseModel):
name: str
"""The name of the environment variable, e.g. `'ANTHROPIC_API_KEY'`."""

keyword_param: Optional[str]
keyword_param: Optional[str] = None
"""
If unset (default), the authentication token is provided as a keyword
argument with the parameter equal to the environment variable name in
Expand Down Expand Up @@ -177,51 +169,10 @@ class IntegerField(BaseModel):
Field = Union[TextField, MultilineTextField, IntegerField]


class ProviderMetaclass(ModelMetaclass):
"""
A metaclass that ensures all class attributes defined inline within the
class definition are accessible and included in `Class.__dict__`.
This is necessary because Pydantic drops any ClassVars that are defined as
an instance field by a parent class, even if they are defined inline within
the class definition. We encountered this case when `langchain` added a
`name` attribute to a parent class shared by all `Provider`s, which caused
`Provider.name` to be inaccessible. See #558 for more info.
"""

def __new__(mcs, name, bases, namespace, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
for key in namespace:
# skip private class attributes
if key.startswith("_"):
continue
# skip class attributes already listed in `cls.__dict__`
if key in cls.__dict__:
continue

setattr(cls, key, namespace[key])

return cls

@property
def server_settings(cls):
return cls._server_settings

@server_settings.setter
def server_settings(cls, value):
if cls._server_settings is not None:
raise AttributeError("'server_settings' attribute was already set")
cls._server_settings = value

_server_settings = None


class BaseProvider(BaseModel, metaclass=ProviderMetaclass):
#
# pydantic config
#
class Config:
extra = Extra.allow
class BaseProvider(BaseModel):
# pydantic v2 model config
# upstream docs: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.extra
model_config = ConfigDict(extra="allow")

#
# class attrs
Expand All @@ -236,15 +187,25 @@ class Config:
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

help: ClassVar[str] = None
help: ClassVar[Optional[str]] = None
"""Text to display in lieu of a model list for a registry provider that does
not provide a list of models."""

model_id_key: ClassVar[str] = ...
"""Kwarg expected by the upstream LangChain provider."""
model_id_key: ClassVar[Optional[str]] = None
"""
Optional field which specifies the key under which `model_id` is passed to
the parent LangChain class.
If unset, this defaults to "model_id".
"""

model_id_label: ClassVar[str] = ""
"""Human-readable label of the model ID."""
model_id_label: ClassVar[Optional[str]] = None
"""
Optional field which sets the label shown in the UI allowing users to
select/type a model ID.
If unset, the label shown in the UI defaults to "Model ID".
"""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""
Expand Down Expand Up @@ -586,7 +547,6 @@ def __init__(self, **kwargs):

id = "gpt4all"
name = "GPT4All"
docs = "https://docs.gpt4all.io/gpt4all_python.html"
models = [
"ggml-gpt4all-j-v1.2-jazzy",
"ggml-gpt4all-j-v1.3-groovy",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import ClassVar, Optional

from pydantic import BaseModel

from ..providers import BaseProvider


def test_provider_classvars():
"""
Asserts that class attributes are not omitted due to parent classes defining
an instance field of the same name. This was a bug present in Pydantic v1,
which led to an issue documented in #558.
This bug is fixed as of `pydantic==2.10.2`, but we will keep this test in
case this behavior changes in future releases.
"""

class Parent(BaseModel):
test: Optional[str] = None

class Base(BaseModel):
test: ClassVar[str]

class Child(Base, Parent):
test: ClassVar[str] = "expected"

assert Child.test == "expected"
Loading

0 comments on commit 26593dc

Please sign in to comment.