diff --git a/app/desktop/studio_server/finetune_api.py b/app/desktop/studio_server/finetune_api.py
index 1de6a737..c8a0f222 100644
--- a/app/desktop/studio_server/finetune_api.py
+++ b/app/desktop/studio_server/finetune_api.py
@@ -9,7 +9,10 @@
ModelProviderName,
built_in_models,
)
-from kiln_ai.adapters.prompt_builders import prompt_builder_from_ui_name
+from kiln_ai.adapters.prompt_builders import (
+ chain_of_thought_prompt,
+ prompt_builder_from_ui_name,
+)
from kiln_ai.adapters.provider_tools import (
provider_enabled,
provider_name_from_id,
@@ -100,6 +103,7 @@ class CreateFinetuneRequest(BaseModel):
base_model_id: str
system_message_generator: str | None = None
custom_system_message: str | None = None
+ custom_thinking_instructions: str | None = None
data_strategy: FinetuneDataStrategy
@@ -245,6 +249,9 @@ async def create_finetune(
system_message = system_message_from_request(
task, request.custom_system_message, request.system_message_generator
)
+ thinking_instructions = thinking_instructions_from_request(
+ task, request.data_strategy, request.custom_thinking_instructions
+ )
_, finetune_model = await finetune_adapter_class.create_and_start(
dataset=dataset,
@@ -252,6 +259,7 @@ async def create_finetune(
provider_base_model_id=request.base_model_id,
train_split_name=request.train_split_name,
system_message=system_message,
+ thinking_instructions=thinking_instructions,
parameters=request.parameters,
name=request.name,
description=request.description,
@@ -271,6 +279,7 @@ async def download_dataset_jsonl(
data_strategy: str,
system_message_generator: str | None = None,
custom_system_message: str | None = None,
+ custom_thinking_instructions: str | None = None,
) -> StreamingResponse:
if format_type not in [format.value for format in DatasetFormat]:
raise HTTPException(
@@ -301,8 +310,15 @@ async def download_dataset_jsonl(
system_message = system_message_from_request(
task, custom_system_message, system_message_generator
)
+ thinking_instructions = thinking_instructions_from_request(
+ task, data_strategy_typed, custom_thinking_instructions
+ )
- dataset_formatter = DatasetFormatter(dataset, system_message)
+ dataset_formatter = DatasetFormatter(
+ dataset=dataset,
+ system_message=system_message,
+ thinking_instructions=thinking_instructions,
+ )
path = dataset_formatter.dump_to_file(
split_name,
format_type_typed,
@@ -349,3 +365,20 @@ def system_message_from_request(
)
return system_message
+
+
+def thinking_instructions_from_request(
+ task: Task,
+ data_strategy: FinetuneDataStrategy,
+ custom_thinking_instructions: str | None,
+) -> str | None:
+ if data_strategy != FinetuneDataStrategy.final_and_intermediate:
+ # Not using COT/Thinking style
+ return None
+
+ if custom_thinking_instructions:
+ # prefer custom instructions
+ return custom_thinking_instructions
+
+ # default for this task
+ return chain_of_thought_prompt(task)
diff --git a/app/desktop/studio_server/test_finetune_api.py b/app/desktop/studio_server/test_finetune_api.py
index 1f1e8290..7de5d28b 100644
--- a/app/desktop/studio_server/test_finetune_api.py
+++ b/app/desktop/studio_server/test_finetune_api.py
@@ -1,6 +1,6 @@
import unittest.mock
from pathlib import Path
-from unittest.mock import AsyncMock, Mock
+from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import FastAPI
@@ -28,6 +28,7 @@
DatasetFilterType,
DatasetSplitType,
connect_fine_tune_api,
+ thinking_instructions_from_request,
)
@@ -424,6 +425,7 @@ def mock_finetune_adapter():
base_model_id="base_model_1",
dataset_split_id="split1",
system_message="Test system message",
+ thinking_instructions=None,
),
)
)
@@ -431,8 +433,16 @@ def mock_finetune_adapter():
@pytest.mark.parametrize(
- "data_strategy",
- [FinetuneDataStrategy.final_only, FinetuneDataStrategy.final_and_intermediate],
+ "data_strategy,custom_thinking_instructions,expected_thinking_instructions",
+ [
+ (FinetuneDataStrategy.final_only, None, None),
+ (
+ FinetuneDataStrategy.final_and_intermediate,
+ None,
+ "Think step by step, explaining your reasoning.",
+ ), # Our default
+ (FinetuneDataStrategy.final_and_intermediate, "CTI", "CTI"),
+ ],
)
async def test_create_finetune(
client,
@@ -441,6 +451,8 @@ async def test_create_finetune(
mock_finetune_registry,
mock_finetune_adapter,
data_strategy,
+ custom_thinking_instructions,
+ expected_thinking_instructions,
):
mock_finetune_registry["test_provider"] = mock_finetune_adapter
@@ -454,6 +466,7 @@ async def test_create_finetune(
"provider": "test_provider",
"base_model_id": "base_model_1",
"custom_system_message": "Test system message",
+ "custom_thinking_instructions": custom_thinking_instructions,
"data_strategy": data_strategy.value,
}
@@ -477,6 +490,7 @@ async def test_create_finetune(
provider_base_model_id="base_model_1",
train_split_name="train",
system_message="Test system message",
+ thinking_instructions=expected_thinking_instructions,
parameters={"learning_rate": 0.001, "epochs": 10},
name="New Finetune",
description="Test description",
@@ -867,6 +881,7 @@ def test_download_dataset_jsonl_with_prompt_builder(
"split_name": "train",
"format_type": "openai_chat_jsonl",
"system_message_generator": "test_prompt_builder",
+ "custom_thinking_instructions": "custom thinking instructions",
"data_strategy": "final_only",
},
)
@@ -879,7 +894,11 @@ def test_download_dataset_jsonl_with_prompt_builder(
split1 = next(split for split in test_task.dataset_splits() if split.id == "split1")
# Verify formatter was created with generated system message
- mock_formatter_class.assert_called_once_with(split1, "Generated system message")
+ mock_formatter_class.assert_called_once_with(
+ dataset=split1,
+ system_message="Generated system message",
+ thinking_instructions=None,
+ )
async def test_get_finetune(client, mock_task_from_id_disk_backed):
@@ -960,3 +979,42 @@ def __class_getitem__(cls, key):
# Verify that status was only checked for the pending finetune
mock_adapter_class.assert_called_once_with(tune1)
mock_adapter.status.assert_called_once()
+
+
+def test_thinking_instructions_non_cot_strategy():
+ """Test that non-COT strategies return None regardless of other parameters"""
+ task = Mock(spec=Task)
+ result = thinking_instructions_from_request(
+ task=task,
+ data_strategy=FinetuneDataStrategy.final_only,
+ custom_thinking_instructions="custom instructions",
+ )
+ assert result is None
+
+
+def test_thinking_instructions_custom():
+ """Test that custom instructions are returned when provided"""
+ task = Mock(spec=Task)
+ custom_instructions = "My custom thinking instructions"
+ result = thinking_instructions_from_request(
+ task=task,
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
+ custom_thinking_instructions=custom_instructions,
+ )
+ assert result == custom_instructions
+
+
+@patch("app.desktop.studio_server.finetune_api.chain_of_thought_prompt")
+def test_thinking_instructions_default(mock_cot):
+ """Test that default chain of thought prompt is used when no custom instructions"""
+ task = Mock(spec=Task)
+ mock_cot.return_value = "Default COT instructions"
+
+ result = thinking_instructions_from_request(
+ task=task,
+ data_strategy=FinetuneDataStrategy.final_and_intermediate,
+ custom_thinking_instructions=None,
+ )
+
+ mock_cot.assert_called_once_with(task)
+ assert result == "Default COT instructions"
diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts
index 15bc16ab..10e4839b 100644
--- a/app/web_ui/src/lib/api_schema.d.ts
+++ b/app/web_ui/src/lib/api_schema.d.ts
@@ -701,6 +701,8 @@ export interface components {
system_message_generator?: string | null;
/** Custom System Message */
custom_system_message?: string | null;
+ /** Custom Thinking Instructions */
+ custom_thinking_instructions?: string | null;
data_strategy: components["schemas"]["FinetuneDataStrategy"];
};
/** DataGenCategoriesApiInput */
@@ -1037,6 +1039,11 @@ export interface components {
* @description The system message to use for this fine-tune.
*/
system_message: string;
+ /**
+ * Thinking Instructions
+ * @description The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate.
+ */
+ thinking_instructions?: string | null;
/**
* @description The latest known status of this fine-tune. Not updated in real time.
* @default unknown
@@ -3181,6 +3188,7 @@ export interface operations {
data_strategy: string;
system_message_generator?: string | null;
custom_system_message?: string | null;
+ custom_thinking_instructions?: string | null;
};
header?: never;
path?: never;
diff --git a/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte b/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte
index 0afdcc84..599e05ef 100644
--- a/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte
+++ b/app/web_ui/src/routes/(app)/fine_tune/[project_id]/[task_id]/create_finetune/+page.svelte
@@ -28,6 +28,8 @@
let automatic_validation = disabled_header
let data_strategy: FinetuneDataStrategy = "final_only"
let finetune_custom_system_prompt = ""
+ let finetune_custom_thinking_instructions =
+ "Think step by step, explaining your reasoning."
let system_prompt_method = "basic"
$: project_id = $page.params.project_id
@@ -336,6 +338,12 @@
? finetune_custom_system_prompt
: undefined
}
+ function get_custom_thinking_instructions_param(): string | undefined {
+ return system_prompt_method === "custom" &&
+ data_strategy === "final_and_intermediate"
+ ? finetune_custom_thinking_instructions
+ : undefined
+ }
let create_finetune_error: KilnError | null = null
let create_finetune_loading = false
@@ -369,6 +377,8 @@
: undefined,
system_message_generator: get_system_prompt_method_param(),
custom_system_message: get_custom_system_prompt_param(),
+ custom_thinking_instructions:
+ get_custom_thinking_instructions_param(),
parameters: hyperparameter_values,
data_strategy: data_strategy,
validation_split_name:
@@ -451,6 +461,7 @@
format_type: download_model_select_options[model_provider],
system_message_generator: get_system_prompt_method_param(),
custom_system_message: get_custom_system_prompt_param(),
+ custom_thinking_instructions: get_custom_thinking_instructions_param(),
}
// Format params as query string, including escaping values and filtering undefined
@@ -613,14 +624,27 @@
custom_prompt_name="Custom Fine Tuning Prompt"
/>
{#if system_prompt_method === "custom"}
-