diff --git a/app/desktop/studio_server/finetune_api.py b/app/desktop/studio_server/finetune_api.py index 1de6a737..c8a0f222 100644 --- a/app/desktop/studio_server/finetune_api.py +++ b/app/desktop/studio_server/finetune_api.py @@ -9,7 +9,10 @@ ModelProviderName, built_in_models, ) -from kiln_ai.adapters.prompt_builders import prompt_builder_from_ui_name +from kiln_ai.adapters.prompt_builders import ( + chain_of_thought_prompt, + prompt_builder_from_ui_name, +) from kiln_ai.adapters.provider_tools import ( provider_enabled, provider_name_from_id, @@ -100,6 +103,7 @@ class CreateFinetuneRequest(BaseModel): base_model_id: str system_message_generator: str | None = None custom_system_message: str | None = None + custom_thinking_instructions: str | None = None data_strategy: FinetuneDataStrategy @@ -245,6 +249,9 @@ async def create_finetune( system_message = system_message_from_request( task, request.custom_system_message, request.system_message_generator ) + thinking_instructions = thinking_instructions_from_request( + task, request.data_strategy, request.custom_thinking_instructions + ) _, finetune_model = await finetune_adapter_class.create_and_start( dataset=dataset, @@ -252,6 +259,7 @@ async def create_finetune( provider_base_model_id=request.base_model_id, train_split_name=request.train_split_name, system_message=system_message, + thinking_instructions=thinking_instructions, parameters=request.parameters, name=request.name, description=request.description, @@ -271,6 +279,7 @@ async def download_dataset_jsonl( data_strategy: str, system_message_generator: str | None = None, custom_system_message: str | None = None, + custom_thinking_instructions: str | None = None, ) -> StreamingResponse: if format_type not in [format.value for format in DatasetFormat]: raise HTTPException( @@ -301,8 +310,15 @@ async def download_dataset_jsonl( system_message = system_message_from_request( task, custom_system_message, system_message_generator ) + thinking_instructions = thinking_instructions_from_request( + task, data_strategy_typed, custom_thinking_instructions + ) - dataset_formatter = DatasetFormatter(dataset, system_message) + dataset_formatter = DatasetFormatter( + dataset=dataset, + system_message=system_message, + thinking_instructions=thinking_instructions, + ) path = dataset_formatter.dump_to_file( split_name, format_type_typed, @@ -349,3 +365,20 @@ def system_message_from_request( ) return system_message + + +def thinking_instructions_from_request( + task: Task, + data_strategy: FinetuneDataStrategy, + custom_thinking_instructions: str | None, +) -> str | None: + if data_strategy != FinetuneDataStrategy.final_and_intermediate: + # Not using COT/Thinking style + return None + + if custom_thinking_instructions: + # prefer custom instructions + return custom_thinking_instructions + + # default for this task + return chain_of_thought_prompt(task) diff --git a/app/desktop/studio_server/test_finetune_api.py b/app/desktop/studio_server/test_finetune_api.py index 1f1e8290..7de5d28b 100644 --- a/app/desktop/studio_server/test_finetune_api.py +++ b/app/desktop/studio_server/test_finetune_api.py @@ -1,6 +1,6 @@ import unittest.mock from pathlib import Path -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, Mock, patch import pytest from fastapi import FastAPI @@ -28,6 +28,7 @@ DatasetFilterType, DatasetSplitType, connect_fine_tune_api, + thinking_instructions_from_request, ) @@ -424,6 +425,7 @@ def mock_finetune_adapter(): base_model_id="base_model_1", dataset_split_id="split1", system_message="Test system message", + thinking_instructions=None, ), ) ) @@ -431,8 +433,16 @@ def mock_finetune_adapter(): @pytest.mark.parametrize( - "data_strategy", - [FinetuneDataStrategy.final_only, FinetuneDataStrategy.final_and_intermediate], + "data_strategy,custom_thinking_instructions,expected_thinking_instructions", + [ + (FinetuneDataStrategy.final_only, None, None), + ( + FinetuneDataStrategy.final_and_intermediate, + None, + "Think step by step, explaining your reasoning.", + ), # Our default + (FinetuneDataStrategy.final_and_intermediate, "CTI", "CTI"), + ], ) async def test_create_finetune( client, @@ -441,6 +451,8 @@ async def test_create_finetune( mock_finetune_registry, mock_finetune_adapter, data_strategy, + custom_thinking_instructions, + expected_thinking_instructions, ): mock_finetune_registry["test_provider"] = mock_finetune_adapter @@ -454,6 +466,7 @@ async def test_create_finetune( "provider": "test_provider", "base_model_id": "base_model_1", "custom_system_message": "Test system message", + "custom_thinking_instructions": custom_thinking_instructions, "data_strategy": data_strategy.value, } @@ -477,6 +490,7 @@ async def test_create_finetune( provider_base_model_id="base_model_1", train_split_name="train", system_message="Test system message", + thinking_instructions=expected_thinking_instructions, parameters={"learning_rate": 0.001, "epochs": 10}, name="New Finetune", description="Test description", @@ -867,6 +881,7 @@ def test_download_dataset_jsonl_with_prompt_builder( "split_name": "train", "format_type": "openai_chat_jsonl", "system_message_generator": "test_prompt_builder", + "custom_thinking_instructions": "custom thinking instructions", "data_strategy": "final_only", }, ) @@ -879,7 +894,11 @@ def test_download_dataset_jsonl_with_prompt_builder( split1 = next(split for split in test_task.dataset_splits() if split.id == "split1") # Verify formatter was created with generated system message - mock_formatter_class.assert_called_once_with(split1, "Generated system message") + mock_formatter_class.assert_called_once_with( + dataset=split1, + system_message="Generated system message", + thinking_instructions=None, + ) async def test_get_finetune(client, mock_task_from_id_disk_backed): @@ -960,3 +979,42 @@ def __class_getitem__(cls, key): # Verify that status was only checked for the pending finetune mock_adapter_class.assert_called_once_with(tune1) mock_adapter.status.assert_called_once() + + +def test_thinking_instructions_non_cot_strategy(): + """Test that non-COT strategies return None regardless of other parameters""" + task = Mock(spec=Task) + result = thinking_instructions_from_request( + task=task, + data_strategy=FinetuneDataStrategy.final_only, + custom_thinking_instructions="custom instructions", + ) + assert result is None + + +def test_thinking_instructions_custom(): + """Test that custom instructions are returned when provided""" + task = Mock(spec=Task) + custom_instructions = "My custom thinking instructions" + result = thinking_instructions_from_request( + task=task, + data_strategy=FinetuneDataStrategy.final_and_intermediate, + custom_thinking_instructions=custom_instructions, + ) + assert result == custom_instructions + + +@patch("app.desktop.studio_server.finetune_api.chain_of_thought_prompt") +def test_thinking_instructions_default(mock_cot): + """Test that default chain of thought prompt is used when no custom instructions""" + task = Mock(spec=Task) + mock_cot.return_value = "Default COT instructions" + + result = thinking_instructions_from_request( + task=task, + data_strategy=FinetuneDataStrategy.final_and_intermediate, + custom_thinking_instructions=None, + ) + + mock_cot.assert_called_once_with(task) + assert result == "Default COT instructions" diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index 15bc16ab..10e4839b 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -701,6 +701,8 @@ export interface components { system_message_generator?: string | null; /** Custom System Message */ custom_system_message?: string | null; + /** Custom Thinking Instructions */ + custom_thinking_instructions?: string | null; data_strategy: components["schemas"]["FinetuneDataStrategy"]; }; /** DataGenCategoriesApiInput */ @@ -1037,6 +1039,11 @@ export interface components { * @description The system message to use for this fine-tune. */ system_message: string; + /** + * Thinking Instructions + * @description The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate. + */ + thinking_instructions?: string | null; /** * @description The latest known status of this fine-tune. Not updated in real time. * @default unknown @@ -3181,6 +3188,7 @@ export interface operations { data_strategy: string; system_message_generator?: string | null; custom_system_message?: string | null; + custom_thinking_instructions?: string | null; }; header?: never; path?: never; diff --git a/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte b/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte index 0afdcc84..599e05ef 100644 --- a/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte +++ b/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte @@ -28,6 +28,8 @@ let automatic_validation = disabled_header let data_strategy: FinetuneDataStrategy = "final_only" let finetune_custom_system_prompt = "" + let finetune_custom_thinking_instructions = + "Think step by step, explaining your reasoning." let system_prompt_method = "basic" $: project_id = $page.params.project_id @@ -336,6 +338,12 @@ ? finetune_custom_system_prompt : undefined } + function get_custom_thinking_instructions_param(): string | undefined { + return system_prompt_method === "custom" && + data_strategy === "final_and_intermediate" + ? finetune_custom_thinking_instructions + : undefined + } let create_finetune_error: KilnError | null = null let create_finetune_loading = false @@ -369,6 +377,8 @@ : undefined, system_message_generator: get_system_prompt_method_param(), custom_system_message: get_custom_system_prompt_param(), + custom_thinking_instructions: + get_custom_thinking_instructions_param(), parameters: hyperparameter_values, data_strategy: data_strategy, validation_split_name: @@ -451,6 +461,7 @@ format_type: download_model_select_options[model_provider], system_message_generator: get_system_prompt_method_param(), custom_system_message: get_custom_system_prompt_param(), + custom_thinking_instructions: get_custom_thinking_instructions_param(), } // Format params as query string, including escaping values and filtering undefined @@ -613,14 +624,27 @@ custom_prompt_name="Custom Fine Tuning Prompt" /> {#if system_prompt_method === "custom"} - +
+ + {#if data_strategy === "final_and_intermediate"} +
+ + {/if} +
{/if} ModelTrainingData: """ Generate data for training. @@ -77,7 +80,6 @@ def build_training_data( final_output = task_run.repaired_output.output thinking = None - thinking_instructions = None thinking_final_answer_prompt = None parent_task = task_run.parent_task() @@ -93,12 +95,20 @@ def build_training_data( raise ValueError( "TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task." ) + + # Prefer reasoning to cot if both are present thinking = task_run.intermediate_outputs.get( "reasoning" ) or task_run.intermediate_outputs.get("chain_of_thought") - thinking_instructions = chain_of_thought_prompt(parent_task) + thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT + # Always use the passed thinking instructions, but check they are present for COT + if not thinking_instructions: + raise ValueError( + "Thinking instructions are required when data_strategy is final_and_intermediate" + ) + return ModelTrainingData( input=task_run.input, system_message=system_message, @@ -350,9 +360,15 @@ def generate_vertex_gemini_1_5( class DatasetFormatter: """Handles formatting of datasets into various output formats""" - def __init__(self, dataset: DatasetSplit, system_message: str): + def __init__( + self, + dataset: DatasetSplit, + system_message: str, + thinking_instructions: str | None = None, + ): self.dataset = dataset self.system_message = system_message + self.thinking_instructions = thinking_instructions task = dataset.parent_task() if task is None: @@ -410,7 +426,10 @@ def dump_to_file( ) training_data = build_training_data( - task_run, self.system_message, include_cot + task_run=task_run, + system_message=self.system_message, + include_cot=include_cot, + thinking_instructions=self.thinking_instructions, ) example = generator(training_data) # Allow non-ascii characters in the dataset. diff --git a/libs/core/kiln_ai/adapters/fine_tune/fireworks_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/fireworks_finetune.py index f52fc274..05227f83 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/fireworks_finetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/fireworks_finetune.py @@ -169,7 +169,11 @@ async def _start(self, dataset: DatasetSplit) -> None: async def generate_and_upload_jsonl( self, dataset: DatasetSplit, split_name: str, task: Task, format: DatasetFormat ) -> str: - formatter = DatasetFormatter(dataset, self.datamodel.system_message) + formatter = DatasetFormatter( + dataset=dataset, + system_message=self.datamodel.system_message, + thinking_instructions=self.datamodel.thinking_instructions, + ) path = formatter.dump_to_file(split_name, format, self.datamodel.data_strategy) # First call creates the dataset diff --git a/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py index 474a334c..bb0727a0 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py @@ -163,7 +163,9 @@ async def _start(self, dataset: DatasetSplit) -> None: async def generate_and_upload_jsonl( self, dataset: DatasetSplit, split_name: str, task: Task, format: DatasetFormat ) -> str: - formatter = DatasetFormatter(dataset, self.datamodel.system_message) + formatter = DatasetFormatter( + dataset, self.datamodel.system_message, self.datamodel.thinking_instructions + ) path = formatter.dump_to_file(split_name, format, self.datamodel.data_strategy) response = await oai_client.files.create( diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py index 329ef8c9..70e443e3 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py @@ -156,6 +156,7 @@ async def test_create_and_start_success(mock_dataset): parameters={"epochs": 10}, # Required parameter system_message="Test system message", data_strategy=FinetuneDataStrategy.final_only, + thinking_instructions=None, ) assert isinstance(adapter, MockFinetune) @@ -169,6 +170,7 @@ async def test_create_and_start_success(mock_dataset): assert datamodel.system_message == "Test system message" assert datamodel.path.exists() assert datamodel.data_strategy == FinetuneDataStrategy.final_only + assert datamodel.thinking_instructions is None async def test_create_and_start_with_all_params(mock_dataset): @@ -184,6 +186,7 @@ async def test_create_and_start_with_all_params(mock_dataset): validation_split_name="test", system_message="Test system message", data_strategy=FinetuneDataStrategy.final_and_intermediate, + thinking_instructions="Custom thinking instructions", ) assert datamodel.name == "Custom Name" @@ -193,6 +196,7 @@ async def test_create_and_start_with_all_params(mock_dataset): assert datamodel.system_message == "Test system message" assert adapter.datamodel == datamodel assert datamodel.data_strategy == FinetuneDataStrategy.final_and_intermediate + assert datamodel.thinking_instructions == "Custom thinking instructions" # load the datamodel from the file, confirm it's saved loaded_datamodel = FinetuneModel.load_from_file(datamodel.path) @@ -209,6 +213,7 @@ async def test_create_and_start_invalid_parameters(mock_dataset): train_split_name="train", parameters={"learning_rate": 0.001}, # Missing required 'epochs' system_message="Test system message", + thinking_instructions=None, data_strategy=FinetuneDataStrategy.final_only, ) @@ -229,6 +234,7 @@ async def test_create_and_start_no_parent_task(): parameters={"epochs": 10}, system_message="Test system message", data_strategy=FinetuneDataStrategy.final_only, + thinking_instructions=None, ) @@ -251,6 +257,7 @@ async def test_create_and_start_no_parent_task_path(): parameters={"epochs": 10}, system_message="Test system message", data_strategy=FinetuneDataStrategy.final_only, + thinking_instructions=None, ) @@ -278,6 +285,7 @@ async def test_create_and_start_invalid_train_split(mock_dataset): parameters={"epochs": 10}, system_message="Test system message", data_strategy=FinetuneDataStrategy.final_only, + thinking_instructions=None, ) @@ -297,4 +305,5 @@ async def test_create_and_start_invalid_validation_split(mock_dataset): parameters={"epochs": 10}, system_message="Test system message", data_strategy=FinetuneDataStrategy.final_only, + thinking_instructions=None, ) diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py b/libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py index 1909e9c4..3c8f6025 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py +++ b/libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py @@ -337,7 +337,11 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path): def test_dataset_formatter_dump_with_intermediate_data( mock_dataset, mock_intermediate_outputs ): - formatter = DatasetFormatter(mock_dataset, "system message 你好") + formatter = DatasetFormatter( + mock_dataset, + "system message 你好", + thinking_instructions="thinking instructions", + ) result_path = formatter.dump_to_file( "train", @@ -361,6 +365,36 @@ def test_dataset_formatter_dump_with_intermediate_data( assert "thinking instructions" in line +def test_dataset_formatter_dump_with_intermediate_data_custom_instructions( + mock_dataset, mock_intermediate_outputs +): + formatter = DatasetFormatter( + mock_dataset, "custom system message 你好", "custom thinking instructions" + ) + + result_path = formatter.dump_to_file( + "train", + DatasetFormat.OPENAI_CHAT_JSONL, + data_strategy=FinetuneDataStrategy.final_and_intermediate, + ) + + assert result_path.exists() + assert result_path.parent == Path(tempfile.gettempdir()) + # Test our nice naming, with cot + assert ( + result_path.name + == "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl" + ) + # Verify file contents + with open(result_path) as f: + lines = f.readlines() + assert len(lines) == 2 + for line in lines: + assert "custom system message 你好" in line + assert "custom thinking instructions" in line + assert "thinking output" in line + + def test_generate_huggingface_chat_template(): training_data = ModelTrainingData( input="test input", @@ -542,6 +576,7 @@ def test_build_training_data(mock_task): assert training_data_output.thinking_final_answer_prompt is None assert training_data_output.input == '{"test": "input 你好"}' assert training_data_output.system_message == "system message" + assert not training_data_output.supports_cot() def test_build_training_data_with_COT(mock_task): @@ -549,16 +584,20 @@ def test_build_training_data_with_COT(mock_task): mock_task_run = mock_task.runs()[0] assert mock_task_run.parent_task() == mock_task mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"} - mock_task.thinking_instruction = "thinking instructions" - assert mock_task.thinking_instruction == "thinking instructions" - training_data_output = build_training_data(mock_task_run, "system message", True) + training_data_output = build_training_data( + mock_task_run, + "system message", + True, + thinking_instructions="thinking instructions", + ) assert training_data_output.final_output == '{"test": "output 你好"}' assert training_data_output.thinking == "cot output" assert training_data_output.thinking_instructions == "thinking instructions" assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT assert training_data_output.input == '{"test": "input 你好"}' assert training_data_output.system_message == "system message" + assert training_data_output.supports_cot() def test_build_training_data_with_thinking(mock_task): @@ -573,13 +612,19 @@ def test_build_training_data_with_thinking(mock_task): mock_task.thinking_instruction = "thinking instructions" assert mock_task.thinking_instruction == "thinking instructions" - training_data_output = build_training_data(mock_task_run, "system message", True) + training_data_output = build_training_data( + mock_task_run, + "system message", + True, + thinking_instructions="thinking instructions", + ) assert training_data_output.final_output == '{"test": "output 你好"}' assert training_data_output.thinking == "thinking output" assert training_data_output.thinking_instructions == "thinking instructions" assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT assert training_data_output.input == '{"test": "input 你好"}' assert training_data_output.system_message == "system message" + assert training_data_output.supports_cot() def test_build_training_data_with_repaired_output(mock_task): diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py b/libs/core/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py index 754e9c64..f1ed5786 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py @@ -230,14 +230,19 @@ def mock_task(): @pytest.mark.parametrize( - "data_strategy", + "data_strategy,thinking_instructions", [ - FinetuneDataStrategy.final_and_intermediate, - FinetuneDataStrategy.final_only, + (FinetuneDataStrategy.final_and_intermediate, "thinking instructions"), + (FinetuneDataStrategy.final_only, None), ], ) async def test_generate_and_upload_jsonl_success( - fireworks_finetune, mock_dataset, mock_task, mock_api_key, data_strategy + mock_dataset, + mock_task, + mock_api_key, + data_strategy, + thinking_instructions, + tmp_path, ): mock_path = Path("mock_path.jsonl") mock_dataset_id = "dataset-123" @@ -258,13 +263,26 @@ async def test_generate_and_upload_jsonl_success( status_response.json.return_value = {"state": "READY"} # Set the data strategy on the finetune model - fireworks_finetune.datamodel.data_strategy = data_strategy + tmp_file = tmp_path / "test-finetune.kiln" + fireworks_finetune = FireworksFinetune( + datamodel=FinetuneModel( + name="test-finetune", + provider="fireworks", + provider_id="fw-123", + base_model_id="llama-v2-7b", + train_split_name="train", + dataset_split_id="dataset-123", + system_message="Test system message", + path=tmp_file, + data_strategy=data_strategy, + thinking_instructions=thinking_instructions, + ), + ) with ( patch( "kiln_ai.adapters.fine_tune.fireworks_finetune.DatasetFormatter", - return_value=mock_formatter, - ), + ) as mock_formatter_constructor, patch("httpx.AsyncClient") as mock_client_class, patch("builtins.open"), patch( @@ -272,6 +290,7 @@ async def test_generate_and_upload_jsonl_success( return_value=mock_dataset_id, ), ): + mock_formatter_constructor.return_value = mock_formatter mock_client = AsyncMock() mock_client.post = AsyncMock(side_effect=[create_response, upload_response]) mock_client.get = AsyncMock(return_value=status_response) @@ -282,11 +301,19 @@ async def test_generate_and_upload_jsonl_success( ) # Verify formatter was created with correct parameters - mock_formatter.dump_to_file.assert_called_once_with( - "train", - DatasetFormat.OPENAI_CHAT_JSONL, - data_strategy, # Confirm we use correct data strategy - ) + assert mock_formatter_constructor.call_count == 1 + assert mock_formatter_constructor.call_args[1] == { + "dataset": mock_dataset, + "system_message": "Test system message", + "thinking_instructions": thinking_instructions, + } + + # Verify the thinking instructions were set on the formatter + mock_formatter.method_calls[0][0] == "dump_to_file" + mock_formatter.method_calls[0][1] == { + "dataset": mock_dataset, + "thinking_instructions": thinking_instructions, + } assert result == mock_dataset_id assert mock_client.post.call_count == 2 diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py index 97fb1a62..34e38b3d 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py +++ b/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py @@ -240,12 +240,14 @@ async def test_generate_and_upload_jsonl_success( # Verify formatter was created with correct parameters mock_formatter_class.assert_called_once_with( - mock_dataset, openai_finetune.datamodel.system_message + mock_dataset, openai_finetune.datamodel.system_message, None ) # Verify correct format was used mock_formatter.dump_to_file.assert_called_once_with( - "train", DatasetFormat.OPENAI_CHAT_JSONL, FinetuneDataStrategy.final_only + "train", + DatasetFormat.OPENAI_CHAT_JSONL, + FinetuneDataStrategy.final_only, ) # Verify file was opened and uploaded @@ -290,7 +292,7 @@ async def test_generate_and_upload_jsonl_schema_success( # Verify formatter was created with correct parameters mock_formatter_class.assert_called_once_with( - mock_dataset, openai_finetune.datamodel.system_message + mock_dataset, openai_finetune.datamodel.system_message, None ) # Verify correct format was used @@ -551,20 +553,33 @@ async def test_status_updates_latest_status(openai_finetune, mock_response): @pytest.mark.parametrize( - "data_strategy", + "data_strategy,thinking_instructions", [ - FinetuneDataStrategy.final_and_intermediate, - FinetuneDataStrategy.final_only, + (FinetuneDataStrategy.final_and_intermediate, "Custom thinking instructions"), + (FinetuneDataStrategy.final_only, None), ], ) async def test_generate_and_upload_jsonl_with_data_strategy( - openai_finetune, mock_dataset, mock_task, data_strategy + mock_dataset, mock_task, data_strategy, thinking_instructions, tmp_path ): mock_path = Path("mock_path.jsonl") mock_file_id = "file-123" - # Set a data strategy on the finetune model - openai_finetune.datamodel.data_strategy = data_strategy + openai_finetune = OpenAIFinetune( + datamodel=FinetuneModel( + name="test-finetune", + provider="openai", + provider_id="openai-123", + base_model_id="gpt-4o", + train_split_name="train", + dataset_split_id="dataset-123", + system_message="Test system message", + fine_tune_model_id="ft-123", + path=tmp_path / "test-finetune.kiln", + data_strategy=data_strategy, + thinking_instructions=thinking_instructions, + ), + ) # Mock the formatter mock_formatter = MagicMock(spec=DatasetFormatter) diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index 22bd5f79..a041f5e4 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -348,6 +348,10 @@ class Finetune(KilnParentedModel): system_message: str = Field( description="The system message to use for this fine-tune.", ) + thinking_instructions: str | None = Field( + default=None, + description="The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate.", + ) latest_status: FineTuneStatusType = Field( default=FineTuneStatusType.unknown, description="The latest known status of this fine-tune. Not updated in real time.", @@ -366,6 +370,24 @@ def parent_task(self) -> Task | None: return None return self.parent + @model_validator(mode="after") + def validate_thinking_instructions(self) -> Self: + if ( + self.thinking_instructions is not None + and self.data_strategy != FinetuneDataStrategy.final_and_intermediate + ): + raise ValueError( + "Thinking instructions can only be used when data_strategy is final_and_intermediate" + ) + if ( + self.thinking_instructions is None + and self.data_strategy == FinetuneDataStrategy.final_and_intermediate + ): + raise ValueError( + "Thinking instructions are required when data_strategy is final_and_intermediate" + ) + return self + class DataSourceType(str, Enum): """ diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 02107b16..046eac01 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -9,6 +9,7 @@ DataSource, DataSourceType, Finetune, + FinetuneDataStrategy, Project, Prompt, Task, @@ -527,3 +528,61 @@ def test_prompt_parent_task(): task = Task(name="Test Task", instruction="Test Instruction") prompt = Prompt(name="Test Prompt", prompt="Test Prompt", parent=task) assert prompt.parent == task + + +@pytest.mark.parametrize( + "thinking_instructions,data_strategy,should_raise,expected_message", + [ + # Test 1: Valid case - no thinking instructions with final_only + ( + None, + FinetuneDataStrategy.final_only, + False, + None, + ), + # Test 2: Valid case - thinking instructions with final_and_intermediate + ( + "Think step by step", + FinetuneDataStrategy.final_and_intermediate, + False, + None, + ), + # Test 3: Invalid case - thinking instructions with final_only + ( + "Think step by step", + FinetuneDataStrategy.final_only, + True, + "Thinking instructions can only be used when data_strategy is final_and_intermediate", + ), + # Test 4: Invalid case - no thinking instructions with final_and_intermediate + ( + None, + FinetuneDataStrategy.final_and_intermediate, + True, + "Thinking instructions are required when data_strategy is final_and_intermediate", + ), + ], +) +def test_finetune_thinking_instructions_validation( + thinking_instructions, data_strategy, should_raise, expected_message +): + base_params = { + "name": "test-finetune", + "provider": "openai", + "base_model_id": "gpt-3.5-turbo", + "dataset_split_id": "split1", + "system_message": "test message", + "data_strategy": data_strategy, + } + + if thinking_instructions is not None: + base_params["thinking_instructions"] = thinking_instructions + + if should_raise: + with pytest.raises(ValueError) as exc_info: + Finetune(**base_params) + assert expected_message in str(exc_info.value) + else: + finetune = Finetune(**base_params) + assert finetune.thinking_instructions == thinking_instructions + assert finetune.data_strategy == data_strategy