From c69d3b5bd9fcd7053fdae8febb68f9f316cbf299 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 24 Jan 2025 12:24:50 +0100 Subject: [PATCH] core: Add ruff rules D102 (missing docstring) --- .../langchain_core/beta/runnables/context.py | 38 ++++++++ libs/core/langchain_core/caches.py | 3 +- .../document_loaders/langsmith.py | 5 +- libs/core/langchain_core/documents/base.py | 5 ++ libs/core/langchain_core/embeddings/fake.py | 5 ++ .../language_models/chat_models.py | 56 ++++++++++++ .../langchain_core/language_models/fake.py | 8 ++ .../language_models/fake_chat_models.py | 11 +++ .../langchain_core/language_models/llms.py | 12 +++ libs/core/langchain_core/load/load.py | 1 + libs/core/langchain_core/load/serializable.py | 1 + libs/core/langchain_core/messages/base.py | 2 + libs/core/langchain_core/messages/tool.py | 1 + .../langchain_core/output_parsers/base.py | 4 + .../langchain_core/output_parsers/list.py | 3 + libs/core/langchain_core/prompts/prompt.py | 5 +- libs/core/langchain_core/runnables/base.py | 90 +++++++++++++++++-- libs/core/langchain_core/runnables/branch.py | 5 +- .../langchain_core/runnables/configurable.py | 19 +++- .../langchain_core/runnables/fallbacks.py | 13 ++- libs/core/langchain_core/runnables/history.py | 4 +- .../langchain_core/runnables/passthrough.py | 34 ++++++- libs/core/langchain_core/runnables/retry.py | 7 +- libs/core/langchain_core/runnables/router.py | 11 ++- libs/core/langchain_core/tools/base.py | 6 ++ libs/core/langchain_core/tools/convert.py | 11 +-- libs/core/langchain_core/tools/simple.py | 2 + libs/core/langchain_core/tools/structured.py | 2 + libs/core/langchain_core/tracers/base.py | 19 ++++ libs/core/langchain_core/utils/iter.py | 1 + libs/core/langchain_core/utils/json.py | 5 +- libs/core/langchain_core/utils/loading.py | 1 + libs/core/langchain_core/utils/mustache.py | 1 + libs/core/langchain_core/vectorstores/base.py | 9 +- .../langchain_core/vectorstores/in_memory.py | 47 ++++++++++ libs/core/pyproject.toml | 5 +- libs/core/tests/unit_tests/test_tools.py | 9 +- 37 files changed, 425 insertions(+), 36 deletions(-) diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 2be721387cbcd..9aa8610dc7ac2 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -13,6 +13,7 @@ ) from pydantic import ConfigDict +from typing_extensions import override from langchain_core._api.beta_decorator import beta from langchain_core.runnables.base import ( @@ -166,6 +167,7 @@ def __str__(self) -> str: @property def ids(self) -> list[str]: + """The context getter ids.""" prefix = self.prefix + "/" if self.prefix else "" keys = self.key if isinstance(self.key, list) else [self.key] return [ @@ -174,6 +176,7 @@ def ids(self) -> list[str]: ] @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: return super().config_specs + [ ConfigurableFieldSpec( @@ -183,6 +186,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: for id_ in self.ids ] + @override def invoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: @@ -193,6 +197,7 @@ def invoke( else: return configurable[self.ids[0]]() + @override async def ainvoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: @@ -253,6 +258,7 @@ def __str__(self) -> str: @property def ids(self) -> list[str]: + """The context setter ids.""" prefix = self.prefix + "/" if self.prefix else "" return [ f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}" @@ -260,6 +266,7 @@ def ids(self) -> list[str]: ] @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: mapper_config_specs = [ s @@ -281,6 +288,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: for id_ in self.ids ] + @override def invoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: @@ -293,6 +301,7 @@ def invoke( configurable[id_](input) return input + @override async def ainvoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: @@ -361,6 +370,11 @@ def create_scope(scope: str, /) -> "PrefixContext": @staticmethod def getter(key: Union[str, list[str]], /) -> ContextGet: + """Return a context getter. + + Args: + key: The context getter key. + """ return ContextGet(key=key) @staticmethod @@ -370,6 +384,13 @@ def setter( /, **kwargs: SetValue, ) -> ContextSet: + """Return a context setter. + + Args: + _key: The context setter key. + _value: The context setter value. + **kwargs: Additional context setter key-value pairs. + """ return ContextSet(_key, _value, prefix="", **kwargs) @@ -379,9 +400,19 @@ class PrefixContext: prefix: str = "" def __init__(self, prefix: str = ""): + """Initialize the prefix context. + + Args: + prefix: The prefix. + """ self.prefix = prefix def getter(self, key: Union[str, list[str]], /) -> ContextGet: + """Return a prefixed context getter. + + Args: + key: The context getter key. + """ return ContextGet(key=key, prefix=self.prefix) def setter( @@ -391,6 +422,13 @@ def setter( /, **kwargs: SetValue, ) -> ContextSet: + """Return a prefixed context setter. + + Args: + _key: The context setter key. + _value: The context setter value. + **kwargs: Additional context setter key-value pairs. + """ return ContextSet(_key, _value, prefix=self.prefix, **kwargs) diff --git a/libs/core/langchain_core/caches.py b/libs/core/langchain_core/caches.py index d534d70d25ffc..bc0bc1c041ca5 100644 --- a/libs/core/langchain_core/caches.py +++ b/libs/core/langchain_core/caches.py @@ -1,5 +1,4 @@ -""" -.. warning:: +""".. warning:: Beta Feature! **Cache** provides an optional caching layer for LLMs. diff --git a/libs/core/langchain_core/document_loaders/langsmith.py b/libs/core/langchain_core/document_loaders/langsmith.py index 39fda02af5792..165c34b1e417d 100644 --- a/libs/core/langchain_core/document_loaders/langsmith.py +++ b/libs/core/langchain_core/document_loaders/langsmith.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Optional, Union from langsmith import Client as LangSmithClient +from typing_extensions import override from langchain_core.document_loaders.base import BaseLoader from langchain_core.documents import Document @@ -53,7 +54,8 @@ def __init__( client: Optional[LangSmithClient] = None, **client_kwargs: Any, ) -> None: - """ + """Initialize a LangSmith loader. + Args: dataset_id: The ID of the dataset to filter by. Defaults to None. dataset_name: The name of the dataset to filter by. Defaults to None. @@ -95,6 +97,7 @@ def __init__( self.metadata = metadata self.filter = filter + @override def lazy_load(self) -> Iterator[Document]: for example in self._client.list_examples( dataset_id=self.dataset_id, diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index 2adfe1a718397..33143b16cc101 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -43,6 +43,11 @@ class BaseMedia(Serializable): @field_validator("id", mode="before") def cast_id_to_str(cls, id_value: Any) -> Optional[str]: + """Coerce the id field to a string. + + Args: + id_value: The id value to coerce. + """ if id_value is not None: return str(id_value) else: diff --git a/libs/core/langchain_core/embeddings/fake.py b/libs/core/langchain_core/embeddings/fake.py index 6f7c4241d54b2..d11ea9bd30d19 100644 --- a/libs/core/langchain_core/embeddings/fake.py +++ b/libs/core/langchain_core/embeddings/fake.py @@ -4,6 +4,7 @@ import hashlib from pydantic import BaseModel +from typing_extensions import override from langchain_core.embeddings import Embeddings @@ -55,9 +56,11 @@ def _get_embedding(self) -> list[float]: return list(np.random.default_rng().normal(size=self.size)) + @override def embed_documents(self, texts: list[str]) -> list[list[float]]: return [self._get_embedding() for _ in texts] + @override def embed_query(self, text: str) -> list[float]: return self._get_embedding() @@ -116,8 +119,10 @@ def _get_seed(self, text: str) -> int: """Get a seed for the random generator, using the hash of the text.""" return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8 + @override def embed_documents(self, texts: list[str]) -> list[list[float]]: return [self._get_embedding(seed=self._get_seed(_)) for _ in texts] + @override def embed_query(self, text: str) -> list[float]: return self._get_embedding(seed=self._get_seed(text)) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 6aaaf7d4ca80a..d55b3517fe940 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -270,6 +270,7 @@ def _convert_input(self, input: LanguageModelInput) -> PromptValue: ) raise ValueError(msg) # noqa: TRY004 + @override def invoke( self, input: LanguageModelInput, @@ -293,6 +294,7 @@ def invoke( ).generations[0][0], ).message + @override async def ainvoke( self, input: LanguageModelInput, @@ -349,6 +351,7 @@ def _should_stream( handlers = run_manager.handlers if run_manager else [] return any(isinstance(h, _StreamingCallbackHandler) for h in handlers) + @override def stream( self, input: LanguageModelInput, @@ -423,6 +426,7 @@ def stream( run_manager.on_llm_end(LLMResult(generations=[[generation]])) + @override async def astream( self, input: LanguageModelInput, @@ -780,6 +784,7 @@ async def agenerate( ] return output + @override def generate_prompt( self, prompts: list[PromptValue], @@ -790,6 +795,7 @@ def generate_prompt( prompt_messages = [p.to_messages() for p in prompts] return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) + @override async def agenerate_prompt( self, prompts: list[PromptValue], @@ -1019,6 +1025,20 @@ def __call__( callbacks: Callbacks = None, **kwargs: Any, ) -> BaseMessage: + """Call the model. + + Args: + messages: List of messages. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + callbacks: Callbacks to pass through. Used for executing additional + functionality, such as logging or streaming, throughout generation. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The model output message. + """ generation = self.generate( [messages], stop=stop, callbacks=callbacks, **kwargs ).generations[0][0] @@ -1049,12 +1069,37 @@ async def _call_async( def call_as_llm( self, message: str, stop: Optional[list[str]] = None, **kwargs: Any ) -> str: + """Call the model. + + Args: + message: The input message. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The model output string. + """ return self.predict(message, stop=stop, **kwargs) @deprecated("0.1.7", alternative="invoke", removal="1.0") + @override def predict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: + """Predict the next message. + + Args: + text: The input message. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The predicted output string. + """ _stop = None if stop is None else list(stop) result = self([HumanMessage(content=text)], stop=_stop, **kwargs) if isinstance(result.content, str): @@ -1064,6 +1109,7 @@ def predict( raise ValueError(msg) # noqa: TRY004 @deprecated("0.1.7", alternative="invoke", removal="1.0") + @override def predict_messages( self, messages: list[BaseMessage], @@ -1075,6 +1121,7 @@ def predict_messages( return self(messages, stop=_stop, **kwargs) @deprecated("0.1.7", alternative="ainvoke", removal="1.0") + @override async def apredict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: @@ -1089,6 +1136,7 @@ async def apredict( raise ValueError(msg) # noqa: TRY004 @deprecated("0.1.7", alternative="ainvoke", removal="1.0") + @override async def apredict_messages( self, messages: list[BaseMessage], @@ -1117,6 +1165,14 @@ def bind_tools( ], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tools to the model. + + Args: + tools: Sequence of tools to bind to the model. + + Returns: + A Runnable that returns a message. + """ raise NotImplementedError def with_structured_output( diff --git a/libs/core/langchain_core/language_models/fake.py b/libs/core/langchain_core/language_models/fake.py index 64bb637068dab..3a71073c0b08c 100644 --- a/libs/core/langchain_core/language_models/fake.py +++ b/libs/core/langchain_core/language_models/fake.py @@ -3,6 +3,8 @@ from collections.abc import AsyncIterator, Iterator, Mapping from typing import Any, Optional +from typing_extensions import override + from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -31,10 +33,12 @@ class FakeListLLM(LLM): """ @property + @override def _llm_type(self) -> str: """Return type of llm.""" return "fake-list" + @override def _call( self, prompt: str, @@ -50,6 +54,7 @@ def _call( self.i = 0 return response + @override async def _acall( self, prompt: str, @@ -66,6 +71,7 @@ async def _acall( return response @property + @override def _identifying_params(self) -> Mapping[str, Any]: return {"responses": self.responses} @@ -86,6 +92,7 @@ class FakeStreamingListLLM(FakeListLLM): error_on_chunk_number: Optional[int] = None """If set, will raise an exception on the specified chunk number.""" + @override def stream( self, input: LanguageModelInput, @@ -106,6 +113,7 @@ def stream( raise FakeListLLMError yield c + @override async def astream( self, input: LanguageModelInput, diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 9bd62f1267ef7..1f9cd9fab4c50 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -6,6 +6,8 @@ from collections.abc import AsyncIterator, Iterator from typing import Any, Optional, Union, cast +from typing_extensions import override + from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -26,6 +28,7 @@ class FakeMessagesListChatModel(BaseChatModel): i: int = 0 """Internally incremented after every model invocation.""" + @override def _generate( self, messages: list[BaseMessage], @@ -42,6 +45,7 @@ def _generate( return ChatResult(generations=[generation]) @property + @override def _llm_type(self) -> str: return "fake-messages-list-chat-model" @@ -62,9 +66,11 @@ class FakeListChatModel(SimpleChatModel): """Internally incremented after every model invocation.""" @property + @override def _llm_type(self) -> str: return "fake-list-chat-model" + @override def _call( self, messages: list[BaseMessage], @@ -80,6 +86,7 @@ def _call( self.i = 0 return response + @override def _stream( self, messages: list[BaseMessage], @@ -103,6 +110,7 @@ def _stream( yield ChatGenerationChunk(message=AIMessageChunk(content=c)) + @override async def _astream( self, messages: list[BaseMessage], @@ -126,9 +134,11 @@ async def _astream( yield ChatGenerationChunk(message=AIMessageChunk(content=c)) @property + @override def _identifying_params(self) -> dict[str, Any]: return {"responses": self.responses} + @override # manually override batch to preserve batch ordering with no concurrency def batch( self, @@ -142,6 +152,7 @@ def batch( return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)] return [self.invoke(m, config, **kwargs) for m in inputs] + @override async def abatch( self, inputs: list[Any], diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 4ba16f516965d..a383d2dad137c 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -374,6 +374,7 @@ def _get_ls_params( return ls_params + @override def invoke( self, input: LanguageModelInput, @@ -398,6 +399,7 @@ def invoke( .text ) + @override async def ainvoke( self, input: LanguageModelInput, @@ -419,6 +421,7 @@ async def ainvoke( ) return llm_result.generations[0][0].text + @override def batch( self, inputs: list[LanguageModelInput], @@ -466,6 +469,7 @@ def batch( ) ] + @override async def abatch( self, inputs: list[LanguageModelInput], @@ -512,6 +516,7 @@ async def abatch( ) ] + @override def stream( self, input: LanguageModelInput, @@ -578,6 +583,7 @@ def stream( run_manager.on_llm_end(LLMResult(generations=[[generation]])) + @override async def astream( self, input: LanguageModelInput, @@ -749,6 +755,7 @@ async def _astream( break yield item # type: ignore[misc] + @override def generate_prompt( self, prompts: list[PromptValue], @@ -759,6 +766,7 @@ def generate_prompt( prompt_strings = [p.to_string() for p in prompts] return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs) + @override async def agenerate_prompt( self, prompts: list[PromptValue], @@ -1329,6 +1337,7 @@ async def _call_async( return result.generations[0][0].text @deprecated("0.1.7", alternative="invoke", removal="1.0") + @override def predict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: @@ -1336,6 +1345,7 @@ def predict( return self(text, stop=_stop, **kwargs) @deprecated("0.1.7", alternative="invoke", removal="1.0") + @override def predict_messages( self, messages: list[BaseMessage], @@ -1349,6 +1359,7 @@ def predict_messages( return AIMessage(content=content) @deprecated("0.1.7", alternative="ainvoke", removal="1.0") + @override async def apredict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: @@ -1356,6 +1367,7 @@ async def apredict( return await self._call_async(text, stop=_stop, **kwargs) @deprecated("0.1.7", alternative="ainvoke", removal="1.0") + @override async def apredict_messages( self, messages: list[BaseMessage], diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index ff991789f4528..606eb4d8ade74 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -86,6 +86,7 @@ def __init__( ) def __call__(self, value: dict[str, Any]) -> Any: + """Revive the value.""" if ( value.get("lc") == 1 and value.get("type") == "secret" diff --git a/libs/core/langchain_core/load/serializable.py b/libs/core/langchain_core/load/serializable.py index 7655438be97af..6bf5a91b813c8 100644 --- a/libs/core/langchain_core/load/serializable.py +++ b/libs/core/langchain_core/load/serializable.py @@ -269,6 +269,7 @@ def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: } def to_json_not_implemented(self) -> SerializedNotImplemented: + """Serialize a "not implemented" object.""" return to_json_not_implemented(self) diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 9eab1ed431af2..1b92143ef8d55 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -59,6 +59,7 @@ class BaseMessage(Serializable): @field_validator("id", mode="before") def cast_id_to_str(cls, id_value: Any) -> Optional[str]: + """Coerce the id field to a string.""" if id_value is not None: return str(id_value) else: @@ -116,6 +117,7 @@ def pretty_repr(self, html: bool = False) -> str: return f"{title}\n\n{self.content}" def pretty_print(self) -> None: + """Print a pretty representation of the message.""" print(self.pretty_repr(html=is_interactive_env())) # noqa: T201 diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 5c14ae045af3e..2cf0b8dad9966 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -96,6 +96,7 @@ def get_lc_namespace(cls) -> list[str]: @model_validator(mode="before") @classmethod def coerce_args(cls, values: dict) -> dict: + """Coerce the model arguments to the correct types.""" content = values["content"] if isinstance(content, tuple): content = list(content) diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 9d080cef300bc..c7f5af86a775d 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -79,6 +79,7 @@ def OutputType(self) -> type[T]: # it is good enough for pydantic to build the schema from return T # type: ignore[misc] + @override def invoke( self, input: Union[str, BaseMessage], @@ -102,6 +103,7 @@ def invoke( run_type="parser", ) + @override async def ainvoke( self, input: Union[str, BaseMessage], @@ -183,6 +185,7 @@ def OutputType(self) -> type[T]: ) raise TypeError(msg) + @override def invoke( self, input: Union[str, BaseMessage], @@ -206,6 +209,7 @@ def invoke( run_type="parser", ) + @override async def ainvoke( self, input: Union[str, BaseMessage], diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index bedbdf47b7aa0..76f72a41c01e1 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -9,6 +9,8 @@ from typing import Optional as Optional from typing import TypeVar, Union +from typing_extensions import override + from langchain_core.messages import BaseMessage from langchain_core.output_parsers.transform import BaseTransformOutputParser @@ -186,6 +188,7 @@ class NumberedListOutputParser(ListOutputParser): pattern: str = r"\d+\.\s([^\n]+)" """The pattern to match a numbered list item.""" + @override def get_format_instructions(self) -> str: return ( "Your response should be a numbered list with each item on a new line. " diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 37f7eda64acff..9db6f9c74834b 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -7,6 +7,7 @@ from typing import Any, Optional, Union from pydantic import BaseModel, model_validator +from typing_extensions import override from langchain_core.prompts.string import ( DEFAULT_FORMATTER_MAPPING, @@ -56,14 +57,15 @@ class PromptTemplate(StringPromptTemplate): """ @property + @override def lc_attributes(self) -> dict[str, Any]: return { "template_format": self.template_format, } @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "prompts", "prompt"] template: str @@ -114,6 +116,7 @@ def pre_init_validation(cls, values: dict) -> Any: return values + @override def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: """Get the input schema for the prompt. diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index bf805c3a4f07c..a3a6d4d0f545c 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2455,6 +2455,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): protected_namespaces=(), ) + @override def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: """Serialize the Runnable to JSON. @@ -2776,8 +2777,8 @@ def __init__( ) @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property @@ -2790,6 +2791,7 @@ def steps(self) -> list[Runnable[Any, Any]]: return [self.first] + self.middle + [self.last] @classmethod + @override def is_lc_serializable(cls) -> bool: """Check if the object is serializable. @@ -2815,6 +2817,7 @@ def OutputType(self) -> type[Output]: """The type of the output of the Runnable.""" return self.last.OutputType + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -2828,6 +2831,7 @@ def get_input_schema( """ return _seq_input_schema(self.steps, config) + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -2842,6 +2846,7 @@ def get_output_schema( return _seq_output_schema(self.steps, config) @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: """Get the config specs of the Runnable. @@ -2892,6 +2897,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: return get_unique_config_specs(spec for spec, _ in all_specs) + @override def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: """Get the graph representation of the Runnable. @@ -2985,6 +2991,7 @@ def __ror__( name=self.name, ) + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -3022,6 +3029,7 @@ def invoke( run_manager.on_chain_end(input) return cast(Output, input) + @override async def ainvoke( self, input: Input, @@ -3066,6 +3074,7 @@ async def ainvoke( await run_manager.on_chain_end(input) return cast(Output, input) + @override def batch( self, inputs: list[Input], @@ -3193,6 +3202,7 @@ def batch( else: raise first_exception + @override async def abatch( self, inputs: list[Input], @@ -3379,6 +3389,7 @@ async def _atransform( async for output in final_pipeline: yield output + @override def transform( self, input: Iterator[Input], @@ -3392,6 +3403,7 @@ def transform( **kwargs, ) + @override def stream( self, input: Input, @@ -3400,6 +3412,7 @@ def stream( ) -> Iterator[Output]: yield from self.transform(iter([input]), config, **kwargs) + @override async def atransform( self, input: AsyncIterator[Input], @@ -3414,6 +3427,7 @@ async def atransform( ): yield chunk + @override async def astream( self, input: Input, @@ -3534,18 +3548,20 @@ def __init__( ) @classmethod + @override def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] model_config = ConfigDict( arbitrary_types_allowed=True, ) + @override def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: @@ -3571,6 +3587,7 @@ def InputType(self) -> Any: return Any + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -3600,6 +3617,7 @@ def get_input_schema( return super().get_input_schema(config) + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -3615,6 +3633,7 @@ def get_output_schema( return create_model_v2(self.get_name("Output"), field_definitions=fields) @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: """Get the config specs of the Runnable. @@ -3625,6 +3644,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: spec for step in self.steps__.values() for spec in step.config_specs ) + @override def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: """Get the graph representation of the Runnable. @@ -3668,6 +3688,7 @@ def __repr__(self) -> str: ) return "{\n " + map_for_repr + "\n}" + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> dict[str, Any]: @@ -3727,6 +3748,7 @@ def _invoke_step( run_manager.on_chain_end(output) return output + @override async def ainvoke( self, input: Input, @@ -3832,6 +3854,7 @@ def _transform( except StopIteration: pass + @override def transform( self, input: Iterator[Input], @@ -3842,6 +3865,7 @@ def transform( input, self._transform, config, **kwargs ) + @override def stream( self, input: Input, @@ -3901,6 +3925,7 @@ async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]: except StopAsyncIteration: pass + @override async def atransform( self, input: AsyncIterator[Input], @@ -3912,6 +3937,7 @@ async def atransform( ): yield chunk + @override async def astream( self, input: Input, @@ -4072,6 +4098,7 @@ def InputType(self) -> Any: except ValueError: return Any + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -4113,6 +4140,7 @@ def OutputType(self) -> Any: except ValueError: return Any + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -4153,6 +4181,7 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: return f"RunnableGenerator({self.name})" + @override def transform( self, input: Iterator[Input], @@ -4169,6 +4198,7 @@ def transform( **kwargs, # type: ignore[arg-type] ) + @override def stream( self, input: Input, @@ -4177,6 +4207,7 @@ def stream( ) -> Iterator[Output]: return self.transform(iter([input]), config, **kwargs) + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -4185,6 +4216,7 @@ def invoke( final = output if final is None else final + output # type: ignore[operator] return cast(Output, final) + @override def atransform( self, input: AsyncIterator[Input], @@ -4199,6 +4231,7 @@ def atransform( input, self._atransform, config, **kwargs ) + @override def astream( self, input: Input, @@ -4210,6 +4243,7 @@ async def input_aiter() -> AsyncIterator[Input]: return self.atransform(input_aiter(), config, **kwargs) + @override async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -4364,6 +4398,7 @@ def InputType(self) -> Any: except ValueError: return Any + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -4433,6 +4468,7 @@ def OutputType(self) -> Any: except ValueError: return Any + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -4483,11 +4519,13 @@ def deps(self) -> list[Runnable]: return deps @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: return get_unique_config_specs( spec for dep in self.deps for spec in dep.config_specs ) + @override def get_graph(self, config: RunnableConfig | None = None) -> Graph: if deps := self.deps: graph = Graph() @@ -4690,6 +4728,7 @@ def _config( ) -> RunnableConfig: return ensure_config(config) + @override def invoke( self, input: Input, @@ -4723,6 +4762,7 @@ def invoke( ) raise TypeError(msg) + @override async def ainvoke( self, input: Input, @@ -4809,6 +4849,7 @@ def _transform( # Otherwise, just yield it yield cast(Output, output) + @override def transform( self, input: Iterator[Input], @@ -4829,6 +4870,7 @@ def transform( ) raise TypeError(msg) + @override def stream( self, input: Input, @@ -4932,6 +4974,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def] # Otherwise, just yield it yield cast(Output, output) + @override async def atransform( self, input: AsyncIterator[Input], @@ -4946,6 +4989,7 @@ async def atransform( ): yield output + @override async def astream( self, input: Input, @@ -4979,6 +5023,7 @@ class RunnableEachBase(RunnableSerializable[list[Input], list[Output]]): def InputType(self) -> Any: return list[self.bound.InputType] # type: ignore[name-defined] + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -5003,6 +5048,7 @@ def get_input_schema( def OutputType(self) -> type[list[Output]]: return list[self.bound.OutputType] # type: ignore[name-defined] + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -5021,19 +5067,22 @@ def get_output_schema( ) @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: return self.bound.config_specs + @override def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: return self.bound.get_graph(config) @classmethod + @override def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] def _invoke( @@ -5048,6 +5097,7 @@ def _invoke( ] return self.bound.batch(inputs, configs, **kwargs) + @override def invoke( self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any ) -> list[Output]: @@ -5065,11 +5115,13 @@ async def _ainvoke( ] return await self.bound.abatch(inputs, configs, **kwargs) + @override async def ainvoke( self, input: list[Input], config: Optional[RunnableConfig] = None, **kwargs: Any ) -> list[Output]: return await self._acall_with_config(self._ainvoke, input, config, **kwargs) + @override async def astream_events( self, input: Input, @@ -5111,24 +5163,28 @@ class RunnableEach(RunnableEachBase[Input, Output]): """ @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + @override def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: name = name or self.name or f"RunnableEach<{self.bound.get_name()}>" return super().get_name(suffix, name=name) + @override def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: return RunnableEach(bound=self.bound.bind(**kwargs)) + @override def with_config( self, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> RunnableEach[Input, Output]: return RunnableEach(bound=self.bound.with_config(config, **kwargs)) + @override def with_listeners( self, *, @@ -5286,6 +5342,7 @@ def __init__( # fields even though total=False on the typed dict. self.config = config or {} + @override def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: @@ -5309,6 +5366,7 @@ def OutputType(self) -> type[Output]: else self.bound.OutputType ) + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -5316,6 +5374,7 @@ def get_input_schema( return super().get_input_schema(config) return self.bound.get_input_schema(merge_configs(self.config, config)) + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -5324,25 +5383,29 @@ def get_output_schema( return self.bound.get_output_schema(merge_configs(self.config, config)) @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: return self.bound.config_specs + @override def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: return self.bound.get_graph(self._merge_configs(config)) @classmethod + @override def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = merge_configs(self.config, *configs) return merge_configs(config, *(f(config) for f in self.config_factories)) + @override def invoke( self, input: Input, @@ -5355,6 +5418,7 @@ def invoke( **{**self.kwargs, **kwargs}, ) + @override async def ainvoke( self, input: Input, @@ -5367,6 +5431,7 @@ async def ainvoke( **{**self.kwargs, **kwargs}, ) + @override def batch( self, inputs: list[Input], @@ -5389,6 +5454,7 @@ def batch( **{**self.kwargs, **kwargs}, ) + @override async def abatch( self, inputs: list[Input], @@ -5431,6 +5497,7 @@ def batch_as_completed( **kwargs: Any, ) -> Iterator[tuple[int, Union[Output, Exception]]]: ... + @override def batch_as_completed( self, inputs: Sequence[Input], @@ -5482,6 +5549,7 @@ def abatch_as_completed( **kwargs: Optional[Any], ) -> AsyncIterator[tuple[int, Union[Output, Exception]]]: ... + @override async def abatch_as_completed( self, inputs: Sequence[Input], @@ -5514,6 +5582,7 @@ async def abatch_as_completed( ): yield item + @override def stream( self, input: Input, @@ -5526,6 +5595,7 @@ def stream( **{**self.kwargs, **kwargs}, ) + @override async def astream( self, input: Input, @@ -5539,6 +5609,7 @@ async def astream( ): yield item + @override async def astream_events( self, input: Input, @@ -5550,6 +5621,7 @@ async def astream_events( ): yield item + @override def transform( self, input: Iterator[Input], @@ -5562,6 +5634,7 @@ def transform( **{**self.kwargs, **kwargs}, ) + @override async def atransform( self, input: AsyncIterator[Input], @@ -5628,10 +5701,11 @@ class RunnableBinding(RunnableBindingBase[Input, Output]): """ @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + @override def bind(self, **kwargs: Any) -> Runnable[Input, Output]: """Bind additional kwargs to a Runnable, returning a new Runnable. @@ -5650,6 +5724,7 @@ def bind(self, **kwargs: Any) -> Runnable[Input, Output]: custom_output_type=self.custom_output_type, ) + @override def with_config( self, config: Optional[RunnableConfig] = None, @@ -5664,6 +5739,7 @@ def with_config( custom_output_type=self.custom_output_type, ) + @override def with_listeners( self, *, @@ -5714,6 +5790,7 @@ def with_listeners( custom_output_type=self.custom_output_type, ) + @override def with_types( self, input_type: Optional[Union[type[Input], BaseModel]] = None, @@ -5731,6 +5808,7 @@ def with_types( ), ) + @override def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( bound=self.bound.with_retry(**kwargs), diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 56c438861899e..0dd7f545177ac 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -8,6 +8,7 @@ ) from pydantic import BaseModel, ConfigDict +from typing_extensions import override from langchain_core.runnables.base import ( Runnable, @@ -144,10 +145,11 @@ def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -167,6 +169,7 @@ def get_input_schema( return super().get_input_schema(config) @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: from langchain_core.beta.runnables.context import ( CONTEXT_CONFIG_PREFIX, diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index b59d0239fb1f3..1ea2ff6ac0800 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -60,12 +60,13 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): ) @classmethod + @override def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property @@ -78,22 +79,26 @@ def InputType(self) -> type[Input]: def OutputType(self) -> type[Output]: return self.default.OutputType + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: runnable, config = self.prepare(config) return runnable.get_input_schema(config) + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: runnable, config = self.prepare(config) return runnable.get_output_schema(config) + @override def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: runnable, config = self.prepare(config) return runnable.get_graph(config) + @override def with_config( self, config: Optional[RunnableConfig] = None, @@ -126,18 +131,21 @@ def _prepare( self, config: Optional[RunnableConfig] = None ) -> tuple[Runnable[Input, Output], RunnableConfig]: ... + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: runnable, config = self.prepare(config) return runnable.invoke(input, config, **kwargs) + @override async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: runnable, config = self.prepare(config) return await runnable.ainvoke(input, config, **kwargs) + @override def batch( self, inputs: list[Input], @@ -180,6 +188,7 @@ def invoke( with get_executor_for_config(configs[0]) as executor: return cast(list[Output], list(executor.map(invoke, prepared, inputs))) + @override async def abatch( self, inputs: list[Input], @@ -218,6 +227,7 @@ async def ainvoke( coros = map(ainvoke, prepared, inputs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) + @override def stream( self, input: Input, @@ -227,6 +237,7 @@ def stream( runnable, config = self.prepare(config) return runnable.stream(input, config, **kwargs) + @override async def astream( self, input: Input, @@ -237,6 +248,7 @@ async def astream( async for chunk in runnable.astream(input, config, **kwargs): yield chunk + @override def transform( self, input: Iterator[Input], @@ -246,6 +258,7 @@ def transform( runnable, config = self.prepare(config) return runnable.transform(input, config, **kwargs) + @override async def atransform( self, input: AsyncIterator[Input], @@ -543,11 +556,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): the alternative named "gpt3" becomes "model==gpt3/temperature".""" @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: with _enums_for_spec_lock: if which_enum := _enums_for_spec.get(self.which): @@ -595,6 +609,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: ] ) + @override def configurable_fields( self, **kwargs: AnyConfigurableField ) -> RunnableSerializable[Input, Output]: diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index f932ce3589e00..05b51773ade89 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -116,17 +116,20 @@ def InputType(self) -> type[Input]: def OutputType(self) -> type[Output]: return self.runnable.OutputType + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: return self.runnable.get_input_schema(config) + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: return self.runnable.get_output_schema(config) @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: return get_unique_config_specs( spec @@ -135,19 +138,22 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: ) @classmethod + @override def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property def runnables(self) -> Iterator[Runnable[Input, Output]]: + """Iterator over the Runnable and its fallbacks.""" yield self.runnable yield from self.fallbacks + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -198,6 +204,7 @@ def invoke( run_manager.on_chain_error(first_error) raise first_error + @override async def ainvoke( self, input: Input, @@ -251,6 +258,7 @@ async def ainvoke( await run_manager.on_chain_error(first_error) raise first_error + @override def batch( self, inputs: list[Input], @@ -344,6 +352,7 @@ def batch( to_return.update(handled_exceptions) return [output for _, output in sorted(to_return.items())] + @override async def abatch( self, inputs: list[Input], @@ -445,6 +454,7 @@ async def abatch( to_return.update(handled_exceptions) return [output for _, output in sorted(to_return.items())] # type: ignore + @override def stream( self, input: Input, @@ -510,6 +520,7 @@ def stream( raise run_manager.on_chain_end(output) + @override async def astream( self, input: Input, diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index d2040a3f3fb6f..c38d2a6cef7a9 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -237,8 +237,8 @@ def get_session_history( history_factory_config: Sequence[ConfigurableFieldSpec] @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] def __init__( @@ -365,12 +365,14 @@ async def _call_runnable_async(_input: Any) -> Runnable: self._history_chain = history_chain @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: """Get the configuration specs for the RunnableWithMessageHistory.""" return get_unique_config_specs( super().config_specs + list(self.history_factory_config) ) + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index b0da175ae3291..82da60a300ddc 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -152,6 +152,7 @@ def fake_llm(prompt: str) -> str: # Fake LLM for the example ] ] = None + @override def __repr_args__(self) -> Any: # Without this repr(self) raises a RecursionError # See https://github.com/pydantic/pydantic/issues/7327 @@ -185,12 +186,13 @@ def __init__( super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) # type: ignore[call-arg] @classmethod + @override def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property @@ -204,6 +206,7 @@ def OutputType(self) -> Any: return self.input_type or Any @classmethod + @override def assign( cls, **kwargs: Union[ @@ -227,6 +230,7 @@ def assign( """ return RunnableAssign(RunnableParallel[dict[str, Any]](kwargs)) + @override def invoke( self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Other: @@ -236,6 +240,7 @@ def invoke( ) return self._call_with_config(identity, input, config) + @override async def ainvoke( self, input: Other, @@ -252,6 +257,7 @@ async def ainvoke( ) return await self._acall_with_config(aidentity, input, config) + @override def transform( self, input: Iterator[Other], @@ -282,6 +288,7 @@ def transform( self.func, final, ensure_config(config), **kwargs ) + @override async def atransform( self, input: AsyncIterator[Other], @@ -324,6 +331,7 @@ async def atransform( elif self.func is not None: call_func_with_variable_args(self.func, final, config, **kwargs) + @override def stream( self, input: Other, @@ -332,6 +340,7 @@ def stream( ) -> Iterator[Other]: return self.transform(iter([input]), config, **kwargs) + @override async def astream( self, input: Other, @@ -395,14 +404,16 @@ def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> N super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg] @classmethod + @override def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + @override def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: @@ -413,6 +424,7 @@ def get_name( ) return super().get_name(suffix, name=name) + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -423,6 +435,7 @@ def get_input_schema( return super().get_input_schema(config) + @override def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -450,9 +463,11 @@ def get_output_schema( return super().get_output_schema(config) @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: return self.mapper.config_specs + @override def get_graph(self, config: RunnableConfig | None = None) -> Graph: # get graph from mapper graph = self.mapper.get_graph(config) @@ -485,6 +500,7 @@ def _invoke( ), } + @override def invoke( self, input: dict[str, Any], @@ -513,6 +529,7 @@ async def _ainvoke( ), } + @override async def ainvoke( self, input: dict[str, Any], @@ -567,6 +584,7 @@ def _transform( for chunk in map_output: yield chunk + @override def transform( self, input: Iterator[dict[str, Any]], @@ -618,6 +636,7 @@ async def _atransform( async for chunk in map_output: yield chunk + @override async def atransform( self, input: AsyncIterator[dict[str, Any]], @@ -629,6 +648,7 @@ async def atransform( ): yield chunk + @override def stream( self, input: dict[str, Any], @@ -637,6 +657,7 @@ def stream( ) -> Iterator[dict[str, Any]]: return self.transform(iter([input]), config, **kwargs) + @override async def astream( self, input: dict[str, Any], @@ -687,14 +708,17 @@ def __init__(self, keys: Union[str, list[str]], **kwargs: Any) -> None: super().__init__(keys=keys, **kwargs) # type: ignore[call-arg] @classmethod + @override def is_lc_serializable(cls) -> bool: return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + @override def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None ) -> str: @@ -725,6 +749,7 @@ def _invoke( ) -> dict[str, Any]: return self._pick(input) + @override def invoke( self, input: dict[str, Any], @@ -739,6 +764,7 @@ async def _ainvoke( ) -> dict[str, Any]: return self._pick(input) + @override async def ainvoke( self, input: dict[str, Any], @@ -756,6 +782,7 @@ def _transform( if picked is not None: yield picked + @override def transform( self, input: Iterator[dict[str, Any]], @@ -775,6 +802,7 @@ async def _atransform( if picked is not None: yield picked + @override async def atransform( self, input: AsyncIterator[dict[str, Any]], @@ -786,6 +814,7 @@ async def atransform( ): yield chunk + @override def stream( self, input: dict[str, Any], @@ -794,6 +823,7 @@ def stream( ) -> Iterator[dict[str, Any]]: return self.transform(iter([input]), config, **kwargs) + @override async def astream( self, input: dict[str, Any], diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 9300a35d899fa..885de50875875 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -16,6 +16,7 @@ stop_after_attempt, wait_exponential_jitter, ) +from typing_extensions import override from langchain_core.runnables.base import Input, Output, RunnableBindingBase from langchain_core.runnables.config import RunnableConfig, patch_config @@ -110,8 +111,8 @@ def foo(input) -> None: """The maximum number of attempts to retry the Runnable.""" @classmethod + @override def get_lc_namespace(cls) -> list[str]: - """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property @@ -173,6 +174,7 @@ def _invoke( attempt.retry_state.set_result(result) return result + @override def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -196,6 +198,7 @@ async def _ainvoke( attempt.retry_state.set_result(result) return result + @override async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -255,6 +258,7 @@ def pending(iterable: list[U]) -> list[U]: outputs.append(result.pop(0)) return outputs + @override def batch( self, inputs: list[Input], @@ -321,6 +325,7 @@ def pending(iterable: list[U]) -> list[U]: outputs.append(result.pop(0)) return outputs + @override async def abatch( self, inputs: list[Input], diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index 29c6359c69aef..af4a35967840f 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -11,7 +11,7 @@ ) from pydantic import ConfigDict -from typing_extensions import TypedDict +from typing_extensions import TypedDict, override from langchain_core.runnables.base import ( Input, @@ -68,6 +68,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): runnables: Mapping[str, Runnable[Any, Output]] @property + @override def config_specs(self) -> list[ConfigurableFieldSpec]: return get_unique_config_specs( spec for step in self.runnables.values() for spec in step.config_specs @@ -86,15 +87,18 @@ def __init__( ) @classmethod + @override def is_lc_serializable(cls) -> bool: """Return whether this class is serializable.""" return True @classmethod + @override def get_lc_namespace(cls) -> list[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] + @override def invoke( self, input: RouterInput, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: @@ -107,6 +111,7 @@ def invoke( runnable = self.runnables[key] return runnable.invoke(actual_input, config) + @override async def ainvoke( self, input: RouterInput, @@ -122,6 +127,7 @@ async def ainvoke( runnable = self.runnables[key] return await runnable.ainvoke(actual_input, config) + @override def batch( self, inputs: list[RouterInput], @@ -158,6 +164,7 @@ def invoke( list(executor.map(invoke, runnables, actual_inputs, configs)), ) + @override async def abatch( self, inputs: list[RouterInput], @@ -193,6 +200,7 @@ async def ainvoke( *starmap(ainvoke, zip(runnables, actual_inputs, configs)), ) + @override def stream( self, input: RouterInput, @@ -208,6 +216,7 @@ def stream( runnable = self.runnables[key] yield from runnable.stream(actual_input, config) + @override async def astream( self, input: RouterInput, diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index f0833cdc24a14..765dca53993be 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -37,6 +37,7 @@ from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from pydantic.v1 import validate_arguments as validate_arguments_v1 +from typing_extensions import override from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -443,10 +444,12 @@ def is_single_input(self) -> bool: @property def args(self) -> dict: + """The arguments of the tool.""" return self.get_input_schema().model_json_schema()["properties"] @property def tool_call_schema(self) -> type[BaseModel]: + """The schema for a tool call.""" full_schema = self.get_input_schema() fields = [] for name, type_ in get_all_basemodel_annotations(full_schema).items(): @@ -458,6 +461,7 @@ def tool_call_schema(self) -> type[BaseModel]: # --- Runnable --- + @override def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> type[BaseModel]: @@ -474,6 +478,7 @@ def get_input_schema( else: return create_schema_from_function(self.name, self._run) + @override def invoke( self, input: Union[str, dict, ToolCall], @@ -483,6 +488,7 @@ def invoke( tool_input, kwargs = _prep_run_args(input, config, **kwargs) return self.run(tool_input, **kwargs) + @override async def ainvoke( self, input: Union[str, dict, ToolCall], diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index bb8b85f5558cc..047826e4d0e09 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -116,13 +116,13 @@ def tool( @tool def search_api(query: str) -> str: # Searches the API for the query. - return + Return: @tool("search", return_direct=True) def search_api(query: str) -> str: # Searches the API for the query. - return + Return: @tool(response_format="content_and_artifact") def search_api(query: str) -> Tuple[str, dict]: return "partial json of results", {"full": "object of results"} @@ -136,7 +136,7 @@ def search_api(query: str) -> Tuple[str, dict]: def foo(bar: str, baz: int) -> str: \"\"\"The foo. - Args: + Args: bar: The bar. baz: The baz. \"\"\" @@ -183,7 +183,8 @@ def invalid_docstring_1(bar: str, baz: int) -> str: # Improper whitespace between summary and args section def invalid_docstring_2(bar: str, baz: int) -> str: \"\"\"The foo. - Args: + + Args: bar: The bar. baz: The baz. \"\"\" @@ -193,7 +194,7 @@ def invalid_docstring_2(bar: str, baz: int) -> str: def invalid_docstring_3(bar: str, baz: int) -> str: \"\"\"The foo. - Args: + Args: banana: The bar. monkey: The baz. \"\"\" diff --git a/libs/core/langchain_core/tools/simple.py b/libs/core/langchain_core/tools/simple.py index d9e38ba227c8b..99dc09cb61220 100644 --- a/libs/core/langchain_core/tools/simple.py +++ b/libs/core/langchain_core/tools/simple.py @@ -10,6 +10,7 @@ ) from pydantic import BaseModel +from typing_extensions import override from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, @@ -35,6 +36,7 @@ class Tool(BaseTool): # --- Runnable --- + @override async def ainvoke( self, input: Union[str, dict, ToolCall], diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index ef185b3e1e844..72d7433ab5f80 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -13,6 +13,7 @@ ) from pydantic import BaseModel, Field, SkipValidation +from typing_extensions import override from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, @@ -45,6 +46,7 @@ class StructuredTool(BaseTool): # --- Runnable --- # TODO: Is this needed? + @override async def ainvoke( self, input: Union[str, dict, ToolCall], diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index f3ae965f6025c..33f9fc08ff7d0 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -15,6 +15,7 @@ from uuid import UUID from tenacity import RetryCallState +from typing_extensions import override from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.exceptions import TracerException # noqa @@ -527,9 +528,11 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): """Async Base interface for tracers.""" @abstractmethod + @override async def _persist_run(self, run: Run) -> None: """Persist a run.""" + @override async def _start_trace(self, run: Run) -> None: """Start a trace for a run. @@ -539,6 +542,7 @@ async def _start_trace(self, run: Run) -> None: super()._start_trace(run) await self._on_run_create(run) + @override async def _end_trace(self, run: Run) -> None: """End a trace for a run. @@ -550,6 +554,7 @@ async def _end_trace(self, run: Run) -> None: self.run_map.pop(str(run.id)) await self._on_run_update(run) + @override async def on_chat_model_start( self, serialized: dict[str, Any], @@ -579,6 +584,7 @@ async def on_chat_model_start( await asyncio.gather(*tasks) return chat_model_run + @override async def on_llm_start( self, serialized: dict[str, Any], @@ -602,6 +608,7 @@ async def on_llm_start( tasks = [self._start_trace(llm_run), self._on_llm_start(llm_run)] await asyncio.gather(*tasks) + @override async def on_llm_new_token( self, token: str, @@ -620,6 +627,7 @@ async def on_llm_new_token( ) await self._on_llm_new_token(llm_run, token, chunk) + @override async def on_retry( self, retry_state: RetryCallState, @@ -632,6 +640,7 @@ async def on_retry( run_id=run_id, ) + @override async def on_llm_end( self, response: LLMResult, @@ -648,6 +657,7 @@ async def on_llm_end( tasks = [self._on_llm_end(llm_run), self._end_trace(llm_run)] await asyncio.gather(*tasks) + @override async def on_llm_error( self, error: BaseException, @@ -664,6 +674,7 @@ async def on_llm_error( tasks = [self._on_llm_error(llm_run), self._end_trace(llm_run)] await asyncio.gather(*tasks) + @override async def on_chain_start( self, serialized: dict[str, Any], @@ -691,6 +702,7 @@ async def on_chain_start( tasks = [self._start_trace(chain_run), self._on_chain_start(chain_run)] await asyncio.gather(*tasks) + @override async def on_chain_end( self, outputs: dict[str, Any], @@ -708,6 +720,7 @@ async def on_chain_end( tasks = [self._end_trace(chain_run), self._on_chain_end(chain_run)] await asyncio.gather(*tasks) + @override async def on_chain_error( self, error: BaseException, @@ -725,6 +738,7 @@ async def on_chain_error( tasks = [self._end_trace(chain_run), self._on_chain_error(chain_run)] await asyncio.gather(*tasks) + @override async def on_tool_start( self, serialized: dict[str, Any], @@ -751,6 +765,7 @@ async def on_tool_start( tasks = [self._start_trace(tool_run), self._on_tool_start(tool_run)] await asyncio.gather(*tasks) + @override async def on_tool_end( self, output: Any, @@ -766,6 +781,7 @@ async def on_tool_end( tasks = [self._end_trace(tool_run), self._on_tool_end(tool_run)] await asyncio.gather(*tasks) + @override async def on_tool_error( self, error: BaseException, @@ -782,6 +798,7 @@ async def on_tool_error( tasks = [self._end_trace(tool_run), self._on_tool_error(tool_run)] await asyncio.gather(*tasks) + @override async def on_retriever_start( self, serialized: dict[str, Any], @@ -809,6 +826,7 @@ async def on_retriever_start( ] await asyncio.gather(*tasks) + @override async def on_retriever_error( self, error: BaseException, @@ -829,6 +847,7 @@ async def on_retriever_error( ] await asyncio.gather(*tasks) + @override async def on_retriever_end( self, documents: Sequence[Document], diff --git a/libs/core/langchain_core/utils/iter.py b/libs/core/langchain_core/utils/iter.py index 7868119caedd2..8ff7b7e49bf6d 100644 --- a/libs/core/langchain_core/utils/iter.py +++ b/libs/core/langchain_core/utils/iter.py @@ -171,6 +171,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: return False def close(self) -> None: + """Close all child iterators.""" for child in self._children: child.close() diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py index 8aedfaf339b71..c76b53d7c27f9 100644 --- a/libs/core/langchain_core/utils/json.py +++ b/libs/core/langchain_core/utils/json.py @@ -160,8 +160,9 @@ def _parse_json( def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: - """Parse a JSON string from a Markdown string and check that it - contains the expected keys. + """Parse and check a JSON string from a Markdown string. + + Checks that it contains the expected keys. Args: text: The Markdown string. diff --git a/libs/core/langchain_core/utils/loading.py b/libs/core/langchain_core/utils/loading.py index ae911c4f4d280..27db390a16dd4 100644 --- a/libs/core/langchain_core/utils/loading.py +++ b/libs/core/langchain_core/utils/loading.py @@ -19,6 +19,7 @@ def try_load_from_hub( *args: Any, **kwargs: Any, ) -> Any: + """[DEPRECATED] Try to load from the old Hub.""" warnings.warn( "Loading from the deprecated github-based Hub is no longer supported. " "Please use the new LangChain Hub at https://smith.langchain.com/hub instead.", diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index ee2ed8f2528f8..10f22b822722b 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -1,4 +1,5 @@ """Adapted from https://github.com/noahmorrison/chevron + MIT License. """ diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index b154a14b98191..6e03631be1618 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -293,8 +293,7 @@ def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]: async def aadd_documents( self, documents: list[Document], **kwargs: Any ) -> list[str]: - """Async run more documents through the embeddings and add to - the vectorstore. + """Async run more documents through the embeddings and add to the vectorstore. Args: documents: Documents to add to the vectorstore. @@ -434,6 +433,7 @@ def _max_inner_product_relevance_score_fn(distance: float) -> float: def _select_relevance_score_fn(self) -> Callable[[float], float]: """The 'correct' relevance function + may differ depending on a few things, including: - the distance / similarity metric used by the VectorStore - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) @@ -509,8 +509,9 @@ async def _asimilarity_search_with_relevance_scores( k: int = 4, **kwargs: Any, ) -> list[tuple[Document, float]]: - """Default similarity search with relevance scores. Modify if necessary - in subclass. + """Default similarity search with relevance scores. + + Modify if necessary in subclass. Return docs and relevance scores in the range [0, 1]. 0 is dissimilar, 1 is most similar. diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index ab32c7cdacbe9..e0023300b5351 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -11,6 +11,8 @@ Optional, ) +from typing_extensions import override + from langchain_core._api import deprecated from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -153,17 +155,21 @@ def __init__(self, embedding: Embeddings) -> None: self.embedding = embedding @property + @override def embeddings(self) -> Embeddings: return self.embedding + @override def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: if ids: for _id in ids: self.store.pop(_id, None) + @override async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: self.delete(ids) + @override def add_documents( self, documents: list[Document], @@ -200,6 +206,7 @@ def add_documents( return ids_ + @override async def aadd_documents( self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any ) -> list[str]: @@ -232,6 +239,7 @@ async def aadd_documents( return ids_ + @override def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: """Get documents by their ids. @@ -264,6 +272,14 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: removal="1.0", ) def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: + """[DEPRECATED] Upsert documents into the store. + + Args: + items: The documents to upsert. + + Returns: + The upsert response. + """ vectors = self.embedding.embed_documents([item.page_content for item in items]) ids = [] for item, vector in zip(items, vectors): @@ -291,6 +307,14 @@ def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: async def aupsert( self, items: Sequence[Document], /, **kwargs: Any ) -> UpsertResponse: + """[DEPRECATED] Upsert documents into the store. + + Args: + items: The documents to upsert. + + Returns: + The upsert response. + """ vectors = await self.embedding.aembed_documents( [item.page_content for item in items] ) @@ -309,6 +333,7 @@ async def aupsert( "failed": [], } + @override async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]: """Async get documents by their ids. @@ -367,6 +392,17 @@ def similarity_search_with_score_by_vector( filter: Optional[Callable[[Document], bool]] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: + """Search for the most similar documents to the given embedding. + + Args: + embedding: The embedding to search for. + k: The number of documents to return. + filter: A function to filter the documents. + **kwargs: Additional arguments. + + Returns: + A list of tuples of Document objects and their similarity scores. + """ return [ (doc, similarity) for doc, similarity, _ in self._similarity_search_with_score_by_vector( @@ -374,6 +410,7 @@ def similarity_search_with_score_by_vector( ) ] + @override def similarity_search_with_score( self, query: str, @@ -388,6 +425,7 @@ def similarity_search_with_score( ) return docs + @override async def asimilarity_search_with_score( self, query: str, k: int = 4, **kwargs: Any ) -> list[tuple[Document, float]]: @@ -399,6 +437,7 @@ async def asimilarity_search_with_score( ) return docs + @override def similarity_search_by_vector( self, embedding: list[float], @@ -412,16 +451,19 @@ def similarity_search_by_vector( ) return [doc for doc, _ in docs_and_scores] + @override async def asimilarity_search_by_vector( self, embedding: list[float], k: int = 4, **kwargs: Any ) -> list[Document]: return self.similarity_search_by_vector(embedding, k, **kwargs) + @override def similarity_search( self, query: str, k: int = 4, **kwargs: Any ) -> list[Document]: return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)] + @override async def asimilarity_search( self, query: str, k: int = 4, **kwargs: Any ) -> list[Document]: @@ -430,6 +472,7 @@ async def asimilarity_search( for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs) ] + @override def max_marginal_relevance_search_by_vector( self, embedding: list[float], @@ -461,6 +504,7 @@ def max_marginal_relevance_search_by_vector( ) return [prefetch_hits[idx][0] for idx in mmr_chosen_indices] + @override def max_marginal_relevance_search( self, query: str, @@ -478,6 +522,7 @@ def max_marginal_relevance_search( **kwargs, ) + @override async def amax_marginal_relevance_search( self, query: str, @@ -496,6 +541,7 @@ async def amax_marginal_relevance_search( ) @classmethod + @override def from_texts( cls, texts: list[str], @@ -510,6 +556,7 @@ def from_texts( return store @classmethod + @override async def afrom_texts( cls, texts: list[str], diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index f9b731c730163..05622857f7ee8 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -44,7 +44,8 @@ python = ">=3.12.4" [tool.poetry.extras] [tool.ruff.lint] -select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",] +pydocstyle.convention = "google" +select = [ "ASYNC", "B", "C4", "COM", "D102", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",] ignore = [ "COM812", "UP007", "S110", "S112",] [tool.coverage.run] @@ -78,7 +79,7 @@ classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_ini "tests/unit_tests/prompts/test_chat.py" = [ "E501",] "tests/unit_tests/runnables/test_runnable.py" = [ "E501",] "tests/unit_tests/runnables/test_graph.py" = [ "E501",] -"tests/**" = [ "S",] +"tests/**" = [ "D", "S",] "scripts/**" = [ "S",] [tool.poetry.group.lint.dependencies] diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 4fd7fb567e885..120d0b9613ce0 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -117,7 +117,7 @@ def test_structured_args() -> None: def test_misannotated_base_tool_raises_error() -> None: - """Test that a BaseTool with the incorrect typehint raises an exception.""" "" + """Test that a BaseTool with the incorrect typehint raises an exception.""" with pytest.raises(SchemaAnnotationError): class _MisAnnotatedTool(BaseTool): @@ -136,7 +136,7 @@ async def _arun( def test_forward_ref_annotated_base_tool_accepted() -> None: - """Test that a using forward ref annotation syntax is accepted.""" "" + """Test that a using forward ref annotation syntax is accepted.""" class _ForwardRefAnnotatedTool(BaseTool): name: str = "structured_api" @@ -1296,6 +1296,7 @@ def foo3(bar: str, baz: int) -> str: def foo4(bar: str, baz: int) -> str: """The foo. + Args: bar: The bar. baz: The baz. @@ -1504,7 +1505,7 @@ def h(x: str) -> str: @tool("foo", parse_docstring=True) def injected_tool(x: int, y: Annotated[str, InjectedToolArg]) -> str: - """foo. + """Foo. Args: x: abc @@ -1518,7 +1519,7 @@ class InjectedTool(BaseTool): description: str = "foo." def _run(self, x: int, y: Annotated[str, InjectedToolArg]) -> Any: - """foo. + """Foo. Args: x: abc