diff --git a/app/desktop/studio_server/data_gen_api.py b/app/desktop/studio_server/data_gen_api.py index d334102..a4f0531 100644 --- a/app/desktop/studio_server/data_gen_api.py +++ b/app/desktop/studio_server/data_gen_api.py @@ -8,6 +8,7 @@ ) from kiln_ai.adapters.prompt_builders import prompt_builder_from_ui_name from kiln_ai.datamodel import DataSource, DataSourceType, TaskRun +from kiln_server.run_api import model_provider_from_string from kiln_server.task_api import task_from_id from pydantic import BaseModel, ConfigDict, Field @@ -83,7 +84,7 @@ async def generate_categories( adapter = adapter_for_task( categories_task, model_name=input.model_name, - provider=input.provider, + provider=model_provider_from_string(input.provider), ) categories_run = await adapter.invoke(task_input.model_dump()) @@ -106,7 +107,7 @@ async def generate_samples( adapter = adapter_for_task( sample_task, model_name=input.model_name, - provider=input.provider, + provider=model_provider_from_string(input.provider), ) samples_run = await adapter.invoke(task_input.model_dump()) @@ -130,7 +131,7 @@ async def save_sample( adapter = adapter_for_task( task, model_name=sample.output_model_name, - provider=sample.output_provider, + provider=model_provider_from_string(sample.output_provider), prompt_builder=prompt_builder, tags=tags, ) diff --git a/app/desktop/studio_server/repair_api.py b/app/desktop/studio_server/repair_api.py index 007a1aa..2b600f5 100644 --- a/app/desktop/studio_server/repair_api.py +++ b/app/desktop/studio_server/repair_api.py @@ -2,7 +2,7 @@ from kiln_ai.adapters.adapter_registry import adapter_for_task from kiln_ai.adapters.repair.repair_task import RepairTaskRun from kiln_ai.datamodel import TaskRun -from kiln_server.run_api import task_and_run_from_id +from kiln_server.run_api import model_provider_from_string, task_and_run_from_id from pydantic import BaseModel, ConfigDict, Field @@ -60,7 +60,9 @@ async def run_repair( ) adapter = adapter_for_task( - repair_task, model_name=model_name, provider=provider + repair_task, + model_name=model_name, + provider=model_provider_from_string(provider), ) repair_run = await adapter.invoke(repair_task_input.model_dump()) diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index f6dc699..c4a67c6 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -1,4 +1,5 @@ from os import getenv +from typing import NoReturn from kiln_ai import datamodel from kiln_ai.adapters.ml_model_list import ModelProviderName @@ -15,28 +16,59 @@ def adapter_for_task( kiln_task: datamodel.Task, model_name: str, - provider: str | None = None, + provider: ModelProviderName, prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ) -> BaseAdapter: - if provider == ModelProviderName.openrouter: - return OpenAICompatibleAdapter( - kiln_task=kiln_task, - config=OpenAICompatibleConfig( - base_url=getenv("OPENROUTER_BASE_URL") - or "https://openrouter.ai/api/v1", - api_key=Config.shared().open_router_api_key, - model_name=model_name, - provider_name=provider, - openrouter_style_reasoning=True, - default_headers={ - "HTTP-Referer": "https://getkiln.ai/openrouter", - "X-Title": "KilnAI", - }, - ), - prompt_builder=prompt_builder, - tags=tags, - ) + match provider: + case ModelProviderName.openrouter: + return OpenAICompatibleAdapter( + kiln_task=kiln_task, + config=OpenAICompatibleConfig( + base_url=getenv("OPENROUTER_BASE_URL") + or "https://openrouter.ai/api/v1", + api_key=Config.shared().open_router_api_key, + model_name=model_name, + provider_name=provider, + openrouter_style_reasoning=True, + default_headers={ + "HTTP-Referer": "https://getkiln.ai/openrouter", + "X-Title": "KilnAI", + }, + ), + prompt_builder=prompt_builder, + tags=tags, + ) + case ModelProviderName.openai: + return OpenAICompatibleAdapter( + kiln_task=kiln_task, + config=OpenAICompatibleConfig( + api_key=Config.shared().open_ai_api_key, + model_name=model_name, + provider_name=provider, + ), + prompt_builder=prompt_builder, + tags=tags, + ) + # Use LangchainAdapter for the rest + case ModelProviderName.openai_compatible: + pass + case ModelProviderName.groq: + pass + case ModelProviderName.amazon_bedrock: + pass + case ModelProviderName.ollama: + pass + case ModelProviderName.fireworks_ai: + pass + case ModelProviderName.kiln_fine_tune: + pass + case ModelProviderName.kiln_custom_registry: + pass + case _: + raise ValueError(f"Unsupported provider: {provider}") + # Triggers typechecking if I miss a case + return NoReturn # We use langchain for all others right now, but moving off it as we touch anything. return LangchainAdapter( diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 13fab1e..c64840a 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -167,7 +167,7 @@ class KilnModel(BaseModel): ), KilnModelProvider( name=ModelProviderName.openrouter, - provider_options={"model": "openai/gpt-4o-2024-08-06"}, + provider_options={"model": "openai/gpt-4o"}, structured_output_mode=StructuredOutputMode.json_schema, ), ], diff --git a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py index 3fc7c69..01f1420 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/model_adapters/langchain_adapters.py @@ -246,8 +246,8 @@ async def langchain_model_from_provider( provider: KilnModelProvider, model_name: str ) -> BaseChatModel: if provider.name == ModelProviderName.openai: - api_key = Config.shared().open_ai_api_key - return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type] + # We use the OpenAICompatibleAdapter for OpenAI + raise ValueError("OpenAI is not supported in Langchain adapter") elif provider.name == ModelProviderName.openai_compatible: # See provider_tools.py for how base_url, key and other parameters are set return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type] diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py index d6a2466..32d3e78 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -25,7 +25,7 @@ class OpenAICompatibleConfig: api_key: str model_name: str provider_name: str - base_url: str | None = None + base_url: str | None = None # Defaults to OpenAI default_headers: dict[str, str] | None = None openrouter_style_reasoning: bool = False diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py index 0adc909..f08c0e7 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_langchain_adapter.py @@ -162,21 +162,6 @@ async def test_get_structured_output_options( assert options.get("method") == expected_method -@pytest.mark.asyncio -async def test_langchain_model_from_provider_openai(): - provider = KilnModelProvider( - name=ModelProviderName.openai, provider_options={"model": "gpt-4"} - ) - - with patch( - "kiln_ai.adapters.model_adapters.langchain_adapters.Config.shared" - ) as mock_config: - mock_config.return_value.open_ai_api_key = "test_key" - model = await langchain_model_from_provider(provider, "gpt-4") - assert isinstance(model, ChatOpenAI) - assert model.model_name == "gpt-4" - - @pytest.mark.asyncio async def test_langchain_model_from_provider_groq(): provider = KilnModelProvider( diff --git a/libs/server/kiln_server/run_api.py b/libs/server/kiln_server/run_api.py index 1be19f8..bd43c15 100644 --- a/libs/server/kiln_server/run_api.py +++ b/libs/server/kiln_server/run_api.py @@ -4,6 +4,7 @@ from fastapi import FastAPI, HTTPException from kiln_ai.adapters.adapter_registry import adapter_for_task +from kiln_ai.adapters.ml_model_list import ModelProviderName from kiln_ai.adapters.prompt_builders import prompt_builder_from_ui_name from kiln_ai.datamodel import Task, TaskOutputRating, TaskOutputRatingType, TaskRun from kiln_ai.datamodel.basemodel import ID_TYPE @@ -199,7 +200,7 @@ async def run_task( adapter = adapter_for_task( task, model_name=request.model_name, - provider=request.provider, + provider=model_provider_from_string(request.provider), prompt_builder=prompt_builder, tags=request.tags, ) @@ -281,3 +282,9 @@ async def update_run_util( updated_run.path = run.path updated_run.save_to_file() return updated_run + + +def model_provider_from_string(provider: str) -> ModelProviderName: + if not provider or provider not in ModelProviderName.__members__: + raise ValueError(f"Unsupported provider: {provider}") + return ModelProviderName(provider) diff --git a/libs/server/kiln_server/test_run_api.py b/libs/server/kiln_server/test_run_api.py index 191107f..477b288 100644 --- a/libs/server/kiln_server/test_run_api.py +++ b/libs/server/kiln_server/test_run_api.py @@ -3,6 +3,7 @@ import pytest from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient +from kiln_ai.adapters.ml_model_list import ModelProviderName from kiln_ai.adapters.model_adapters.langchain_adapters import LangchainAdapter from kiln_ai.datamodel import ( DataSource, @@ -20,6 +21,7 @@ RunSummary, connect_run_api, deep_update, + model_provider_from_string, run_from_id, ) @@ -64,7 +66,7 @@ def task_run_setup(tmp_path): run_task_request = { "model_name": "gpt_4o", - "provider": "openai", + "provider": "ollama", "plaintext_input": "Test input", } @@ -80,7 +82,7 @@ def task_run_setup(tmp_path): type=DataSourceType.synthetic, properties={ "model_name": "gpt_4o", - "model_provider": "openai", + "model_provider": "ollama", "adapter_name": "kiln_langchain_adapter", "prompt_builder_name": "simple_prompt_builder", }, @@ -188,7 +190,7 @@ async def test_run_task_structured_input(client, task_run_setup): ): run_task_request = { "model_name": "gpt_4o", - "provider": "openai", + "provider": "ollama", "structured_input": {"key": "value"}, } @@ -1215,3 +1217,11 @@ async def test_remove_tags_multiple_runs(client, task_run_setup): updated_run2 = TaskRun.from_id_and_parent_path(second_run.id, task.path) assert set(updated_run1.tags) == {"tag2"} assert set(updated_run2.tags) == {"tag3"} + + +def test_model_provider_from_string(): + assert model_provider_from_string("openai") == ModelProviderName.openai + assert model_provider_from_string("ollama") == ModelProviderName.ollama + + with pytest.raises(ValueError, match="Unsupported provider: unknown"): + model_provider_from_string("unknown")