diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 873d4be7..f6dc6992 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -6,6 +6,7 @@ from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.adapters.model_adapters.openai_model_adapter import ( OpenAICompatibleAdapter, + OpenAICompatibleConfig, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder from kiln_ai.utils.config import Config @@ -19,20 +20,22 @@ def adapter_for_task( tags: list[str] | None = None, ) -> BaseAdapter: if provider == ModelProviderName.openrouter: - api_key = Config.shared().open_router_api_key - base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1" return OpenAICompatibleAdapter( - base_url=base_url, - api_key=api_key, kiln_task=kiln_task, - model_name=model_name, - provider_name=provider, + config=OpenAICompatibleConfig( + base_url=getenv("OPENROUTER_BASE_URL") + or "https://openrouter.ai/api/v1", + api_key=Config.shared().open_router_api_key, + model_name=model_name, + provider_name=provider, + openrouter_style_reasoning=True, + default_headers={ + "HTTP-Referer": "https://getkiln.ai/openrouter", + "X-Title": "KilnAI", + }, + ), prompt_builder=prompt_builder, tags=tags, - default_headers={ - "HTTP-Referer": "https://getkiln.ai/openrouter", - "X-Title": "KilnAI", - }, ) # We use langchain for all others right now, but moving off it as we touch anything. diff --git a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index 1ec4ff8f..3fc7c692 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -1,5 +1,4 @@ import os -from os import getenv from typing import Any, Dict, NoReturn from langchain_aws import ChatBedrockConverse diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py index eb39aeb7..d6a24661 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any, Dict, NoReturn from openai import AsyncOpenAI @@ -7,7 +8,6 @@ ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, ) -from openai.types.chat.chat_completion import Choice import kiln_ai.datamodel as datamodel from kiln_ai.adapters.ml_model_list import StructuredOutputMode @@ -20,32 +20,35 @@ from kiln_ai.adapters.parsers.json_parser import parse_json_string +@dataclass +class OpenAICompatibleConfig: + api_key: str + model_name: str + provider_name: str + base_url: str | None = None + default_headers: dict[str, str] | None = None + openrouter_style_reasoning: bool = False + + class OpenAICompatibleAdapter(BaseAdapter): def __init__( self, - api_key: str, + config: OpenAICompatibleConfig, kiln_task: datamodel.Task, - model_name: str, - provider_name: str, - base_url: str | None = None, # Client will default to OpenAI - default_headers: dict[str, str] | None = None, prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ): - if not model_name or not provider_name: - raise ValueError( - "model_name and provider_name must be provided for OpenAI compatible adapter" - ) - - # Create an async OpenAI client instead + self.config = config self.client = AsyncOpenAI( - api_key=api_key, base_url=base_url, default_headers=default_headers + api_key=config.api_key, + base_url=config.base_url, + default_headers=config.default_headers, ) super().__init__( kiln_task, - model_name=model_name, - model_provider_name=provider_name, + model_name=config.model_name, + model_provider_name=config.provider_name, prompt_builder=prompt_builder, tags=tags, ) @@ -95,15 +98,15 @@ async def _run(self, input: Dict | str) -> RunOutput: # Main completion call response_format_options = await self.response_format_options() - print(f"response_format_options: {response_format_options}") response = await self.client.chat.completions.create( model=provider.provider_options["model"], messages=messages, - # TODO P0: remove this - extra_body={"include_reasoning": True}, + extra_body={"include_reasoning": True} + if self.config.openrouter_style_reasoning + else {}, **response_format_options, ) - print(f"response: {response}") + if not isinstance(response, ChatCompletion): raise RuntimeError( f"Expected ChatCompletion response, got {type(response)}." @@ -121,7 +124,11 @@ async def _run(self, input: Dict | str) -> RunOutput: message = response.choices[0].message # Save reasoning if it exists - if hasattr(message, "reasoning") and message.reasoning: # pyright: ignore + if ( + self.config.openrouter_style_reasoning + and hasattr(message, "reasoning") + and message.reasoning # pyright: ignore + ): intermediate_outputs["reasoning"] = message.reasoning # pyright: ignore # the string content of the response @@ -170,14 +177,11 @@ async def response_format_options(self) -> dict[str, Any]: return {} provider = await self.model_provider() - # TODO check these match provider.structured_output_mode: case StructuredOutputMode.json_mode: return {"response_format": {"type": "json_object"}} case StructuredOutputMode.json_schema: - # TODO P0: use json_schema output_schema = self.kiln_task.output_schema() - print(f"output_schema: {output_schema}") return { "response_format": { "type": "json_schema", @@ -188,17 +192,15 @@ async def response_format_options(self) -> dict[str, Any]: } } case StructuredOutputMode.function_calling: - # TODO P0 return self.tool_call_params() case StructuredOutputMode.json_instructions: - # JSON done via instructions in prompt, not the API response format - # TODO try json_object on stable API + # JSON done via instructions in prompt, not the API response format. Do not ask for json_object (see option below). return {} case StructuredOutputMode.json_instruction_and_object: - # We set response_format to json_object, but we also need to set the instructions in the prompt + # We set response_format to json_object and also set json instructions in the prompt return {"response_format": {"type": "json_object"}} case StructuredOutputMode.default: - # Default to function calling -- it's older than the other modes + # Default to function calling -- it's older than the other modes. Higher compatibility. return self.tool_call_params() case _: raise ValueError(