Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

incorporate tool calling #131

Merged
merged 8 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,947 changes: 1,147 additions & 800 deletions backend/poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fastapi = "^0.109.2"
langserve = "^0.0.45"
uvicorn = "^0.27.1"
pydantic = "^1.10"
langchain-openai = "^0.0.8"
langchain-openai = "^0.1.3"
jsonschema = "^4.21.1"
sse-starlette = "^2.0.0"
alembic = "^1.13.1"
Expand All @@ -26,6 +26,8 @@ lxml = "^5.1.0"
faiss-cpu = "^1.7.4"
python-multipart = "^0.0.9"
langchain-fireworks = "^0.1.1"
langchain-anthropic = "^0.1.11"
langchain-groq = "^0.1.3"

[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
Expand Down
25 changes: 12 additions & 13 deletions backend/server/extraction_runnable.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import json
import uuid
from typing import Any, Dict, List, Optional, Sequence

from fastapi import HTTPException
from jsonschema import Draft202012Validator, exceptions
from langchain.text_splitter import TokenTextSplitter
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import chain
from langserve import CustomUserType
Expand Down Expand Up @@ -97,19 +98,18 @@ def _make_prompt_template(
# TODO: We'll need to refactor this at some point to
# support other encoding strategies. The function calling logic here
# has some hard-coded assumptions (e.g., name of parameters like `data`).
function_call = {
"arguments": json.dumps(
{
"data": example.output,
}
),
_id = uuid.uuid4().hex[:]
tool_call = {
"args": {"data": example.output},
"name": function_name,
"id": _id,
}
few_shot_prompt.extend(
[
HumanMessage(content=example.text),
AIMessage(
content="", additional_kwargs={"function_call": function_call}
AIMessage(content="", tool_calls=[tool_call]),
ToolMessage(
content="You have correctly called this tool.", tool_call_id=_id
),
]
)
Expand Down Expand Up @@ -172,10 +172,9 @@ async def extraction_runnable(extraction_request: ExtractRequest) -> ExtractResp
schema["title"],
)
model = get_model(extraction_request.model_name)
# N.B. method must be consistent with examples in _make_prompt_template
runnable = (
prompt | model.with_structured_output(schema=schema, method="function_calling")
).with_config({"run_name": "extraction"})
runnable = (prompt | model.with_structured_output(schema=schema)).with_config(
{"run_name": "extraction"}
)

return await runnable.ainvoke({"text": extraction_request.text})

Expand Down
17 changes: 17 additions & 0 deletions backend/server/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from typing import Optional

from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_fireworks import ChatFireworks
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI


Expand Down Expand Up @@ -37,6 +39,21 @@ def get_supported_models():
),
"description": "Mixtral 8x7B Instruct v0.1 (Together AI)",
}
if "ANTHROPIC_API_KEY" in os.environ:
models["claude-3-sonnet-20240229"] = {
"chat_model": ChatAnthropic(
model="claude-3-sonnet-20240229", temperature=0
),
"description": "Claude 3 Sonnet",
}
if "GROQ_API_KEY" in os.environ:
models["groq-llama3-8b-8192"] = {
"chat_model": ChatGroq(
model="llama3-8b-8192",
temperature=0,
),
"description": "GROQ Llama 3 8B",
}

return models

Expand Down
20 changes: 14 additions & 6 deletions backend/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,33 @@
from tests.unit_tests.fake.chat_model import GenericFakeChatModel


class AnyStr(str):
def __init__(self) -> None:
super().__init__()

def __eq__(self, other: object) -> bool:
return isinstance(other, str)


def test_generic_fake_chat_model_invoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())


async def test_generic_fake_chat_model_ainvoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
11 changes: 7 additions & 4 deletions backend/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.messages import AIMessage

from extraction.utils import update_json_schema
from server.extraction_runnable import ExtractionExample, _make_prompt_template
Expand Down Expand Up @@ -82,19 +83,21 @@ def test_make_prompt_template() -> None:
)
prompt = _make_prompt_template(instructions, examples, "name")
messages = prompt.messages
assert 4 == len(messages)
assert 5 == len(messages)
system = messages[0].prompt.template
assert system.startswith(prefix)
assert system.endswith(instructions)

example_input = messages[1]
assert example_input.content == "Test text."
example_output = messages[2]
assert "function_call" in example_output.additional_kwargs
assert example_output.additional_kwargs["function_call"]["name"] == "name"
assert isinstance(example_output, AIMessage)
assert example_output.tool_calls
assert len(example_output.tool_calls) == 1
assert example_output.tool_calls[0]["name"] == "name"

prompt = _make_prompt_template(instructions, None, "name")
assert 2 == len(prompt.messages)

prompt = _make_prompt_template(None, examples, "name")
assert 4 == len(prompt.messages)
assert 5 == len(prompt.messages)
Loading