Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Jan 31, 2025
1 parent e5daed4 commit a253729
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 100 deletions.
92 changes: 42 additions & 50 deletions libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ollama_base_url,
ollama_model_installed,
)
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
from kiln_ai.utils.config import Config

LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
Expand Down Expand Up @@ -92,12 +91,11 @@ async def model(self) -> LangChainModelType:
if self._model:
return self._model

self._model = await langchain_model_from(
self.model_name, self.model_provider_name
)
self._model = await self.langchain_model_from()

# TODO better way: structured_output_mode = raw?
# Don't setup structured output unless task is structured, and we haven't said we want instruction based JSON
# Decide if we want to use Langchain's structured output:
# 1. Only for structured tasks
# 2. Only if the provider's mode isn't json_instructions (only mode that doesn't use an API option for structured output capabilities)
provider = await self.model_provider()
use_lc_structured_output = (
self.has_structured_output()
Expand All @@ -120,7 +118,7 @@ async def model(self) -> LangChainModelType:
)
output_schema["title"] = "task_response"
output_schema["description"] = "A response from the task"
with_structured_output_options = await get_structured_output_options(
with_structured_output_options = await self.get_structured_output_options(
self.model_name, self.model_provider_name
)
self._model = self._model.with_structured_output(
Expand All @@ -142,14 +140,12 @@ async def _run(self, input: Dict | str) -> RunOutput:
HumanMessage(content=user_msg),
]

# TODO: make this thinking native
# TODO: make this compatible with thinking models
# COT with structured output
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
if cot_prompt and self.has_structured_output():
# Base model (without structured output) used for COT message
base_model = await langchain_model_from(
self.model_name, self.model_provider_name
)
base_model = await self.langchain_model_from()
messages.append(
SystemMessage(content=cot_prompt),
)
Expand All @@ -164,9 +160,7 @@ async def _run(self, input: Dict | str) -> RunOutput:
elif cot_prompt:
messages.append(SystemMessage(content=cot_prompt))

# TODO: make this thinking native
response = await chain.ainvoke(messages)
print(f"response: {response}")

# Langchain may have already parsed the response into structured output, so use that if available.
# However, a plain string may still be fixed at the parsing layer, so not being structured isn't a critical failure (yet)
Expand All @@ -184,9 +178,11 @@ async def _run(self, input: Dict | str) -> RunOutput:

if not isinstance(response, BaseMessage):
raise RuntimeError(f"response is not a BaseMessage: {response}")

text_content = response.content
if not isinstance(text_content, str):
raise RuntimeError(f"response is not a string: {text_content}")

return RunOutput(
output=text_content,
intermediate_outputs=intermediate_outputs,
Expand All @@ -211,44 +207,40 @@ def _munge_response(self, response: Dict) -> Dict:
return response["arguments"]
return response

async def get_structured_output_options(
self, model_name: str, model_provider_name: str
) -> Dict[str, Any]:
provider = await self.model_provider()
if not provider:
return {}

options = {}
# We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now
match provider.structured_output_mode:
case StructuredOutputMode.function_calling:
options["method"] = "function_calling"
case StructuredOutputMode.json_mode:
options["method"] = "json_mode"
case StructuredOutputMode.json_schema:
options["method"] = "json_schema"
case StructuredOutputMode.json_instructions:
# JSON done via instructions in prompt, not via API
pass
case StructuredOutputMode.default:
# Let langchain decide the default
pass
case _:
raise ValueError(
f"Unhandled enum value: {provider.structured_output_mode}"
)
# triggers pyright warning if I miss a case
return NoReturn

async def get_structured_output_options(
model_name: str, model_provider_name: str
) -> Dict[str, Any]:
# TODO self cache
provider = await kiln_model_provider_from(model_name, model_provider_name)
if not provider:
return {}

options = {}
# We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now
match provider.structured_output_mode:
case StructuredOutputMode.function_calling:
options["method"] = "function_calling"
case StructuredOutputMode.json_mode:
options["method"] = "json_mode"
case StructuredOutputMode.json_schema:
options["method"] = "json_schema"
case StructuredOutputMode.json_instructions:
# JSON done via instructions in prompt, not via API
pass
case StructuredOutputMode.default:
# Let langchain decide the default
pass
case _:
raise ValueError(f"Unhandled enum value: {provider.structured_output_mode}")
# triggers pyright warning if I miss a case
return NoReturn

return options


async def langchain_model_from(
name: str, provider_name: str | None = None
) -> BaseChatModel:
# TODO use self.model_provider() for caching
provider = await kiln_model_provider_from(name, provider_name)
return await langchain_model_from_provider(provider, name)
return options

async def langchain_model_from(self) -> BaseChatModel:
provider = await self.model_provider()
return await langchain_model_from_provider(provider, self.model_name)


async def langchain_model_from_provider(
Expand Down
90 changes: 42 additions & 48 deletions libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,22 @@
)
from kiln_ai.adapters.model_adapters.langchain_adapters import (
LangchainAdapter,
get_structured_output_options,
langchain_model_from_provider,
)
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
from kiln_ai.adapters.test_prompt_adaptors import build_test_task


def test_langchain_adapter_munge_response(tmp_path):
task = build_test_task(tmp_path)
lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
@pytest.fixture
def mock_adapter(tmp_path):
return LangchainAdapter(
kiln_task=build_test_task(tmp_path),
model_name="llama_3_1_8b",
provider="ollama",
)


def test_langchain_adapter_munge_response(mock_adapter):
# Mistral Large tool calling format is a bit different
response = {
"name": "task_response",
Expand All @@ -34,12 +40,12 @@ def test_langchain_adapter_munge_response(tmp_path):
"punchline": "Because she wanted to be a moo-sician!",
},
}
munged = lca._munge_response(response)
munged = mock_adapter._munge_response(response)
assert munged["setup"] == "Why did the cow join a band?"
assert munged["punchline"] == "Because she wanted to be a moo-sician!"

# non mistral format should continue to work
munged = lca._munge_response(response["arguments"])
munged = mock_adapter._munge_response(response["arguments"])
assert munged["setup"] == "Why did the cow join a band?"
assert munged["punchline"] == "Because she wanted to be a moo-sician!"

Expand Down Expand Up @@ -93,10 +99,7 @@ async def test_langchain_adapter_with_cot(tmp_path):

# Patch both the langchain_model_from function and self.model()
with (
patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from",
mock_model_from,
),
patch.object(LangchainAdapter, "langchain_model_from", mock_model_from),
patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
):
response = await lca._run("test input")
Expand Down Expand Up @@ -145,18 +148,18 @@ async def test_langchain_adapter_with_cot(tmp_path):
(StructuredOutputMode.default, None),
],
)
async def test_get_structured_output_options(structured_output_mode, expected_method):
async def test_get_structured_output_options(
mock_adapter, structured_output_mode, expected_method
):
# Mock the provider response
mock_provider = MagicMock()
mock_provider.structured_output_mode = structured_output_mode

