Skip to content

Commit

Permalink
core[patch]: fix runnable history and add docs (#22283)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 authored May 30, 2024
1 parent dcec133 commit ee32369
Show file tree
Hide file tree
Showing 7 changed files with 487 additions and 349 deletions.
795 changes: 455 additions & 340 deletions docs/docs/how_to/message_history.ipynb

Large diffs are not rendered by default.

Binary file added docs/static/img/message_history.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 19 additions & 2 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def _get_input_messages(
) -> List[BaseMessage]:
from langchain_core.messages import BaseMessage

# If dictionary, try to pluck the single key representing messages
if isinstance(input_val, dict):
if self.input_messages_key:
key = self.input_messages_key
Expand All @@ -381,13 +382,25 @@ def _get_input_messages(
key = "input"
input_val = input_val[key]

# If value is a string, convert to a human message
if isinstance(input_val, str):
from langchain_core.messages import HumanMessage

return [HumanMessage(content=input_val)]
# If value is a single message, convert to a list
elif isinstance(input_val, BaseMessage):
return [input_val]
# If value is a list or tuple...
elif isinstance(input_val, (list, tuple)):
# Handle empty case
if len(input_val) == 0:
return list(input_val)
# If is a list of list, then return the first value
# This occurs for chat models - since we batch inputs
if isinstance(input_val[0], list):
if len(input_val) != 1:
raise ValueError()
return input_val[0]
return list(input_val)
else:
raise ValueError(
Expand All @@ -400,6 +413,7 @@ def _get_output_messages(
) -> List[BaseMessage]:
from langchain_core.messages import BaseMessage

# If dictionary, try to pluck the single key representing messages
if isinstance(output_val, dict):
if self.output_messages_key:
key = self.output_messages_key
Expand All @@ -418,6 +432,7 @@ def _get_output_messages(
from langchain_core.messages import AIMessage

return [AIMessage(content=output_val)]
# If value is a single message, convert to a list
elif isinstance(output_val, BaseMessage):
return [output_val]
elif isinstance(output_val, (list, tuple)):
Expand All @@ -431,7 +446,10 @@ def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage

if not self.history_messages_key:
# return all messages
messages += self._get_input_messages(input)
input_val = (
input if not self.input_messages_key else input[self.input_messages_key]
)
messages += self._get_input_messages(input_val)
return messages

async def _aenter_history(
Expand All @@ -454,7 +472,6 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None:
# Get the input messages
inputs = load(run.inputs)
input_messages = self._get_input_messages(inputs)

# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
Expand Down
14 changes: 9 additions & 5 deletions libs/core/langchain_core/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __init__(
self,
*,
_schema_format: Literal["original", "streaming_events"] = "original",
_schema_format: Literal[
"original", "streaming_events", "original+chat"
] = "original",
**kwargs: Any,
) -> None:
"""Initialize the tracer.
Expand All @@ -63,6 +65,8 @@ def __init__(
for internal usage. It will likely change in the future, or
be deprecated entirely in favor of a dedicated async tracer
for streaming events.
- 'original+chat' is a format that is the same as 'original'
except it does NOT raise an attribute error on_chat_model_start
kwargs: Additional keyword arguments that will be passed to
the super class.
"""
Expand Down Expand Up @@ -163,7 +167,7 @@ def on_chat_model_start(
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run."""
if self._schema_format != "streaming_events":
if self._schema_format not in ("streaming_events", "original+chat"):
# Please keep this un-implemented for backwards compatibility.
# When it's unimplemented old tracers that use the "original" format
# fallback on the on_llm_start method implementation if they
Expand Down Expand Up @@ -360,7 +364,7 @@ def on_chain_start(

def _get_chain_inputs(self, inputs: Any) -> Any:
"""Get the inputs for a chain run."""
if self._schema_format == "original":
if self._schema_format in ("original", "original+chat"):
return inputs if isinstance(inputs, dict) else {"input": inputs}
elif self._schema_format == "streaming_events":
return {
Expand All @@ -371,7 +375,7 @@ def _get_chain_inputs(self, inputs: Any) -> Any:

def _get_chain_outputs(self, outputs: Any) -> Any:
"""Get the outputs for a chain run."""
if self._schema_format == "original":
if self._schema_format in ("original", "original+chat"):
return outputs if isinstance(outputs, dict) else {"output": outputs}
elif self._schema_format == "streaming_events":
return {
Expand Down Expand Up @@ -436,7 +440,7 @@ def on_tool_start(
if metadata:
kwargs.update({"metadata": metadata})

if self._schema_format == "original":
if self._schema_format in ("original", "original+chat"):
inputs = {"input": input_str}
elif self._schema_format == "streaming_events":
inputs = {"input": inputs}
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/tracers/log_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def _get_standardized_inputs(


def _get_standardized_outputs(
run: Run, schema_format: Literal["original", "streaming_events"]
run: Run, schema_format: Literal["original", "streaming_events", "original+chat"]
) -> Optional[Any]:
"""Extract standardized output from a run.
Expand Down
2 changes: 1 addition & 1 deletion libs/core/langchain_core/tracers/root_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
on_end: Optional[Listener],
on_error: Optional[Listener],
) -> None:
super().__init__()
super().__init__(_schema_format="original+chat")

self.config = config
self._arg_on_start = on_start
Expand Down
2 changes: 2 additions & 0 deletions libs/core/tests/unit_tests/fake/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class ChatMessageHistory(BaseChatMessageHistory, BaseModel):

def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
if not isinstance(message, BaseMessage):
raise ValueError
self.messages.append(message)

def clear(self) -> None:
Expand Down

0 comments on commit ee32369

Please sign in to comment.