Skip to content

Commit

Permalink
core: Add ruff rules D102 (missing docstring)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 24, 2025
1 parent dbb6b7b commit d7d9134
Show file tree
Hide file tree
Showing 37 changed files with 427 additions and 36 deletions.
38 changes: 38 additions & 0 deletions libs/core/langchain_core/beta/runnables/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from pydantic import ConfigDict
from typing_extensions import override

from langchain_core._api.beta_decorator import beta
from langchain_core.runnables.base import (
Expand Down Expand Up @@ -166,6 +167,7 @@ def __str__(self) -> str:

@property
def ids(self) -> list[str]:
"""The context getter ids."""
prefix = self.prefix + "/" if self.prefix else ""
keys = self.key if isinstance(self.key, list) else [self.key]
return [
Expand All @@ -174,6 +176,7 @@ def ids(self) -> list[str]:
]

@property
@override
def config_specs(self) -> list[ConfigurableFieldSpec]:
return super().config_specs + [
ConfigurableFieldSpec(
Expand All @@ -183,6 +186,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]:
for id_ in self.ids
]

@override
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
Expand All @@ -193,6 +197,7 @@ def invoke(
else:
return configurable[self.ids[0]]()

@override
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
Expand Down Expand Up @@ -253,13 +258,15 @@ def __str__(self) -> str:

@property
def ids(self) -> list[str]:
"""The context setter ids."""
prefix = self.prefix + "/" if self.prefix else ""
return [
f"{CONTEXT_CONFIG_PREFIX}{prefix}{key}{CONTEXT_CONFIG_SUFFIX_SET}"
for key in self.keys
]

@property
@override
def config_specs(self) -> list[ConfigurableFieldSpec]:
mapper_config_specs = [
s
Expand All @@ -281,6 +288,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]:
for id_ in self.ids
]

@override
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
Expand All @@ -293,6 +301,7 @@ def invoke(
configurable[id_](input)
return input

@override
async def ainvoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
Expand Down Expand Up @@ -361,6 +370,11 @@ def create_scope(scope: str, /) -> "PrefixContext":

@staticmethod
def getter(key: Union[str, list[str]], /) -> ContextGet:
"""Return a context getter.
Args:
key: The context getter key.
"""
return ContextGet(key=key)

@staticmethod
Expand All @@ -370,6 +384,13 @@ def setter(
/,
**kwargs: SetValue,
) -> ContextSet:
"""Return a context setter.
Args:
_key: The context setter key.
_value: The context setter value.
**kwargs: Additional context setter key-value pairs.
"""
return ContextSet(_key, _value, prefix="", **kwargs)


Expand All @@ -379,9 +400,19 @@ class PrefixContext:
prefix: str = ""

def __init__(self, prefix: str = ""):
"""Initialize the prefix context.
Args:
prefix: The prefix.
"""
self.prefix = prefix

def getter(self, key: Union[str, list[str]], /) -> ContextGet:
"""Return a prefixed context getter.
Args:
key: The context getter key.
"""
return ContextGet(key=key, prefix=self.prefix)

def setter(
Expand All @@ -391,6 +422,13 @@ def setter(
/,
**kwargs: SetValue,
) -> ContextSet:
"""Return a prefixed context setter.
Args:
_key: The context setter key.
_value: The context setter value.
**kwargs: Additional context setter key-value pairs.
"""
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)


Expand Down
3 changes: 1 addition & 2 deletions libs/core/langchain_core/caches.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
.. warning::
""".. warning::
Beta Feature!
**Cache** provides an optional caching layer for LLMs.
Expand Down
5 changes: 4 additions & 1 deletion libs/core/langchain_core/document_loaders/langsmith.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Optional, Union

from langsmith import Client as LangSmithClient
from typing_extensions import override

from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
Expand Down Expand Up @@ -53,7 +54,8 @@ def __init__(
client: Optional[LangSmithClient] = None,
**client_kwargs: Any,
) -> None:
"""
"""Initialize a LangSmith loader.
Args:
dataset_id: The ID of the dataset to filter by. Defaults to None.
dataset_name: The name of the dataset to filter by. Defaults to None.
Expand Down Expand Up @@ -95,6 +97,7 @@ def __init__(
self.metadata = metadata
self.filter = filter

@override
def lazy_load(self) -> Iterator[Document]:
for example in self._client.list_examples(
dataset_id=self.dataset_id,
Expand Down
5 changes: 5 additions & 0 deletions libs/core/langchain_core/documents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class BaseMedia(Serializable):

@field_validator("id", mode="before")
def cast_id_to_str(cls, id_value: Any) -> Optional[str]:
"""Coerce the id field to a string.
Args:
id_value: The id value to coerce.
"""
if id_value is not None:
return str(id_value)
else:
Expand Down
5 changes: 5 additions & 0 deletions libs/core/langchain_core/embeddings/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hashlib

from pydantic import BaseModel
from typing_extensions import override

from langchain_core.embeddings import Embeddings

Expand Down Expand Up @@ -55,9 +56,11 @@ def _get_embedding(self) -> list[float]:

return list(np.random.default_rng().normal(size=self.size))

@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self._get_embedding() for _ in texts]

@override
def embed_query(self, text: str) -> list[float]:
return self._get_embedding()

Expand Down Expand Up @@ -116,8 +119,10 @@ def _get_seed(self, text: str) -> int:
"""Get a seed for the random generator, using the hash of the text."""
return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8

@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self._get_embedding(seed=self._get_seed(_)) for _ in texts]

@override
def embed_query(self, text: str) -> list[float]:
return self._get_embedding(seed=self._get_seed(text))
56 changes: 56 additions & 0 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def _convert_input(self, input: LanguageModelInput) -> PromptValue:
)
raise ValueError(msg) # noqa: TRY004

@override
def invoke(
self,
input: LanguageModelInput,
Expand All @@ -293,6 +294,7 @@ def invoke(
).generations[0][0],
).message

@override
async def ainvoke(
self,
input: LanguageModelInput,
Expand Down Expand Up @@ -349,6 +351,7 @@ def _should_stream(
handlers = run_manager.handlers if run_manager else []
return any(isinstance(h, _StreamingCallbackHandler) for h in handlers)

@override
def stream(
self,
input: LanguageModelInput,
Expand Down Expand Up @@ -423,6 +426,7 @@ def stream(

run_manager.on_llm_end(LLMResult(generations=[[generation]]))

@override
async def astream(
self,
input: LanguageModelInput,
Expand Down Expand Up @@ -780,6 +784,7 @@ async def agenerate(
]
return output

@override
def generate_prompt(
self,
prompts: list[PromptValue],
Expand All @@ -790,6 +795,7 @@ def generate_prompt(
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)

@override
async def agenerate_prompt(
self,
prompts: list[PromptValue],
Expand Down Expand Up @@ -1019,6 +1025,20 @@ def __call__(
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage:
"""Call the model.
Args:
messages: List of messages.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The model output message.
"""
generation = self.generate(
[messages], stop=stop, callbacks=callbacks, **kwargs
).generations[0][0]
Expand Down Expand Up @@ -1049,12 +1069,37 @@ async def _call_async(
def call_as_llm(
self, message: str, stop: Optional[list[str]] = None, **kwargs: Any
) -> str:
"""Call the model.
Args:
message: The input message.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The model output string.
"""
return self.predict(message, stop=stop, **kwargs)

@deprecated("0.1.7", alternative="invoke", removal="1.0")
@override
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict the next message.
Args:
text: The input message.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The predicted output string.
"""
_stop = None if stop is None else list(stop)
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
if isinstance(result.content, str):
Expand All @@ -1064,6 +1109,7 @@ def predict(
raise ValueError(msg) # noqa: TRY004

@deprecated("0.1.7", alternative="invoke", removal="1.0")
@override
def predict_messages(
self,
messages: list[BaseMessage],
Expand All @@ -1075,6 +1121,7 @@ def predict_messages(
return self(messages, stop=_stop, **kwargs)

@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@override
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
Expand All @@ -1089,6 +1136,7 @@ async def apredict(
raise ValueError(msg) # noqa: TRY004

@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@override
async def apredict_messages(
self,
messages: list[BaseMessage],
Expand Down Expand Up @@ -1117,6 +1165,14 @@ def bind_tools(
],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model.
Args:
tools: Sequence of tools to bind to the model.
Returns:
A Runnable that returns a message.
"""
raise NotImplementedError

def with_structured_output(
Expand Down
Loading

0 comments on commit d7d9134

Please sign in to comment.