diff --git a/app/desktop/studio_server/finetune_api.py b/app/desktop/studio_server/finetune_api.py
index 1de6a737..c8a0f222 100644
--- a/app/desktop/studio_server/finetune_api.py
+++ b/app/desktop/studio_server/finetune_api.py
@@ -9,7 +9,10 @@
ModelProviderName,
built_in_models,
)
-from kiln_ai.adapters.prompt_builders import prompt_builder_from_ui_name
+from kiln_ai.adapters.prompt_builders import (
+ chain_of_thought_prompt,
+ prompt_builder_from_ui_name,
+)
from kiln_ai.adapters.provider_tools import (
provider_enabled,
provider_name_from_id,
@@ -100,6 +103,7 @@ class CreateFinetuneRequest(BaseModel):
base_model_id: str
system_message_generator: str | None = None
custom_system_message: str | None = None
+ custom_thinking_instructions: str | None = None
data_strategy: FinetuneDataStrategy
@@ -245,6 +249,9 @@ async def create_finetune(
system_message = system_message_from_request(
task, request.custom_system_message, request.system_message_generator
)
+ thinking_instructions = thinking_instructions_from_request(
+ task, request.data_strategy, request.custom_thinking_instructions
+ )
_, finetune_model = await finetune_adapter_class.create_and_start(
dataset=dataset,
@@ -252,6 +259,7 @@ async def create_finetune(
provider_base_model_id=request.base_model_id,
train_split_name=request.train_split_name,
system_message=system_message,
+ thinking_instructions=thinking_instructions,
parameters=request.parameters,
name=request.name,
description=request.description,
@@ -271,6 +279,7 @@ async def download_dataset_jsonl(
data_strategy: str,
system_message_generator: str | None = None,
custom_system_message: str | None = None,
+ custom_thinking_instructions: str | None = None,
) -> StreamingResponse:
if format_type not in [format.value for format in DatasetFormat]:
raise HTTPException(
@@ -301,8 +310,15 @@ async def download_dataset_jsonl(
system_message = system_message_from_request(
task, custom_system_message, system_message_generator
)
+ thinking_instructions = thinking_instructions_from_request(
+ task, data_strategy_typed, custom_thinking_instructions
+ )
- dataset_formatter = DatasetFormatter(dataset, system_message)
+ dataset_formatter = DatasetFormatter(
+ dataset=dataset,
+ system_message=system_message,
+ thinking_instructions=thinking_instructions,
+ )
path = dataset_formatter.dump_to_file(
split_name,
format_type_typed,
@@ -349,3 +365,20 @@ def system_message_from_request(
)
return system_message
+
+
+def thinking_instructions_from_request(
+ task: Task,
+ data_strategy: FinetuneDataStrategy,
+ custom_thinking_instructions: str | None,
+) -> str | None:
+ if data_strategy != FinetuneDataStrategy.final_and_intermediate:
+ # Not using COT/Thinking style
+ return None
+
+ if custom_thinking_instructions:
+ # prefer custom instructions
+ return custom_thinking_instructions
+
+ # default for this task
+ return chain_of_thought_prompt(task)
diff --git a/app/desktop/studio_server/test_finetune_api.py b/app/desktop/studio_server/test_finetune_api.py
index 1f1e8290..7de5d28b 100644
--- a/app/desktop/studio_server/test_finetune_api.py
+++ b/app/desktop/studio_server/test_finetune_api.py
@@ -1,6 +1,6 @@
import unittest.mock
from pathlib import Path
-from unittest.mock import AsyncMock, Mock
+from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import FastAPI
@@ -28,6 +28,7 @@
DatasetFilterType,
DatasetSplitType,
connect_fine_tune_api,
+ thinking_instructions_from_request,
)
@@ -424,6 +425,7 @@ def mock_finetune_adapter():
base_model_id="base_model_1",
dataset_split_id="split1",
system_message="Test system message",
+ thinking_instructions=None,
),
)
)
@@ -431,8 +433,16 @@ def mock_finetune_adapter():
@pytest.mark.parametrize(
- "data_strategy",
- [FinetuneDataStrategy.final_only, FinetuneDataStrategy.final_and_intermediate],
+ "data_strategy,custom_thinking_instructions,expected_thinking_instructions",
+ [
+ (FinetuneDataStrategy.final_only, None, None),
+ (
+ FinetuneDataStrategy.final_and_intermediate,
+ None,
+ "Think step by step, explaining your reasoning.",
+ ), # Our default
+ (FinetuneDataStrategy.final_and_intermediate, "CTI", "CTI"),
+ ],
)
async def test_create_finetune(
client,
@@ -441,6 +451,8 @@ async def test_create_finetune(
mock_finetune_registry,
mock_finetune_adapter,
data_strategy,
+ custom_thinking_instructions,
+ expected_thinking_instructions,
):
mock_finetune_registry["test_provider"] = mock_finetune_adapter
@@ -454,6 +466,7 @@ async def test_create_finetune(
"provider": "test_provider",
"base_model_id": "base_model_1",
"custom_system_message": "Test system message",
+ "custom_thinking_instructions": custom_thinking_instructions,
"data_strategy": data_strategy.value,
}
@@ -477,6 +490,7 @@ async def test_create_finetune(
provider_base_model_id="base_model_1",
train_split_name="train",
system_message="Test system message",
+ thinking_instructions=expected_thinking_instructions,
parameters={"learning_rate": 0.001, "epochs": 10},
name="New Finetune",
description="Test description",
@@ -867,6 +881,7 @@ def test_download_dataset_jsonl_with_prompt_builder(
"split_name": "train",
"format_type": "openai_chat_jsonl",
"system_message_generator": "test_prompt_builder",
+ "custom_thinking_instructions": "custom thinking instructions",
"data_strategy": "final_only",
},
)
@@ -879,7 +894,11 @@ def test_download_dataset_jsonl_with_prompt_builder(
split1 = next(split for split in test_task.dataset_splits() if split.id == "split1")
# Verify formatter was created with generated system message
- mock_formatter_class.assert_called_once_with(split1, "Generated system message")
+ mock_formatter_class.assert_called_once_with(
+ dataset=split1,
+ system_message="Generated system message",
+ thinking_instructions=None,
+ )
async def test_get_finetune(client, mock_task_from_id_disk_backed):
@@ -960,3 +979,42 @@ def __class_getitem__(cls, key):
# Verify that status was only checked for the pending finetune
mock_adapter_class.assert_called_once_with(tune1)
mock_adapter.status.assert_called_once()
+
+
+def test_thinking_instructions_non_cot_strategy():
+ """Test that non-COT strategies return None regardless of other parameters"""
+ task = Mock(spec=Task)
+ result = thinking_instructions_from_request(
+ task=task,
+ data_strategy=FinetuneDataStrategy.final_only,
+ custom_thinking_instructions="custom instructions",
+ )
+ assert result is None
+
+
+def test_thinking_instructions_custom():
+ """Test that custom instructions are returned when provided"""
+ task = Mock(spec=Task)
+ custom_instructions = "My custom thinking instructions"
+ result = thinking_instructions_from_request(
+ task=task,
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
+ custom_thinking_instructions=custom_instructions,
+ )
+ assert result == custom_instructions
+
+
+@patch("app.desktop.studio_server.finetune_api.chain_of_thought_prompt")
+def test_thinking_instructions_default(mock_cot):
+ """Test that default chain of thought prompt is used when no custom instructions"""
+ task = Mock(spec=Task)
+ mock_cot.return_value = "Default COT instructions"
+
+ result = thinking_instructions_from_request(
+ task=task,
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
+ custom_thinking_instructions=None,
+ )
+
+ mock_cot.assert_called_once_with(task)
+ assert result == "Default COT instructions"
diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts
index 15bc16ab..10e4839b 100644
--- a/app/web_ui/src/lib/api_schema.d.ts
+++ b/app/web_ui/src/lib/api_schema.d.ts
@@ -701,6 +701,8 @@ export interface components {
system_message_generator?: string | null;
/** Custom System Message */
custom_system_message?: string | null;
+ /** Custom Thinking Instructions */
+ custom_thinking_instructions?: string | null;
data_strategy: components["schemas"]["FinetuneDataStrategy"];
};
/** DataGenCategoriesApiInput */
@@ -1037,6 +1039,11 @@ export interface components {
* @description The system message to use for this fine-tune.
*/
system_message: string;
+ /**
+ * Thinking Instructions
+ * @description The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate.
+ */
+ thinking_instructions?: string | null;
/**
* @description The latest known status of this fine-tune. Not updated in real time.
* @default unknown
@@ -3181,6 +3188,7 @@ export interface operations {
data_strategy: string;
system_message_generator?: string | null;
custom_system_message?: string | null;
+ custom_thinking_instructions?: string | null;
};
header?: never;
path?: never;
diff --git a/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte b/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte
index 0afdcc84..599e05ef 100644
--- a/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte
+++ b/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte
@@ -28,6 +28,8 @@
let automatic_validation = disabled_header
let data_strategy: FinetuneDataStrategy = "final_only"
let finetune_custom_system_prompt = ""
+ let finetune_custom_thinking_instructions =
+ "Think step by step, explaining your reasoning."
let system_prompt_method = "basic"
$: project_id = $page.params.project_id
@@ -336,6 +338,12 @@
? finetune_custom_system_prompt
: undefined
}
+ function get_custom_thinking_instructions_param(): string | undefined {
+ return system_prompt_method === "custom" &&
+ data_strategy === "final_and_intermediate"
+ ? finetune_custom_thinking_instructions
+ : undefined
+ }
let create_finetune_error: KilnError | null = null
let create_finetune_loading = false
@@ -369,6 +377,8 @@
: undefined,
system_message_generator: get_system_prompt_method_param(),
custom_system_message: get_custom_system_prompt_param(),
+ custom_thinking_instructions:
+ get_custom_thinking_instructions_param(),
parameters: hyperparameter_values,
data_strategy: data_strategy,
validation_split_name:
@@ -451,6 +461,7 @@
format_type: download_model_select_options[model_provider],
system_message_generator: get_system_prompt_method_param(),
custom_system_message: get_custom_system_prompt_param(),
+ custom_thinking_instructions: get_custom_thinking_instructions_param(),
}
// Format params as query string, including escaping values and filtering undefined
@@ -613,14 +624,27 @@
custom_prompt_name="Custom Fine Tuning Prompt"
/>
{#if system_prompt_method === "custom"}
-
+
+
+ {#if data_strategy === "final_and_intermediate"}
+
+
+ {/if}
+
{/if}
ModelTrainingData:
"""
Generate data for training.
@@ -77,7 +80,6 @@ def build_training_data(
final_output = task_run.repaired_output.output
thinking = None
- thinking_instructions = None
thinking_final_answer_prompt = None
parent_task = task_run.parent_task()
@@ -93,12 +95,20 @@ def build_training_data(
raise ValueError(
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
)
+
+ # Prefer reasoning to cot if both are present
thinking = task_run.intermediate_outputs.get(
"reasoning"
) or task_run.intermediate_outputs.get("chain_of_thought")
- thinking_instructions = chain_of_thought_prompt(parent_task)
+
thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT
+ # Always use the passed thinking instructions, but check they are present for COT
+ if not thinking_instructions:
+ raise ValueError(
+ "Thinking instructions are required when data_strategy is final_and_intermediate"
+ )
+
return ModelTrainingData(
input=task_run.input,
system_message=system_message,
@@ -350,9 +360,15 @@ def generate_vertex_gemini_1_5(
class DatasetFormatter:
"""Handles formatting of datasets into various output formats"""
- def __init__(self, dataset: DatasetSplit, system_message: str):
+ def __init__(
+ self,
+ dataset: DatasetSplit,
+ system_message: str,
+ thinking_instructions: str | None = None,
+ ):
self.dataset = dataset
self.system_message = system_message
+ self.thinking_instructions = thinking_instructions
task = dataset.parent_task()
if task is None:
@@ -410,7 +426,10 @@ def dump_to_file(
)
training_data = build_training_data(
- task_run, self.system_message, include_cot
+ task_run=task_run,
+ system_message=self.system_message,
+ include_cot=include_cot,
+ thinking_instructions=self.thinking_instructions,
)
example = generator(training_data)
# Allow non-ascii characters in the dataset.
diff --git a/libs/core/kiln_ai/adapters/fine_tune/fireworks_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/fireworks_finetune.py
index f52fc274..05227f83 100644
--- a/libs/core/kiln_ai/adapters/fine_tune/fireworks_finetune.py
+++ b/libs/core/kiln_ai/adapters/fine_tune/fireworks_finetune.py
@@ -169,7 +169,11 @@ async def _start(self, dataset: DatasetSplit) -> None:
async def generate_and_upload_jsonl(
self, dataset: DatasetSplit, split_name: str, task: Task, format: DatasetFormat
) -> str:
- formatter = DatasetFormatter(dataset, self.datamodel.system_message)
+ formatter = DatasetFormatter(
+ dataset=dataset,
+ system_message=self.datamodel.system_message,
+ thinking_instructions=self.datamodel.thinking_instructions,
+ )
path = formatter.dump_to_file(split_name, format, self.datamodel.data_strategy)
# First call creates the dataset
diff --git a/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py
index 474a334c..bb0727a0 100644
--- a/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py
+++ b/libs/core/kiln_ai/adapters/fine_tune/openai_finetune.py
@@ -163,7 +163,9 @@ async def _start(self, dataset: DatasetSplit) -> None:
async def generate_and_upload_jsonl(
self, dataset: DatasetSplit, split_name: str, task: Task, format: DatasetFormat
) -> str:
- formatter = DatasetFormatter(dataset, self.datamodel.system_message)
+ formatter = DatasetFormatter(
+ dataset, self.datamodel.system_message, self.datamodel.thinking_instructions
+ )
path = formatter.dump_to_file(split_name, format, self.datamodel.data_strategy)
response = await oai_client.files.create(
diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py
index 329ef8c9..70e443e3 100644
--- a/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py
+++ b/libs/core/kiln_ai/adapters/fine_tune/test_base_finetune.py
@@ -156,6 +156,7 @@ async def test_create_and_start_success(mock_dataset):
parameters={"epochs": 10}, # Required parameter
system_message="Test system message",
data_strategy=FinetuneDataStrategy.final_only,
+ thinking_instructions=None,
)
assert isinstance(adapter, MockFinetune)
@@ -169,6 +170,7 @@ async def test_create_and_start_success(mock_dataset):
assert datamodel.system_message == "Test system message"
assert datamodel.path.exists()
assert datamodel.data_strategy == FinetuneDataStrategy.final_only
+ assert datamodel.thinking_instructions is None
async def test_create_and_start_with_all_params(mock_dataset):
@@ -184,6 +186,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
validation_split_name="test",
system_message="Test system message",
data_strategy=FinetuneDataStrategy.final_and_intermediate,
+ thinking_instructions="Custom thinking instructions",
)
assert datamodel.name == "Custom Name"
@@ -193,6 +196,7 @@ async def test_create_and_start_with_all_params(mock_dataset):
assert datamodel.system_message == "Test system message"
assert adapter.datamodel == datamodel
assert datamodel.data_strategy == FinetuneDataStrategy.final_and_intermediate
+ assert datamodel.thinking_instructions == "Custom thinking instructions"
# load the datamodel from the file, confirm it's saved
loaded_datamodel = FinetuneModel.load_from_file(datamodel.path)
@@ -209,6 +213,7 @@ async def test_create_and_start_invalid_parameters(mock_dataset):
train_split_name="train",
parameters={"learning_rate": 0.001}, # Missing required 'epochs'
system_message="Test system message",
+ thinking_instructions=None,
data_strategy=FinetuneDataStrategy.final_only,
)
@@ -229,6 +234,7 @@ async def test_create_and_start_no_parent_task():
parameters={"epochs": 10},
system_message="Test system message",
data_strategy=FinetuneDataStrategy.final_only,
+ thinking_instructions=None,
)
@@ -251,6 +257,7 @@ async def test_create_and_start_no_parent_task_path():
parameters={"epochs": 10},
system_message="Test system message",
data_strategy=FinetuneDataStrategy.final_only,
+ thinking_instructions=None,
)
@@ -278,6 +285,7 @@ async def test_create_and_start_invalid_train_split(mock_dataset):
parameters={"epochs": 10},
system_message="Test system message",
data_strategy=FinetuneDataStrategy.final_only,
+ thinking_instructions=None,
)
@@ -297,4 +305,5 @@ async def test_create_and_start_invalid_validation_split(mock_dataset):
parameters={"epochs": 10},
system_message="Test system message",
data_strategy=FinetuneDataStrategy.final_only,
+ thinking_instructions=None,
)
diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py b/libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py
index 1909e9c4..3c8f6025 100644
--- a/libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py
+++ b/libs/core/kiln_ai/adapters/fine_tune/test_dataset_formatter.py
@@ -337,7 +337,11 @@ def test_dataset_formatter_dump_to_file_tool_format(mock_dataset, tmp_path):
def test_dataset_formatter_dump_with_intermediate_data(
mock_dataset, mock_intermediate_outputs
):
- formatter = DatasetFormatter(mock_dataset, "system message 你好")
+ formatter = DatasetFormatter(
+ mock_dataset,
+ "system message 你好",
+ thinking_instructions="thinking instructions",
+ )
result_path = formatter.dump_to_file(
"train",
@@ -361,6 +365,36 @@ def test_dataset_formatter_dump_with_intermediate_data(
assert "thinking instructions" in line
+def test_dataset_formatter_dump_with_intermediate_data_custom_instructions(
+ mock_dataset, mock_intermediate_outputs
+):
+ formatter = DatasetFormatter(
+ mock_dataset, "custom system message 你好", "custom thinking instructions"
+ )
+
+ result_path = formatter.dump_to_file(
+ "train",
+ DatasetFormat.OPENAI_CHAT_JSONL,
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
+ )
+
+ assert result_path.exists()
+ assert result_path.parent == Path(tempfile.gettempdir())
+ # Test our nice naming, with cot
+ assert (
+ result_path.name
+ == "test_dataset -- split-train -- format-openai_chat_jsonl -- cot.jsonl"
+ )
+ # Verify file contents
+ with open(result_path) as f:
+ lines = f.readlines()
+ assert len(lines) == 2
+ for line in lines:
+ assert "custom system message 你好" in line
+ assert "custom thinking instructions" in line
+ assert "thinking output" in line
+
+
def test_generate_huggingface_chat_template():
training_data = ModelTrainingData(
input="test input",
@@ -542,6 +576,7 @@ def test_build_training_data(mock_task):
assert training_data_output.thinking_final_answer_prompt is None
assert training_data_output.input == '{"test": "input 你好"}'
assert training_data_output.system_message == "system message"
+ assert not training_data_output.supports_cot()
def test_build_training_data_with_COT(mock_task):
@@ -549,16 +584,20 @@ def test_build_training_data_with_COT(mock_task):
mock_task_run = mock_task.runs()[0]
assert mock_task_run.parent_task() == mock_task
mock_task_run.intermediate_outputs = {"chain_of_thought": "cot output"}
- mock_task.thinking_instruction = "thinking instructions"
- assert mock_task.thinking_instruction == "thinking instructions"
- training_data_output = build_training_data(mock_task_run, "system message", True)
+ training_data_output = build_training_data(
+ mock_task_run,
+ "system message",
+ True,
+ thinking_instructions="thinking instructions",
+ )
assert training_data_output.final_output == '{"test": "output 你好"}'
assert training_data_output.thinking == "cot output"
assert training_data_output.thinking_instructions == "thinking instructions"
assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
assert training_data_output.input == '{"test": "input 你好"}'
assert training_data_output.system_message == "system message"
+ assert training_data_output.supports_cot()
def test_build_training_data_with_thinking(mock_task):
@@ -573,13 +612,19 @@ def test_build_training_data_with_thinking(mock_task):
mock_task.thinking_instruction = "thinking instructions"
assert mock_task.thinking_instruction == "thinking instructions"
- training_data_output = build_training_data(mock_task_run, "system message", True)
+ training_data_output = build_training_data(
+ mock_task_run,
+ "system message",
+ True,
+ thinking_instructions="thinking instructions",
+ )
assert training_data_output.final_output == '{"test": "output 你好"}'
assert training_data_output.thinking == "thinking output"
assert training_data_output.thinking_instructions == "thinking instructions"
assert training_data_output.thinking_final_answer_prompt == COT_FINAL_ANSWER_PROMPT
assert training_data_output.input == '{"test": "input 你好"}'
assert training_data_output.system_message == "system message"
+ assert training_data_output.supports_cot()
def test_build_training_data_with_repaired_output(mock_task):
diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py b/libs/core/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py
index 754e9c64..f1ed5786 100644
--- a/libs/core/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py
+++ b/libs/core/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py
@@ -230,14 +230,19 @@ def mock_task():
@pytest.mark.parametrize(
- "data_strategy",
+ "data_strategy,thinking_instructions",
[
- FinetuneDataStrategy.final_and_intermediate,
- FinetuneDataStrategy.final_only,
+ (FinetuneDataStrategy.final_and_intermediate, "thinking instructions"),
+ (FinetuneDataStrategy.final_only, None),
],
)
async def test_generate_and_upload_jsonl_success(
- fireworks_finetune, mock_dataset, mock_task, mock_api_key, data_strategy
+ mock_dataset,
+ mock_task,
+ mock_api_key,
+ data_strategy,
+ thinking_instructions,
+ tmp_path,
):
mock_path = Path("mock_path.jsonl")
mock_dataset_id = "dataset-123"
@@ -258,13 +263,26 @@ async def test_generate_and_upload_jsonl_success(
status_response.json.return_value = {"state": "READY"}
# Set the data strategy on the finetune model
- fireworks_finetune.datamodel.data_strategy = data_strategy
+ tmp_file = tmp_path / "test-finetune.kiln"
+ fireworks_finetune = FireworksFinetune(
+ datamodel=FinetuneModel(
+ name="test-finetune",
+ provider="fireworks",
+ provider_id="fw-123",
+ base_model_id="llama-v2-7b",
+ train_split_name="train",
+ dataset_split_id="dataset-123",
+ system_message="Test system message",
+ path=tmp_file,
+ data_strategy=data_strategy,
+ thinking_instructions=thinking_instructions,
+ ),
+ )
with (
patch(
"kiln_ai.adapters.fine_tune.fireworks_finetune.DatasetFormatter",
- return_value=mock_formatter,
- ),
+ ) as mock_formatter_constructor,
patch("httpx.AsyncClient") as mock_client_class,
patch("builtins.open"),
patch(
@@ -272,6 +290,7 @@ async def test_generate_and_upload_jsonl_success(
return_value=mock_dataset_id,
),
):
+ mock_formatter_constructor.return_value = mock_formatter
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=[create_response, upload_response])
mock_client.get = AsyncMock(return_value=status_response)
@@ -282,11 +301,19 @@ async def test_generate_and_upload_jsonl_success(
)
# Verify formatter was created with correct parameters
- mock_formatter.dump_to_file.assert_called_once_with(
- "train",
- DatasetFormat.OPENAI_CHAT_JSONL,
- data_strategy, # Confirm we use correct data strategy
- )
+ assert mock_formatter_constructor.call_count == 1
+ assert mock_formatter_constructor.call_args[1] == {
+ "dataset": mock_dataset,
+ "system_message": "Test system message",
+ "thinking_instructions": thinking_instructions,
+ }
+
+ # Verify the thinking instructions were set on the formatter
+ mock_formatter.method_calls[0][0] == "dump_to_file"
+ mock_formatter.method_calls[0][1] == {
+ "dataset": mock_dataset,
+ "thinking_instructions": thinking_instructions,
+ }
assert result == mock_dataset_id
assert mock_client.post.call_count == 2
diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py b/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py
index 97fb1a62..34e38b3d 100644
--- a/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py
+++ b/libs/core/kiln_ai/adapters/fine_tune/test_openai_finetune.py
@@ -240,12 +240,14 @@ async def test_generate_and_upload_jsonl_success(
# Verify formatter was created with correct parameters
mock_formatter_class.assert_called_once_with(
- mock_dataset, openai_finetune.datamodel.system_message
+ mock_dataset, openai_finetune.datamodel.system_message, None
)
# Verify correct format was used
mock_formatter.dump_to_file.assert_called_once_with(
- "train", DatasetFormat.OPENAI_CHAT_JSONL, FinetuneDataStrategy.final_only
+ "train",
+ DatasetFormat.OPENAI_CHAT_JSONL,
+ FinetuneDataStrategy.final_only,
)
# Verify file was opened and uploaded
@@ -290,7 +292,7 @@ async def test_generate_and_upload_jsonl_schema_success(
# Verify formatter was created with correct parameters
mock_formatter_class.assert_called_once_with(
- mock_dataset, openai_finetune.datamodel.system_message
+ mock_dataset, openai_finetune.datamodel.system_message, None
)
# Verify correct format was used
@@ -551,20 +553,33 @@ async def test_status_updates_latest_status(openai_finetune, mock_response):
@pytest.mark.parametrize(
- "data_strategy",
+ "data_strategy,thinking_instructions",
[
- FinetuneDataStrategy.final_and_intermediate,
- FinetuneDataStrategy.final_only,
+ (FinetuneDataStrategy.final_and_intermediate, "Custom thinking instructions"),
+ (FinetuneDataStrategy.final_only, None),
],
)
async def test_generate_and_upload_jsonl_with_data_strategy(
- openai_finetune, mock_dataset, mock_task, data_strategy
+ mock_dataset, mock_task, data_strategy, thinking_instructions, tmp_path
):
mock_path = Path("mock_path.jsonl")
mock_file_id = "file-123"
- # Set a data strategy on the finetune model
- openai_finetune.datamodel.data_strategy = data_strategy
+ openai_finetune = OpenAIFinetune(
+ datamodel=FinetuneModel(
+ name="test-finetune",
+ provider="openai",
+ provider_id="openai-123",
+ base_model_id="gpt-4o",
+ train_split_name="train",
+ dataset_split_id="dataset-123",
+ system_message="Test system message",
+ fine_tune_model_id="ft-123",
+ path=tmp_path / "test-finetune.kiln",
+ data_strategy=data_strategy,
+ thinking_instructions=thinking_instructions,
+ ),
+ )
# Mock the formatter
mock_formatter = MagicMock(spec=DatasetFormatter)
diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py
index 22bd5f79..a041f5e4 100644
--- a/libs/core/kiln_ai/datamodel/__init__.py
+++ b/libs/core/kiln_ai/datamodel/__init__.py
@@ -348,6 +348,10 @@ class Finetune(KilnParentedModel):
system_message: str = Field(
description="The system message to use for this fine-tune.",
)
+ thinking_instructions: str | None = Field(
+ default=None,
+ description="The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate.",
+ )
latest_status: FineTuneStatusType = Field(
default=FineTuneStatusType.unknown,
description="The latest known status of this fine-tune. Not updated in real time.",
@@ -366,6 +370,24 @@ def parent_task(self) -> Task | None:
return None
return self.parent
+ @model_validator(mode="after")
+ def validate_thinking_instructions(self) -> Self:
+ if (
+ self.thinking_instructions is not None
+ and self.data_strategy != FinetuneDataStrategy.final_and_intermediate
+ ):
+ raise ValueError(
+ "Thinking instructions can only be used when data_strategy is final_and_intermediate"
+ )
+ if (
+ self.thinking_instructions is None
+ and self.data_strategy == FinetuneDataStrategy.final_and_intermediate
+ ):
+ raise ValueError(
+ "Thinking instructions are required when data_strategy is final_and_intermediate"
+ )
+ return self
+
class DataSourceType(str, Enum):
"""
diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py
index 02107b16..046eac01 100644
--- a/libs/core/kiln_ai/datamodel/test_models.py
+++ b/libs/core/kiln_ai/datamodel/test_models.py
@@ -9,6 +9,7 @@
DataSource,
DataSourceType,
Finetune,
+ FinetuneDataStrategy,
Project,
Prompt,
Task,
@@ -527,3 +528,61 @@ def test_prompt_parent_task():
task = Task(name="Test Task", instruction="Test Instruction")
prompt = Prompt(name="Test Prompt", prompt="Test Prompt", parent=task)
assert prompt.parent == task
+
+
+@pytest.mark.parametrize(
+ "thinking_instructions,data_strategy,should_raise,expected_message",
+ [
+ # Test 1: Valid case - no thinking instructions with final_only
+ (
+ None,
+ FinetuneDataStrategy.final_only,
+ False,
+ None,
+ ),
+ # Test 2: Valid case - thinking instructions with final_and_intermediate
+ (
+ "Think step by step",
+ FinetuneDataStrategy.final_and_intermediate,
+ False,
+ None,
+ ),
+ # Test 3: Invalid case - thinking instructions with final_only
+ (
+ "Think step by step",
+ FinetuneDataStrategy.final_only,
+ True,
+ "Thinking instructions can only be used when data_strategy is final_and_intermediate",
+ ),
+ # Test 4: Invalid case - no thinking instructions with final_and_intermediate
+ (
+ None,
+ FinetuneDataStrategy.final_and_intermediate,
+ True,
+ "Thinking instructions are required when data_strategy is final_and_intermediate",
+ ),
+ ],
+)
+def test_finetune_thinking_instructions_validation(
+ thinking_instructions, data_strategy, should_raise, expected_message
+):
+ base_params = {
+ "name": "test-finetune",
+ "provider": "openai",
+ "base_model_id": "gpt-3.5-turbo",
+ "dataset_split_id": "split1",
+ "system_message": "test message",
+ "data_strategy": data_strategy,
+ }
+
+ if thinking_instructions is not None:
+ base_params["thinking_instructions"] = thinking_instructions
+
+ if should_raise:
+ with pytest.raises(ValueError) as exc_info:
+ Finetune(**base_params)
+ assert expected_message in str(exc_info.value)
+ else:
+ finetune = Finetune(**base_params)
+ assert finetune.thinking_instructions == thinking_instructions
+ assert finetune.data_strategy == data_strategy