Skip to content

Commit

Permalink
Use OpenAPICompatible adapter for openai compatible models.
Browse files Browse the repository at this point in the history
Also separate concepts of KilnModelProviders (what’s supported by a model) and adapter config/set (keys, urls, etc).
  • Loading branch information
scosman committed Feb 1, 2025
1 parent 55c6c2f commit a85036f
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 40 deletions.
12 changes: 9 additions & 3 deletions libs/core/kiln_ai/adapters/adapter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
OpenAICompatibleConfig,
)
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
from kiln_ai.adapters.provider_tools import core_provider
from kiln_ai.adapters.provider_tools import core_provider, openai_compatible_config
from kiln_ai.utils.config import Config


Expand Down Expand Up @@ -54,9 +54,15 @@ def adapter_for_task(
prompt_builder=prompt_builder,
tags=tags,
)
# Use LangchainAdapter for the rest
case ModelProviderName.openai_compatible:
pass
config = openai_compatible_config(model_name)
return OpenAICompatibleAdapter(
kiln_task=kiln_task,
config=config,
prompt_builder=prompt_builder,
tags=tags,
)
# Use LangchainAdapter for the rest
case ModelProviderName.groq:
pass
case ModelProviderName.amazon_bedrock:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from langchain_fireworks import ChatFireworks
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from pydantic import BaseModel

import kiln_ai.datamodel as datamodel
Expand Down Expand Up @@ -256,8 +255,8 @@ async def langchain_model_from_provider(
# We use the OpenAICompatibleAdapter for OpenAI
raise ValueError("OpenAI is not supported in Langchain adapter")
elif provider.name == ModelProviderName.openai_compatible:
# See provider_tools.py for how base_url, key and other parameters are set
return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type]
# We use the OpenAICompatibleAdapter for OpenAI compatible
raise ValueError("OpenAI compatible is not supported in Langchain adapter")
elif provider.name == ModelProviderName.groq:
api_key = Config.shared().groq_api_key
if api_key is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass


@dataclass
class OpenAICompatibleConfig:
api_key: str
model_name: str
provider_name: str
base_url: str | None = None # Defaults to OpenAI
default_headers: dict[str, str] | None = None
openrouter_style_reasoning: bool = False
14 changes: 3 additions & 11 deletions libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import Any, Dict, NoReturn

from openai import AsyncOpenAI
Expand All @@ -17,19 +16,12 @@
BasePromptBuilder,
RunOutput,
)
from kiln_ai.adapters.model_adapters.openai_compatible_config import (
OpenAICompatibleConfig,
)
from kiln_ai.adapters.parsers.json_parser import parse_json_string


@dataclass
class OpenAICompatibleConfig:
api_key: str
model_name: str
provider_name: str
base_url: str | None = None # Defaults to OpenAI
default_headers: dict[str, str] | None = None
openrouter_style_reasoning: bool = False


class OpenAICompatibleAdapter(BaseAdapter):
def __init__(
self,
Expand Down
20 changes: 16 additions & 4 deletions libs/core/kiln_ai/adapters/provider_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
StructuredOutputMode,
built_in_models,
)
from kiln_ai.adapters.model_adapters.openai_compatible_config import (
OpenAICompatibleConfig,
)
from kiln_ai.adapters.ollama_tools import (
get_ollama_connection,
)
Expand Down Expand Up @@ -175,9 +178,9 @@ async def kiln_model_provider_from(
)


def openai_compatible_provider_model(
def openai_compatible_config(
model_id: str,
) -> KilnModelProvider:
) -> OpenAICompatibleConfig:
try:
openai_provider_name, model_id = model_id.split("::")
except Exception:
Expand All @@ -201,12 +204,21 @@ def openai_compatible_provider_model(
f"OpenAI compatible provider {openai_provider_name} has no base URL"
)

return OpenAICompatibleConfig(
api_key=api_key,
model_name=model_id,
provider_name=ModelProviderName.openai_compatible,
base_url=base_url,
)


