Skip to content

Commit

Permalink
anthropic[patch]: use core output parsers for structured output (#23776)
Browse files Browse the repository at this point in the history
Also add to standard tests for structured output.
  • Loading branch information
ccurme authored Jul 2, 2024
1 parent dc39683 commit 46cbf0e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
15 changes: 11 additions & 4 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.output_parsers import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import (
Expand All @@ -58,7 +63,7 @@
)
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_anthropic.output_parsers import ToolsOutputParser, extract_tool_calls
from langchain_anthropic.output_parsers import extract_tool_calls

_message_type_lookups = {
"human": "user",
Expand Down Expand Up @@ -990,11 +995,13 @@ class AnswerWithJustification(BaseModel):
tool_name = convert_to_anthropic_tool(schema)["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser = ToolsOutputParser(
first_tool_only=True, pydantic_schemas=[schema]
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
output_parser = ToolsOutputParser(first_tool_only=True, args_only=True)
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)

if include_raw:
parser_assign = RunnablePassthrough.assign(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,25 @@ class Joke(BaseModel):
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")

# Pydantic class
chat = model.with_structured_output(Joke)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = model.with_structured_output(Joke.schema())
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}

def test_tool_message_histories_string_content(
self,
model: BaseChatModel,
Expand Down

0 comments on commit 46cbf0e

Please sign in to comment.