From a85036fa2d468f7e9678b95f54318643328fdcf5 Mon Sep 17 00:00:00 2001 From: scosman Date: Sat, 1 Feb 2025 11:17:07 -0500 Subject: [PATCH] Use OpenAPICompatible adapter for openai compatible models. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also separate concepts of KilnModelProviders (what’s supported by a model) and adapter config/set (keys, urls, etc). --- .../core/kiln_ai/adapters/adapter_registry.py | 12 +++-- .../model_adapters/langchain_adapters.py | 5 +- .../openai_compatible_config.py | 11 +++++ .../model_adapters/openai_model_adapter.py | 14 ++---- libs/core/kiln_ai/adapters/provider_tools.py | 20 ++++++-- .../kiln_ai/adapters/test_provider_tools.py | 47 +++++++++++-------- 6 files changed, 69 insertions(+), 40 deletions(-) create mode 100644 libs/core/kiln_ai/adapters/model_adapters/openai_compatible_config.py diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 30dd2c45..88884658 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -10,7 +10,7 @@ OpenAICompatibleConfig, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder -from kiln_ai.adapters.provider_tools import core_provider +from kiln_ai.adapters.provider_tools import core_provider, openai_compatible_config from kiln_ai.utils.config import Config @@ -54,9 +54,15 @@ def adapter_for_task( prompt_builder=prompt_builder, tags=tags, ) - # Use LangchainAdapter for the rest case ModelProviderName.openai_compatible: - pass + config = openai_compatible_config(model_name) + return OpenAICompatibleAdapter( + kiln_task=kiln_task, + config=config, + prompt_builder=prompt_builder, + tags=tags, + ) + # Use LangchainAdapter for the rest case ModelProviderName.groq: pass case ModelProviderName.amazon_bedrock: diff --git a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index 4ef01a3e..1d876d03 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -10,7 +10,6 @@ from langchain_fireworks import ChatFireworks from langchain_groq import ChatGroq from langchain_ollama import ChatOllama -from langchain_openai import ChatOpenAI from pydantic import BaseModel import kiln_ai.datamodel as datamodel @@ -256,8 +255,8 @@ async def langchain_model_from_provider( # We use the OpenAICompatibleAdapter for OpenAI raise ValueError("OpenAI is not supported in Langchain adapter") elif provider.name == ModelProviderName.openai_compatible: - # See provider_tools.py for how base_url, key and other parameters are set - return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type] + # We use the OpenAICompatibleAdapter for OpenAI compatible + raise ValueError("OpenAI compatible is not supported in Langchain adapter") elif provider.name == ModelProviderName.groq: api_key = Config.shared().groq_api_key if api_key is None: diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_compatible_config.py b/libs/core/kiln_ai/adapters/model_adapters/openai_compatible_config.py new file mode 100644 index 00000000..9c743649 --- /dev/null +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_compatible_config.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + + +@dataclass +class OpenAICompatibleConfig: + api_key: str + model_name: str + provider_name: str + base_url: str | None = None # Defaults to OpenAI + default_headers: dict[str, str] | None = None + openrouter_style_reasoning: bool = False diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py index 815e49ed..1be44edb 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any, Dict, NoReturn from openai import AsyncOpenAI @@ -17,19 +16,12 @@ BasePromptBuilder, RunOutput, ) +from kiln_ai.adapters.model_adapters.openai_compatible_config import ( + OpenAICompatibleConfig, +) from kiln_ai.adapters.parsers.json_parser import parse_json_string -@dataclass -class OpenAICompatibleConfig: - api_key: str - model_name: str - provider_name: str - base_url: str | None = None # Defaults to OpenAI - default_headers: dict[str, str] | None = None - openrouter_style_reasoning: bool = False - - class OpenAICompatibleAdapter(BaseAdapter): def __init__( self, diff --git a/libs/core/kiln_ai/adapters/provider_tools.py b/libs/core/kiln_ai/adapters/provider_tools.py index 821cd29d..31d08736 100644 --- a/libs/core/kiln_ai/adapters/provider_tools.py +++ b/libs/core/kiln_ai/adapters/provider_tools.py @@ -9,6 +9,9 @@ StructuredOutputMode, built_in_models, ) +from kiln_ai.adapters.model_adapters.openai_compatible_config import ( + OpenAICompatibleConfig, +) from kiln_ai.adapters.ollama_tools import ( get_ollama_connection, ) @@ -175,9 +178,9 @@ async def kiln_model_provider_from( ) -def openai_compatible_provider_model( +def openai_compatible_config( model_id: str, -) -> KilnModelProvider: +) -> OpenAICompatibleConfig: try: openai_provider_name, model_id = model_id.split("::") except Exception: @@ -201,12 +204,21 @@ def openai_compatible_provider_model( f"OpenAI compatible provider {openai_provider_name} has no base URL" ) + return OpenAICompatibleConfig( + api_key=api_key, + model_name=model_id, + provider_name=ModelProviderName.openai_compatible, + base_url=base_url, + ) + + +def openai_compatible_provider_model( + model_id: str, +) -> KilnModelProvider: return KilnModelProvider( name=ModelProviderName.openai_compatible, provider_options={ "model": model_id, - "api_key": api_key, - "openai_api_base": base_url, }, supports_structured_output=False, supports_data_gen=False, diff --git a/libs/core/kiln_ai/adapters/test_provider_tools.py b/libs/core/kiln_ai/adapters/test_provider_tools.py index b7229f89..7d95d402 100644 --- a/libs/core/kiln_ai/adapters/test_provider_tools.py +++ b/libs/core/kiln_ai/adapters/test_provider_tools.py @@ -17,6 +17,7 @@ finetune_provider_model, get_model_and_provider, kiln_model_provider_from, + openai_compatible_config, openai_compatible_provider_model, parse_custom_model_id, provider_enabled, @@ -578,6 +579,18 @@ def test_finetune_provider_model_structured_mode( assert provider.structured_output_mode == expected_mode +def test_openai_compatible_provider_config(mock_shared_config): + """Test successful creation of an OpenAI compatible provider""" + model_id = "test_provider::gpt-4" + + config = openai_compatible_config(model_id) + + assert config.provider_name == ModelProviderName.openai_compatible + assert config.model_name == "gpt-4" + assert config.api_key == "test-key" + assert config.base_url == "https://api.test.com" + + def test_openai_compatible_provider_model_success(mock_shared_config): """Test successful creation of an OpenAI compatible provider""" model_id = "test_provider::gpt-4" @@ -586,57 +599,53 @@ def test_openai_compatible_provider_model_success(mock_shared_config): assert provider.name == ModelProviderName.openai_compatible assert provider.provider_options == { - "model": "gpt-4", - "api_key": "test-key", - "openai_api_base": "https://api.test.com", + "model": model_id, } assert provider.supports_structured_output is False assert provider.supports_data_gen is False assert provider.untested_model is True -def test_openai_compatible_provider_model_no_api_key(mock_shared_config): +def test_openai_compatible_config_no_api_key(mock_shared_config): """Test provider creation without API key (should work as some providers don't require it)""" model_id = "no_key_provider::gpt-4" - provider = openai_compatible_provider_model(model_id) + config = openai_compatible_config(model_id) - assert provider.name == ModelProviderName.openai_compatible - assert provider.provider_options == { - "model": "gpt-4", - "api_key": None, - "openai_api_base": "https://api.nokey.com", - } + assert config.provider_name == ModelProviderName.openai_compatible + assert config.model_name == "gpt-4" + assert config.api_key is None + assert config.base_url == "https://api.nokey.com" -def test_openai_compatible_provider_model_invalid_id(): +def test_openai_compatible_config_invalid_id(): """Test handling of invalid model ID format""" with pytest.raises(ValueError) as exc_info: - openai_compatible_provider_model("invalid-id-format") + openai_compatible_config("invalid-id-format") assert ( str(exc_info.value) == "Invalid openai compatible model ID: invalid-id-format" ) -def test_openai_compatible_provider_model_no_providers(mock_shared_config): +def test_openai_compatible_config_no_providers(mock_shared_config): """Test handling when no providers are configured""" mock_shared_config.return_value.openai_compatible_providers = None with pytest.raises(ValueError) as exc_info: - openai_compatible_provider_model("test_provider::gpt-4") + openai_compatible_config("test_provider::gpt-4") assert str(exc_info.value) == "OpenAI compatible provider test_provider not found" -def test_openai_compatible_provider_model_provider_not_found(mock_shared_config): +def test_openai_compatible_config_provider_not_found(mock_shared_config): """Test handling of non-existent provider""" with pytest.raises(ValueError) as exc_info: - openai_compatible_provider_model("unknown_provider::gpt-4") + openai_compatible_config("unknown_provider::gpt-4") assert ( str(exc_info.value) == "OpenAI compatible provider unknown_provider not found" ) -def test_openai_compatible_provider_model_no_base_url(mock_shared_config): +def test_openai_compatible_config_no_base_url(mock_shared_config): """Test handling of provider without base URL""" mock_shared_config.return_value.openai_compatible_providers = [ { @@ -646,7 +655,7 @@ def test_openai_compatible_provider_model_no_base_url(mock_shared_config): ] with pytest.raises(ValueError) as exc_info: - openai_compatible_provider_model("test_provider::gpt-4") + openai_compatible_config("test_provider::gpt-4") assert ( str(exc_info.value) == "OpenAI compatible provider test_provider has no base URL"