Skip to content

Commit

Permalink
Use openai adapter for... openai
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Jan 31, 2025
1 parent 5d00cf5 commit 51b100e
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 47 deletions.
7 changes: 4 additions & 3 deletions app/desktop/studio_server/data_gen_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from kiln_ai.adapters.prompt_builders import prompt_builder_from_ui_name
from kiln_ai.datamodel import DataSource, DataSourceType, TaskRun
from kiln_server.run_api import model_provider_from_string
from kiln_server.task_api import task_from_id
from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -83,7 +84,7 @@ async def generate_categories(
adapter = adapter_for_task(
categories_task,
model_name=input.model_name,
provider=input.provider,
provider=model_provider_from_string(input.provider),
)

categories_run = await adapter.invoke(task_input.model_dump())
Expand All @@ -106,7 +107,7 @@ async def generate_samples(
adapter = adapter_for_task(
sample_task,
model_name=input.model_name,
provider=input.provider,
provider=model_provider_from_string(input.provider),
)

samples_run = await adapter.invoke(task_input.model_dump())
Expand All @@ -130,7 +131,7 @@ async def save_sample(
adapter = adapter_for_task(
task,
model_name=sample.output_model_name,
provider=sample.output_provider,
provider=model_provider_from_string(sample.output_provider),
prompt_builder=prompt_builder,
tags=tags,
)
Expand Down
6 changes: 4 additions & 2 deletions app/desktop/studio_server/repair_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kiln_ai.adapters.adapter_registry import adapter_for_task
from kiln_ai.adapters.repair.repair_task import RepairTaskRun
from kiln_ai.datamodel import TaskRun
from kiln_server.run_api import task_and_run_from_id
from kiln_server.run_api import model_provider_from_string, task_and_run_from_id
from pydantic import BaseModel, ConfigDict, Field


Expand Down Expand Up @@ -60,7 +60,9 @@ async def run_repair(
)

adapter = adapter_for_task(
repair_task, model_name=model_name, provider=provider
repair_task,
model_name=model_name,
provider=model_provider_from_string(provider),
)

repair_run = await adapter.invoke(repair_task_input.model_dump())
Expand Down
70 changes: 51 additions & 19 deletions libs/core/kiln_ai/adapters/adapter_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os import getenv
from typing import NoReturn

from kiln_ai import datamodel
from kiln_ai.adapters.ml_model_list import ModelProviderName
Expand All @@ -15,28 +16,59 @@
def adapter_for_task(
kiln_task: datamodel.Task,
model_name: str,
provider: str | None = None,
provider: ModelProviderName,
prompt_builder: BasePromptBuilder | None = None,
tags: list[str] | None = None,
) -> BaseAdapter:
if provider == ModelProviderName.openrouter:
return OpenAICompatibleAdapter(
kiln_task=kiln_task,
config=OpenAICompatibleConfig(
base_url=getenv("OPENROUTER_BASE_URL")
or "https://openrouter.ai/api/v1",
api_key=Config.shared().open_router_api_key,
model_name=model_name,
provider_name=provider,
openrouter_style_reasoning=True,
default_headers={
"HTTP-Referer": "https://getkiln.ai/openrouter",
"X-Title": "KilnAI",
},
),
prompt_builder=prompt_builder,
tags=tags,
)
match provider:
case ModelProviderName.openrouter:
return OpenAICompatibleAdapter(
kiln_task=kiln_task,
config=OpenAICompatibleConfig(
base_url=getenv("OPENROUTER_BASE_URL")
or "https://openrouter.ai/api/v1",
api_key=Config.shared().open_router_api_key,
model_name=model_name,
provider_name=provider,
openrouter_style_reasoning=True,
default_headers={
"HTTP-Referer": "https://getkiln.ai/openrouter",
"X-Title": "KilnAI",
},
),
prompt_builder=prompt_builder,
tags=tags,
)
case ModelProviderName.openai:
return OpenAICompatibleAdapter(
kiln_task=kiln_task,
config=OpenAICompatibleConfig(
api_key=Config.shared().open_ai_api_key,
model_name=model_name,
provider_name=provider,
),
prompt_builder=prompt_builder,
tags=tags,
)
# Use LangchainAdapter for the rest
case ModelProviderName.openai_compatible:
pass
case ModelProviderName.groq:
pass
case ModelProviderName.amazon_bedrock:
pass
case ModelProviderName.ollama:
pass
case ModelProviderName.fireworks_ai:
pass
case ModelProviderName.kiln_fine_tune:
pass
case ModelProviderName.kiln_custom_registry:
pass
case _:
raise ValueError(f"Unsupported provider: {provider}")
# Triggers typechecking if I miss a case
return NoReturn

# We use langchain for all others right now, but moving off it as we touch anything.
return LangchainAdapter(
Expand Down
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/ml_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class KilnModel(BaseModel):
),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "openai/gpt-4o-2024-08-06"},
provider_options={"model": "openai/gpt-4o"},
structured_output_mode=StructuredOutputMode.json_schema,
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ async def langchain_model_from_provider(
provider: KilnModelProvider, model_name: str
) -> BaseChatModel:
if provider.name == ModelProviderName.openai:
api_key = Config.shared().open_ai_api_key
return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
# We use the OpenAICompatibleAdapter for OpenAI
raise ValueError("OpenAI is not supported in Langchain adapter")
elif provider.name == ModelProviderName.openai_compatible:
# See provider_tools.py for how base_url, key and other parameters are set
return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class OpenAICompatibleConfig:
api_key: str
model_name: str
provider_name: str
base_url: str | None = None
base_url: str | None = None # Defaults to OpenAI
default_headers: dict[str, str] | None = None
openrouter_style_reasoning: bool = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,6 @@ async def test_get_structured_output_options(
assert options.get("method") == expected_method


@pytest.mark.asyncio
async def test_langchain_model_from_provider_openai():
provider = KilnModelProvider(
name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
)

with patch(
"kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared"
) as mock_config:
mock_config.return_value.open_ai_api_key = "test_key"
model = await langchain_model_from_provider(provider, "gpt-4")
assert isinstance(model, ChatOpenAI)
assert model.model_name == "gpt-4"


@pytest.mark.asyncio
async def test_langchain_model_from_provider_groq():
provider = KilnModelProvider(
Expand Down
9 changes: 8 additions & 1 deletion libs/server/kiln_server/run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import FastAPI, HTTPException
from kiln_ai.adapters.adapter_registry import adapter_for_task
from kiln_ai.adapters.ml_model_list import ModelProviderName
from kiln_ai.adapters.prompt_builders import prompt_builder_from_ui_name
from kiln_ai.datamodel import Task, TaskOutputRating, TaskOutputRatingType, TaskRun
from kiln_ai.datamodel.basemodel import ID_TYPE
Expand Down Expand Up @@ -199,7 +200,7 @@ async def run_task(
adapter = adapter_for_task(
task,
model_name=request.model_name,
provider=request.provider,
provider=model_provider_from_string(request.provider),
prompt_builder=prompt_builder,
tags=request.tags,
)
Expand Down Expand Up @@ -281,3 +282,9 @@ async def update_run_util(
updated_run.path = run.path
updated_run.save_to_file()
return updated_run


def model_provider_from_string(provider: str) -> ModelProviderName:
if not provider or provider not in ModelProviderName.__members__:
raise ValueError(f"Unsupported provider: {provider}")
return ModelProviderName(provider)
16 changes: 13 additions & 3 deletions libs/server/kiln_server/test_run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from kiln_ai.adapters.ml_model_list import ModelProviderName
from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter
from kiln_ai.datamodel import (
DataSource,
Expand All @@ -20,6 +21,7 @@
RunSummary,
connect_run_api,
deep_update,
model_provider_from_string,
run_from_id,
)

Expand Down Expand Up @@ -64,7 +66,7 @@ def task_run_setup(tmp_path):

run_task_request = {
"model_name": "gpt_4o",
"provider": "openai",
"provider": "ollama",
"plaintext_input": "Test input",
}

Expand All @@ -80,7 +82,7 @@ def task_run_setup(tmp_path):
type=DataSourceType.synthetic,
properties={
"model_name": "gpt_4o",
"model_provider": "openai",
"model_provider": "ollama",
"adapter_name": "kiln_langchain_adapter",
"prompt_builder_name": "simple_prompt_builder",
},
Expand Down Expand Up @@ -188,7 +190,7 @@ async def test_run_task_structured_input(client, task_run_setup):
):
run_task_request = {
"model_name": "gpt_4o",
"provider": "openai",
"provider": "ollama",
"structured_input": {"key": "value"},
}

Expand Down Expand Up @@ -1215,3 +1217,11 @@ async def test_remove_tags_multiple_runs(client, task_run_setup):
updated_run2 = TaskRun.from_id_and_parent_path(second_run.id, task.path)
assert set(updated_run1.tags) == {"tag2"}
assert set(updated_run2.tags) == {"tag3"}


def test_model_provider_from_string():
assert model_provider_from_string("openai") == ModelProviderName.openai
assert model_provider_from_string("ollama") == ModelProviderName.ollama

with pytest.raises(ValueError, match="Unsupported provider: unknown"):
model_provider_from_string("unknown")

0 comments on commit 51b100e

Please sign in to comment.