# Test with provider that has options
with patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.kiln_model_provider_from",
AsyncMock(return_value=mock_provider),
):
options = await get_structured_output_options("model_name", "provider")
assert options.get("method") == expected_method
# Mock adapter.model_provider()
mock_adapter.model_provider = AsyncMock(return_value=mock_provider)

options = await mock_adapter.get_structured_output_options("model_name", "provider")
assert options.get("method") == expected_method


@pytest.mark.asyncio
Expand Down Expand Up @@ -292,31 +295,25 @@ async def test_langchain_adapter_model_structured_output(tmp_path):
adapter = LangchainAdapter(
kiln_task=task, model_name="test_model", provider="ollama"
)
adapter.get_structured_output_options = AsyncMock(
return_value={"option1": "value1"}
)
adapter.langchain_model_from = AsyncMock(return_value=mock_model)

with (
patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from",
AsyncMock(return_value=mock_model),
),
patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.get_structured_output_options",
AsyncMock(return_value={"option1": "value1"}),
),
):
model = await adapter.model()

# Verify the model was configured with structured output
mock_model.with_structured_output.assert_called_once_with(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
"title": "task_response",
"description": "A response from the task",
},
include_raw=True,
option1="value1",
)
assert model == "structured_model"
model = await adapter.model()

# Verify the model was configured with structured output
mock_model.with_structured_output.assert_called_once_with(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
"title": "task_response",
"description": "A response from the task",
},
include_raw=True,
option1="value1",
)
assert model == "structured_model"


@pytest.mark.asyncio
Expand All @@ -333,10 +330,7 @@ async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
adapter = LangchainAdapter(
kiln_task=task, model_name="test_model", provider="ollama"
)
adapter.langchain_model_from = AsyncMock(return_value=mock_model)

with patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from",
AsyncMock(return_value=mock_model),
):
with pytest.raises(ValueError, match="does not support structured output"):
await adapter.model()
with pytest.raises(ValueError, match="does not support structured output"):
await adapter.model()
4 changes: 2 additions & 2 deletions libs/core/kiln_ai/datamodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ class StructuredOutputMode(str, Enum):
- json_schema: request json using API capabilities for json_schema
- function_calling: request json using API capabilities for function calling
- json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema
- json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used.
- json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities.
- json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings.
- json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries).
"""

default = "default"
Expand Down

0 comments on commit a253729

Please sign in to comment.