From 5d00cf535fe234abf8519be94b7facd982dfeedb Mon Sep 17 00:00:00 2001 From: Steve Cosman Date: Fri, 31 Jan 2025 10:48:54 -0500 Subject: [PATCH] Thinking models Part 1: Parsers and OpenAI adapter (#137) Lots here: - OpenAI adapter for better control. - R1 support that really works. Saves reasoning into intermediate outputs. - Parser infra and R1 parser Still need some tests, but working. --- app/desktop/studio_server/test_prompt_api.py | 2 +- libs/core/kiln_ai/adapters/__init__.py | 14 +- .../core/kiln_ai/adapters/adapter_registry.py | 35 ++- libs/core/kiln_ai/adapters/ml_model_list.py | 47 ++-- .../adapters/model_adapters/__init__.py | 18 ++ .../{ => model_adapters}/base_adapter.py | 68 ++++-- .../langchain_adapters.py | 203 +++++++++------- .../model_adapters/openai_model_adapter.py | 228 ++++++++++++++++++ .../test_langchain_adapter.py | 123 +++++----- .../test_saving_adapter_results.py | 23 +- .../test_structured_output.py | 8 +- .../core/kiln_ai/adapters/parsers/__init__.py | 10 + .../kiln_ai/adapters/parsers/base_parser.py | 12 + .../kiln_ai/adapters/parsers/json_parser.py | 35 +++ .../adapters/parsers/parser_registry.py | 22 ++ .../kiln_ai/adapters/parsers/r1_parser.py | 69 ++++++ .../adapters/parsers/test_json_parser.py | 75 ++++++ .../adapters/parsers/test_r1_parser.py | 144 +++++++++++ libs/core/kiln_ai/adapters/prompt_builders.py | 24 +- .../adapters/repair/test_repair_task.py | 8 +- libs/core/kiln_ai/adapters/run_output.py | 8 + .../kiln_ai/adapters/test_prompt_adaptors.py | 12 +- .../kiln_ai/adapters/test_prompt_builders.py | 6 +- libs/core/kiln_ai/datamodel/__init__.py | 9 + libs/server/kiln_server/test_run_api.py | 2 +- 25 files changed, 984 insertions(+), 221 deletions(-) create mode 100644 libs/core/kiln_ai/adapters/model_adapters/__init__.py rename libs/core/kiln_ai/adapters/{ => model_adapters}/base_adapter.py (67%) rename libs/core/kiln_ai/adapters/{ => model_adapters}/langchain_adapters.py (61%) create mode 100644 libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py rename libs/core/kiln_ai/adapters/{ => model_adapters}/test_langchain_adapter.py (77%) rename libs/core/kiln_ai/adapters/{ => model_adapters}/test_saving_adapter_results.py (94%) rename libs/core/kiln_ai/adapters/{ => model_adapters}/test_structured_output.py (97%) create mode 100644 libs/core/kiln_ai/adapters/parsers/__init__.py create mode 100644 libs/core/kiln_ai/adapters/parsers/base_parser.py create mode 100644 libs/core/kiln_ai/adapters/parsers/json_parser.py create mode 100644 libs/core/kiln_ai/adapters/parsers/parser_registry.py create mode 100644 libs/core/kiln_ai/adapters/parsers/r1_parser.py create mode 100644 libs/core/kiln_ai/adapters/parsers/test_json_parser.py create mode 100644 libs/core/kiln_ai/adapters/parsers/test_r1_parser.py create mode 100644 libs/core/kiln_ai/adapters/run_output.py diff --git a/app/desktop/studio_server/test_prompt_api.py b/app/desktop/studio_server/test_prompt_api.py index 3138635f..35c0f17c 100644 --- a/app/desktop/studio_server/test_prompt_api.py +++ b/app/desktop/studio_server/test_prompt_api.py @@ -24,7 +24,7 @@ class MockPromptBuilder(BasePromptBuilder): def prompt_builder_name(cls): return "MockPromptBuilder" - def build_prompt(self): + def build_base_prompt(self): return "Mock prompt" def build_prompt_for_ui(self): diff --git a/libs/core/kiln_ai/adapters/__init__.py b/libs/core/kiln_ai/adapters/__init__.py index 57d31b62..1bbecd19 100644 --- a/libs/core/kiln_ai/adapters/__init__.py +++ b/libs/core/kiln_ai/adapters/__init__.py @@ -3,31 +3,31 @@ Adapters are used to connect Kiln to external systems, or to add new functionality to Kiln. -BaseAdapter is extensible, and used for adding adapters that provide AI functionality. There's currently a LangChain adapter which provides a bridge to LangChain. +Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc. The ml_model_list submodule contains a list of models that can be used for machine learning tasks. More can easily be added, but we keep a list here of models that are known to work well with Kiln's structured data and tool calling systems. The prompt_builders submodule contains classes that build prompts for use with the AI agents. The repair submodule contains an adapter for the repair task. + +The parser submodule contains parsers for the output of the AI models. """ from . import ( - base_adapter, data_gen, fine_tune, - langchain_adapters, ml_model_list, + model_adapters, prompt_builders, repair, ) __all__ = [ - "base_adapter", - "langchain_adapters", + "model_adapters", + "data_gen", + "fine_tune", "ml_model_list", "prompt_builders", "repair", - "data_gen", - "fine_tune", ] diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 0e766eea..f6dc6992 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -1,17 +1,44 @@ +from os import getenv + from kiln_ai import datamodel -from kiln_ai.adapters.base_adapter import BaseAdapter -from kiln_ai.adapters.langchain_adapters import LangchainAdapter +from kiln_ai.adapters.ml_model_list import ModelProviderName +from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter +from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter +from kiln_ai.adapters.model_adapters.openai_model_adapter import ( + OpenAICompatibleAdapter, + OpenAICompatibleConfig, +) from kiln_ai.adapters.prompt_builders import BasePromptBuilder +from kiln_ai.utils.config import Config def adapter_for_task( kiln_task: datamodel.Task, - model_name: str | None = None, + model_name: str, provider: str | None = None, prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ) -> BaseAdapter: - # We use langchain for everything right now, but can add any others here + if provider == ModelProviderName.openrouter: + return OpenAICompatibleAdapter( + kiln_task=kiln_task, + config=OpenAICompatibleConfig( + base_url=getenv("OPENROUTER_BASE_URL") + or "https://openrouter.ai/api/v1", + api_key=Config.shared().open_router_api_key, + model_name=model_name, + provider_name=provider, + openrouter_style_reasoning=True, + default_headers={ + "HTTP-Referer": "https://getkiln.ai/openrouter", + "X-Title": "KilnAI", + }, + ), + prompt_builder=prompt_builder, + tags=tags, + ) + + # We use langchain for all others right now, but moving off it as we touch anything. return LangchainAdapter( kiln_task, model_name=model_name, diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 5a85cbd3..13fab1ef 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -82,6 +82,14 @@ class ModelName(str, Enum): deepseek_r1 = "deepseek_r1" +class ModelParserID(str, Enum): + """ + Enumeration of supported model parsers. + """ + + r1_thinking = "r1_thinking" + + class KilnModelProvider(BaseModel): """ Configuration for a specific model provider. @@ -94,6 +102,7 @@ class KilnModelProvider(BaseModel): provider_finetune_id: The finetune ID for the provider, if applicable provider_options: Additional provider-specific configuration options structured_output_mode: The mode we should use to call the model for structured output, if it was trained with structured output. + parser: A parser to use for the model, if applicable """ name: ModelProviderName @@ -103,6 +112,7 @@ class KilnModelProvider(BaseModel): provider_finetune_id: str | None = None provider_options: Dict = {} structured_output_mode: StructuredOutputMode = StructuredOutputMode.default + parser: ModelParserID | None = None class KilnModel(BaseModel): @@ -170,6 +180,7 @@ class KilnModel(BaseModel): providers=[ KilnModelProvider( name=ModelProviderName.openrouter, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "anthropic/claude-3-5-haiku"}, ), ], @@ -182,6 +193,7 @@ class KilnModel(BaseModel): providers=[ KilnModelProvider( name=ModelProviderName.openrouter, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "anthropic/claude-3.5-sonnet"}, ), ], @@ -195,6 +207,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "deepseek/deepseek-chat"}, + structured_output_mode=StructuredOutputMode.function_calling, ), ], ), @@ -207,21 +220,21 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "deepseek/deepseek-r1"}, - structured_output_mode=StructuredOutputMode.json_schema, + # No custom parser -- openrouter implemented it themselves + structured_output_mode=StructuredOutputMode.json_instructions, ), KilnModelProvider( name=ModelProviderName.fireworks_ai, provider_options={"model": "accounts/fireworks/models/deepseek-r1"}, - # Truncates the thinking, but json_mode works - structured_output_mode=StructuredOutputMode.json_mode, - supports_structured_output=False, - supports_data_gen=False, + parser=ModelParserID.r1_thinking, + structured_output_mode=StructuredOutputMode.json_instructions, ), KilnModelProvider( # I want your RAM name=ModelProviderName.ollama, provider_options={"model": "deepseek-r1:671b"}, - structured_output_mode=StructuredOutputMode.json_schema, + parser=ModelParserID.r1_thinking, + structured_output_mode=StructuredOutputMode.json_instructions, ), ], ), @@ -306,7 +319,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "meta-llama/llama-3.1-8b-instruct"}, ), KilnModelProvider( @@ -340,7 +353,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "meta-llama/llama-3.1-70b-instruct"}, ), KilnModelProvider( @@ -381,7 +394,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.function_calling, provider_options={"model": "meta-llama/llama-3.1-405b-instruct"}, ), KilnModelProvider( @@ -403,6 +416,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "mistralai/mistral-nemo"}, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, ), ], ), @@ -442,6 +456,7 @@ class KilnModel(BaseModel): name=ModelProviderName.openrouter, supports_structured_output=False, supports_data_gen=False, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, provider_options={"model": "meta-llama/llama-3.2-1b-instruct"}, ), KilnModelProvider( @@ -462,6 +477,7 @@ class KilnModel(BaseModel): name=ModelProviderName.openrouter, supports_structured_output=False, supports_data_gen=False, + structured_output_mode=StructuredOutputMode.json_schema, provider_options={"model": "meta-llama/llama-3.2-3b-instruct"}, ), KilnModelProvider( @@ -587,6 +603,7 @@ class KilnModel(BaseModel): supports_structured_output=False, supports_data_gen=False, provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"}, + structured_output_mode=StructuredOutputMode.json_schema, ), KilnModelProvider( name=ModelProviderName.fireworks_ai, @@ -612,8 +629,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, # JSON mode not consistent enough to enable in UI - structured_output_mode=StructuredOutputMode.json_mode, - supports_structured_output=False, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, supports_data_gen=False, provider_options={"model": "microsoft/phi-4"}, ), @@ -651,7 +667,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, supports_data_gen=False, provider_options={"model": "google/gemma-2-9b-it"}, ), @@ -673,6 +689,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, supports_data_gen=False, provider_options={"model": "google/gemma-2-27b-it"}, ), @@ -687,7 +704,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "mistralai/mixtral-8x7b-instruct"}, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, ), KilnModelProvider( name=ModelProviderName.ollama, @@ -705,7 +722,7 @@ class KilnModel(BaseModel): KilnModelProvider( name=ModelProviderName.openrouter, provider_options={"model": "qwen/qwen-2.5-7b-instruct"}, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, ), KilnModelProvider( name=ModelProviderName.ollama, @@ -726,7 +743,7 @@ class KilnModel(BaseModel): # Not consistent with structure data. Works sometimes but not often supports_structured_output=False, supports_data_gen=False, - structured_output_mode=StructuredOutputMode.json_schema, + structured_output_mode=StructuredOutputMode.json_instruction_and_object, ), KilnModelProvider( name=ModelProviderName.ollama, diff --git a/libs/core/kiln_ai/adapters/model_adapters/__init__.py b/libs/core/kiln_ai/adapters/model_adapters/__init__.py new file mode 100644 index 00000000..3be0e6bb --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/__init__.py @@ -0,0 +1,18 @@ +""" +# Model Adapters + +Model adapters are used to call AI models, like Ollama, OpenAI, Anthropic, etc. + +""" + +from . import ( + base_adapter, + langchain_adapters, + openai_model_adapter, +) + +__all__ = [ + "base_adapter", + "langchain_adapters", + "openai_model_adapter", +] diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py similarity index 67% rename from libs/core/kiln_ai/adapters/base_adapter.py rename to libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index 6282f699..7e816706 100644 --- a/libs/core/kiln_ai/adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -3,6 +3,11 @@ from dataclasses import dataclass from typing import Dict +from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode +from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id +from kiln_ai.adapters.prompt_builders import BasePromptBuilder, SimplePromptBuilder +from kiln_ai.adapters.provider_tools import kiln_model_provider_from +from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import ( DataSource, DataSourceType, @@ -13,8 +18,6 @@ from kiln_ai.datamodel.json_schema import validate_schema from kiln_ai.utils.config import Config -from .prompt_builders import BasePromptBuilder, SimplePromptBuilder - @dataclass class AdapterInfo: @@ -25,12 +28,6 @@ class AdapterInfo: prompt_id: str | None = None -@dataclass -class RunOutput: - output: Dict | str - intermediate_outputs: Dict[str, str] | None - - class BaseAdapter(metaclass=ABCMeta): """Base class for AI model adapters that handle task execution. @@ -48,6 +45,8 @@ class BaseAdapter(metaclass=ABCMeta): def __init__( self, kiln_task: Task, + model_name: str, + model_provider_name: str, prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ): @@ -56,6 +55,26 @@ def __init__( self.output_schema = self.kiln_task.output_json_schema self.input_schema = self.kiln_task.input_json_schema self.default_tags = tags + self.model_name = model_name + self.model_provider_name = model_provider_name + self._model_provider: KilnModelProvider | None = None + + async def model_provider(self) -> KilnModelProvider: + """ + Lazy load the model provider for this adapter. + """ + if self._model_provider is not None: + return self._model_provider + if not self.model_name or not self.model_provider_name: + raise ValueError("model_name and model_provider_name must be provided") + self._model_provider = await kiln_model_provider_from( + self.model_name, self.model_provider_name + ) + if not self._model_provider: + raise ValueError( + f"model_provider_name {self.model_provider_name} not found for model {self.model_name}" + ) + return self._model_provider async def invoke_returning_raw( self, @@ -82,21 +101,28 @@ async def invoke( # Run run_output = await self._run(input) + # Parse + provider = await self.model_provider() + parser = model_parser_from_id(provider.parser)( + structured_output=self.has_structured_output() + ) + parsed_output = parser.parse_output(original_output=run_output) + # validate output if self.output_schema is not None: - if not isinstance(run_output.output, dict): + if not isinstance(parsed_output.output, dict): raise RuntimeError( - f"structured response is not a dict: {run_output.output}" + f"structured response is not a dict: {parsed_output.output}" ) - validate_schema(run_output.output, self.output_schema) + validate_schema(parsed_output.output, self.output_schema) else: - if not isinstance(run_output.output, str): + if not isinstance(parsed_output.output, str): raise RuntimeError( - f"response is not a string for non-structured task: {run_output.output}" + f"response is not a string for non-structured task: {parsed_output.output}" ) # Generate the run and output - run = self.generate_run(input, input_source, run_output) + run = self.generate_run(input, input_source, parsed_output) # Save the run if configured to do so, and we have a path to save to if Config.shared().autosave_runs and self.kiln_task.path is not None: @@ -118,8 +144,18 @@ def adapter_info(self) -> AdapterInfo: async def _run(self, input: Dict | str) -> RunOutput: pass - def build_prompt(self) -> str: - return self.prompt_builder.build_prompt() + async def build_prompt(self) -> str: + # The prompt builder needs to know if we want to inject formatting instructions + provider = await self.model_provider() + add_json_instructions = self.has_structured_output() and ( + provider.structured_output_mode == StructuredOutputMode.json_instructions + or provider.structured_output_mode + == StructuredOutputMode.json_instruction_and_object + ) + + return self.prompt_builder.build_prompt( + include_json_instructions=add_json_instructions + ) # create a run and task output def generate_run( diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py similarity index 61% rename from libs/core/kiln_ai/adapters/langchain_adapters.py rename to libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index 9cf6be31..3fc7c692 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -1,5 +1,4 @@ import os -from os import getenv from typing import Any, Dict, NoReturn from langchain_aws import ChatBedrockConverse @@ -15,6 +14,17 @@ from pydantic import BaseModel import kiln_ai.datamodel as datamodel +from kiln_ai.adapters.ml_model_list import ( + KilnModelProvider, + ModelProviderName, + StructuredOutputMode, +) +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterInfo, + BaseAdapter, + BasePromptBuilder, + RunOutput, +) from kiln_ai.adapters.ollama_tools import ( get_ollama_connection, ollama_base_url, @@ -22,10 +32,6 @@ ) from kiln_ai.utils.config import Config -from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput -from .ml_model_list import KilnModelProvider, ModelProviderName, StructuredOutputMode -from .provider_tools import kiln_model_provider_from - LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel] @@ -41,39 +47,62 @@ def __init__( prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ): - super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags) if custom_model is not None: self._model = custom_model # Attempt to infer model provider and name from custom model - self.model_provider = "custom.langchain:" + custom_model.__class__.__name__ - self.model_name = "custom.langchain:unknown_model" - if hasattr(custom_model, "model_name") and isinstance( - getattr(custom_model, "model_name"), str - ): - self.model_name = "custom.langchain:" + getattr( - custom_model, "model_name" - ) - if hasattr(custom_model, "model") and isinstance( - getattr(custom_model, "model"), str - ): - self.model_name = "custom.langchain:" + getattr(custom_model, "model") + if provider is None: + provider = "custom.langchain:" + custom_model.__class__.__name__ + + if model_name is None: + model_name = "custom.langchain:unknown_model" + if hasattr(custom_model, "model_name") and isinstance( + getattr(custom_model, "model_name"), str + ): + model_name = "custom.langchain:" + getattr( + custom_model, "model_name" + ) + if hasattr(custom_model, "model") and isinstance( + getattr(custom_model, "model"), str + ): + model_name = "custom.langchain:" + getattr(custom_model, "model") elif model_name is not None: - self.model_name = model_name - self.model_provider = provider or "custom.langchain.default_provider" + # default provider name if not provided + provider = provider or "custom.langchain.default_provider" else: raise ValueError( "model_name and provider must be provided if custom_model is not provided" ) + if model_name is None: + raise ValueError("model_name must be provided") + + super().__init__( + kiln_task, + model_name=model_name, + model_provider_name=provider, + prompt_builder=prompt_builder, + tags=tags, + ) + async def model(self) -> LangChainModelType: # cached model if self._model: return self._model - self._model = await langchain_model_from(self.model_name, self.model_provider) + self._model = await self.langchain_model_from() + + # Decide if we want to use Langchain's structured output: + # 1. Only for structured tasks + # 2. Only if the provider's mode isn't json_instructions (only mode that doesn't use an API option for structured output capabilities) + provider = await self.model_provider() + use_lc_structured_output = ( + self.has_structured_output() + and provider.structured_output_mode + != StructuredOutputMode.json_instructions + ) - if self.has_structured_output(): + if use_lc_structured_output: if not hasattr(self._model, "with_structured_output") or not callable( getattr(self._model, "with_structured_output") ): @@ -88,8 +117,8 @@ async def model(self) -> LangChainModelType: ) output_schema["title"] = "task_response" output_schema["description"] = "A response from the task" - with_structured_output_options = await get_structured_output_options( - self.model_name, self.model_provider + with_structured_output_options = await self.get_structured_output_options( + self.model_name, self.model_provider_name ) self._model = self._model.with_structured_output( output_schema, @@ -103,20 +132,19 @@ async def _run(self, input: Dict | str) -> RunOutput: chain = model intermediate_outputs = {} - prompt = self.build_prompt() + prompt = await self.build_prompt() user_msg = self.prompt_builder.build_user_message(input) messages = [ SystemMessage(content=prompt), HumanMessage(content=user_msg), ] + # TODO: make this compatible with thinking models # COT with structured output cot_prompt = self.prompt_builder.chain_of_thought_prompt() if cot_prompt and self.has_structured_output(): # Base model (without structured output) used for COT message - base_model = await langchain_model_from( - self.model_name, self.model_provider - ) + base_model = await self.langchain_model_from() messages.append( SystemMessage(content=cot_prompt), ) @@ -133,33 +161,36 @@ async def _run(self, input: Dict | str) -> RunOutput: response = await chain.ainvoke(messages) - if self.has_structured_output(): - if ( - not isinstance(response, dict) - or "parsed" not in response - or not isinstance(response["parsed"], dict) - ): - raise RuntimeError(f"structured response not returned: {response}") + # Langchain may have already parsed the response into structured output, so use that if available. + # However, a plain string may still be fixed at the parsing layer, so not being structured isn't a critical failure (yet) + if ( + self.has_structured_output() + and isinstance(response, dict) + and "parsed" in response + and isinstance(response["parsed"], dict) + ): structured_response = response["parsed"] return RunOutput( output=self._munge_response(structured_response), intermediate_outputs=intermediate_outputs, ) - else: - if not isinstance(response, BaseMessage): - raise RuntimeError(f"response is not a BaseMessage: {response}") - text_content = response.content - if not isinstance(text_content, str): - raise RuntimeError(f"response is not a string: {text_content}") - return RunOutput( - output=text_content, - intermediate_outputs=intermediate_outputs, - ) + + if not isinstance(response, BaseMessage): + raise RuntimeError(f"response is not a BaseMessage: {response}") + + text_content = response.content + if not isinstance(text_content, str): + raise RuntimeError(f"response is not a string: {text_content}") + + return RunOutput( + output=text_content, + intermediate_outputs=intermediate_outputs, + ) def adapter_info(self) -> AdapterInfo: return AdapterInfo( model_name=self.model_name, - model_provider=self.model_provider, + model_provider=self.model_provider_name, adapter_name="kiln_langchain_adapter", prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(), prompt_id=self.prompt_builder.prompt_id(), @@ -175,38 +206,40 @@ def _munge_response(self, response: Dict) -> Dict: return response["arguments"] return response + async def get_structured_output_options( + self, model_name: str, model_provider_name: str + ) -> Dict[str, Any]: + provider = await self.model_provider() + if not provider: + return {} + + options = {} + # We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now + match provider.structured_output_mode: + case StructuredOutputMode.function_calling: + options["method"] = "function_calling" + case StructuredOutputMode.json_mode: + options["method"] = "json_mode" + case StructuredOutputMode.json_schema: + options["method"] = "json_schema" + case StructuredOutputMode.json_instructions: + # JSON done via instructions in prompt, not via API + pass + case StructuredOutputMode.default: + # Let langchain decide the default + pass + case _: + raise ValueError( + f"Unhandled enum value: {provider.structured_output_mode}" + ) + # triggers pyright warning if I miss a case + return NoReturn -async def get_structured_output_options( - model_name: str, model_provider: str -) -> Dict[str, Any]: - finetune_provider = await kiln_model_provider_from(model_name, model_provider) - if not finetune_provider: - return {} - - options = {} - # We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now - match finetune_provider.structured_output_mode: - case StructuredOutputMode.function_calling: - options["method"] = "function_calling" - case StructuredOutputMode.json_mode: - options["method"] = "json_mode" - case StructuredOutputMode.json_schema: - options["method"] = "json_schema" - case StructuredOutputMode.default: - # Let langchain decide the default - pass - case _: - # triggers pyright warning if I miss a case - raise_exhaustive_error(finetune_provider.structured_output_mode) - - return options - - -async def langchain_model_from( - name: str, provider_name: str | None = None -) -> BaseChatModel: - provider = await kiln_model_provider_from(name, provider_name) - return await langchain_model_from_provider(provider, name) + return options + + async def langchain_model_from(self) -> BaseChatModel: + provider = await self.model_provider() + return await langchain_model_from_provider(provider, self.model_name) async def langchain_model_from_provider( @@ -257,20 +290,6 @@ async def langchain_model_from_provider( raise ValueError(f"Model {model_name} not installed on Ollama") elif provider.name == ModelProviderName.openrouter: - api_key = Config.shared().open_router_api_key - base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" - return ChatOpenAI( - **provider.provider_options, - openai_api_key=api_key, # type: ignore[arg-type] - openai_api_base=base_url, # type: ignore[arg-type] - default_headers={ - "HTTP-Referer": "https://getkiln.ai/openrouter", - "X-Title": "KilnAI", - }, - ) + raise ValueError("OpenRouter is not supported in Langchain adapter") else: raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}") - - -def raise_exhaustive_error(value: NoReturn) -> NoReturn: - raise ValueError(f"Unhandled enum value: {value}") diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py new file mode 100644 index 00000000..d6a24661 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -0,0 +1,228 @@ +from dataclasses import dataclass +from typing import Any, Dict, NoReturn + +from openai import AsyncOpenAI +from openai.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, +) + +import kiln_ai.datamodel as datamodel +from kiln_ai.adapters.ml_model_list import StructuredOutputMode +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterInfo, + BaseAdapter, + BasePromptBuilder, + RunOutput, +) +from kiln_ai.adapters.parsers.json_parser import parse_json_string + + +@dataclass +class OpenAICompatibleConfig: + api_key: str + model_name: str + provider_name: str + base_url: str | None = None + default_headers: dict[str, str] | None = None + openrouter_style_reasoning: bool = False + + +class OpenAICompatibleAdapter(BaseAdapter): + def __init__( + self, + config: OpenAICompatibleConfig, + kiln_task: datamodel.Task, + prompt_builder: BasePromptBuilder | None = None, + tags: list[str] | None = None, + ): + self.config = config + self.client = AsyncOpenAI( + api_key=config.api_key, + base_url=config.base_url, + default_headers=config.default_headers, + ) + + super().__init__( + kiln_task, + model_name=config.model_name, + model_provider_name=config.provider_name, + prompt_builder=prompt_builder, + tags=tags, + ) + + async def _run(self, input: Dict | str) -> RunOutput: + provider = await self.model_provider() + + intermediate_outputs: dict[str, str] = {} + + prompt = await self.build_prompt() + user_msg = self.prompt_builder.build_user_message(input) + messages = [ + ChatCompletionSystemMessageParam(role="system", content=prompt), + ChatCompletionUserMessageParam(role="user", content=user_msg), + ] + + # Handle chain of thought if enabled + cot_prompt = self.prompt_builder.chain_of_thought_prompt() + if cot_prompt and self.has_structured_output(): + # TODO P0: Fix COT + messages.append({"role": "system", "content": cot_prompt}) + + # First call for chain of thought + cot_response = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + ) + cot_content = cot_response.choices[0].message.content + if cot_content is not None: + intermediate_outputs["chain_of_thought"] = cot_content + + messages.extend( + [ + ChatCompletionAssistantMessageParam( + role="assistant", content=cot_content + ), + ChatCompletionSystemMessageParam( + role="system", + content="Considering the above, return a final result.", + ), + ] + ) + elif cot_prompt: + messages.append({"role": "system", "content": cot_prompt}) + else: + intermediate_outputs = {} + + # Main completion call + response_format_options = await self.response_format_options() + response = await self.client.chat.completions.create( + model=provider.provider_options["model"], + messages=messages, + extra_body={"include_reasoning": True} + if self.config.openrouter_style_reasoning + else {}, + **response_format_options, + ) + + if not isinstance(response, ChatCompletion): + raise RuntimeError( + f"Expected ChatCompletion response, got {type(response)}." + ) + + if hasattr(response, "error") and response.error: # pyright: ignore + raise RuntimeError( + f"OpenAI compatible API returned status code {response.error.get('code')}: {response.error.get('message') or 'Unknown error'}." # pyright: ignore + ) + if not response.choices or len(response.choices) == 0: + raise RuntimeError( + "No message content returned in the response from OpenAI compatible API" + ) + + message = response.choices[0].message + + # Save reasoning if it exists + if ( + self.config.openrouter_style_reasoning + and hasattr(message, "reasoning") + and message.reasoning # pyright: ignore + ): + intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore + + # the string content of the response + response_content = message.content + + # Fallback: Use args of first tool call to task_response if it exists + if not response_content and message.tool_calls: + tool_call = next( + ( + tool_call + for tool_call in message.tool_calls + if tool_call.function.name == "task_response" + ), + None, + ) + if tool_call: + response_content = tool_call.function.arguments + + if not isinstance(response_content, str): + raise RuntimeError(f"response is not a string: {response_content}") + + if self.has_structured_output(): + structured_response = parse_json_string(response_content) + return RunOutput( + output=structured_response, + intermediate_outputs=intermediate_outputs, + ) + + return RunOutput( + output=response_content, + intermediate_outputs=intermediate_outputs, + ) + + def adapter_info(self) -> AdapterInfo: + return AdapterInfo( + model_name=self.model_name, + model_provider=self.model_provider_name, + adapter_name="kiln_openai_compatible_adapter", + prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(), + prompt_id=self.prompt_builder.prompt_id(), + ) + + async def response_format_options(self) -> dict[str, Any]: + # Unstructured if task isn't structured + if not self.has_structured_output(): + return {} + + provider = await self.model_provider() + match provider.structured_output_mode: + case StructuredOutputMode.json_mode: + return {"response_format": {"type": "json_object"}} + case StructuredOutputMode.json_schema: + output_schema = self.kiln_task.output_schema() + return { + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "task_response", + "schema": output_schema, + }, + } + } + case StructuredOutputMode.function_calling: + return self.tool_call_params() + case StructuredOutputMode.json_instructions: + # JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below). + return {} + case StructuredOutputMode.json_instruction_and_object: + # We set response_format to json_object and also set json instructions in the prompt + return {"response_format": {"type": "json_object"}} + case StructuredOutputMode.default: + # Default to function calling -- it's older than the other modes. Higher compatibility. + return self.tool_call_params() + case _: + raise ValueError( + f"Unsupported structured output mode: {provider.structured_output_mode}" + ) + # pyright will detect missing cases with this + return NoReturn + + def tool_call_params(self) -> dict[str, Any]: + return { + "tools": [ + { + "type": "function", + "function": { + "name": "task_response", + "parameters": self.kiln_task.output_schema(), + "strict": True, + }, + } + ], + "tool_choice": { + "type": "function", + "function": {"name": "task_response"}, + }, + } diff --git a/libs/core/kiln_ai/adapters/test_langchain_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py similarity index 77% rename from libs/core/kiln_ai/adapters/test_langchain_adapter.py rename to libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py index 5720ca39..0adc9095 100644 --- a/libs/core/kiln_ai/adapters/test_langchain_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py @@ -9,23 +9,29 @@ from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI -from kiln_ai.adapters.langchain_adapters import ( - LangchainAdapter, - get_structured_output_options, - langchain_model_from_provider, -) from kiln_ai.adapters.ml_model_list import ( KilnModelProvider, ModelProviderName, StructuredOutputMode, ) +from kiln_ai.adapters.model_adapters.langchain_adapters import ( + LangchainAdapter, + langchain_model_from_provider, +) from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder from kiln_ai.adapters.test_prompt_adaptors import build_test_task -def test_langchain_adapter_munge_response(tmp_path): - task = build_test_task(tmp_path) - lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama") +@pytest.fixture +def mock_adapter(tmp_path): + return LangchainAdapter( + kiln_task=build_test_task(tmp_path), + model_name="llama_3_1_8b", + provider="ollama", + ) + + +def test_langchain_adapter_munge_response(mock_adapter): # Mistral Large tool calling format is a bit different response = { "name": "task_response", @@ -34,12 +40,12 @@ def test_langchain_adapter_munge_response(tmp_path): "punchline": "Because she wanted to be a moo-sician!", }, } - munged = lca._munge_response(response) + munged = mock_adapter._munge_response(response) assert munged["setup"] == "Why did the cow join a band?" assert munged["punchline"] == "Because she wanted to be a moo-sician!" # non mistral format should continue to work - munged = lca._munge_response(response["arguments"]) + munged = mock_adapter._munge_response(response["arguments"]) assert munged["setup"] == "Why did the cow join a band?" assert munged["punchline"] == "Because she wanted to be a moo-sician!" @@ -93,9 +99,7 @@ async def test_langchain_adapter_with_cot(tmp_path): # Patch both the langchain_model_from function and self.model() with ( - patch( - "kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from - ), + patch.object(LangchainAdapter, "langchain_model_from", mock_model_from), patch.object(LangchainAdapter, "model", return_value=mock_model_instance), ): response = await lca._run("test input") @@ -144,18 +148,18 @@ async def test_langchain_adapter_with_cot(tmp_path): (StructuredOutputMode.default, None), ], ) -async def test_get_structured_output_options(structured_output_mode, expected_method): +async def test_get_structured_output_options( + mock_adapter, structured_output_mode, expected_method +): # Mock the provider response mock_provider = MagicMock() mock_provider.structured_output_mode = structured_output_mode - # Test with provider that has options - with patch( - "kiln_ai.adapters.langchain_adapters.kiln_model_provider_from", - AsyncMock(return_value=mock_provider), - ): - options = await get_structured_output_options("model_name", "provider") - assert options.get("method") == expected_method + # Mock adapter.model_provider() + mock_adapter.model_provider = AsyncMock(return_value=mock_provider) + + options = await mock_adapter.get_structured_output_options("model_name", "provider") + assert options.get("method") == expected_method @pytest.mark.asyncio @@ -164,7 +168,9 @@ async def test_langchain_model_from_provider_openai(): name=ModelProviderName.openai, provider_options={"model": "gpt-4"} ) - with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + with patch( + "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" + ) as mock_config: mock_config.return_value.open_ai_api_key = "test_key" model = await langchain_model_from_provider(provider, "gpt-4") assert isinstance(model, ChatOpenAI) @@ -177,7 +183,9 @@ async def test_langchain_model_from_provider_groq(): name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"} ) - with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + with patch( + "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" + ) as mock_config: mock_config.return_value.groq_api_key = "test_key" model = await langchain_model_from_provider(provider, "mixtral-8x7b") assert isinstance(model, ChatGroq) @@ -191,7 +199,9 @@ async def test_langchain_model_from_provider_bedrock(): provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"}, ) - with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + with patch( + "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" + ) as mock_config: mock_config.return_value.bedrock_access_key = "test_access" mock_config.return_value.bedrock_secret_key = "test_secret" model = await langchain_model_from_provider(provider, "anthropic.claude-v2") @@ -206,7 +216,9 @@ async def test_langchain_model_from_provider_fireworks(): name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"} ) - with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config: + with patch( + "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" + ) as mock_config: mock_config.return_value.fireworks_api_key = "test_key" model = await langchain_model_from_provider(provider, "mixtral-8x7b") assert isinstance(model, ChatFireworks) @@ -222,15 +234,15 @@ async def test_langchain_model_from_provider_ollama(): mock_connection = MagicMock() with ( patch( - "kiln_ai.adapters.langchain_adapters.get_ollama_connection", + "kiln_ai.adapters.model_adapters.langchain_adapters.get_ollama_connection", return_value=AsyncMock(return_value=mock_connection), ), patch( - "kiln_ai.adapters.langchain_adapters.ollama_model_installed", + "kiln_ai.adapters.model_adapters.langchain_adapters.ollama_model_installed", return_value=True, ), patch( - "kiln_ai.adapters.langchain_adapters.ollama_base_url", + "kiln_ai.adapters.model_adapters.langchain_adapters.ollama_base_url", return_value="http://localhost:11434", ), ): @@ -281,33 +293,27 @@ async def test_langchain_adapter_model_structured_output(tmp_path): mock_model.with_structured_output = MagicMock(return_value="structured_model") adapter = LangchainAdapter( - kiln_task=task, model_name="test_model", provider="test_provider" + kiln_task=task, model_name="test_model", provider="ollama" + ) + adapter.get_structured_output_options = AsyncMock( + return_value={"option1": "value1"} ) + adapter.langchain_model_from = AsyncMock(return_value=mock_model) - with ( - patch( - "kiln_ai.adapters.langchain_adapters.langchain_model_from", - AsyncMock(return_value=mock_model), - ), - patch( - "kiln_ai.adapters.langchain_adapters.get_structured_output_options", - AsyncMock(return_value={"option1": "value1"}), - ), - ): - model = await adapter.model() - - # Verify the model was configured with structured output - mock_model.with_structured_output.assert_called_once_with( - { - "type": "object", - "properties": {"count": {"type": "integer"}}, - "title": "task_response", - "description": "A response from the task", - }, - include_raw=True, - option1="value1", - ) - assert model == "structured_model" + model = await adapter.model() + + # Verify the model was configured with structured output + mock_model.with_structured_output.assert_called_once_with( + { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "title": "task_response", + "description": "A response from the task", + }, + include_raw=True, + option1="value1", + ) + assert model == "structured_model" @pytest.mark.asyncio @@ -322,12 +328,9 @@ async def test_langchain_adapter_model_no_structured_output_support(tmp_path): del mock_model.with_structured_output adapter = LangchainAdapter( - kiln_task=task, model_name="test_model", provider="test_provider" + kiln_task=task, model_name="test_model", provider="ollama" ) + adapter.langchain_model_from = AsyncMock(return_value=mock_model) - with patch( - "kiln_ai.adapters.langchain_adapters.langchain_model_from", - AsyncMock(return_value=mock_model), - ): - with pytest.raises(ValueError, match="does not support structured output"): - await adapter.model() + with pytest.raises(ValueError, match="does not support structured output"): + await adapter.model() diff --git a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py similarity index 94% rename from libs/core/kiln_ai/adapters/test_saving_adapter_results.py rename to libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index 54a32f50..64a9b6fd 100644 --- a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -2,7 +2,11 @@ import pytest -from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterInfo, + BaseAdapter, + RunOutput, +) from kiln_ai.datamodel import ( DataSource, DataSourceType, @@ -39,8 +43,12 @@ def test_task(tmp_path): return task -def test_save_run_isolation(test_task): - adapter = MockAdapter(test_task) +@pytest.fixture +def adapter(test_task): + return MockAdapter(test_task, model_name="phi_3_5", model_provider_name="ollama") + + +def test_save_run_isolation(test_task, adapter): input_data = "Test input" output_data = "Test output" run_output = RunOutput( @@ -124,8 +132,7 @@ def test_save_run_isolation(test_task): assert output_data in set(run.output.output for run in test_task.runs()) -def test_generate_run_non_ascii(test_task): - adapter = MockAdapter(test_task) +def test_generate_run_non_ascii(test_task, adapter): input_data = {"key": "input with non-ascii character: 你好"} output_data = {"key": "output with non-ascii character: 你好"} run_output = RunOutput( @@ -154,13 +161,12 @@ def test_generate_run_non_ascii(test_task): @pytest.mark.asyncio -async def test_autosave_false(test_task): +async def test_autosave_false(test_task, adapter): with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = False mock_config.user_id = "test_user" - adapter = MockAdapter(test_task) input_data = "Test input" run = await adapter.invoke(input_data) @@ -173,13 +179,12 @@ async def test_autosave_false(test_task): @pytest.mark.asyncio -async def test_autosave_true(test_task): +async def test_autosave_true(test_task, adapter): with patch("kiln_ai.utils.config.Config.shared") as mock_shared: mock_config = mock_shared.return_value mock_config.autosave_runs = True mock_config.user_id = "test_user" - adapter = MockAdapter(test_task) input_data = "Test input" run = await adapter.invoke(input_data) diff --git a/libs/core/kiln_ai/adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py similarity index 97% rename from libs/core/kiln_ai/adapters/test_structured_output.py rename to libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py index 0ea56f54..04987b33 100644 --- a/libs/core/kiln_ai/adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_structured_output.py @@ -7,10 +7,14 @@ import kiln_ai.datamodel as datamodel from kiln_ai.adapters.adapter_registry import adapter_for_task -from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput from kiln_ai.adapters.ml_model_list import ( built_in_models, ) +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterInfo, + BaseAdapter, + RunOutput, +) from kiln_ai.adapters.ollama_tools import ollama_online from kiln_ai.adapters.prompt_builders import ( BasePromptBuilder, @@ -44,7 +48,7 @@ async def test_structured_output_ollama_llama(tmp_path, model_name): class MockAdapter(BaseAdapter): def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None): - super().__init__(kiln_task) + super().__init__(kiln_task, model_name="phi_3_5", model_provider_name="ollama") self.response = response async def _run(self, input: str) -> RunOutput: diff --git a/libs/core/kiln_ai/adapters/parsers/__init__.py b/libs/core/kiln_ai/adapters/parsers/__init__.py new file mode 100644 index 00000000..87287284 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/__init__.py @@ -0,0 +1,10 @@ +""" +# Parsers + +Parsing utilities for JSON and models with custom output formats (R1, etc.) + +""" + +from . import base_parser, json_parser, r1_parser + +__all__ = ["r1_parser", "base_parser", "json_parser"] diff --git a/libs/core/kiln_ai/adapters/parsers/base_parser.py b/libs/core/kiln_ai/adapters/parsers/base_parser.py new file mode 100644 index 00000000..98c9c05d --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/base_parser.py @@ -0,0 +1,12 @@ +from kiln_ai.adapters.run_output import RunOutput + + +class BaseParser: + def __init__(self, structured_output: bool = False): + self.structured_output = structured_output + + def parse_output(self, original_output: RunOutput) -> RunOutput: + """ + Method for parsing the output of a model. Typically overridden by subclasses. + """ + return original_output diff --git a/libs/core/kiln_ai/adapters/parsers/json_parser.py b/libs/core/kiln_ai/adapters/parsers/json_parser.py new file mode 100644 index 00000000..0af17dac --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/json_parser.py @@ -0,0 +1,35 @@ +import json +from typing import Any, Dict + + +def parse_json_string(json_string: str) -> Dict[str, Any]: + """ + Parse a JSON string into a dictionary. Handles multiple formats: + - Plain JSON + - JSON wrapped in ```json code blocks + - JSON wrapped in ``` code blocks + + Args: + json_string: String containing JSON data, possibly wrapped in code blocks + + Returns: + Dict containing parsed JSON data + + Raises: + ValueError: If JSON parsing fails + """ + # Remove code block markers if present + cleaned_string = json_string.strip() + if cleaned_string.startswith("```"): + # Split by newlines and remove first/last lines if they contain ``` + lines = cleaned_string.split("\n") + if lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].strip() == "```": + lines = lines[:-1] + cleaned_string = "\n".join(lines) + + try: + return json.loads(cleaned_string) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse JSON: {str(e)}") from e diff --git a/libs/core/kiln_ai/adapters/parsers/parser_registry.py b/libs/core/kiln_ai/adapters/parsers/parser_registry.py new file mode 100644 index 00000000..b90fdfe0 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/parser_registry.py @@ -0,0 +1,22 @@ +from typing import NoReturn, Type + +from kiln_ai.adapters.ml_model_list import ModelParserID +from kiln_ai.adapters.parsers.base_parser import BaseParser +from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser + + +def model_parser_from_id(parser_id: ModelParserID | None) -> Type[BaseParser]: + """ + Get a model parser from its ID. + """ + match parser_id: + case None: + return BaseParser + case ModelParserID.r1_thinking: + return R1ThinkingParser + case _: + # triggers pyright warning if I miss a case + raise ValueError( + f"Unhandled enum value for parser ID. You may need to update Kiln to work with this project. Value: {parser_id}" + ) + return NoReturn diff --git a/libs/core/kiln_ai/adapters/parsers/r1_parser.py b/libs/core/kiln_ai/adapters/parsers/r1_parser.py new file mode 100644 index 00000000..d32fb0f4 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/r1_parser.py @@ -0,0 +1,69 @@ +from kiln_ai.adapters.parsers.base_parser import BaseParser +from kiln_ai.adapters.parsers.json_parser import parse_json_string +from kiln_ai.adapters.run_output import RunOutput + + +class R1ThinkingParser(BaseParser): + START_TAG = "" + END_TAG = "" + + def parse_output(self, original_output: RunOutput) -> RunOutput: + """ + Parse the tags from the response into the intermediate and final outputs. + + Args: + original_output: RunOutput containing the raw response string + + Returns: + ParsedOutput containing the intermediate content (thinking content) and final result + + Raises: + ValueError: If response format is invalid (missing tags, multiple tags, or no content after closing tag) + """ + # This parser only works for strings + if not isinstance(original_output.output, str): + raise ValueError("Response must be a string for R1 parser") + + # Strip whitespace and validate basic structure + cleaned_response = original_output.output.strip() + if not cleaned_response.startswith(self.START_TAG): + raise ValueError("Response must start with tag") + + # Find the thinking tags + think_start = cleaned_response.find(self.START_TAG) + think_end = cleaned_response.find(self.END_TAG) + + if think_start == -1 or think_end == -1: + raise ValueError("Missing thinking tags") + + # Check for multiple tags + if ( + cleaned_response.count(self.START_TAG) > 1 + or cleaned_response.count(self.END_TAG) > 1 + ): + raise ValueError("Multiple thinking tags found") + + # Extract thinking content + thinking_content = cleaned_response[ + think_start + len(self.START_TAG) : think_end + ].strip() + + # Extract result (everything after ) + result = cleaned_response[think_end + len(self.END_TAG) :].strip() + + if not result or len(result) == 0: + raise ValueError("No content found after tag") + + # Parse JSON if needed + output = result + if self.structured_output: + output = parse_json_string(result) + + # Add thinking content to intermediate outputs if it exists + intermediate_outputs = original_output.intermediate_outputs or {} + intermediate_outputs["reasoning"] = thinking_content + + return RunOutput( + output=output, + intermediate_outputs=intermediate_outputs, + ) diff --git a/libs/core/kiln_ai/adapters/parsers/test_json_parser.py b/libs/core/kiln_ai/adapters/parsers/test_json_parser.py new file mode 100644 index 00000000..16042950 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/test_json_parser.py @@ -0,0 +1,75 @@ +import pytest + +from kiln_ai.adapters.parsers.json_parser import parse_json_string + + +def test_parse_plain_json(): + json_str = '{"key": "value", "number": 42}' + result = parse_json_string(json_str) + assert result == {"key": "value", "number": 42} + + +def test_parse_json_with_code_block(): + json_str = """``` + {"key": "value", "number": 42} + ```""" + result = parse_json_string(json_str) + assert result == {"key": "value", "number": 42} + + +def test_parse_json_with_language_block(): + json_str = """```json + {"key": "value", "number": 42} + ```""" + result = parse_json_string(json_str) + assert result == {"key": "value", "number": 42} + + +def test_parse_json_with_whitespace(): + json_str = """ + { + "key": "value", + "number": 42 + } + """ + result = parse_json_string(json_str) + assert result == {"key": "value", "number": 42} + + +def test_parse_invalid_json(): + json_str = '{"key": "value", invalid}' + with pytest.raises(ValueError) as exc_info: + parse_json_string(json_str) + assert "Failed to parse JSON" in str(exc_info.value) + + +def test_parse_empty_code_block(): + json_str = """```json + ```""" + with pytest.raises(ValueError) as exc_info: + parse_json_string(json_str) + assert "Failed to parse JSON" in str(exc_info.value) + + +def test_parse_complex_json(): + json_str = """```json + { + "string": "hello", + "number": 42, + "bool": true, + "null": null, + "array": [1, 2, 3], + "nested": { + "inner": "value" + } + } + ```""" + result = parse_json_string(json_str) + assert result == { + "string": "hello", + "number": 42, + "bool": True, + "null": None, + "array": [1, 2, 3], + "nested": {"inner": "value"}, + } diff --git a/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py new file mode 100644 index 00000000..bc1be410 --- /dev/null +++ b/libs/core/kiln_ai/adapters/parsers/test_r1_parser.py @@ -0,0 +1,144 @@ +import pytest + +from kiln_ai.adapters.parsers.r1_parser import R1ThinkingParser +from kiln_ai.adapters.run_output import RunOutput + + +@pytest.fixture +def parser(): + return R1ThinkingParser() + + +def test_valid_response(parser): + response = RunOutput( + output="This is thinking contentThis is the result", + intermediate_outputs=None, + ) + parsed = parser.parse_output(response) + assert parsed.intermediate_outputs["reasoning"] == "This is thinking content" + assert parsed.output == "This is the result" + + +def test_response_with_whitespace(parser): + response = RunOutput( + output=""" + + This is thinking content + + This is the result + """, + intermediate_outputs=None, + ) + parsed = parser.parse_output(response) + assert ( + parsed.intermediate_outputs["reasoning"].strip() == "This is thinking content" + ) + assert parsed.output.strip() == "This is the result" + + +def test_missing_start_tag(parser): + with pytest.raises(ValueError, match="Response must start with tag"): + parser.parse_output( + RunOutput(output="Some contentresult", intermediate_outputs=None) + ) + + +def test_missing_end_tag(parser): + with pytest.raises(ValueError, match="Missing thinking tags"): + parser.parse_output( + RunOutput(output="Some content", intermediate_outputs=None) + ) + + +def test_multiple_start_tags(parser): + with pytest.raises(ValueError, match="Multiple thinking tags found"): + parser.parse_output( + RunOutput( + output="content1content2result", + intermediate_outputs=None, + ) + ) + + +def test_multiple_end_tags(parser): + with pytest.raises(ValueError, match="Multiple thinking tags found"): + parser.parse_output( + RunOutput( + output="contentresult", intermediate_outputs=None + ) + ) + + +def test_empty_thinking_content(parser): + response = RunOutput( + output="This is the result", intermediate_outputs=None + ) + parsed = parser.parse_output(response) + assert parsed.intermediate_outputs == {"reasoning": ""} + assert parsed.output == "This is the result" + + +def test_missing_result(parser): + with pytest.raises(ValueError, match="No content found after tag"): + parser.parse_output( + RunOutput(output="Some content", intermediate_outputs=None) + ) + + +def test_multiline_content(parser): + response = RunOutput( + output="""Line 1 + Line 2 + Line 3Final result""", + intermediate_outputs=None, + ) + parsed = parser.parse_output(response) + assert "Line 1" in parsed.intermediate_outputs["reasoning"] + assert "Line 2" in parsed.intermediate_outputs["reasoning"] + assert "Line 3" in parsed.intermediate_outputs["reasoning"] + assert parsed.output == "Final result" + + +def test_special_characters(parser): + response = RunOutput( + output="Content with: !@#$%^&*思()Result with: !@#$%^&*思()", + intermediate_outputs=None, + ) + parsed = parser.parse_output(response) + assert parsed.intermediate_outputs["reasoning"] == "Content with: !@#$%^&*思()" + assert parsed.output == "Result with: !@#$%^&*思()" + + +def test_non_string_input(parser): + with pytest.raises(ValueError, match="Response must be a string for R1 parser"): + parser.parse_output(RunOutput(output={}, intermediate_outputs=None)) + + +def test_intermediate_outputs(parser): + # append to existing intermediate outputs + out = parser.parse_output( + RunOutput( + output="Some contentresult", + intermediate_outputs={"existing": "data"}, + ) + ) + assert out.intermediate_outputs["reasoning"] == "Some content" + assert out.intermediate_outputs["existing"] == "data" + + # empty dict is allowed + out = parser.parse_output( + RunOutput( + output="Some contentresult", + intermediate_outputs={}, + ) + ) + assert out.intermediate_outputs["reasoning"] == "Some content" + + # None is allowed + out = parser.parse_output( + RunOutput( + output="Some contentresult", + intermediate_outputs=None, + ) + ) + assert out.intermediate_outputs["reasoning"] == "Some content" diff --git a/libs/core/kiln_ai/adapters/prompt_builders.py b/libs/core/kiln_ai/adapters/prompt_builders.py index 79be35d2..bfcb2b6c 100644 --- a/libs/core/kiln_ai/adapters/prompt_builders.py +++ b/libs/core/kiln_ai/adapters/prompt_builders.py @@ -28,8 +28,24 @@ def prompt_id(self) -> str | None: """ return None + def build_prompt(self, include_json_instructions: bool = False) -> str: + """Build and return the complete prompt string. + + Returns: + str: The constructed prompt. + """ + prompt = self.build_base_prompt() + + if include_json_instructions and self.task.output_schema(): + prompt = ( + prompt + + f"\n\n# Format Instructions\n\nReturn a JSON object conforming to the following schema:\n```\n{self.task.output_schema()}\n```" + ) + + return prompt + @abstractmethod - def build_prompt(self) -> str: + def build_base_prompt(self) -> str: """Build and return the complete prompt string. Returns: @@ -88,7 +104,7 @@ def build_prompt_for_ui(self) -> str: class SimplePromptBuilder(BasePromptBuilder): """A basic prompt builder that combines task instruction with requirements.""" - def build_prompt(self) -> str: + def build_base_prompt(self) -> str: """Build a simple prompt with instruction and requirements. Returns: @@ -120,7 +136,7 @@ def example_count(cls) -> int: """ return 25 - def build_prompt(self) -> str: + def build_base_prompt(self) -> str: """Build a prompt with instruction, requirements, and multiple examples. Returns: @@ -272,7 +288,7 @@ def __init__(self, task: Task, prompt_id: str): def prompt_id(self) -> str | None: return self.prompt_model.id - def build_prompt(self) -> str: + def build_base_prompt(self) -> str: """Returns a saved prompt. Returns: diff --git a/libs/core/kiln_ai/adapters/repair/test_repair_task.py b/libs/core/kiln_ai/adapters/repair/test_repair_task.py index 50993429..9c63d974 100644 --- a/libs/core/kiln_ai/adapters/repair/test_repair_task.py +++ b/libs/core/kiln_ai/adapters/repair/test_repair_task.py @@ -6,8 +6,8 @@ from pydantic import ValidationError from kiln_ai.adapters.adapter_registry import adapter_for_task -from kiln_ai.adapters.base_adapter import RunOutput -from kiln_ai.adapters.langchain_adapters import LangchainAdapter +from kiln_ai.adapters.model_adapters.base_adapter import RunOutput +from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.repair.repair_task import ( RepairTaskInput, RepairTaskRun, @@ -223,7 +223,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai ) adapter = adapter_for_task( - repair_task, model_name="llama_3_1_8b", provider="groq" + repair_task, model_name="llama_3_1_8b", provider="ollama" ) run = await adapter.invoke(repair_task_input.model_dump()) @@ -237,7 +237,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai assert run.output.source.properties == { "adapter_name": "kiln_langchain_adapter", "model_name": "llama_3_1_8b", - "model_provider": "groq", + "model_provider": "ollama", "prompt_builder_name": "simple_prompt_builder", } assert run.input_source.type == DataSourceType.human diff --git a/libs/core/kiln_ai/adapters/run_output.py b/libs/core/kiln_ai/adapters/run_output.py new file mode 100644 index 00000000..7c34cae6 --- /dev/null +++ b/libs/core/kiln_ai/adapters/run_output.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import Dict + + +@dataclass +class RunOutput: + output: Dict | str + intermediate_outputs: Dict[str, str] | None diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index 2a64310e..e7b97f90 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -6,8 +6,8 @@ import kiln_ai.datamodel as datamodel from kiln_ai.adapters.adapter_registry import adapter_for_task -from kiln_ai.adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.ml_model_list import built_in_models +from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.ollama_tools import ollama_online from kiln_ai.adapters.prompt_builders import ( BasePromptBuilder, @@ -108,7 +108,11 @@ async def test_amazon_bedrock(tmp_path): async def test_mock(tmp_path): task = build_test_task(tmp_path) mockChatModel = FakeListChatModel(responses=["mock response"]) - adapter = LangchainAdapter(task, custom_model=mockChatModel) + adapter = LangchainAdapter( + task, + custom_model=mockChatModel, + provider="ollama", + ) run = await adapter.invoke("You are a mock, send me the response!") assert "mock response" in run.output.output @@ -116,7 +120,7 @@ async def test_mock(tmp_path): async def test_mock_returning_run(tmp_path): task = build_test_task(tmp_path) mockChatModel = FakeListChatModel(responses=["mock response"]) - adapter = LangchainAdapter(task, custom_model=mockChatModel) + adapter = LangchainAdapter(task, custom_model=mockChatModel, provider="ollama") run = await adapter.invoke("You are a mock, send me the response!") assert run.output.output == "mock response" assert run is not None @@ -127,7 +131,7 @@ async def test_mock_returning_run(tmp_path): assert run.output.source.properties == { "adapter_name": "kiln_langchain_adapter", "model_name": "custom.langchain:unknown_model", - "model_provider": "custom.langchain:FakeListChatModel", + "model_provider": "ollama", "prompt_builder_name": "simple_prompt_builder", } diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 96687131..8cb24795 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -2,7 +2,10 @@ import pytest -from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter +from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter +from kiln_ai.adapters.model_adapters.test_structured_output import ( + build_structured_output_test_task, +) from kiln_ai.adapters.prompt_builders import ( FewShotChainOfThoughtPromptBuilder, FewShotPromptBuilder, @@ -16,7 +19,6 @@ prompt_builder_from_ui_name, ) from kiln_ai.adapters.test_prompt_adaptors import build_test_task -from kiln_ai.adapters.test_structured_output import build_structured_output_test_task from kiln_ai.datamodel import ( DataSource, DataSourceType, diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index 5751448f..69122d6e 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -278,12 +278,21 @@ class FineTuneStatusType(str, Enum): class StructuredOutputMode(str, Enum): """ Enumeration of supported structured output modes. + + - default: let the adapter decide + - json_schema: request json using API capabilities for json_schema + - function_calling: request json using API capabilities for function calling + - json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema + - json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings. + - json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries). """ default = "default" json_schema = "json_schema" function_calling = "function_calling" json_mode = "json_mode" + json_instructions = "json_instructions" + json_instruction_and_object = "json_instruction_and_object" class Finetune(KilnParentedModel): diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index a9fafa19..191107fb 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -3,7 +3,7 @@ import pytest from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient -from kiln_ai.adapters.langchain_adapters import LangchainAdapter +from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.datamodel import ( DataSource, DataSourceType,