From e67c8ddea28d8e691f72a10fbcae07b469f42ce5 Mon Sep 17 00:00:00 2001 From: scosman Date: Thu, 30 Jan 2025 21:51:35 -0500 Subject: [PATCH] Test fix --- libs/core/kiln_ai/adapters/adapter_registry.py | 2 +- libs/core/kiln_ai/adapters/repair/test_repair_task.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/core/kiln_ai/adapters/adapter_registry.py b/libs/core/kiln_ai/adapters/adapter_registry.py index 68af2708..9fccbc8c 100644 --- a/libs/core/kiln_ai/adapters/adapter_registry.py +++ b/libs/core/kiln_ai/adapters/adapter_registry.py @@ -13,7 +13,7 @@ def adapter_for_task( kiln_task: datamodel.Task, - model_name: str | None = None, + model_name: str, provider: str | None = None, prompt_builder: BasePromptBuilder | None = None, tags: list[str] | None = None, diff --git a/libs/core/kiln_ai/adapters/repair/test_repair_task.py b/libs/core/kiln_ai/adapters/repair/test_repair_task.py index ace3eb84..9c63d974 100644 --- a/libs/core/kiln_ai/adapters/repair/test_repair_task.py +++ b/libs/core/kiln_ai/adapters/repair/test_repair_task.py @@ -223,7 +223,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai ) adapter = adapter_for_task( - repair_task, model_name="llama_3_1_8b", provider="groq" + repair_task, model_name="llama_3_1_8b", provider="ollama" ) run = await adapter.invoke(repair_task_input.model_dump()) @@ -237,7 +237,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai assert run.output.source.properties == { "adapter_name": "kiln_langchain_adapter", "model_name": "llama_3_1_8b", - "model_provider": "groq", + "model_provider": "ollama", "prompt_builder_name": "simple_prompt_builder", } assert run.input_source.type == DataSourceType.human