From 55c6c2fa9be994098402e37483f558023fa6a3a2 Mon Sep 17 00:00:00 2001 From: scosman Date: Sat, 1 Feb 2025 09:59:06 -0500 Subject: [PATCH] Make custom models and fine-tunes work in our new adapter regsitry. Now they use the correct adapter (OpenAI for OpenAI fine-tunes or custom models, etc) Improve fine-tune cache. --- .../core/kiln_ai/adapters/adapter_registry.py | 6 +- .../model_adapters/openai_model_adapter.py | 10 +- libs/core/kiln_ai/adapters/provider_tools.py | 77 +++++- .../kiln_ai/adapters/test_provider_tools.py | 229 +++++++++++++++++- libs/core/kiln_ai/datamodel/model_cache.py | 2 +- 5 files changed, 297 insertions(+), 27 deletions(-) diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index c4a67c6b..30dd2c45 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -10,6 +10,7 @@ OpenAICompatibleConfig, ) from kiln_ai.adapters.prompt_builders import BasePromptBuilder +from kiln_ai.adapters.provider_tools import core_provider from kiln_ai.utils.config import Config @@ -20,7 +21,10 @@ def adapter_for_task( prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, ) -> BaseAdapter: - match provider: + # Get the provider to run. For things like the fine-tune provider, we want to run the underlying provider + core_provider_name = core_provider(model_name, provider) + + match core_provider_name: case ModelProviderName.openrouter: return OpenAICompatibleAdapter( kiln_task=kiln_task, diff --git a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py index 66dadc16..815e49ed 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/openai_model_adapter.py @@ -218,13 +218,21 @@ async def response_format_options(self) -> dict[str, Any]: return NoReturn def tool_call_params(self) -> dict[str, Any]: + # Add additional_properties: false to the schema (OpenAI requires this for some models) + output_schema = self.kiln_task.output_schema() + if not isinstance(output_schema, dict): + raise ValueError( + "Invalid output schema for this task. Can not use tool calls." + ) + output_schema["additionalProperties"] = False + return { "tools": [ { "type": "function", "function": { "name": "task_response", - "parameters": self.kiln_task.output_schema(), + "parameters": output_schema, "strict": True, }, } diff --git a/libs/core/kiln_ai/adapters/provider_tools.py b/libs/core/kiln_ai/adapters/provider_tools.py index 05bf6c89..821cd29d 100644 --- a/libs/core/kiln_ai/adapters/provider_tools.py +++ b/libs/core/kiln_ai/adapters/provider_tools.py @@ -6,6 +6,7 @@ KilnModelProvider, ModelName, ModelProviderName, + StructuredOutputMode, built_in_models, ) from kiln_ai.adapters.ollama_tools import ( @@ -13,8 +14,7 @@ ) from kiln_ai.datamodel import Finetune, Task from kiln_ai.datamodel.registry import project_from_id - -from ..utils.config import Config +from kiln_ai.utils.config import Config async def provider_enabled(provider_name: ModelProviderName) -> bool: @@ -102,6 +102,46 @@ async def builtin_model_from( return provider +def core_provider(model_id: str, provider_name: ModelProviderName) -> ModelProviderName: + """ + Get the provider that should be run. + + Some provider IDs are wrappers (fine-tunes, custom models). This maps these to runnable providers (openai, ollama, etc) + """ + + # Custom models map to the underlying provider + if provider_name is ModelProviderName.kiln_custom_registry: + provider_name, _ = parse_custom_model_id(model_id) + return provider_name + + # Fine-tune provider maps to an underlying provider + if provider_name is ModelProviderName.kiln_fine_tune: + finetune = finetune_from_id(model_id) + if finetune.provider not in ModelProviderName.__members__: + raise ValueError( + f"Finetune {model_id} has no underlying provider {finetune.provider}" + ) + return ModelProviderName(finetune.provider) + + return provider_name + + +def parse_custom_model_id( + model_id: str, +) -> tuple[ModelProviderName, str]: + if "::" not in model_id: + raise ValueError(f"Invalid custom model ID: {model_id}") + + # For custom registry, get the provider name and model name from the model id + provider_name = model_id.split("::", 1)[0] + model_name = model_id.split("::", 1)[1] + + if provider_name not in ModelProviderName.__members__: + raise ValueError(f"Invalid provider name: {provider_name}") + + return ModelProviderName(provider_name), model_name + + async def kiln_model_provider_from( name: str, provider_name: str | None = None ) -> KilnModelProvider: @@ -117,8 +157,7 @@ async def kiln_model_provider_from( # For custom registry, get the provider name and model name from the model id if provider_name == ModelProviderName.kiln_custom_registry: - provider_name = name.split("::", 1)[0] - name = name.split("::", 1)[1] + provider_name, name = parse_custom_model_id(name) # Custom/untested model. Set untested, and build a ModelProvider at runtime if provider_name is None: @@ -136,9 +175,6 @@ async def kiln_model_provider_from( ) -finetune_cache: dict[str, KilnModelProvider] = {} - - def openai_compatible_provider_model( model_id: str, ) -> KilnModelProvider: @@ -178,9 +214,10 @@ def openai_compatible_provider_model( ) -def finetune_provider_model( - model_id: str, -) -> KilnModelProvider: +finetune_cache: dict[str, Finetune] = {} + + +def finetune_from_id(model_id: str) -> Finetune: if model_id in finetune_cache: return finetune_cache[model_id] @@ -202,6 +239,15 @@ def finetune_provider_model( f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab." ) + finetune_cache[model_id] = fine_tune + return fine_tune + + +def finetune_provider_model( + model_id: str, +) -> KilnModelProvider: + fine_tune = finetune_from_id(model_id) + provider = ModelProviderName[fine_tune.provider] model_provider = KilnModelProvider( name=provider, @@ -210,11 +256,18 @@ def finetune_provider_model( }, ) - # If we know the model was trained with specific output mode, set it if fine_tune.structured_output_mode is not None: + # If we know the model was trained with specific output mode, set it model_provider.structured_output_mode = fine_tune.structured_output_mode + else: + # Some early adopters won't have structured_output_mode set on their fine-tunes. + # We know that OpenAI uses json_schema, and Fireworks (only other provider) use json_mode. + # This can be removed in the future + if provider == ModelProviderName.openai: + model_provider.structured_output_mode = StructuredOutputMode.json_schema + else: + model_provider.structured_output_mode = StructuredOutputMode.json_mode - finetune_cache[model_id] = model_provider return model_provider diff --git a/libs/core/kiln_ai/adapters/test_provider_tools.py b/libs/core/kiln_ai/adapters/test_provider_tools.py index 4024db4d..b7229f89 100644 --- a/libs/core/kiln_ai/adapters/test_provider_tools.py +++ b/libs/core/kiln_ai/adapters/test_provider_tools.py @@ -11,11 +11,14 @@ from kiln_ai.adapters.provider_tools import ( builtin_model_from, check_provider_warnings, + core_provider, finetune_cache, + finetune_from_id, finetune_provider_model, get_model_and_provider, kiln_model_provider_from, openai_compatible_provider_model, + parse_custom_model_id, provider_enabled, provider_name_from_id, provider_options_for_custom_model, @@ -478,10 +481,6 @@ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune) assert provider.provider_options == {"model": "ft:gpt-3.5-turbo:custom:model-123"} assert provider.structured_output_mode == StructuredOutputMode.json_schema - # Test cache - cached_provider = finetune_provider_model(model_id) - assert cached_provider is provider - def test_finetune_provider_model_invalid_id(): """Test handling of invalid model ID format""" @@ -536,27 +535,45 @@ def test_finetune_provider_model_incomplete_finetune( @pytest.mark.parametrize( - "structured_output_mode, expected_mode", + "structured_output_mode, provider_name, expected_mode", [ - (StructuredOutputMode.json_mode, StructuredOutputMode.json_mode), - (StructuredOutputMode.json_schema, StructuredOutputMode.json_schema), - (StructuredOutputMode.function_calling, StructuredOutputMode.function_calling), - (None, StructuredOutputMode.default), + ( + StructuredOutputMode.json_mode, + ModelProviderName.fireworks_ai, + StructuredOutputMode.json_mode, + ), + ( + StructuredOutputMode.json_schema, + ModelProviderName.openai, + StructuredOutputMode.json_schema, + ), + ( + StructuredOutputMode.function_calling, + ModelProviderName.openai, + StructuredOutputMode.function_calling, + ), + (None, ModelProviderName.fireworks_ai, StructuredOutputMode.json_mode), + (None, ModelProviderName.openai, StructuredOutputMode.json_schema), ], ) def test_finetune_provider_model_structured_mode( - mock_project, mock_task, mock_finetune, structured_output_mode, expected_mode + mock_project, + mock_task, + mock_finetune, + structured_output_mode, + provider_name, + expected_mode, ): """Test creation of provider with different structured output modes""" finetune = Mock(spec=Finetune) - finetune.provider = ModelProviderName.fireworks_ai + finetune.provider = provider_name finetune.fine_tune_model_id = "fireworks-model-123" finetune.structured_output_mode = structured_output_mode mock_finetune.return_value = finetune provider = finetune_provider_model("project-123::task-456::finetune-789") - assert provider.name == ModelProviderName.fireworks_ai + assert provider.name == provider_name assert provider.provider_options == {"model": "fireworks-model-123"} assert provider.structured_output_mode == expected_mode @@ -634,3 +651,191 @@ def test_openai_compatible_provider_model_no_base_url(mock_shared_config): str(exc_info.value) == "OpenAI compatible provider test_provider has no base URL" ) + + +def test_parse_custom_model_id_valid(): + """Test parsing a valid custom model ID""" + provider_name, model_name = parse_custom_model_id( + "openai::gpt-4-turbo-elite-enterprise-editon" + ) + assert provider_name == ModelProviderName.openai + assert model_name == "gpt-4-turbo-elite-enterprise-editon" + + +def test_parse_custom_model_id_no_separator(): + """Test parsing an invalid model ID without separator""" + with pytest.raises(ValueError) as exc_info: + parse_custom_model_id("invalid-model-id") + assert str(exc_info.value) == "Invalid custom model ID: invalid-model-id" + + +def test_parse_custom_model_id_invalid_provider(): + """Test parsing model ID with invalid provider""" + with pytest.raises(ValueError) as exc_info: + parse_custom_model_id("invalid_provider::model") + assert str(exc_info.value) == "Invalid provider name: invalid_provider" + + +def test_parse_custom_model_id_empty_parts(): + """Test parsing model ID with empty provider or model name""" + with pytest.raises(ValueError) as exc_info: + parse_custom_model_id("::model") + assert str(exc_info.value) == "Invalid provider name: " + + +def test_core_provider_basic_provider(): + """Test core_provider with a basic provider that doesn't need mapping""" + result = core_provider("gpt-4", ModelProviderName.openai) + assert result == ModelProviderName.openai + + +def test_core_provider_custom_registry(): + """Test core_provider with custom registry provider""" + result = core_provider("openai::gpt-4", ModelProviderName.kiln_custom_registry) + assert result == ModelProviderName.openai + + +def test_core_provider_finetune(): + """Test core_provider with fine-tune provider""" + model_id = "project-123::task-456::finetune-789" + + with patch( + "kiln_ai.adapters.provider_tools.finetune_from_id" + ) as mock_finetune_from_id: + # Mock the finetune object + finetune = Mock(spec=Finetune) + finetune.provider = ModelProviderName.openai + mock_finetune_from_id.return_value = finetune + + result = core_provider(model_id, ModelProviderName.kiln_fine_tune) + assert result == ModelProviderName.openai + mock_finetune_from_id.assert_called_once_with(model_id) + + +def test_core_provider_finetune_invalid_provider(): + """Test core_provider with fine-tune having invalid provider""" + model_id = "project-123::task-456::finetune-789" + + with patch( + "kiln_ai.adapters.provider_tools.finetune_from_id" + ) as mock_finetune_from_id: + # Mock finetune with invalid provider + finetune = Mock(spec=Finetune) + finetune.provider = "invalid_provider" + mock_finetune_from_id.return_value = finetune + + with pytest.raises(ValueError) as exc_info: + core_provider(model_id, ModelProviderName.kiln_fine_tune) + assert ( + str(exc_info.value) + == f"Finetune {model_id} has no underlying provider invalid_provider" + ) + mock_finetune_from_id.assert_called_once_with(model_id) + + +def test_finetune_from_id_success(mock_project, mock_task, mock_finetune): + """Test successful retrieval of a finetune model""" + model_id = "project-123::task-456::finetune-789" + + # First call should hit the database + finetune = finetune_from_id(model_id) + + assert finetune.provider == ModelProviderName.openai + assert finetune.fine_tune_model_id == "ft:gpt-3.5-turbo:custom:model-123" + + # Verify mocks were called correctly + mock_project.assert_called_once_with("project-123") + mock_task.assert_called_once_with("task-456", "/fake/path") + mock_finetune.assert_called_once_with("finetune-789", "/fake/path/task") + + # Second call should use cache + cached_finetune = finetune_from_id(model_id) + assert cached_finetune is finetune + + # Verify no additional disk calls were made + mock_project.assert_called_once() + mock_task.assert_called_once() + mock_finetune.assert_called_once() + + +def test_finetune_from_id_invalid_id(): + """Test handling of invalid model ID format""" + with pytest.raises(ValueError) as exc_info: + finetune_from_id("invalid-id-format") + assert str(exc_info.value) == "Invalid fine tune ID: invalid-id-format" + + +def test_finetune_from_id_project_not_found(mock_project): + """Test handling of non-existent project""" + mock_project.return_value = None + model_id = "project-123::task-456::finetune-789" + + with pytest.raises(ValueError) as exc_info: + finetune_from_id(model_id) + assert str(exc_info.value) == "Project project-123 not found" + + # Verify cache was not populated + assert model_id not in finetune_cache + + +def test_finetune_from_id_task_not_found(mock_project, mock_task): + """Test handling of non-existent task""" + mock_task.return_value = None + model_id = "project-123::task-456::finetune-789" + + with pytest.raises(ValueError) as exc_info: + finetune_from_id(model_id) + assert str(exc_info.value) == "Task task-456 not found" + + # Verify cache was not populated + assert model_id not in finetune_cache + + +def test_finetune_from_id_finetune_not_found(mock_project, mock_task, mock_finetune): + """Test handling of non-existent finetune""" + mock_finetune.return_value = None + model_id = "project-123::task-456::finetune-789" + + with pytest.raises(ValueError) as exc_info: + finetune_from_id(model_id) + assert str(exc_info.value) == "Fine tune finetune-789 not found" + + # Verify cache was not populated + assert model_id not in finetune_cache + + +def test_finetune_from_id_incomplete_finetune(mock_project, mock_task, mock_finetune): + """Test handling of incomplete finetune""" + finetune = Mock(spec=Finetune) + finetune.fine_tune_model_id = None + mock_finetune.return_value = finetune + model_id = "project-123::task-456::finetune-789" + + with pytest.raises(ValueError) as exc_info: + finetune_from_id(model_id) + assert ( + str(exc_info.value) + == "Fine tune finetune-789 not completed. Refresh it's status in the fine-tune tab." + ) + + # Verify cache was not populated with incomplete finetune + assert model_id not in finetune_cache + + +def test_finetune_from_id_cache_hit(mock_project, mock_task, mock_finetune): + """Test that cached finetune is returned without database calls""" + model_id = "project-123::task-456::finetune-789" + + # Pre-populate cache + finetune = Mock(spec=Finetune) + finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123" + finetune_cache[model_id] = finetune + + # Get finetune from cache + result = finetune_from_id(model_id) + + assert result == finetune + # Verify no database calls were made + mock_project.assert_not_called() + mock_task.assert_not_called() + mock_finetune.assert_not_called() diff --git a/libs/core/kiln_ai/datamodel/model_cache.py b/libs/core/kiln_ai/datamodel/model_cache.py index 127c702b..f93385af 100644 --- a/libs/core/kiln_ai/datamodel/model_cache.py +++ b/libs/core/kiln_ai/datamodel/model_cache.py @@ -65,7 +65,7 @@ def _get_model(self, path: Path, model_type: Type[T]) -> Optional[T]: def get_model( self, path: Path, model_type: Type[T], readonly: bool = False ) -> Optional[T]: - # We return a copy so in-memory edits don't impact the cache until they are saved + # We return a copy by default, so in-memory edits don't impact the cache until they are saved # Benchmark shows about 2x slower, but much more foolproof model = self._get_model(path, model_type) if model: