Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RunnablePassthrough.assign(...) #11222

Merged
merged 2 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 10 additions & 29 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
patch_config,
)
from langchain.schema.runnable.utils import (
AddableDict,
Input,
Output,
accepts_config,
Expand Down Expand Up @@ -1748,30 +1749,6 @@ async def input_aiter() -> AsyncIterator[Input]:
yield chunk


class RunnableMapChunk(Dict[str, Any]):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to utils

"""
Partial output from a RunnableMap
"""

def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] = chunk[key] + other[key]
return chunk

def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
chunk = RunnableMapChunk(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] = chunk[key] + self[key]
return chunk


class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
"""
A runnable that runs a mapping of runnables in parallel,
Expand Down Expand Up @@ -1814,14 +1791,18 @@ def InputType(self) -> Any:

@property
def input_schema(self) -> type[BaseModel]:
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()):
if all(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works in more cases

s.input_schema.schema().get("type", "object") == "object"
for s in self.steps.values()
):
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapInput",
**{
k: (v.type_, v.default)
for step in self.steps.values()
for k, v in step.input_schema.__fields__.items()
if k != "__root__"
},
)

Expand Down Expand Up @@ -1934,7 +1915,7 @@ def _transform(
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Iterator[RunnableMapChunk]:
) -> Iterator[AddableDict]:
# Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps)
# Each step gets a copy of the input iterator,
Expand Down Expand Up @@ -1967,7 +1948,7 @@ def _transform(
for future in completed_futures:
(step_name, generator) = futures.pop(future)
try:
chunk = RunnableMapChunk({step_name: future.result()})
chunk = AddableDict({step_name: future.result()})
yield chunk
futures[executor.submit(next, generator)] = (
step_name,
Expand Down Expand Up @@ -1999,7 +1980,7 @@ async def _atransform(
input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> AsyncIterator[RunnableMapChunk]:
) -> AsyncIterator[AddableDict]:
# Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps)
# Each step gets a copy of the input iterator,
Expand Down Expand Up @@ -2038,7 +2019,7 @@ async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]:
for task in completed_tasks:
(step_name, generator) = tasks.pop(task)
try:
chunk = RunnableMapChunk({step_name: task.result()})
chunk = AddableDict({step_name: task.result()})
yield chunk
new_task = asyncio.create_task(get_next_chunk(generator))
tasks[new_task] = (step_name, generator)
Expand Down
201 changes: 198 additions & 3 deletions libs/langchain/langchain/schema/runnable/passthrough.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
from __future__ import annotations

from typing import Any, AsyncIterator, Iterator, List, Optional, Type
import asyncio
import threading
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
Union,
cast,
)

from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain.pydantic_v1 import BaseModel, create_model
from langchain.schema.runnable.base import Input, Runnable, RunnableMap
from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config
from langchain.schema.runnable.utils import AddableDict
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee


def identity(x: Input) -> Input:
Expand Down Expand Up @@ -38,6 +56,30 @@ def InputType(self) -> Any:
def OutputType(self) -> Any:
return self.input_type or Any

@classmethod
def assign(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an assign method to RunnableMap?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would we use that for?

cls,
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
Mapping[
str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
],
],
) -> RunnableAssign:
"""
Merge the Dict input with the output produced by the mapping argument.

Args:
mapping: A mapping from keys to runnables or callables.

Returns:
A runnable that merges the Dict input with the output produced by the
mapping argument.
"""
return RunnableAssign(RunnableMap(kwargs))

def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(identity, input, config)

Expand Down Expand Up @@ -65,3 +107,156 @@ async def atransform(
) -> AsyncIterator[Input]:
async for chunk in self._atransform_stream_with_config(input, identity, config):
yield chunk


class RunnableAssign(Serializable, Runnable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
"""

mapper: RunnableMap[Dict[str, Any]]

def __init__(self, mapper: RunnableMap[Dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs)

@classmethod
def is_lc_serializable(cls) -> bool:
return True

@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]

@property
def input_schema(self) -> type[BaseModel]:
map_input_schema = self.mapper.input_schema
if not map_input_schema.__custom_root_type__:
# ie. it's a dict
return map_input_schema

return super().input_schema

@property
def output_schema(self) -> type[BaseModel]:
map_input_schema = self.mapper.input_schema
map_output_schema = self.mapper.output_schema
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
):
# ie. both are dicts
return create_model( # type: ignore[call-overload]
"RunnableAssignOutput",
**{
k: (v.type_, v.default)
for s in (map_input_schema, map_output_schema)
for k, v in s.__fields__.items()
},
)

return super().output_schema

def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(input, dict)
return {
**input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is a passthrough, should not raise on key collision here? (i.e., a user is unlikely to want to use this functionality to overwrite existing keys)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is on purpose, so you can eg modify an existing key

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same semantics as JS Object.assign

**self.mapper.invoke(input, config, **kwargs),
}

async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(input, dict)
return {
**input,
**await self.mapper.ainvoke(input, config, **kwargs),
}

def transform(
self,
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
# create map output stream
map_output = self.mapper.transform(for_map, config, **kwargs)
# get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor:
# start map output stream
first_map_chunk_future = executor.submit(next, map_output) # type: ignore
# consume passthrough stream
for chunk in for_passthrough:
assert isinstance(chunk, dict)
# remove mapper keys from passthrough chunk, to be overwritten by map
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield cast(Dict[str, Any], first_map_chunk_future.result())
for chunk in map_output:
yield chunk

async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream
map_output = self.mapper.atransform(for_map, config, **kwargs)
# start map output stream
first_map_chunk_task = cast(
asyncio.Task[Dict[str, Any]],
asyncio.create_task(py_anext(map_output)), # type: ignore[arg-type]
)
# consume passthrough stream
async for chunk in for_passthrough:
assert isinstance(chunk, dict)
# remove mapper keys from passthrough chunk, to be overwritten by map output
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield await first_map_chunk_task
async for chunk in map_output:
yield chunk

def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)

async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
yield input

async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk
71 changes: 70 additions & 1 deletion libs/langchain/langchain/schema/runnable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,20 @@
import inspect
import textwrap
from inspect import signature
from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union
from typing import (
Any,
AsyncIterable,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Protocol,
Set,
TypeVar,
Union,
)

Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Expand Down Expand Up @@ -142,3 +155,59 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
spaces = " " * n_spaces
lines = text.splitlines()
return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])


class AddableDict(Dict[str, Any]):
"""
Dictionary that can be added to another dictionary.
"""

def __add__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(self)
for key in other:
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] = chunk[key] + other[key]
return chunk

def __radd__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(other)
for key in self:
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] = chunk[key] + self[key]
return chunk


_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)


class SupportsAdd(Protocol[_T_contra, _T_co]):
def __add__(self, __x: _T_contra) -> _T_co:
...


Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])


def add(addables: Iterable[Addable]) -> Optional[Addable]:
final = None
for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
return final


async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
final = None
async for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
return final
Loading