def openai_compatible_provider_model(
model_id: str,
) -> KilnModelProvider:
return KilnModelProvider(
name=ModelProviderName.openai_compatible,
provider_options={
"model": model_id,
"api_key": api_key,
"openai_api_base": base_url,
},
supports_structured_output=False,
supports_data_gen=False,
Expand Down
47 changes: 28 additions & 19 deletions libs/core/kiln_ai/adapters/test_provider_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
finetune_provider_model,
get_model_and_provider,
kiln_model_provider_from,
openai_compatible_config,
openai_compatible_provider_model,
parse_custom_model_id,
provider_enabled,
Expand Down Expand Up @@ -578,6 +579,18 @@ def test_finetune_provider_model_structured_mode(
assert provider.structured_output_mode == expected_mode


def test_openai_compatible_provider_config(mock_shared_config):
"""Test successful creation of an OpenAI compatible provider"""
model_id = "test_provider::gpt-4"

config = openai_compatible_config(model_id)

assert config.provider_name == ModelProviderName.openai_compatible
assert config.model_name == "gpt-4"
assert config.api_key == "test-key"
assert config.base_url == "https://api.test.com"


def test_openai_compatible_provider_model_success(mock_shared_config):
"""Test successful creation of an OpenAI compatible provider"""
model_id = "test_provider::gpt-4"
Expand All @@ -586,57 +599,53 @@ def test_openai_compatible_provider_model_success(mock_shared_config):

assert provider.name == ModelProviderName.openai_compatible
assert provider.provider_options == {
"model": "gpt-4",
"api_key": "test-key",
"openai_api_base": "https://api.test.com",
"model": model_id,
}
assert provider.supports_structured_output is False
assert provider.supports_data_gen is False
assert provider.untested_model is True


def test_openai_compatible_provider_model_no_api_key(mock_shared_config):
def test_openai_compatible_config_no_api_key(mock_shared_config):
"""Test provider creation without API key (should work as some providers don't require it)"""
model_id = "no_key_provider::gpt-4"

provider = openai_compatible_provider_model(model_id)
config = openai_compatible_config(model_id)

assert provider.name == ModelProviderName.openai_compatible
assert provider.provider_options == {
"model": "gpt-4",
"api_key": None,
"openai_api_base": "https://api.nokey.com",
}
assert config.provider_name == ModelProviderName.openai_compatible
assert config.model_name == "gpt-4"
assert config.api_key is None
assert config.base_url == "https://api.nokey.com"


def test_openai_compatible_provider_model_invalid_id():
def test_openai_compatible_config_invalid_id():
"""Test handling of invalid model ID format"""
with pytest.raises(ValueError) as exc_info:
openai_compatible_provider_model("invalid-id-format")
openai_compatible_config("invalid-id-format")
assert (
str(exc_info.value) == "Invalid openai compatible model ID: invalid-id-format"
)


def test_openai_compatible_provider_model_no_providers(mock_shared_config):
def test_openai_compatible_config_no_providers(mock_shared_config):
"""Test handling when no providers are configured"""
mock_shared_config.return_value.openai_compatible_providers = None

with pytest.raises(ValueError) as exc_info:
openai_compatible_provider_model("test_provider::gpt-4")
openai_compatible_config("test_provider::gpt-4")
assert str(exc_info.value) == "OpenAI compatible provider test_provider not found"


def test_openai_compatible_provider_model_provider_not_found(mock_shared_config):
def test_openai_compatible_config_provider_not_found(mock_shared_config):
"""Test handling of non-existent provider"""
with pytest.raises(ValueError) as exc_info:
openai_compatible_provider_model("unknown_provider::gpt-4")
openai_compatible_config("unknown_provider::gpt-4")
assert (
str(exc_info.value) == "OpenAI compatible provider unknown_provider not found"
)


def test_openai_compatible_provider_model_no_base_url(mock_shared_config):
def test_openai_compatible_config_no_base_url(mock_shared_config):
"""Test handling of provider without base URL"""
mock_shared_config.return_value.openai_compatible_providers = [
{
Expand All @@ -646,7 +655,7 @@ def test_openai_compatible_provider_model_no_base_url(mock_shared_config):
]

with pytest.raises(ValueError) as exc_info:
openai_compatible_provider_model("test_provider::gpt-4")
openai_compatible_config("test_provider::gpt-4")
assert (
str(exc_info.value)
== "OpenAI compatible provider test_provider has no base URL"
Expand Down

0 comments on commit a85036f

Please sign in to comment.