From 88f069aae32adb0712251cd42878cc6eeae8f3a2 Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 17:17:35 -0500 Subject: [PATCH 01/18] Checkpoint --- .../core/kiln_ai/adapters/adapter_registry.py | 27 +++- libs/core/kiln_ai/adapters/base_adapter.py | 54 +++++-- .../kiln_ai/adapters/langchain_adapters.py | 144 ++++++++++++------ libs/core/kiln_ai/adapters/ml_model_list.py | 20 ++- .../kiln_ai/adapters/test_prompt_adaptors.py | 10 +- .../adapters/test_saving_adapter_results.py | 17 ++- .../adapters/test_structured_output.py | 2 +- 7 files changed, 195 insertions(+), 79 deletions(-) diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 0e766eea..28cd0686 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -1,7 +1,14 @@ +from os import getenv + from kiln_ai import datamodel from kiln_ai.adapters.base_adapter import BaseAdapter from kiln_ai.adapters.langchain_adapters import LangchainAdapter +from kiln_ai.adapters.ml_model_list import ModelProviderName +from kiln_ai.adapters.model_adapters.open_ai_model_adapter import ( + OpenAICompatibleAdapter, +) from kiln_ai.adapters.prompt_builders import BasePromptBuilder +from kiln_ai.utils.config import Config def adapter_for_task( @@ -11,7 +18,25 @@ def adapter_for_task( prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ) -> BaseAdapter: - # We use langchain for everything right now, but can add any others here + if provider == ModelProviderName.openrouter: + api_key = Config.shared().open_router_api_key + base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" + print(f"base_url: {base_url} provider: {provider} model_name: {model_name}") + return OpenAICompatibleAdapter( + base_url=base_url, + api_key=api_key, + kiln_task=kiln_task, + model_name=model_name, + provider_name=provider, + prompt_builder=prompt_builder, + tags=tags, + default_headers={ + "HTTP-Referer": "https://getkiln.ai/openrouter", + "X-Title": "KilnAI", + }, + ) + + # We use langchain for all others right now, but moving off it as we touch anything. return LangchainAdapter( kiln_task, model_name=model_name, diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/base_adapter.py index 6282f699..d01c4f42 100644 --- a/libs/core/kiln_ai/adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/base_adapter.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Dict +from kiln_ai.adapters.provider_tools import kiln_model_provider_from +from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import ( DataSource, DataSourceType, @@ -13,6 +15,8 @@ from kiln_ai.datamodel.json_schema import validate_schema from kiln_ai.utils.config import Config +from .ml_model_list import KilnModelProvider +from .parsers.parser_registry import model_parser_from_id from .prompt_builders import BasePromptBuilder, SimplePromptBuilder @@ -25,12 +29,6 @@ class AdapterInfo: prompt_id: str | None = None -@dataclass -class RunOutput: - output: Dict | str - intermediate_outputs: Dict[str, str] | None - - class BaseAdapter(metaclass=ABCMeta): """Base class for AI model adapters that handle task execution. @@ -48,6 +46,8 @@ class BaseAdapter(metaclass=ABCMeta): def __init__( self, kiln_task: Task, + model_name: str, + model_provider_name: str, prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ): @@ -56,6 +56,26 @@ def __init__( self.output_schema = self.kiln_task.output_json_schema self.input_schema = self.kiln_task.input_json_schema self.default_tags = tags + self.model_name = model_name + self.model_provider_name = model_provider_name + self._model_provider: KilnModelProvider | None = None + + async def model_provider(self) -> KilnModelProvider: + """ + Lazy load the model provider for this adapter. + """ + if self._model_provider is not None: + return self._model_provider + if not self.model_name or not self.model_provider_name: + raise ValueError("model_name and model_provider_name must be provided") + self._model_provider = await kiln_model_provider_from( + self.model_name, self.model_provider_name + ) + if not self._model_provider: + raise ValueError( + f"model_provider_name {self.model_provider_name} not found for model {self.model_name}" + ) + return self._model_provider async def invoke_returning_raw( self, @@ -82,21 +102,31 @@ async def invoke( # Run run_output = await self._run(input) + # Parse + provider = await self.model_provider() + parser = model_parser_from_id(provider.parser)( + structured_output=self.has_structured_output() + ) + parsed_output = parser.parse_output(original_output=run_output) + print( + f"parsed_output: {parsed_output.output} \nIntermediate outputs: {parsed_output.intermediate_outputs}" + ) + # validate output if self.output_schema is not None: - if not isinstance(run_output.output, dict): + if not isinstance(parsed_output.output, dict): raise RuntimeError( - f"structured response is not a dict: {run_output.output}" + f"structured response is not a dict: {parsed_output.output}" ) - validate_schema(run_output.output, self.output_schema) + validate_schema(parsed_output.output, self.output_schema) else: - if not isinstance(run_output.output, str): + if not isinstance(parsed_output.output, str): raise RuntimeError( - f"response is not a string for non-structured task: {run_output.output}" + f"response is not a string for non-structured task: {parsed_output.output}" ) # Generate the run and output - run = self.generate_run(input, input_source, run_output) + run = self.generate_run(input, input_source, parsed_output) # Save the run if configured to do so, and we have a path to save to if Config.shared().autosave_runs and self.kiln_task.path is not None: diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/langchain_adapters.py index 9cf6be31..8c7db7d0 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/langchain_adapters.py @@ -23,7 +23,12 @@ from kiln_ai.utils.config import Config from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput -from .ml_model_list import KilnModelProvider, ModelProviderName, StructuredOutputMode +from .ml_model_list import ( + KilnModelProvider, + ModelParserID, + ModelProviderName, + StructuredOutputMode, +) from .provider_tools import kiln_model_provider_from LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel] @@ -41,37 +46,55 @@ def __init__( prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ): - super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags) if custom_model is not None: self._model = custom_model # Attempt to infer model provider and name from custom model - self.model_provider = "custom.langchain:" + custom_model.__class__.__name__ - self.model_name = "custom.langchain:unknown_model" - if hasattr(custom_model, "model_name") and isinstance( - getattr(custom_model, "model_name"), str - ): - self.model_name = "custom.langchain:" + getattr( - custom_model, "model_name" - ) - if hasattr(custom_model, "model") and isinstance( - getattr(custom_model, "model"), str - ): - self.model_name = "custom.langchain:" + getattr(custom_model, "model") + if provider is None: + provider = "custom.langchain:" + custom_model.__class__.__name__ + + if model_name is None: + model_name = "custom.langchain:unknown_model" + if hasattr(custom_model, "model_name") and isinstance( + getattr(custom_model, "model_name"), str + ): + model_name = "custom.langchain:" + getattr( + custom_model, "model_name" + ) + if hasattr(custom_model, "model") and isinstance( + getattr(custom_model, "model"), str + ): + model_name = "custom.langchain:" + getattr(custom_model, "model") elif model_name is not None: - self.model_name = model_name - self.model_provider = provider or "custom.langchain.default_provider" + # default provider name if not provided + provider = provider or "custom.langchain.default_provider" else: raise ValueError( "model_name and provider must be provided if custom_model is not provided" ) + super().__init__( + kiln_task, + model_name=model_name, + model_provider_name=provider, + prompt_builder=prompt_builder, + tags=tags, + ) + async def model(self) -> LangChainModelType: # cached model if self._model: return self._model - self._model = await langchain_model_from(self.model_name, self.model_provider) + self._model = await langchain_model_from( + self.model_name, self.model_provider_name + ) + + # TODO better way: structured_output_mode = raw? + # If we're going to parse the output, don't use LC's structured output + provider = await self.model_provider() + custom_output_parser = provider.parser == ModelParserID.r1_thinking + print(f"custom_output_parser: {custom_output_parser}") if self.has_structured_output(): if not hasattr(self._model, "with_structured_output") or not callable( @@ -89,7 +112,7 @@ async def model(self) -> LangChainModelType: output_schema["title"] = "task_response" output_schema["description"] = "A response from the task" with_structured_output_options = await get_structured_output_options( - self.model_name, self.model_provider + self.model_name, self.model_provider_name ) self._model = self._model.with_structured_output( output_schema, @@ -110,12 +133,13 @@ async def _run(self, input: Dict | str) -> RunOutput: HumanMessage(content=user_msg), ] + # TODO: make this thinking native # COT with structured output cot_prompt = self.prompt_builder.chain_of_thought_prompt() if cot_prompt and self.has_structured_output(): # Base model (without structured output) used for COT message base_model = await langchain_model_from( - self.model_name, self.model_provider + self.model_name, self.model_provider_name ) messages.append( SystemMessage(content=cot_prompt), @@ -131,35 +155,38 @@ async def _run(self, input: Dict | str) -> RunOutput: elif cot_prompt: messages.append(SystemMessage(content=cot_prompt)) - response = await chain.ainvoke(messages) + # TODO: make this thinking native + response = await chain.ainvoke(messages, include_reasoning=True) + print(f"response: {response}") - if self.has_structured_output(): - if ( - not isinstance(response, dict) - or "parsed" not in response - or not isinstance(response["parsed"], dict) - ): - raise RuntimeError(f"structured response not returned: {response}") + # Langchain may have already parsed the response into structured output, so use that if available. + # However, a plain string may still be fixed at the parsing layer, so not being structured isn't a critical failure (yet) + if ( + self.has_structured_output() + and isinstance(response, dict) + and "parsed" in response + and isinstance(response["parsed"], dict) + ): structured_response = response["parsed"] return RunOutput( output=self._munge_response(structured_response), intermediate_outputs=intermediate_outputs, ) - else: - if not isinstance(response, BaseMessage): - raise RuntimeError(f"response is not a BaseMessage: {response}") - text_content = response.content - if not isinstance(text_content, str): - raise RuntimeError(f"response is not a string: {text_content}") - return RunOutput( - output=text_content, - intermediate_outputs=intermediate_outputs, - ) + + if not isinstance(response, BaseMessage): + raise RuntimeError(f"response is not a BaseMessage: {response}") + text_content = response.content + if not isinstance(text_content, str): + raise RuntimeError(f"response is not a string: {text_content}") + return RunOutput( + output=text_content, + intermediate_outputs=intermediate_outputs, + ) def adapter_info(self) -> AdapterInfo: return AdapterInfo( model_name=self.model_name, - model_provider=self.model_provider, + model_provider=self.model_provider_name, adapter_name="kiln_langchain_adapter", prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(), prompt_id=self.prompt_builder.prompt_id(), @@ -177,15 +204,16 @@ def _munge_response(self, response: Dict) -> Dict: async def get_structured_output_options( - model_name: str, model_provider: str + model_name: str, model_provider_name: str ) -> Dict[str, Any]: - finetune_provider = await kiln_model_provider_from(model_name, model_provider) - if not finetune_provider: + # TODO self cache + provider = await kiln_model_provider_from(model_name, model_provider_name) + if not provider: return {} options = {} # We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now - match finetune_provider.structured_output_mode: + match provider.structured_output_mode: case StructuredOutputMode.function_calling: options["method"] = "function_calling" case StructuredOutputMode.json_mode: @@ -196,15 +224,35 @@ async def get_structured_output_options( # Let langchain decide the default pass case _: + raise ValueError(f"Unhandled enum value: {provider.structured_output_mode}") # triggers pyright warning if I miss a case - raise_exhaustive_error(finetune_provider.structured_output_mode) + return NoReturn return options +# class that wraps another class and logs all function calls being executed +class Wrapper: + def __init__(self, wrapped_class): + self.wrapped_class = wrapped_class + + def __getattr__(self, attr): + original_func = getattr(self.wrapped_class, attr) + + def wrapper(*args, **kwargs): + print(f"Calling function: {attr}") + print(f"Arguments: {args}, {kwargs}") + result = original_func(*args, **kwargs) + print(f"Response: {result}") + return result + + return wrapper + + async def langchain_model_from( name: str, provider_name: str | None = None ) -> BaseChatModel: + # TODO use self.model_provider() for caching provider = await kiln_model_provider_from(name, provider_name) return await langchain_model_from_provider(provider, name) @@ -257,9 +305,11 @@ async def langchain_model_from_provider( raise ValueError(f"Model {model_name} not installed on Ollama") elif provider.name == ModelProviderName.openrouter: + # TODO raise error + raise ValueError("OpenRouter is not supported in Langchain") api_key = Config.shared().open_router_api_key base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" - return ChatOpenAI( + cor = ChatOpenAI( **provider.provider_options, openai_api_key=api_key, # type: ignore[arg-type] openai_api_base=base_url, # type: ignore[arg-type] @@ -268,9 +318,7 @@ async def langchain_model_from_provider( "X-Title": "KilnAI", }, ) + cor.client = Wrapper(cor.client) + return cor else: raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}") - - -def raise_exhaustive_error(value: NoReturn) -> NoReturn: - raise ValueError(f"Unhandled enum value: {value}") diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 5a85cbd3..16f91579 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -82,6 +82,14 @@ class ModelName(str, Enum): deepseek_r1 = "deepseek_r1" +class ModelParserID(str, Enum): + """ + Enumeration of supported model parsers. + """ + + r1_thinking = "r1_thinking" + + class KilnModelProvider(BaseModel): """ Configuration for a specific model provider. @@ -94,6 +102,7 @@ class KilnModelProvider(BaseModel): provider_finetune_id: The finetune ID for the provider, if applicable provider_options: Additional provider-specific configuration options structured_output_mode: The mode we should use to call the model for structured output, if it was trained with structured output. + parser: A parser to use for the model, if applicable """ name: ModelProviderName @@ -103,6 +112,7 @@ class KilnModelProvider(BaseModel): provider_finetune_id: str | None = None provider_options: Dict = {} structured_output_mode: StructuredOutputMode = StructuredOutputMode.default + parser: ModelParserID | None = None class KilnModel(BaseModel): @@ -207,21 +217,19 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "deepseek/deepseek-r1"}, - structured_output_mode=StructuredOutputMode.json_schema, + parser=ModelParserID.r1_thinking, + structured_output_mode=StructuredOutputMode.json_mode, ), KilnModelProvider( name=ModelProviderName.fireworks_ai, provider_options={"model": "accounts/fireworks/models/deepseek-r1"}, - # Truncates the thinking, but json_mode works - structured_output_mode=StructuredOutputMode.json_mode, - supports_structured_output=False, - supports_data_gen=False, + parser=ModelParserID.r1_thinking, ), KilnModelProvider( # I want your RAM name=ModelProviderName.ollama, provider_options={"model": "deepseek-r1:671b"}, - structured_output_mode=StructuredOutputMode.json_schema, + parser=ModelParserID.r1_thinking, ), ], ), diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index 2a64310e..6acef2fc 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -108,7 +108,11 @@ async def test_amazon_bedrock(tmp_path): async def test_mock(tmp_path): task = build_test_task(tmp_path) mockChatModel = FakeListChatModel(responses=["mock response"]) - adapter = LangchainAdapter(task, custom_model=mockChatModel) + adapter = LangchainAdapter( + task, + custom_model=mockChatModel, + provider="ollama", + ) run = await adapter.invoke("You are a mock, send me the response!") assert "mock response" in run.output.output @@ -116,7 +120,7 @@ async def test_mock(tmp_path): async def test_mock_returning_run(tmp_path): task = build_test_task(tmp_path) mockChatModel = FakeListChatModel(responses=["mock response"]) - adapter = LangchainAdapter(task, custom_model=mockChatModel) + adapter = LangchainAdapter(task, custom_model=mockChatModel, provider="ollama") run = await adapter.invoke("You are a mock, send me the response!") assert run.output.output == "mock response" assert run is not None @@ -127,7 +131,7 @@ async def test_mock_returning_run(tmp_path): assert run.output.source.properties == { "adapter_name": "kiln_langchain_adapter", "model_name": "custom.langchain:unknown_model", - "model_provider": "custom.langchain:FakeListChatModel", + "model_provider": "ollama", "prompt_builder_name": "simple_prompt_builder", } diff --git a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/test_saving_adapter_results.py index 54a32f50..62867bdd 100644 --- a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/test_saving_adapter_results.py @@ -39,8 +39,12 @@ def test_task(tmp_path): return task -def test_save_run_isolation(test_task): - adapter = MockAdapter(test_task) +@pytest.fixture +def adapter(test_task): + return MockAdapter(test_task, model_name="phi_3_5", model_provider_name="ollama") + + +def test_save_run_isolation(test_task, adapter): input_data = "Test input" output_data = "Test output" run_output = RunOutput( @@ -124,8 +128,7 @@ def test_save_run_isolation(test_task): assert output_data in set(run.output.output for run in test_task.runs()) -def test_generate_run_non_ascii(test_task): - adapter = MockAdapter(test_task) +def test_generate_run_non_ascii(test_task, adapter): input_data = {"key": "input with non-ascii character: 你好"} output_data = {"key": "output with non-ascii character: 你好"} run_output = RunOutput( @@ -154,13 +157,12 @@ def test_generate_run_non_ascii(test_task): @pytest.mark.asyncio -async def test_autosave_false(test_task): +async def test_autosave_false(test_task, adapter): with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = False mock_config.user_id = "test_user" - adapter = MockAdapter(test_task) input_data = "Test input" run = await adapter.invoke(input_data) @@ -173,13 +175,12 @@ async def test_autosave_false(test_task): @pytest.mark.asyncio -async def test_autosave_true(test_task): +async def test_autosave_true(test_task, adapter): with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = True mock_config.user_id = "test_user" - adapter = MockAdapter(test_task) input_data = "Test input" run = await adapter.invoke(input_data) diff --git a/libs/core/kiln_ai/adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/test_structured_output.py index 0ea56f54..9b44527d 100644 --- a/libs/core/kiln_ai/adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/test_structured_output.py @@ -44,7 +44,7 @@ async def test_structured_output_ollama_llama(tmp_path, model_name): class MockAdapter(BaseAdapter): def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None): - super().__init__(kiln_task) + super().__init__(kiln_task, model_name="phi_3_5", model_provider_name="ollama") self.response = response async def _run(self, input: str) -> RunOutput: From 13d7cc9da475cb0171eb671d2ed68e9e0a279d5c Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 17:18:05 -0500 Subject: [PATCH 02/18] Checkpoint missing files --- .../model_adapters/open_ai_model_adapter.py | 178 ++++++++++++++++++ .../kiln_ai/adapters/parsers/base_parser.py | 47 +++++ .../adapters/parsers/parser_registry.py | 22 +++ .../kiln_ai/adapters/parsers/r1_parser.py | 74 ++++++++ .../adapters/parsers/test_base_parser.py | 79 ++++++++ .../adapters/parsers/test_r1_parser.py | 139 ++++++++++++++ libs/core/kiln_ai/adapters/run_output.py | 8 + 7 files changed, 547 insertions(+) create mode 100644 libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py create mode 100644 libs/core/kiln_ai/adapters/parsers/base_parser.py create mode 100644 libs/core/kiln_ai/adapters/parsers/parser_registry.py create mode 100644 libs/core/kiln_ai/adapters/parsers/r1_parser.py create mode 100644 libs/core/kiln_ai/adapters/parsers/test_base_parser.py create mode 100644 libs/core/kiln_ai/adapters/parsers/test_r1_parser.py create mode 100644 libs/core/kiln_ai/adapters/run_output.py diff --git a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py new file mode 100644 index 00000000..bade1813 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py @@ -0,0 +1,178 @@ +import json +from typing import Any, Dict, NoReturn + +import kiln_ai.datamodel as datamodel +from kiln_ai.adapters.base_adapter import ( + AdapterInfo, + BaseAdapter, + BasePromptBuilder, + RunOutput, +) +from kiln_ai.adapters.ml_model_list import StructuredOutputMode +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice + + +class OpenAICompatibleAdapter(BaseAdapter): + def __init__( + self, + api_key: str, + kiln_task: datamodel.Task, + model_name: str, + provider_name: str, + base_url: str | None = None, # Client will default to OpenAI + default_headers: dict[str, str] | None = None, + prompt_builder: BasePromptBuilder | None = None, + tags: list[str] | None = None, + ): + if not model_name or not provider_name: + raise ValueError( + "model_name and provider_name must be provided for OpenAI compatible adapter" + ) + + # Create an async OpenAI client instead + self.client = AsyncOpenAI( + api_key=api_key, base_url=base_url, default_headers=default_headers + ) + + super().__init__( + kiln_task, + model_name=model_name, + model_provider_name=provider_name, + prompt_builder=prompt_builder, + tags=tags, + ) + + async def _run(self, input: Dict | str) -> RunOutput: + provider = await self.model_provider() + + intermediate_outputs = {} + + prompt = self.build_prompt() + user_msg = self.prompt_builder.build_user_message(input) + messages = [ + {"role": "system", "content": prompt}, + {"role": "user", "content": user_msg}, + ] + + # Handle chain of thought if enabled + cot_prompt = self.prompt_builder.chain_of_thought_prompt() + if cot_prompt and self.has_structured_output(): + # TODO P0: Fix COT + messages.append({"role": "system", "content": cot_prompt}) + + # First call for chain of thought + cot_response = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + ) + cot_content = cot_response.choices[0].message.content + intermediate_outputs = {"chain_of_thought": cot_content} + + messages.extend( + [ + {"role": "assistant", "content": cot_content}, + { + "role": "system", + "content": "Considering the above, return a final result.", + }, + ] + ) + elif cot_prompt: + messages.append({"role": "system", "content": cot_prompt}) + else: + intermediate_outputs = {} + + # Main completion call + response_format_options = await self.response_format_options() + print(f"response_format_options: {response_format_options}") + response = await self.client.chat.completions.create( + model=provider.provider_options["model"], + messages=messages, + # TODO P0: remove this + extra_body={"include_reasoning": True}, + **response_format_options, + ) + print(f"response: {response}") + if not isinstance(response, ChatCompletion): + raise RuntimeError( + f"Expected ChatCompletion response, got {type(response)}." + ) + + if response.error: + raise RuntimeError( + f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}." + ) + if not response.choices or len(response.choices) == 0: + raise RuntimeError( + "No message content returned in the response from OpenAI compatible API" + ) + + response_content = response.choices[0].message.content + if not isinstance(response_content, str): + raise RuntimeError(f"response is not a string: {response_content}") + + # reasoning = response.choices[0].message.get("reasoning") + # print(f"reasoning: {reasoning}") + + if self.has_structured_output(): + try: + structured_response = json.loads(response_content) + return RunOutput( + output=structured_response, + intermediate_outputs=intermediate_outputs, + ) + except json.JSONDecodeError as e: + raise RuntimeError( + f"Failed to parse JSON response: {response_content}" + ) from e + + return RunOutput( + output=response_content, + intermediate_outputs=intermediate_outputs, + ) + + def adapter_info(self) -> AdapterInfo: + return AdapterInfo( + model_name=self.model_name, + model_provider=self.model_provider_name, + adapter_name="kiln_openai_compatible_adapter", + prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(), + prompt_id=self.prompt_builder.prompt_id(), + ) + + async def response_format_options(self) -> dict[str, Any]: + # Unstructured if task isn't structured + if not self.has_structured_output(): + return {} + + provider = await self.model_provider() + # TODO check these + match provider.structured_output_mode: + case StructuredOutputMode.json_mode: + return {"response_format": {"type": "json_object"}} + case StructuredOutputMode.json_schema: + # TODO P0: use json_schema + output_schema = self.kiln_task.output_schema() + print(f"output_schema: {output_schema}") + return { + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "task_response", + "schema": output_schema, + }, + } + } + case StructuredOutputMode.function_calling: + # TODO P0 + return {"response_format": {"type": "function_calling"}} + case StructuredOutputMode.default: + return {} + case _: + raise ValueError( + f"Unsupported structured output mode: {provider.structured_output_mode}" + ) + # pyright will detect missing cases with this + return NoReturn diff --git a/libs/core/kiln_ai/adapters/parsers/base_parser.py b/libs/core/kiln_ai/adapters/parsers/base_parser.py new file mode 100644 index 00000000..69b738d4 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/base_parser.py @@ -0,0 +1,47 @@ +import json +from typing import Any, Dict + +from kiln_ai.adapters.run_output import RunOutput + + +class BaseParser: + def __init__(self, structured_output: bool = False): + self.structured_output = structured_output + + def parse_output(self, original_output: RunOutput) -> RunOutput: + """ + Method for parsing the output of a model. Typically overridden by subclasses. + """ + return original_output + + def parse_json_string(self, json_string: str) -> Dict[str, Any]: + """ + Parse a JSON string into a dictionary. Handles multiple formats: + - Plain JSON + - JSON wrapped in ```json code blocks + - JSON wrapped in ``` code blocks + + Args: + json_string: String containing JSON data, possibly wrapped in code blocks + + Returns: + Dict containing parsed JSON data + + Raises: + ValueError: If JSON parsing fails + """ + # Remove code block markers if present + cleaned_string = json_string.strip() + if cleaned_string.startswith("```"): + # Split by newlines and remove first/last lines if they contain ``` + lines = cleaned_string.split("\n") + if lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].strip() == "```": + lines = lines[:-1] + cleaned_string = "\n".join(lines) + + try: + return json.loads(cleaned_string) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse JSON: {str(e)}") from e diff --git a/libs/core/kiln_ai/adapters/parsers/parser_registry.py b/libs/core/kiln_ai/adapters/parsers/parser_registry.py new file mode 100644 index 00000000..b90fdfe0 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/parser_registry.py @@ -0,0 +1,22 @@ +from typing import NoReturn, Type + +from kiln_ai.adapters.ml_model_list import ModelParserID +from kiln_ai.adapters.parsers.base_parser import BaseParser +from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser + + +def model_parser_from_id(parser_id: ModelParserID | None) -> Type[BaseParser]: + """ + Get a model parser from its ID. + """ + match parser_id: + case None: + return BaseParser + case ModelParserID.r1_thinking: + return R1ThinkingParser + case _: + # triggers pyright warning if I miss a case + raise ValueError( + f"Unhandled enum value for parser ID. You may need to update Kiln to work with this project. Value: {parser_id}" + ) + return NoReturn diff --git a/libs/core/kiln_ai/adapters/parsers/r1_parser.py b/libs/core/kiln_ai/adapters/parsers/r1_parser.py new file mode 100644 index 00000000..55f00ba2 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/r1_parser.py @@ -0,0 +1,74 @@ +from kiln_ai.adapters.parsers.base_parser import BaseParser +from kiln_ai.adapters.run_output import RunOutput + + +class R1ThinkingParser(BaseParser): + START_TAG = "" + END_TAG = "" + + def parse_output(self, original_output: RunOutput) -> RunOutput: + """ + Parse the tags from the response into the intermediate and final outputs. + + Args: + original_output: RunOutput containing the raw response string + + Returns: + ParsedOutput containing the intermediate content (thinking content) and final result + + Raises: + ValueError: If response format is invalid (missing tags, multiple tags, or no content after closing tag) + """ + # TODO + print(f"original_output: {original_output.output}") + + # This parser only works for strings + if not isinstance(original_output.output, str): + raise ValueError("Response must be a string for R1 parser") + if ( + original_output.intermediate_outputs + and len(original_output.intermediate_outputs) > 0 + ): + raise ValueError("Intermediate outputs must be empty for R1 parser") + + # Strip whitespace and validate basic structure + cleaned_response = original_output.output.strip() + if not cleaned_response.startswith(self.START_TAG): + raise ValueError("Response must start with tag") + + # Find the thinking tags + think_start = cleaned_response.find(self.START_TAG) + think_end = cleaned_response.find(self.END_TAG) + + if think_start == -1 or think_end == -1: + raise ValueError("Missing thinking tags") + + # Check for multiple tags + if ( + cleaned_response.count(self.START_TAG) > 1 + or cleaned_response.count(self.END_TAG) > 1 + ): + raise ValueError("Multiple thinking tags found") + + # Extract thinking content + thinking_content = cleaned_response[ + think_start + len(self.START_TAG) : think_end + ].strip() + if len(thinking_content) == 0: + thinking_content = None + + # Extract result (everything after ) + result = cleaned_response[think_end + len(self.END_TAG) :].strip() + + if not result or len(result) == 0: + raise ValueError("No content found after tag") + + # Parse JSON if needed + output = result + if self.structured_output: + output = self.parse_json_string(result) + + return RunOutput( + output=output, + intermediate_outputs={"reasoning": thinking_content}, + ) diff --git a/libs/core/kiln_ai/adapters/parsers/test_base_parser.py b/libs/core/kiln_ai/adapters/parsers/test_base_parser.py new file mode 100644 index 00000000..49726868 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/test_base_parser.py @@ -0,0 +1,79 @@ +import pytest +from kiln_ai.adapters.parsers.base_parser import BaseParser + + +@pytest.fixture +def parser(): + return BaseParser() + + +def test_parse_plain_json(parser): + json_str = '{"key": "value", "number": 42}' + result = parser.parse_json_string(json_str) + assert result == {"key": "value", "number": 42} + + +def test_parse_json_with_code_block(parser): + json_str = """``` + {"key": "value", "number": 42} + ```""" + result = parser.parse_json_string(json_str) + assert result == {"key": "value", "number": 42} + + +def test_parse_json_with_language_block(parser): + json_str = """```json + {"key": "value", "number": 42} + ```""" + result = parser.parse_json_string(json_str) + assert result == {"key": "value", "number": 42} + + +def test_parse_json_with_whitespace(parser): + json_str = """ + { + "key": "value", + "number": 42 + } + """ + result = parser.parse_json_string(json_str) + assert result == {"key": "value", "number": 42} + + +def test_parse_invalid_json(parser): + json_str = '{"key": "value", invalid}' + with pytest.raises(ValueError) as exc_info: + parser.parse_json_string(json_str) + assert "Failed to parse JSON" in str(exc_info.value) + + +def test_parse_empty_code_block(parser): + json_str = """```json + ```""" + with pytest.raises(ValueError) as exc_info: + parser.parse_json_string(json_str) + assert "Failed to parse JSON" in str(exc_info.value) + + +def test_parse_complex_json(parser): + json_str = """```json + { + "string": "hello", + "number": 42, + "bool": true, + "null": null, + "array": [1, 2, 3], + "nested": { + "inner": "value" + } + } + ```""" + result = parser.parse_json_string(json_str) + assert result == { + "string": "hello", + "number": 42, + "bool": True, + "null": None, + "array": [1, 2, 3], + "nested": {"inner": "value"}, + } diff --git a/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py new file mode 100644 index 00000000..821d186e --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py @@ -0,0 +1,139 @@ +import pytest +from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser +from kiln_ai.adapters.run_output import RunOutput + + +@pytest.fixture +def parser(): + return R1ThinkingParser() + + +def test_valid_response(parser): + response = RunOutput( + output="This is thinking contentThis is the result", + intermediate_outputs=None, + ) + parsed = parser.parse_output(response) + assert parsed.intermediate_outputs["reasoning"] == "This is thinking content" + assert parsed.output == "This is the result" + + +def test_response_with_whitespace(parser): + response = RunOutput( + output=""" + + This is thinking content + + This is the result + """, + intermediate_outputs=None, + ) + parsed = parser.parse_output(response) + assert ( + parsed.intermediate_outputs["reasoning"].strip() == "This is thinking content" + ) + assert parsed.output.strip() == "This is the result" + + +def test_missing_start_tag(parser): + with pytest.raises(ValueError, match="Response must start with tag"): + parser.parse_output( + RunOutput(output="Some contentresult", intermediate_outputs=None) + ) + + +def test_missing_end_tag(parser): + with pytest.raises(ValueError, match="Missing thinking tags"): + parser.parse_output( + RunOutput(output="Some content", intermediate_outputs=None) + ) + + +def test_multiple_start_tags(parser): + with pytest.raises(ValueError, match="Multiple thinking tags found"): + parser.parse_output( + RunOutput( + output="content1content2result", + intermediate_outputs=None, + ) + ) + + +def test_multiple_end_tags(parser): + with pytest.raises(ValueError, match="Multiple thinking tags found"): + parser.parse_output( + RunOutput( + output="contentresult", intermediate_outputs=None + ) + ) + + +def test_empty_thinking_content(parser): + response = RunOutput( + output="This is the result", intermediate_outputs=None + ) + parsed = parser.parse_output(response) + assert parsed.intermediate_outputs["reasoning"] is None + assert parsed.output == "This is the result" + + +def test_missing_result(parser): + with pytest.raises(ValueError, match="No content found after tag"): + parser.parse_output( + RunOutput(output="Some content", intermediate_outputs=None) + ) + + +def test_multiline_content(parser): + response = RunOutput( + output="""Line 1 + Line 2 + Line 3Final result""", + intermediate_outputs=None, + ) + parsed = parser.parse_output(response) + assert "Line 1" in parsed.intermediate_outputs["reasoning"] + assert "Line 2" in parsed.intermediate_outputs["reasoning"] + assert "Line 3" in parsed.intermediate_outputs["reasoning"] + assert parsed.output == "Final result" + + +def test_special_characters(parser): + response = RunOutput( + output="Content with: !@#$%^&*思()Result with: !@#$%^&*思()", + intermediate_outputs=None, + ) + parsed = parser.parse_output(response) + assert parsed.intermediate_outputs["reasoning"] == "Content with: !@#$%^&*思()" + assert parsed.output == "Result with: !@#$%^&*思()" + + +def test_non_string_input(parser): + with pytest.raises(ValueError, match="Response must be a string for R1 parser"): + parser.parse_output(RunOutput(output={}, intermediate_outputs=None)) + + +def test_non_none_intermediate_outputs(parser): + with pytest.raises( + ValueError, match="Intermediate outputs must be empty for R1 parser" + ): + parser.parse_output( + RunOutput( + output="Some contentresult", + intermediate_outputs={"some": "data"}, + ) + ) + + # empty dict and None are allowed + parser.parse_output( + RunOutput( + output="Some contentresult", + intermediate_outputs={}, + ) + ) + parser.parse_output( + RunOutput( + output="Some contentresult", + intermediate_outputs=None, + ) + ) diff --git a/libs/core/kiln_ai/adapters/run_output.py b/libs/core/kiln_ai/adapters/run_output.py new file mode 100644 index 00000000..7c34cae6 --- /dev/null +++ b/libs/core/kiln_ai/adapters/run_output.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import Dict + + +@dataclass +class RunOutput: + output: Dict | str + intermediate_outputs: Dict[str, str] | None From 60f02034bf311e75b348780e7d04b6821d854885 Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 18:05:24 -0500 Subject: [PATCH 03/18] Checkpoint --- libs/core/kiln_ai/adapters/base_adapter.py | 1 + libs/core/kiln_ai/adapters/ml_model_list.py | 4 +- .../model_adapters/open_ai_model_adapter.py | 41 ++++++++++++------- .../kiln_ai/adapters/parsers/base_parser.py | 32 --------------- .../kiln_ai/adapters/parsers/json_parser.py | 35 ++++++++++++++++ .../kiln_ai/adapters/parsers/r1_parser.py | 3 +- .../adapters/parsers/test_base_parser.py | 35 +++++++--------- libs/core/kiln_ai/datamodel/__init__.py | 1 + 8 files changed, 82 insertions(+), 70 deletions(-) create mode 100644 libs/core/kiln_ai/adapters/parsers/json_parser.py diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/base_adapter.py index d01c4f42..38e3e0a5 100644 --- a/libs/core/kiln_ai/adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/base_adapter.py @@ -103,6 +103,7 @@ async def invoke( run_output = await self._run(input) # Parse + # TODO P0 provider = await self.model_provider() parser = model_parser_from_id(provider.parser)( structured_output=self.has_structured_output() diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 16f91579..64779085 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -217,8 +217,8 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "deepseek/deepseek-r1"}, - parser=ModelParserID.r1_thinking, - structured_output_mode=StructuredOutputMode.json_mode, + # parser=ModelParserID.r1_thinking, + structured_output_mode=StructuredOutputMode.json_instructions, ), KilnModelProvider( name=ModelProviderName.fireworks_ai, diff --git a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py index bade1813..2ab2ce10 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py @@ -8,7 +8,8 @@ BasePromptBuilder, RunOutput, ) -from kiln_ai.adapters.ml_model_list import StructuredOutputMode +from kiln_ai.adapters.ml_model_list import ModelParserID, StructuredOutputMode +from kiln_ai.adapters.parsers.json_parser import parse_json_string from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice @@ -50,6 +51,15 @@ async def _run(self, input: Dict | str) -> RunOutput: intermediate_outputs = {} prompt = self.build_prompt() + + # TODO P0: move this to prompt builder + if provider.structured_output_mode == StructuredOutputMode.json_instructions: + prompt = ( + prompt + + f"\n\n### Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.kiln_task.output_schema()}\n```" + ) + print(f"prompt: {prompt}") + user_msg = self.prompt_builder.build_user_message(input) messages = [ {"role": "system", "content": prompt}, @@ -100,7 +110,7 @@ async def _run(self, input: Dict | str) -> RunOutput: f"Expected ChatCompletion response, got {type(response)}." ) - if response.error: + if hasattr(response, "error") and response.error: raise RuntimeError( f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}." ) @@ -109,24 +119,21 @@ async def _run(self, input: Dict | str) -> RunOutput: "No message content returned in the response from OpenAI compatible API" ) - response_content = response.choices[0].message.content + message = response.choices[0].message + response_content = message.content if not isinstance(response_content, str): raise RuntimeError(f"response is not a string: {response_content}") - # reasoning = response.choices[0].message.get("reasoning") - # print(f"reasoning: {reasoning}") + # Save reasoning if it exists + if hasattr(message, "reasoning") and message.reasoning: + intermediate_outputs["reasoning"] = message.reasoning if self.has_structured_output(): - try: - structured_response = json.loads(response_content) - return RunOutput( - output=structured_response, - intermediate_outputs=intermediate_outputs, - ) - except json.JSONDecodeError as e: - raise RuntimeError( - f"Failed to parse JSON response: {response_content}" - ) from e + structured_response = parse_json_string(response_content) + return RunOutput( + output=structured_response, + intermediate_outputs=intermediate_outputs, + ) return RunOutput( output=response_content, @@ -168,6 +175,10 @@ async def response_format_options(self) -> dict[str, Any]: case StructuredOutputMode.function_calling: # TODO P0 return {"response_format": {"type": "function_calling"}} + case StructuredOutputMode.json_instructions: + # JSON done via instructions in prompt, not the API response format + # TODO try json_object on stable API + return {} case StructuredOutputMode.default: return {} case _: diff --git a/libs/core/kiln_ai/adapters/parsers/base_parser.py b/libs/core/kiln_ai/adapters/parsers/base_parser.py index 69b738d4..153bc539 100644 --- a/libs/core/kiln_ai/adapters/parsers/base_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/base_parser.py @@ -13,35 +13,3 @@ def parse_output(self, original_output: RunOutput) -> RunOutput: Method for parsing the output of a model. Typically overridden by subclasses. """ return original_output - - def parse_json_string(self, json_string: str) -> Dict[str, Any]: - """ - Parse a JSON string into a dictionary. Handles multiple formats: - - Plain JSON - - JSON wrapped in ```json code blocks - - JSON wrapped in ``` code blocks - - Args: - json_string: String containing JSON data, possibly wrapped in code blocks - - Returns: - Dict containing parsed JSON data - - Raises: - ValueError: If JSON parsing fails - """ - # Remove code block markers if present - cleaned_string = json_string.strip() - if cleaned_string.startswith("```"): - # Split by newlines and remove first/last lines if they contain ``` - lines = cleaned_string.split("\n") - if lines[0].startswith("```"): - lines = lines[1:] - if lines and lines[-1].strip() == "```": - lines = lines[:-1] - cleaned_string = "\n".join(lines) - - try: - return json.loads(cleaned_string) - except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse JSON: {str(e)}") from e diff --git a/libs/core/kiln_ai/adapters/parsers/json_parser.py b/libs/core/kiln_ai/adapters/parsers/json_parser.py new file mode 100644 index 00000000..0af17dac --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/json_parser.py @@ -0,0 +1,35 @@ +import json +from typing import Any, Dict + + +def parse_json_string(json_string: str) -> Dict[str, Any]: + """ + Parse a JSON string into a dictionary. Handles multiple formats: + - Plain JSON + - JSON wrapped in ```json code blocks + - JSON wrapped in ``` code blocks + + Args: + json_string: String containing JSON data, possibly wrapped in code blocks + + Returns: + Dict containing parsed JSON data + + Raises: + ValueError: If JSON parsing fails + """ + # Remove code block markers if present + cleaned_string = json_string.strip() + if cleaned_string.startswith("```"): + # Split by newlines and remove first/last lines if they contain ``` + lines = cleaned_string.split("\n") + if lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].strip() == "```": + lines = lines[:-1] + cleaned_string = "\n".join(lines) + + try: + return json.loads(cleaned_string) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse JSON: {str(e)}") from e diff --git a/libs/core/kiln_ai/adapters/parsers/r1_parser.py b/libs/core/kiln_ai/adapters/parsers/r1_parser.py index 55f00ba2..7629f148 100644 --- a/libs/core/kiln_ai/adapters/parsers/r1_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/r1_parser.py @@ -1,4 +1,5 @@ from kiln_ai.adapters.parsers.base_parser import BaseParser +from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.run_output import RunOutput @@ -66,7 +67,7 @@ def parse_output(self, original_output: RunOutput) -> RunOutput: # Parse JSON if needed output = result if self.structured_output: - output = self.parse_json_string(result) + output = parse_json_string(result) return RunOutput( output=output, diff --git a/libs/core/kiln_ai/adapters/parsers/test_base_parser.py b/libs/core/kiln_ai/adapters/parsers/test_base_parser.py index 49726868..46f61e82 100644 --- a/libs/core/kiln_ai/adapters/parsers/test_base_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/test_base_parser.py @@ -1,61 +1,56 @@ import pytest -from kiln_ai.adapters.parsers.base_parser import BaseParser +from kiln_ai.adapters.parsers.json_parser import parse_json_string -@pytest.fixture -def parser(): - return BaseParser() - - -def test_parse_plain_json(parser): +def test_parse_plain_json(): json_str = '{"key": "value", "number": 42}' - result = parser.parse_json_string(json_str) + result = parse_json_string(json_str) assert result == {"key": "value", "number": 42} -def test_parse_json_with_code_block(parser): +def test_parse_json_with_code_block(): json_str = """``` {"key": "value", "number": 42} ```""" - result = parser.parse_json_string(json_str) + result = parse_json_string(json_str) assert result == {"key": "value", "number": 42} -def test_parse_json_with_language_block(parser): +def test_parse_json_with_language_block(): json_str = """```json {"key": "value", "number": 42} ```""" - result = parser.parse_json_string(json_str) + result = parse_json_string(json_str) assert result == {"key": "value", "number": 42} -def test_parse_json_with_whitespace(parser): +def test_parse_json_with_whitespace(): json_str = """ { "key": "value", "number": 42 } """ - result = parser.parse_json_string(json_str) + result = parse_json_string(json_str) assert result == {"key": "value", "number": 42} -def test_parse_invalid_json(parser): +def test_parse_invalid_json(): json_str = '{"key": "value", invalid}' with pytest.raises(ValueError) as exc_info: - parser.parse_json_string(json_str) + parse_json_string(json_str) assert "Failed to parse JSON" in str(exc_info.value) -def test_parse_empty_code_block(parser): +def test_parse_empty_code_block(): json_str = """```json ```""" with pytest.raises(ValueError) as exc_info: - parser.parse_json_string(json_str) + parse_json_string(json_str) assert "Failed to parse JSON" in str(exc_info.value) -def test_parse_complex_json(parser): +def test_parse_complex_json(): json_str = """```json { "string": "hello", @@ -68,7 +63,7 @@ def test_parse_complex_json(parser): } } ```""" - result = parser.parse_json_string(json_str) + result = parse_json_string(json_str) assert result == { "string": "hello", "number": 42, diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index 5751448f..2306b59c 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -284,6 +284,7 @@ class StructuredOutputMode(str, Enum): json_schema = "json_schema" function_calling = "function_calling" json_mode = "json_mode" + json_instructions = "json_instructions" class Finetune(KilnParentedModel): From 15e4a59b38c0b63807a2a4a704c404c94e889631 Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 18:21:41 -0500 Subject: [PATCH 04/18] Checkpoint: everything working?? --- .../kiln_ai/adapters/langchain_adapters.py | 25 +++++++++++++++---- libs/core/kiln_ai/adapters/ml_model_list.py | 2 ++ .../model_adapters/open_ai_model_adapter.py | 3 +-- .../kiln_ai/adapters/parsers/base_parser.py | 3 --- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/langchain_adapters.py index 8c7db7d0..15a99a3f 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/langchain_adapters.py @@ -91,12 +91,15 @@ async def model(self) -> LangChainModelType: ) # TODO better way: structured_output_mode = raw? - # If we're going to parse the output, don't use LC's structured output + # Don't setup structured output unless task is structured, and we haven't said we want instruction based JSON provider = await self.model_provider() - custom_output_parser = provider.parser == ModelParserID.r1_thinking - print(f"custom_output_parser: {custom_output_parser}") + use_lc_structured_output = ( + self.has_structured_output() + and provider.structured_output_mode + != StructuredOutputMode.json_instructions + ) - if self.has_structured_output(): + if use_lc_structured_output: if not hasattr(self._model, "with_structured_output") or not callable( getattr(self._model, "with_structured_output") ): @@ -127,6 +130,15 @@ async def _run(self, input: Dict | str) -> RunOutput: intermediate_outputs = {} prompt = self.build_prompt() + # TODO P0: move this to prompt builder + provider = await self.model_provider() + if provider.structured_output_mode == StructuredOutputMode.json_instructions: + prompt = ( + prompt + + f"\n\n### Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.kiln_task.output_schema()}\n```" + ) + print(f"prompt: {prompt}") + user_msg = self.prompt_builder.build_user_message(input) messages = [ SystemMessage(content=prompt), @@ -156,7 +168,7 @@ async def _run(self, input: Dict | str) -> RunOutput: messages.append(SystemMessage(content=cot_prompt)) # TODO: make this thinking native - response = await chain.ainvoke(messages, include_reasoning=True) + response = await chain.ainvoke(messages) print(f"response: {response}") # Langchain may have already parsed the response into structured output, so use that if available. @@ -220,6 +232,9 @@ async def get_structured_output_options( options["method"] = "json_mode" case StructuredOutputMode.json_schema: options["method"] = "json_schema" + case StructuredOutputMode.json_instructions: + # JSON done via instructions in prompt, not via API + pass case StructuredOutputMode.default: # Let langchain decide the default pass diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 64779085..0aae8036 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -224,12 +224,14 @@ class KilnModel(BaseModel): name=ModelProviderName.fireworks_ai, provider_options={"model": "accounts/fireworks/models/deepseek-r1"}, parser=ModelParserID.r1_thinking, + structured_output_mode=StructuredOutputMode.json_instructions, ), KilnModelProvider( # I want your RAM name=ModelProviderName.ollama, provider_options={"model": "deepseek-r1:671b"}, parser=ModelParserID.r1_thinking, + structured_output_mode=StructuredOutputMode.json_instructions, ), ], ), diff --git a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py index 2ab2ce10..76a344c9 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py @@ -1,4 +1,3 @@ -import json from typing import Any, Dict, NoReturn import kiln_ai.datamodel as datamodel @@ -8,7 +7,7 @@ BasePromptBuilder, RunOutput, ) -from kiln_ai.adapters.ml_model_list import ModelParserID, StructuredOutputMode +from kiln_ai.adapters.ml_model_list import StructuredOutputMode from kiln_ai.adapters.parsers.json_parser import parse_json_string from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessage diff --git a/libs/core/kiln_ai/adapters/parsers/base_parser.py b/libs/core/kiln_ai/adapters/parsers/base_parser.py index 153bc539..98c9c05d 100644 --- a/libs/core/kiln_ai/adapters/parsers/base_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/base_parser.py @@ -1,6 +1,3 @@ -import json -from typing import Any, Dict - from kiln_ai.adapters.run_output import RunOutput From 72fcb72a0ef59bd9d45f166f78c62a10ed5de72e Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 18:24:42 -0500 Subject: [PATCH 05/18] Fix unstructured output --- libs/core/kiln_ai/adapters/langchain_adapters.py | 7 +++++-- .../adapters/model_adapters/open_ai_model_adapter.py | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/langchain_adapters.py index 15a99a3f..49085ceb 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/langchain_adapters.py @@ -25,7 +25,6 @@ from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput from .ml_model_list import ( KilnModelProvider, - ModelParserID, ModelProviderName, StructuredOutputMode, ) @@ -132,7 +131,11 @@ async def _run(self, input: Dict | str) -> RunOutput: prompt = self.build_prompt() # TODO P0: move this to prompt builder provider = await self.model_provider() - if provider.structured_output_mode == StructuredOutputMode.json_instructions: + if ( + self.has_structured_output() + and provider.structured_output_mode + == StructuredOutputMode.json_instructions + ): prompt = ( prompt + f"\n\n### Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.kiln_task.output_schema()}\n```" diff --git a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py index 76a344c9..f9d40a3f 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py @@ -52,7 +52,11 @@ async def _run(self, input: Dict | str) -> RunOutput: prompt = self.build_prompt() # TODO P0: move this to prompt builder - if provider.structured_output_mode == StructuredOutputMode.json_instructions: + if ( + self.has_structured_output() + and provider.structured_output_mode + == StructuredOutputMode.json_instructions + ): prompt = ( prompt + f"\n\n### Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.kiln_task.output_schema()}\n```" From eb3b75d1a617357ccaa82281db0b2fe5309bb40d Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 20:43:58 -0500 Subject: [PATCH 06/18] Checkpoint --- app/desktop/studio_server/test_prompt_api.py | 2 +- libs/core/kiln_ai/adapters/base_adapter.py | 14 ++++- .../kiln_ai/adapters/langchain_adapters.py | 15 +---- libs/core/kiln_ai/adapters/ml_model_list.py | 28 ++++++--- .../model_adapters/open_ai_model_adapter.py | 63 +++++++++++++------ libs/core/kiln_ai/adapters/prompt_builders.py | 24 +++++-- libs/core/kiln_ai/datamodel/__init__.py | 1 + 7 files changed, 96 insertions(+), 51 deletions(-) diff --git a/app/desktop/studio_server/test_prompt_api.py b/app/desktop/studio_server/test_prompt_api.py index 3138635f..35c0f17c 100644 --- a/app/desktop/studio_server/test_prompt_api.py +++ b/app/desktop/studio_server/test_prompt_api.py @@ -24,7 +24,7 @@ class MockPromptBuilder(BasePromptBuilder): def prompt_builder_name(cls): return "MockPromptBuilder" - def build_prompt(self): + def build_base_prompt(self): return "Mock prompt" def build_prompt_for_ui(self): diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/base_adapter.py index 38e3e0a5..caa16bf7 100644 --- a/libs/core/kiln_ai/adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/base_adapter.py @@ -15,7 +15,7 @@ from kiln_ai.datamodel.json_schema import validate_schema from kiln_ai.utils.config import Config -from .ml_model_list import KilnModelProvider +from .ml_model_list import KilnModelProvider, StructuredOutputMode from .parsers.parser_registry import model_parser_from_id from .prompt_builders import BasePromptBuilder, SimplePromptBuilder @@ -149,8 +149,16 @@ def adapter_info(self) -> AdapterInfo: async def _run(self, input: Dict | str) -> RunOutput: pass - def build_prompt(self) -> str: - return self.prompt_builder.build_prompt() + async def build_prompt(self) -> str: + provider = await self.model_provider() + json_instructions = self.has_structured_output() and ( + provider.structured_output_mode == StructuredOutputMode.json_instructions + or provider.structured_output_mode + == StructuredOutputMode.json_instruction_and_object + ) + return self.prompt_builder.build_prompt( + include_json_instructions=json_instructions + ) # create a run and task output def generate_run( diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/langchain_adapters.py index 49085ceb..13b698ea 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/langchain_adapters.py @@ -128,20 +128,7 @@ async def _run(self, input: Dict | str) -> RunOutput: chain = model intermediate_outputs = {} - prompt = self.build_prompt() - # TODO P0: move this to prompt builder - provider = await self.model_provider() - if ( - self.has_structured_output() - and provider.structured_output_mode - == StructuredOutputMode.json_instructions - ): - prompt = ( - prompt - + f"\n\n### Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.kiln_task.output_schema()}\n```" - ) - print(f"prompt: {prompt}") - + prompt = await self.build_prompt() user_msg = self.prompt_builder.build_user_message(input) messages = [ SystemMessage(content=prompt), diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 0aae8036..88707060 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -180,6 +180,7 @@ class KilnModel(BaseModel): providers=[ KilnModelProvider( name=ModelProviderName.openrouter, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "anthropic/claude-3-5-haiku"}, ), ], @@ -192,6 +193,7 @@ class KilnModel(BaseModel): providers=[ KilnModelProvider( name=ModelProviderName.openrouter, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "anthropic/claude-3.5-sonnet"}, ), ], @@ -205,6 +207,8 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "deepseek/deepseek-chat"}, + # TODO test this + structured_output_mode=StructuredOutputMode.function_calling, ), ], ), @@ -217,7 +221,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "deepseek/deepseek-r1"}, - # parser=ModelParserID.r1_thinking, + # No custom parser -- openrouter implemented it themselves structured_output_mode=StructuredOutputMode.json_instructions, ), KilnModelProvider( @@ -316,7 +320,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "meta-llama/llama-3.1-8b-instruct"}, ), KilnModelProvider( @@ -350,7 +354,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "meta-llama/llama-3.1-70b-instruct"}, ), KilnModelProvider( @@ -391,7 +395,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "meta-llama/llama-3.1-405b-instruct"}, ), KilnModelProvider( @@ -413,6 +417,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "mistralai/mistral-nemo"}, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, ), ], ), @@ -452,6 +457,7 @@ class KilnModel(BaseModel): name=ModelProviderName.openrouter, supports_structured_output=False, supports_data_gen=False, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, provider_options={"model": "meta-llama/llama-3.2-1b-instruct"}, ), KilnModelProvider( @@ -472,6 +478,7 @@ class KilnModel(BaseModel): name=ModelProviderName.openrouter, supports_structured_output=False, supports_data_gen=False, + structured_output_mode=StructuredOutputMode.json_schema, provider_options={"model": "meta-llama/llama-3.2-3b-instruct"}, ), KilnModelProvider( @@ -597,6 +604,7 @@ class KilnModel(BaseModel): supports_structured_output=False, supports_data_gen=False, provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"}, + structured_output_mode=StructuredOutputMode.json_schema, ), KilnModelProvider( name=ModelProviderName.fireworks_ai, @@ -622,8 +630,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, # JSON mode not consistent enough to enable in UI - structured_output_mode=StructuredOutputMode.json_mode, - supports_structured_output=False, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, supports_data_gen=False, provider_options={"model": "microsoft/phi-4"}, ), @@ -661,7 +668,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, supports_data_gen=False, provider_options={"model": "google/gemma-2-9b-it"}, ), @@ -683,6 +690,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, supports_data_gen=False, provider_options={"model": "google/gemma-2-27b-it"}, ), @@ -697,7 +705,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "mistralai/mixtral-8x7b-instruct"}, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, ), KilnModelProvider( name=ModelProviderName.ollama, @@ -715,7 +723,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "qwen/qwen-2.5-7b-instruct"}, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, ), KilnModelProvider( name=ModelProviderName.ollama, @@ -736,7 +744,7 @@ class KilnModel(BaseModel): # Not consistent with structure data. Works sometimes but not often supports_structured_output=False, supports_data_gen=False, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, ), KilnModelProvider( name=ModelProviderName.ollama, diff --git a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py index f9d40a3f..acbced92 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py @@ -49,20 +49,7 @@ async def _run(self, input: Dict | str) -> RunOutput: intermediate_outputs = {} - prompt = self.build_prompt() - - # TODO P0: move this to prompt builder - if ( - self.has_structured_output() - and provider.structured_output_mode - == StructuredOutputMode.json_instructions - ): - prompt = ( - prompt - + f"\n\n### Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.kiln_task.output_schema()}\n```" - ) - print(f"prompt: {prompt}") - + prompt = await self.build_prompt() user_msg = self.prompt_builder.build_user_message(input) messages = [ {"role": "system", "content": prompt}, @@ -123,14 +110,30 @@ async def _run(self, input: Dict | str) -> RunOutput: ) message = response.choices[0].message - response_content = message.content - if not isinstance(response_content, str): - raise RuntimeError(f"response is not a string: {response_content}") # Save reasoning if it exists if hasattr(message, "reasoning") and message.reasoning: intermediate_outputs["reasoning"] = message.reasoning + # the string content of the response + response_content = message.content + + # Fallback: Use args of first tool call to task_response if it exists + if not response_content and message.tool_calls: + tool_call = next( + ( + tool_call + for tool_call in message.tool_calls + if tool_call.function.name == "task_response" + ), + None, + ) + if tool_call: + response_content = tool_call.function.arguments + + if not isinstance(response_content, str): + raise RuntimeError(f"response is not a string: {response_content}") + if self.has_structured_output(): structured_response = parse_json_string(response_content) return RunOutput( @@ -177,16 +180,38 @@ async def response_format_options(self) -> dict[str, Any]: } case StructuredOutputMode.function_calling: # TODO P0 - return {"response_format": {"type": "function_calling"}} + return self.tool_call_params() case StructuredOutputMode.json_instructions: # JSON done via instructions in prompt, not the API response format # TODO try json_object on stable API return {} + case StructuredOutputMode.json_instruction_and_object: + # We set response_format to json_object, but we also need to set the instructions in the prompt + return {"response_format": {"type": "json_object"}} case StructuredOutputMode.default: - return {} + # Default to function calling -- it's older than the other modes + return self.tool_call_params() case _: raise ValueError( f"Unsupported structured output mode: {provider.structured_output_mode}" ) # pyright will detect missing cases with this return NoReturn + + def tool_call_params(self) -> dict[str, Any]: + return { + "tools": [ + { + "type": "function", + "function": { + "name": "task_response", + "parameters": self.kiln_task.output_schema(), + "strict": True, + }, + } + ], + "tool_choice": { + "type": "function", + "function": {"name": "task_response"}, + }, + } diff --git a/libs/core/kiln_ai/adapters/prompt_builders.py b/libs/core/kiln_ai/adapters/prompt_builders.py index 79be35d2..bfcb2b6c 100644 --- a/libs/core/kiln_ai/adapters/prompt_builders.py +++ b/libs/core/kiln_ai/adapters/prompt_builders.py @@ -28,8 +28,24 @@ def prompt_id(self) -> str | None: """ return None + def build_prompt(self, include_json_instructions: bool = False) -> str: + """Build and return the complete prompt string. + + Returns: + str: The constructed prompt. + """ + prompt = self.build_base_prompt() + + if include_json_instructions and self.task.output_schema(): + prompt = ( + prompt + + f"\n\n# Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.task.output_schema()}\n```" + ) + + return prompt + @abstractmethod - def build_prompt(self) -> str: + def build_base_prompt(self) -> str: """Build and return the complete prompt string. Returns: @@ -88,7 +104,7 @@ def build_prompt_for_ui(self) -> str: class SimplePromptBuilder(BasePromptBuilder): """A basic prompt builder that combines task instruction with requirements.""" - def build_prompt(self) -> str: + def build_base_prompt(self) -> str: """Build a simple prompt with instruction and requirements. Returns: @@ -120,7 +136,7 @@ def example_count(cls) -> int: """ return 25 - def build_prompt(self) -> str: + def build_base_prompt(self) -> str: """Build a prompt with instruction, requirements, and multiple examples. Returns: @@ -272,7 +288,7 @@ def __init__(self, task: Task, prompt_id: str): def prompt_id(self) -> str | None: return self.prompt_model.id - def build_prompt(self) -> str: + def build_base_prompt(self) -> str: """Returns a saved prompt. Returns: diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index 2306b59c..91090e9a 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -285,6 +285,7 @@ class StructuredOutputMode(str, Enum): function_calling = "function_calling" json_mode = "json_mode" json_instructions = "json_instructions" + json_instruction_and_object = "json_instruction_and_object" class Finetune(KilnParentedModel): From f38450ed9c2a63f419f96ec443d8dcc03ff8ebcb Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 21:32:03 -0500 Subject: [PATCH 07/18] Move model adapters into their own folder --- libs/core/kiln_ai/adapters/__init__.py | 14 +++--- .../core/kiln_ai/adapters/adapter_registry.py | 6 +-- .../adapters/model_adapters/__init__.py | 18 +++++++ .../{ => model_adapters}/base_adapter.py | 7 ++- .../langchain_adapters.py | 20 ++++---- ...del_adapter.py => openai_model_adapter.py} | 4 +- .../test_langchain_adapter.py | 47 +++++++++++-------- .../test_saving_adapter_results.py | 6 ++- .../test_structured_output.py | 6 ++- .../core/kiln_ai/adapters/parsers/__init__.py | 0 .../adapters/repair/test_repair_task.py | 4 +- .../kiln_ai/adapters/test_prompt_adaptors.py | 2 +- .../kiln_ai/adapters/test_prompt_builders.py | 2 +- libs/server/kiln_server/test_run_api.py | 2 +- 14 files changed, 88 insertions(+), 50 deletions(-) create mode 100644 libs/core/kiln_ai/adapters/model_adapters/__init__.py rename libs/core/kiln_ai/adapters/{ => model_adapters}/base_adapter.py (96%) rename libs/core/kiln_ai/adapters/{ => model_adapters}/langchain_adapters.py (98%) rename libs/core/kiln_ai/adapters/model_adapters/{open_ai_model_adapter.py => openai_model_adapter.py} (99%) rename libs/core/kiln_ai/adapters/{ => model_adapters}/test_langchain_adapter.py (88%) rename libs/core/kiln_ai/adapters/{ => model_adapters}/test_saving_adapter_results.py (98%) rename libs/core/kiln_ai/adapters/{ => model_adapters}/test_structured_output.py (98%) create mode 100644 libs/core/kiln_ai/adapters/parsers/__init__.py diff --git a/libs/core/kiln_ai/adapters/__init__.py b/libs/core/kiln_ai/adapters/__init__.py index 57d31b62..1bbecd19 100644 --- a/libs/core/kiln_ai/adapters/__init__.py +++ b/libs/core/kiln_ai/adapters/__init__.py @@ -3,31 +3,31 @@ Adapters are used to connect Kiln to external systems, or to add new functionality to Kiln. -BaseAdapter is extensible, and used for adding adapters that provide AI functionality. There's currently a LangChain adapter which provides a bridge to LangChain. +Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc. The ml_model_list submodule contains a list of models that can be used for machine learning tasks. More can easily be added, but we keep a list here of models that are known to work well with Kiln's structured data and tool calling systems. The prompt_builders submodule contains classes that build prompts for use with the AI agents. The repair submodule contains an adapter for the repair task. + +The parser submodule contains parsers for the output of the AI models. """ from . import ( - base_adapter, data_gen, fine_tune, - langchain_adapters, ml_model_list, + model_adapters, prompt_builders, repair, ) __all__ = [ - "base_adapter", - "langchain_adapters", + "model_adapters", + "data_gen", + "fine_tune", "ml_model_list", "prompt_builders", "repair", - "data_gen", - "fine_tune", ] diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 28cd0686..68af2708 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -1,10 +1,10 @@ from os import getenv from kiln_ai import datamodel -from kiln_ai.adapters.base_adapter import BaseAdapter -from kiln_ai.adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.ml_model_list import ModelProviderName -from kiln_ai.adapters.model_adapters.open_ai_model_adapter import ( +from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter +from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter +from kiln_ai.adapters.model_adapters.openai_model_adapter import ( OpenAICompatibleAdapter, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder diff --git a/libs/core/kiln_ai/adapters/model_adapters/__init__.py b/libs/core/kiln_ai/adapters/model_adapters/__init__.py new file mode 100644 index 00000000..3be0e6bb --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/__init__.py @@ -0,0 +1,18 @@ +""" +# Model Adapters + +Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc. + +""" + +from . import ( + base_adapter, + langchain_adapters, + openai_model_adapter, +) + +__all__ = [ + "base_adapter", + "langchain_adapters", + "openai_model_adapter", +] diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py similarity index 96% rename from libs/core/kiln_ai/adapters/base_adapter.py rename to libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index caa16bf7..05162235 100644 --- a/libs/core/kiln_ai/adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -3,6 +3,9 @@ from dataclasses import dataclass from typing import Dict +from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode +from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id +from kiln_ai.adapters.prompt_builders import BasePromptBuilder, SimplePromptBuilder from kiln_ai.adapters.provider_tools import kiln_model_provider_from from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import ( @@ -15,10 +18,6 @@ from kiln_ai.datamodel.json_schema import validate_schema from kiln_ai.utils.config import Config -from .ml_model_list import KilnModelProvider, StructuredOutputMode -from .parsers.parser_registry import model_parser_from_id -from .prompt_builders import BasePromptBuilder, SimplePromptBuilder - @dataclass class AdapterInfo: diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py similarity index 98% rename from libs/core/kiln_ai/adapters/langchain_adapters.py rename to libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index 13b698ea..c0fedd88 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -15,21 +15,25 @@ from pydantic import BaseModel import kiln_ai.datamodel as datamodel +from kiln_ai.adapters.ml_model_list import ( + KilnModelProvider, + ModelProviderName, + StructuredOutputMode, +) +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterInfo, + BaseAdapter, + BasePromptBuilder, + RunOutput, +) from kiln_ai.adapters.ollama_tools import ( get_ollama_connection, ollama_base_url, ollama_model_installed, ) +from kiln_ai.adapters.provider_tools import kiln_model_provider_from from kiln_ai.utils.config import Config -from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput -from .ml_model_list import ( - KilnModelProvider, - ModelProviderName, - StructuredOutputMode, -) -from .provider_tools import kiln_model_provider_from - LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel] diff --git a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py similarity index 99% rename from libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py rename to libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py index acbced92..86c6cd1c 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/open_ai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -1,13 +1,13 @@ from typing import Any, Dict, NoReturn import kiln_ai.datamodel as datamodel -from kiln_ai.adapters.base_adapter import ( +from kiln_ai.adapters.ml_model_list import StructuredOutputMode +from kiln_ai.adapters.model_adapters.base_adapter import ( AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput, ) -from kiln_ai.adapters.ml_model_list import StructuredOutputMode from kiln_ai.adapters.parsers.json_parser import parse_json_string from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionMessage diff --git a/libs/core/kiln_ai/adapters/test_langchain_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py similarity index 88% rename from libs/core/kiln_ai/adapters/test_langchain_adapter.py rename to libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py index 5720ca39..de548369 100644 --- a/libs/core/kiln_ai/adapters/test_langchain_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py @@ -9,16 +9,16 @@ from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI -from kiln_ai.adapters.langchain_adapters import ( - LangchainAdapter, - get_structured_output_options, - langchain_model_from_provider, -) from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, ModelProviderName, StructuredOutputMode, ) +from kiln_ai.adapters.model_adapters.langchain_adapters import ( + LangchainAdapter, + get_structured_output_options, + langchain_model_from_provider, +) from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder from kiln_ai.adapters.test_prompt_adaptors import build_test_task @@ -94,7 +94,8 @@ async def test_langchain_adapter_with_cot(tmp_path): # Patch both the langchain_model_from function and self.model() with ( patch( - "kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from + "kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from", + mock_model_from, ), patch.object(LangchainAdapter, "model", return_value=mock_model_instance), ): @@ -151,7 +152,7 @@ async def test_get_structured_output_options(structured_output_mode, expected_me # Test with provider that has options with patch( - "kiln_ai.adapters.langchain_adapters.kiln_model_provider_from", + "kiln_ai.adapters.model_adapters.langchain_adapters.kiln_model_provider_from", AsyncMock(return_value=mock_provider), ): options = await get_structured_output_options("model_name", "provider") @@ -164,7 +165,9 @@ async def test_langchain_model_from_provider_openai(): name=ModelProviderName.openai, provider_options={"model": "gpt-4"} ) - with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + with patch( + "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" + ) as mock_config: mock_config.return_value.open_ai_api_key = "test_key" model = await langchain_model_from_provider(provider, "gpt-4") assert isinstance(model, ChatOpenAI) @@ -177,7 +180,9 @@ async def test_langchain_model_from_provider_groq(): name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"} ) - with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + with patch( + "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" + ) as mock_config: mock_config.return_value.groq_api_key = "test_key" model = await langchain_model_from_provider(provider, "mixtral-8x7b") assert isinstance(model, ChatGroq) @@ -191,7 +196,9 @@ async def test_langchain_model_from_provider_bedrock(): provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"}, ) - with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + with patch( + "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" + ) as mock_config: mock_config.return_value.bedrock_access_key = "test_access" mock_config.return_value.bedrock_secret_key = "test_secret" model = await langchain_model_from_provider(provider, "anthropic.claude-v2") @@ -206,7 +213,9 @@ async def test_langchain_model_from_provider_fireworks(): name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"} ) - with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + with patch( + "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" + ) as mock_config: mock_config.return_value.fireworks_api_key = "test_key" model = await langchain_model_from_provider(provider, "mixtral-8x7b") assert isinstance(model, ChatFireworks) @@ -222,15 +231,15 @@ async def test_langchain_model_from_provider_ollama(): mock_connection = MagicMock() with ( patch( - "kiln_ai.adapters.langchain_adapters.get_ollama_connection", + "kiln_ai.adapters.model_adapters.langchain_adapters.get_ollama_connection", return_value=AsyncMock(return_value=mock_connection), ), patch( - "kiln_ai.adapters.langchain_adapters.ollama_model_installed", + "kiln_ai.adapters.model_adapters.langchain_adapters.ollama_model_installed", return_value=True, ), patch( - "kiln_ai.adapters.langchain_adapters.ollama_base_url", + "kiln_ai.adapters.model_adapters.langchain_adapters.ollama_base_url", return_value="http://localhost:11434", ), ): @@ -281,16 +290,16 @@ async def test_langchain_adapter_model_structured_output(tmp_path): mock_model.with_structured_output = MagicMock(return_value="structured_model") adapter = LangchainAdapter( - kiln_task=task, model_name="test_model", provider="test_provider" + kiln_task=task, model_name="test_model", provider="ollama" ) with ( patch( - "kiln_ai.adapters.langchain_adapters.langchain_model_from", + "kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from", AsyncMock(return_value=mock_model), ), patch( - "kiln_ai.adapters.langchain_adapters.get_structured_output_options", + "kiln_ai.adapters.model_adapters.langchain_adapters.get_structured_output_options", AsyncMock(return_value={"option1": "value1"}), ), ): @@ -322,11 +331,11 @@ async def test_langchain_adapter_model_no_structured_output_support(tmp_path): del mock_model.with_structured_output adapter = LangchainAdapter( - kiln_task=task, model_name="test_model", provider="test_provider" + kiln_task=task, model_name="test_model", provider="ollama" ) with patch( - "kiln_ai.adapters.langchain_adapters.langchain_model_from", + "kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from", AsyncMock(return_value=mock_model), ): with pytest.raises(ValueError, match="does not support structured output"): diff --git a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py similarity index 98% rename from libs/core/kiln_ai/adapters/test_saving_adapter_results.py rename to libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index 62867bdd..64a9b6fd 100644 --- a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -2,7 +2,11 @@ import pytest -from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterInfo, + BaseAdapter, + RunOutput, +) from kiln_ai.datamodel import ( DataSource, DataSourceType, diff --git a/libs/core/kiln_ai/adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py similarity index 98% rename from libs/core/kiln_ai/adapters/test_structured_output.py rename to libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py index 9b44527d..04987b33 100644 --- a/libs/core/kiln_ai/adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py @@ -7,10 +7,14 @@ import kiln_ai.datamodel as datamodel from kiln_ai.adapters.adapter_registry import adapter_for_task -from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput from kiln_ai.adapters.ml_model_list import ( built_in_models, ) +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterInfo, + BaseAdapter, + RunOutput, +) from kiln_ai.adapters.ollama_tools import ollama_online from kiln_ai.adapters.prompt_builders import ( BasePromptBuilder, diff --git a/libs/core/kiln_ai/adapters/parsers/__init__.py b/libs/core/kiln_ai/adapters/parsers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/core/kiln_ai/adapters/repair/test_repair_task.py b/libs/core/kiln_ai/adapters/repair/test_repair_task.py index 50993429..ace3eb84 100644 --- a/libs/core/kiln_ai/adapters/repair/test_repair_task.py +++ b/libs/core/kiln_ai/adapters/repair/test_repair_task.py @@ -6,8 +6,8 @@ from pydantic import ValidationError from kiln_ai.adapters.adapter_registry import adapter_for_task -from kiln_ai.adapters.base_adapter import RunOutput -from kiln_ai.adapters.langchain_adapters import LangchainAdapter +from kiln_ai.adapters.model_adapters.base_adapter import RunOutput +from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.repair.repair_task import ( RepairTaskInput, RepairTaskRun, diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index 6acef2fc..e7b97f90 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -6,8 +6,8 @@ import kiln_ai.datamodel as datamodel from kiln_ai.adapters.adapter_registry import adapter_for_task -from kiln_ai.adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.ml_model_list import built_in_models +from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.ollama_tools import ollama_online from kiln_ai.adapters.prompt_builders import ( BasePromptBuilder, diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 96687131..2c2fd556 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -2,7 +2,7 @@ import pytest -from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter +from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter from kiln_ai.adapters.prompt_builders import ( FewShotChainOfThoughtPromptBuilder, FewShotPromptBuilder, diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index a9fafa19..191107fb 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -3,7 +3,7 @@ import pytest from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient -from kiln_ai.adapters.langchain_adapters import LangchainAdapter +from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.datamodel import ( DataSource, DataSourceType, From 6e4c693a32f28737785cd1761cd13ad8fc28e717 Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 21:42:02 -0500 Subject: [PATCH 08/18] Fix path --- .../adapters/model_adapters/openai_model_adapter.py | 7 ++++--- libs/core/kiln_ai/adapters/parsers/test_base_parser.py | 1 + libs/core/kiln_ai/adapters/parsers/test_r1_parser.py | 1 + libs/core/kiln_ai/adapters/test_prompt_builders.py | 4 +++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py index 86c6cd1c..04406e44 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -1,5 +1,9 @@ from typing import Any, Dict, NoReturn +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice + import kiln_ai.datamodel as datamodel from kiln_ai.adapters.ml_model_list import StructuredOutputMode from kiln_ai.adapters.model_adapters.base_adapter import ( @@ -9,9 +13,6 @@ RunOutput, ) from kiln_ai.adapters.parsers.json_parser import parse_json_string -from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice class OpenAICompatibleAdapter(BaseAdapter): diff --git a/libs/core/kiln_ai/adapters/parsers/test_base_parser.py b/libs/core/kiln_ai/adapters/parsers/test_base_parser.py index 46f61e82..16042950 100644 --- a/libs/core/kiln_ai/adapters/parsers/test_base_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/test_base_parser.py @@ -1,4 +1,5 @@ import pytest + from kiln_ai.adapters.parsers.json_parser import parse_json_string diff --git a/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py index 821d186e..19130039 100644 --- a/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py @@ -1,4 +1,5 @@ import pytest + from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser from kiln_ai.adapters.run_output import RunOutput diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 2c2fd556..8cb24795 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -3,6 +3,9 @@ import pytest from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter +from kiln_ai.adapters.model_adapters.test_structured_output import ( + build_structured_output_test_task, +) from kiln_ai.adapters.prompt_builders import ( FewShotChainOfThoughtPromptBuilder, FewShotPromptBuilder, @@ -16,7 +19,6 @@ prompt_builder_from_ui_name, ) from kiln_ai.adapters.test_prompt_adaptors import build_test_task -from kiln_ai.adapters.test_structured_output import build_structured_output_test_task from kiln_ai.datamodel import ( DataSource, DataSourceType, From e67c8ddea28d8e691f72a10fbcae07b469f42ce5 Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 21:51:35 -0500 Subject: [PATCH 09/18] Test fix --- libs/core/kiln_ai/adapters/adapter_registry.py | 2 +- libs/core/kiln_ai/adapters/repair/test_repair_task.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 68af2708..9fccbc8c 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -13,7 +13,7 @@ def adapter_for_task( kiln_task: datamodel.Task, - model_name: str | None = None, + model_name: str, provider: str | None = None, prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, diff --git a/libs/core/kiln_ai/adapters/repair/test_repair_task.py b/libs/core/kiln_ai/adapters/repair/test_repair_task.py index ace3eb84..9c63d974 100644 --- a/libs/core/kiln_ai/adapters/repair/test_repair_task.py +++ b/libs/core/kiln_ai/adapters/repair/test_repair_task.py @@ -223,7 +223,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai ) adapter = adapter_for_task( - repair_task, model_name="llama_3_1_8b", provider="groq" + repair_task, model_name="llama_3_1_8b", provider="ollama" ) run = await adapter.invoke(repair_task_input.model_dump()) @@ -237,7 +237,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai assert run.output.source.properties == { "adapter_name": "kiln_langchain_adapter", "model_name": "llama_3_1_8b", - "model_provider": "groq", + "model_provider": "ollama", "prompt_builder_name": "simple_prompt_builder", } assert run.input_source.type == DataSourceType.human From e7c0df5299ff44102d6f0e8439f2f55f74e7fd86 Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 22:10:08 -0500 Subject: [PATCH 10/18] Fix all type errors and tests --- .../model_adapters/langchain_adapters.py | 3 ++ .../model_adapters/openai_model_adapter.py | 36 +++++++++++-------- .../kiln_ai/adapters/parsers/r1_parser.py | 6 +++- .../adapters/parsers/test_r1_parser.py | 2 +- 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index c0fedd88..68026b91 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -76,6 +76,9 @@ def __init__( "model_name and provider must be provided if custom_model is not provided" ) + if model_name is None: + raise ValueError("model_name must be provided") + super().__init__( kiln_task, model_name=model_name, diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py index 04406e44..eb39aeb7 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -1,7 +1,12 @@ from typing import Any, Dict, NoReturn from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, +) from openai.types.chat.chat_completion import Choice import kiln_ai.datamodel as datamodel @@ -48,13 +53,13 @@ def __init__( async def _run(self, input: Dict | str) -> RunOutput: provider = await self.model_provider() - intermediate_outputs = {} + intermediate_outputs: dict[str, str] = {} prompt = await self.build_prompt() user_msg = self.prompt_builder.build_user_message(input) messages = [ - {"role": "system", "content": prompt}, - {"role": "user", "content": user_msg}, + ChatCompletionSystemMessageParam(role="system", content=prompt), + ChatCompletionUserMessageParam(role="user", content=user_msg), ] # Handle chain of thought if enabled @@ -69,15 +74,18 @@ async def _run(self, input: Dict | str) -> RunOutput: messages=messages, ) cot_content = cot_response.choices[0].message.content - intermediate_outputs = {"chain_of_thought": cot_content} + if cot_content is not None: + intermediate_outputs["chain_of_thought"] = cot_content messages.extend( [ - {"role": "assistant", "content": cot_content}, - { - "role": "system", - "content": "Considering the above, return a final result.", - }, + ChatCompletionAssistantMessageParam( + role="assistant", content=cot_content + ), + ChatCompletionSystemMessageParam( + role="system", + content="Considering the above, return a final result.", + ), ] ) elif cot_prompt: @@ -101,9 +109,9 @@ async def _run(self, input: Dict | str) -> RunOutput: f"Expected ChatCompletion response, got {type(response)}." ) - if hasattr(response, "error") and response.error: + if hasattr(response, "error") and response.error: # pyright: ignore raise RuntimeError( - f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}." + f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}." # pyright: ignore ) if not response.choices or len(response.choices) == 0: raise RuntimeError( @@ -113,8 +121,8 @@ async def _run(self, input: Dict | str) -> RunOutput: message = response.choices[0].message # Save reasoning if it exists - if hasattr(message, "reasoning") and message.reasoning: - intermediate_outputs["reasoning"] = message.reasoning + if hasattr(message, "reasoning") and message.reasoning: # pyright: ignore + intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore # the string content of the response response_content = message.content diff --git a/libs/core/kiln_ai/adapters/parsers/r1_parser.py b/libs/core/kiln_ai/adapters/parsers/r1_parser.py index 7629f148..f435a71f 100644 --- a/libs/core/kiln_ai/adapters/parsers/r1_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/r1_parser.py @@ -69,7 +69,11 @@ def parse_output(self, original_output: RunOutput) -> RunOutput: if self.structured_output: output = parse_json_string(result) + intermediate_outputs = {} + if thinking_content: + intermediate_outputs = {"reasoning": thinking_content} + return RunOutput( output=output, - intermediate_outputs={"reasoning": thinking_content}, + intermediate_outputs=intermediate_outputs, ) diff --git a/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py index 19130039..859b0ca5 100644 --- a/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py @@ -74,7 +74,7 @@ def test_empty_thinking_content(parser): output="This is the result", intermediate_outputs=None ) parsed = parser.parse_output(response) - assert parsed.intermediate_outputs["reasoning"] is None + assert parsed.intermediate_outputs == {} assert parsed.output == "This is the result" From 743a83307b3dd27f0c533189a124846d529560a8 Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 31 Jan 2025 09:16:17 -0500 Subject: [PATCH 11/18] Remove dead code --- .../model_adapters/langchain_adapters.py | 34 +------------------ 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index 68026b91..bc7af80a 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -243,24 +243,6 @@ async def get_structured_output_options( return options -# class that wraps another class and logs all function calls being executed -class Wrapper: - def __init__(self, wrapped_class): - self.wrapped_class = wrapped_class - - def __getattr__(self, attr): - original_func = getattr(self.wrapped_class, attr) - - def wrapper(*args, **kwargs): - print(f"Calling function: {attr}") - print(f"Arguments: {args}, {kwargs}") - result = original_func(*args, **kwargs) - print(f"Response: {result}") - return result - - return wrapper - - async def langchain_model_from( name: str, provider_name: str | None = None ) -> BaseChatModel: @@ -317,20 +299,6 @@ async def langchain_model_from_provider( raise ValueError(f"Model {model_name} not installed on Ollama") elif provider.name == ModelProviderName.openrouter: - # TODO raise error - raise ValueError("OpenRouter is not supported in Langchain") - api_key = Config.shared().open_router_api_key - base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" - cor = ChatOpenAI( - **provider.provider_options, - openai_api_key=api_key, # type: ignore[arg-type] - openai_api_base=base_url, # type: ignore[arg-type] - default_headers={ - "HTTP-Referer": "https://getkiln.ai/openrouter", - "X-Title": "KilnAI", - }, - ) - cor.client = Wrapper(cor.client) - return cor + raise ValueError("OpenRouter is not supported in Langchain adapter") else: raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}") From 1378d3afcba64ce667978b266dc00d07b302965d Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 31 Jan 2025 09:23:41 -0500 Subject: [PATCH 12/18] PR feedback --- libs/core/kiln_ai/adapters/adapter_registry.py | 1 - libs/core/kiln_ai/adapters/ml_model_list.py | 1 - libs/core/kiln_ai/datamodel/__init__.py | 7 +++++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 9fccbc8c..873d4be7 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -21,7 +21,6 @@ def adapter_for_task( if provider == ModelProviderName.openrouter: api_key = Config.shared().open_router_api_key base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" - print(f"base_url: {base_url} provider: {provider} model_name: {model_name}") return OpenAICompatibleAdapter( base_url=base_url, api_key=api_key, diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 88707060..13fab1ef 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -207,7 +207,6 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "deepseek/deepseek-chat"}, - # TODO test this structured_output_mode=StructuredOutputMode.function_calling, ), ], diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index 91090e9a..a4ba5ebd 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -278,6 +278,13 @@ class FineTuneStatusType(str, Enum): class StructuredOutputMode(str, Enum): """ Enumeration of supported structured output modes. + + - default: let the adapter decide + - json_schema: request json using API capabilities for json_schema + - function_calling: request json using API capabilities for function calling + - json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema + - json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. + - json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities. """ default = "default" From e5daed4342090718beeff888e091c5792721f6b3 Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 31 Jan 2025 09:27:58 -0500 Subject: [PATCH 13/18] PR feedback --- .../kiln_ai/adapters/model_adapters/base_adapter.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 05162235..7e816706 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -102,15 +102,11 @@ async def invoke( run_output = await self._run(input) # Parse - # TODO P0 provider = await self.model_provider() parser = model_parser_from_id(provider.parser)( structured_output=self.has_structured_output() ) parsed_output = parser.parse_output(original_output=run_output) - print( - f"parsed_output: {parsed_output.output} \nIntermediate outputs: {parsed_output.intermediate_outputs}" - ) # validate output if self.output_schema is not None: @@ -149,14 +145,16 @@ async def _run(self, input: Dict | str) -> RunOutput: pass async def build_prompt(self) -> str: + # The prompt builder needs to know if we want to inject formatting instructions provider = await self.model_provider() - json_instructions = self.has_structured_output() and ( + add_json_instructions = self.has_structured_output() and ( provider.structured_output_mode == StructuredOutputMode.json_instructions or provider.structured_output_mode == StructuredOutputMode.json_instruction_and_object ) + return self.prompt_builder.build_prompt( - include_json_instructions=json_instructions + include_json_instructions=add_json_instructions ) # create a run and task output From a253729272cfcee2006d5cbcc2b23f7ad2029ed8 Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 31 Jan 2025 10:11:14 -0500 Subject: [PATCH 14/18] PR feedback --- .../model_adapters/langchain_adapters.py | 92 +++++++++---------- .../model_adapters/test_langchain_adapter.py | 90 +++++++++--------- libs/core/kiln_ai/datamodel/__init__.py | 4 +- 3 files changed, 86 insertions(+), 100 deletions(-) diff --git a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index bc7af80a..1ec4ff8f 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -31,7 +31,6 @@ ollama_base_url, ollama_model_installed, ) -from kiln_ai.adapters.provider_tools import kiln_model_provider_from from kiln_ai.utils.config import Config LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel] @@ -92,12 +91,11 @@ async def model(self) -> LangChainModelType: if self._model: return self._model - self._model = await langchain_model_from( - self.model_name, self.model_provider_name - ) + self._model = await self.langchain_model_from() - # TODO better way: structured_output_mode = raw? - # Don't setup structured output unless task is structured, and we haven't said we want instruction based JSON + # Decide if we want to use Langchain's structured output: + # 1. Only for structured tasks + # 2. Only if the provider's mode isn't json_instructions (only mode that doesn't use an API option for structured output capabilities) provider = await self.model_provider() use_lc_structured_output = ( self.has_structured_output() @@ -120,7 +118,7 @@ async def model(self) -> LangChainModelType: ) output_schema["title"] = "task_response" output_schema["description"] = "A response from the task" - with_structured_output_options = await get_structured_output_options( + with_structured_output_options = await self.get_structured_output_options( self.model_name, self.model_provider_name ) self._model = self._model.with_structured_output( @@ -142,14 +140,12 @@ async def _run(self, input: Dict | str) -> RunOutput: HumanMessage(content=user_msg), ] - # TODO: make this thinking native + # TODO: make this compatible with thinking models # COT with structured output cot_prompt = self.prompt_builder.chain_of_thought_prompt() if cot_prompt and self.has_structured_output(): # Base model (without structured output) used for COT message - base_model = await langchain_model_from( - self.model_name, self.model_provider_name - ) + base_model = await self.langchain_model_from() messages.append( SystemMessage(content=cot_prompt), ) @@ -164,9 +160,7 @@ async def _run(self, input: Dict | str) -> RunOutput: elif cot_prompt: messages.append(SystemMessage(content=cot_prompt)) - # TODO: make this thinking native response = await chain.ainvoke(messages) - print(f"response: {response}") # Langchain may have already parsed the response into structured output, so use that if available. # However, a plain string may still be fixed at the parsing layer, so not being structured isn't a critical failure (yet) @@ -184,9 +178,11 @@ async def _run(self, input: Dict | str) -> RunOutput: if not isinstance(response, BaseMessage): raise RuntimeError(f"response is not a BaseMessage: {response}") + text_content = response.content if not isinstance(text_content, str): raise RuntimeError(f"response is not a string: {text_content}") + return RunOutput( output=text_content, intermediate_outputs=intermediate_outputs, @@ -211,44 +207,40 @@ def _munge_response(self, response: Dict) -> Dict: return response["arguments"] return response + async def get_structured_output_options( + self, model_name: str, model_provider_name: str + ) -> Dict[str, Any]: + provider = await self.model_provider() + if not provider: + return {} + + options = {} + # We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now + match provider.structured_output_mode: + case StructuredOutputMode.function_calling: + options["method"] = "function_calling" + case StructuredOutputMode.json_mode: + options["method"] = "json_mode" + case StructuredOutputMode.json_schema: + options["method"] = "json_schema" + case StructuredOutputMode.json_instructions: + # JSON done via instructions in prompt, not via API + pass + case StructuredOutputMode.default: + # Let langchain decide the default + pass + case _: + raise ValueError( + f"Unhandled enum value: {provider.structured_output_mode}" + ) + # triggers pyright warning if I miss a case + return NoReturn -async def get_structured_output_options( - model_name: str, model_provider_name: str -) -> Dict[str, Any]: - # TODO self cache - provider = await kiln_model_provider_from(model_name, model_provider_name) - if not provider: - return {} - - options = {} - # We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now - match provider.structured_output_mode: - case StructuredOutputMode.function_calling: - options["method"] = "function_calling" - case StructuredOutputMode.json_mode: - options["method"] = "json_mode" - case StructuredOutputMode.json_schema: - options["method"] = "json_schema" - case StructuredOutputMode.json_instructions: - # JSON done via instructions in prompt, not via API - pass - case StructuredOutputMode.default: - # Let langchain decide the default - pass - case _: - raise ValueError(f"Unhandled enum value: {provider.structured_output_mode}") - # triggers pyright warning if I miss a case - return NoReturn - - return options - - -async def langchain_model_from( - name: str, provider_name: str | None = None -) -> BaseChatModel: - # TODO use self.model_provider() for caching - provider = await kiln_model_provider_from(name, provider_name) - return await langchain_model_from_provider(provider, name) + return options + + async def langchain_model_from(self) -> BaseChatModel: + provider = await self.model_provider() + return await langchain_model_from_provider(provider, self.model_name) async def langchain_model_from_provider( diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py index de548369..0adc9095 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py @@ -16,16 +16,22 @@ ) from kiln_ai.adapters.model_adapters.langchain_adapters import ( LangchainAdapter, - get_structured_output_options, langchain_model_from_provider, ) from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder from kiln_ai.adapters.test_prompt_adaptors import build_test_task -def test_langchain_adapter_munge_response(tmp_path): - task = build_test_task(tmp_path) - lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama") +@pytest.fixture +def mock_adapter(tmp_path): + return LangchainAdapter( + kiln_task=build_test_task(tmp_path), + model_name="llama_3_1_8b", + provider="ollama", + ) + + +def test_langchain_adapter_munge_response(mock_adapter): # Mistral Large tool calling format is a bit different response = { "name": "task_response", @@ -34,12 +40,12 @@ def test_langchain_adapter_munge_response(tmp_path): "punchline": "Because she wanted to be a moo-sician!", }, } - munged = lca._munge_response(response) + munged = mock_adapter._munge_response(response) assert munged["setup"] == "Why did the cow join a band?" assert munged["punchline"] == "Because she wanted to be a moo-sician!" # non mistral format should continue to work - munged = lca._munge_response(response["arguments"]) + munged = mock_adapter._munge_response(response["arguments"]) assert munged["setup"] == "Why did the cow join a band?" assert munged["punchline"] == "Because she wanted to be a moo-sician!" @@ -93,10 +99,7 @@ async def test_langchain_adapter_with_cot(tmp_path): # Patch both the langchain_model_from function and self.model() with ( - patch( - "kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from", - mock_model_from, - ), + patch.object(LangchainAdapter, "langchain_model_from", mock_model_from), patch.object(LangchainAdapter, "model", return_value=mock_model_instance), ): response = await lca._run("test input") @@ -145,18 +148,18 @@ async def test_langchain_adapter_with_cot(tmp_path): (StructuredOutputMode.default, None), ], ) -async def test_get_structured_output_options(structured_output_mode, expected_method): +async def test_get_structured_output_options( + mock_adapter, structured_output_mode, expected_method +): # Mock the provider response mock_provider = MagicMock() mock_provider.structured_output_mode = structured_output_mode - # Test with provider that has options - with patch( - "kiln_ai.adapters.model_adapters.langchain_adapters.kiln_model_provider_from", - AsyncMock(return_value=mock_provider), - ): - options = await get_structured_output_options("model_name", "provider") - assert options.get("method") == expected_method + # Mock adapter.model_provider() + mock_adapter.model_provider = AsyncMock(return_value=mock_provider) + + options = await mock_adapter.get_structured_output_options("model_name", "provider") + assert options.get("method") == expected_method @pytest.mark.asyncio @@ -292,31 +295,25 @@ async def test_langchain_adapter_model_structured_output(tmp_path): adapter = LangchainAdapter( kiln_task=task, model_name="test_model", provider="ollama" ) + adapter.get_structured_output_options = AsyncMock( + return_value={"option1": "value1"} + ) + adapter.langchain_model_from = AsyncMock(return_value=mock_model) - with ( - patch( - "kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from", - AsyncMock(return_value=mock_model), - ), - patch( - "kiln_ai.adapters.model_adapters.langchain_adapters.get_structured_output_options", - AsyncMock(return_value={"option1": "value1"}), - ), - ): - model = await adapter.model() - - # Verify the model was configured with structured output - mock_model.with_structured_output.assert_called_once_with( - { - "type": "object", - "properties": {"count": {"type": "integer"}}, - "title": "task_response", - "description": "A response from the task", - }, - include_raw=True, - option1="value1", - ) - assert model == "structured_model" + model = await adapter.model() + + # Verify the model was configured with structured output + mock_model.with_structured_output.assert_called_once_with( + { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "title": "task_response", + "description": "A response from the task", + }, + include_raw=True, + option1="value1", + ) + assert model == "structured_model" @pytest.mark.asyncio @@ -333,10 +330,7 @@ async def test_langchain_adapter_model_no_structured_output_support(tmp_path): adapter = LangchainAdapter( kiln_task=task, model_name="test_model", provider="ollama" ) + adapter.langchain_model_from = AsyncMock(return_value=mock_model) - with patch( - "kiln_ai.adapters.model_adapters.langchain_adapters.langchain_model_from", - AsyncMock(return_value=mock_model), - ): - with pytest.raises(ValueError, match="does not support structured output"): - await adapter.model() + with pytest.raises(ValueError, match="does not support structured output"): + await adapter.model() diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index a4ba5ebd..69122d6e 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -283,8 +283,8 @@ class StructuredOutputMode(str, Enum): - json_schema: request json using API capabilities for json_schema - function_calling: request json using API capabilities for function calling - json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema - - json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. - - json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities. + - json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings. + - json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries). """ default = "default" From a55d9e91fb3166f11888a74b801a29b91358d7ed Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 31 Jan 2025 10:20:46 -0500 Subject: [PATCH 15/18] PR feedback --- .../kiln_ai/adapters/parsers/r1_parser.py | 16 ++-------- .../adapters/parsers/test_r1_parser.py | 30 +++++++++++-------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/libs/core/kiln_ai/adapters/parsers/r1_parser.py b/libs/core/kiln_ai/adapters/parsers/r1_parser.py index f435a71f..d32fb0f4 100644 --- a/libs/core/kiln_ai/adapters/parsers/r1_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/r1_parser.py @@ -20,17 +20,9 @@ def parse_output(self, original_output: RunOutput) -> RunOutput: Raises: ValueError: If response format is invalid (missing tags, multiple tags, or no content after closing tag) """ - # TODO - print(f"original_output: {original_output.output}") - # This parser only works for strings if not isinstance(original_output.output, str): raise ValueError("Response must be a string for R1 parser") - if ( - original_output.intermediate_outputs - and len(original_output.intermediate_outputs) > 0 - ): - raise ValueError("Intermediate outputs must be empty for R1 parser") # Strip whitespace and validate basic structure cleaned_response = original_output.output.strip() @@ -55,8 +47,6 @@ def parse_output(self, original_output: RunOutput) -> RunOutput: thinking_content = cleaned_response[ think_start + len(self.START_TAG) : think_end ].strip() - if len(thinking_content) == 0: - thinking_content = None # Extract result (everything after ) result = cleaned_response[think_end + len(self.END_TAG) :].strip() @@ -69,9 +59,9 @@ def parse_output(self, original_output: RunOutput) -> RunOutput: if self.structured_output: output = parse_json_string(result) - intermediate_outputs = {} - if thinking_content: - intermediate_outputs = {"reasoning": thinking_content} + # Add thinking content to intermediate outputs if it exists + intermediate_outputs = original_output.intermediate_outputs or {} + intermediate_outputs["reasoning"] = thinking_content return RunOutput( output=output, diff --git a/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py index 859b0ca5..bc1be410 100644 --- a/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py +++ b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py @@ -74,7 +74,7 @@ def test_empty_thinking_content(parser): output="This is the result", intermediate_outputs=None ) parsed = parser.parse_output(response) - assert parsed.intermediate_outputs == {} + assert parsed.intermediate_outputs == {"reasoning": ""} assert parsed.output == "This is the result" @@ -114,27 +114,31 @@ def test_non_string_input(parser): parser.parse_output(RunOutput(output={}, intermediate_outputs=None)) -def test_non_none_intermediate_outputs(parser): - with pytest.raises( - ValueError, match="Intermediate outputs must be empty for R1 parser" - ): - parser.parse_output( - RunOutput( - output="Some contentresult", - intermediate_outputs={"some": "data"}, - ) +def test_intermediate_outputs(parser): + # append to existing intermediate outputs + out = parser.parse_output( + RunOutput( + output="Some contentresult", + intermediate_outputs={"existing": "data"}, ) + ) + assert out.intermediate_outputs["reasoning"] == "Some content" + assert out.intermediate_outputs["existing"] == "data" - # empty dict and None are allowed - parser.parse_output( + # empty dict is allowed + out = parser.parse_output( RunOutput( output="Some contentresult", intermediate_outputs={}, ) ) - parser.parse_output( + assert out.intermediate_outputs["reasoning"] == "Some content" + + # None is allowed + out = parser.parse_output( RunOutput( output="Some contentresult", intermediate_outputs=None, ) ) + assert out.intermediate_outputs["reasoning"] == "Some content" From 67536c612d9f794f3dfcf658b905b9252729f297 Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 31 Jan 2025 10:24:09 -0500 Subject: [PATCH 16/18] Rename test file --- .../adapters/parsers/{test_base_parser.py => test_json_parser.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename libs/core/kiln_ai/adapters/parsers/{test_base_parser.py => test_json_parser.py} (100%) diff --git a/libs/core/kiln_ai/adapters/parsers/test_base_parser.py b/libs/core/kiln_ai/adapters/parsers/test_json_parser.py similarity index 100% rename from libs/core/kiln_ai/adapters/parsers/test_base_parser.py rename to libs/core/kiln_ai/adapters/parsers/test_json_parser.py From f91a594bbc66fc22528df3716b2a1ded920cb0c3 Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 31 Jan 2025 10:26:09 -0500 Subject: [PATCH 17/18] Setup init file --- libs/core/kiln_ai/adapters/parsers/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/libs/core/kiln_ai/adapters/parsers/__init__.py b/libs/core/kiln_ai/adapters/parsers/__init__.py index e69de29b..87287284 100644 --- a/libs/core/kiln_ai/adapters/parsers/__init__.py +++ b/libs/core/kiln_ai/adapters/parsers/__init__.py @@ -0,0 +1,10 @@ +""" +# Parsers + +Parsing utilities for JSON and models with custom output formats (R1, etc.) + +""" + +from . import base_parser, json_parser, r1_parser + +__all__ = ["r1_parser", "base_parser", "json_parser"] From 1965f89fc92add39f41fc07adcd10e26010193dd Mon Sep 17 00:00:00 2001 From: scosman Date: Fri, 31 Jan 2025 10:46:53 -0500 Subject: [PATCH 18/18] CR feedback --- .../core/kiln_ai/adapters/adapter_registry.py | 23 ++++---- .../model_adapters/langchain_adapters.py | 1 - .../model_adapters/openai_model_adapter.py | 58 ++++++++++--------- 3 files changed, 43 insertions(+), 39 deletions(-) diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 873d4be7..f6dc6992 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -6,6 +6,7 @@ from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.model_adapters.openai_model_adapter import ( OpenAICompatibleAdapter, + OpenAICompatibleConfig, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder from kiln_ai.utils.config import Config @@ -19,20 +20,22 @@ def adapter_for_task( tags: list[str] | None = None, ) -> BaseAdapter: if provider == ModelProviderName.openrouter: - api_key = Config.shared().open_router_api_key - base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" return OpenAICompatibleAdapter( - base_url=base_url, - api_key=api_key, kiln_task=kiln_task, - model_name=model_name, - provider_name=provider, + config=OpenAICompatibleConfig( + base_url=getenv("OPENROUTER_BASE_URL") + or "https://openrouter.ai/api/v1", + api_key=Config.shared().open_router_api_key, + model_name=model_name, + provider_name=provider, + openrouter_style_reasoning=True, + default_headers={ + "HTTP-Referer": "https://getkiln.ai/openrouter", + "X-Title": "KilnAI", + }, + ), prompt_builder=prompt_builder, tags=tags, - default_headers={ - "HTTP-Referer": "https://getkiln.ai/openrouter", - "X-Title": "KilnAI", - }, ) # We use langchain for all others right now, but moving off it as we touch anything. diff --git a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index 1ec4ff8f..3fc7c692 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -1,5 +1,4 @@ import os -from os import getenv from typing import Any, Dict, NoReturn from langchain_aws import ChatBedrockConverse diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py index eb39aeb7..d6a24661 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any, Dict, NoReturn from openai import AsyncOpenAI @@ -7,7 +8,6 @@ ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, ) -from openai.types.chat.chat_completion import Choice import kiln_ai.datamodel as datamodel from kiln_ai.adapters.ml_model_list import StructuredOutputMode @@ -20,32 +20,35 @@ from kiln_ai.adapters.parsers.json_parser import parse_json_string +@dataclass +class OpenAICompatibleConfig: + api_key: str + model_name: str + provider_name: str + base_url: str | None = None + default_headers: dict[str, str] | None = None + openrouter_style_reasoning: bool = False + + class OpenAICompatibleAdapter(BaseAdapter): def __init__( self, - api_key: str, + config: OpenAICompatibleConfig, kiln_task: datamodel.Task, - model_name: str, - provider_name: str, - base_url: str | None = None, # Client will default to OpenAI - default_headers: dict[str, str] | None = None, prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ): - if not model_name or not provider_name: - raise ValueError( - "model_name and provider_name must be provided for OpenAI compatible adapter" - ) - - # Create an async OpenAI client instead + self.config = config self.client = AsyncOpenAI( - api_key=api_key, base_url=base_url, default_headers=default_headers + api_key=config.api_key, + base_url=config.base_url, + default_headers=config.default_headers, ) super().__init__( kiln_task, - model_name=model_name, - model_provider_name=provider_name, + model_name=config.model_name, + model_provider_name=config.provider_name, prompt_builder=prompt_builder, tags=tags, ) @@ -95,15 +98,15 @@ async def _run(self, input: Dict | str) -> RunOutput: # Main completion call response_format_options = await self.response_format_options() - print(f"response_format_options: {response_format_options}") response = await self.client.chat.completions.create( model=provider.provider_options["model"], messages=messages, - # TODO P0: remove this - extra_body={"include_reasoning": True}, + extra_body={"include_reasoning": True} + if self.config.openrouter_style_reasoning + else {}, **response_format_options, ) - print(f"response: {response}") + if not isinstance(response, ChatCompletion): raise RuntimeError( f"Expected ChatCompletion response, got {type(response)}." @@ -121,7 +124,11 @@ async def _run(self, input: Dict | str) -> RunOutput: message = response.choices[0].message # Save reasoning if it exists - if hasattr(message, "reasoning") and message.reasoning: # pyright: ignore + if ( + self.config.openrouter_style_reasoning + and hasattr(message, "reasoning") + and message.reasoning # pyright: ignore + ): intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore # the string content of the response @@ -170,14 +177,11 @@ async def response_format_options(self) -> dict[str, Any]: return {} provider = await self.model_provider() - # TODO check these match provider.structured_output_mode: case StructuredOutputMode.json_mode: return {"response_format": {"type": "json_object"}} case StructuredOutputMode.json_schema: - # TODO P0: use json_schema output_schema = self.kiln_task.output_schema() - print(f"output_schema: {output_schema}") return { "response_format": { "type": "json_schema", @@ -188,17 +192,15 @@ async def response_format_options(self) -> dict[str, Any]: } } case StructuredOutputMode.function_calling: - # TODO P0 return self.tool_call_params() case StructuredOutputMode.json_instructions: - # JSON done via instructions in prompt, not the API response format - # TODO try json_object on stable API + # JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below). return {} case StructuredOutputMode.json_instruction_and_object: - # We set response_format to json_object, but we also need to set the instructions in the prompt + # We set response_format to json_object and also set json instructions in the prompt return {"response_format": {"type": "json_object"}} case StructuredOutputMode.default: - # Default to function calling -- it's older than the other modes + # Default to function calling -- it's older than the other modes. Higher compatibility. return self.tool_call_params() case _: raise ValueError(