Skip to content

Commit

Permalink
Support a custom COT instruction for fine tuning, for building infere…
Browse files Browse the repository at this point in the history
…nce time-compute models.
  • Loading branch information
scosman committed Feb 3, 2025
1 parent f1fddbc commit dd26e20
Show file tree
Hide file tree
Showing 14 changed files with 374 additions and 47 deletions.
37 changes: 35 additions & 2 deletions app/desktop/studio_server/finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
ModelProviderName,
built_in_models,
)
from kiln_ai.adapters.prompt_builders import prompt_builder_from_ui_name
from kiln_ai.adapters.prompt_builders import (
chain_of_thought_prompt,
prompt_builder_from_ui_name,
)
from kiln_ai.adapters.provider_tools import (
provider_enabled,
provider_name_from_id,
Expand Down Expand Up @@ -100,6 +103,7 @@ class CreateFinetuneRequest(BaseModel):
base_model_id: str
system_message_generator: str | None = None
custom_system_message: str | None = None
custom_thinking_instructions: str | None = None
data_strategy: FinetuneDataStrategy


Expand Down Expand Up @@ -245,13 +249,17 @@ async def create_finetune(
system_message = system_message_from_request(
task, request.custom_system_message, request.system_message_generator
)
thinking_instructions = thinking_instructions_from_request(
task, request.data_strategy, request.custom_thinking_instructions
)

_, finetune_model = await finetune_adapter_class.create_and_start(
dataset=dataset,
provider_id=request.provider,
provider_base_model_id=request.base_model_id,
train_split_name=request.train_split_name,
system_message=system_message,
thinking_instructions=thinking_instructions,
parameters=request.parameters,
name=request.name,
description=request.description,
Expand All @@ -271,6 +279,7 @@ async def download_dataset_jsonl(
data_strategy: str,
system_message_generator: str | None = None,
custom_system_message: str | None = None,
custom_thinking_instructions: str | None = None,
) -> StreamingResponse:
if format_type not in [format.value for format in DatasetFormat]:
raise HTTPException(
Expand Down Expand Up @@ -301,8 +310,15 @@ async def download_dataset_jsonl(
system_message = system_message_from_request(
task, custom_system_message, system_message_generator
)
thinking_instructions = thinking_instructions_from_request(
task, data_strategy_typed, custom_thinking_instructions
)

dataset_formatter = DatasetFormatter(dataset, system_message)
dataset_formatter = DatasetFormatter(
dataset=dataset,
system_message=system_message,
thinking_instructions=thinking_instructions,
)
path = dataset_formatter.dump_to_file(
split_name,
format_type_typed,
Expand Down Expand Up @@ -349,3 +365,20 @@ def system_message_from_request(
)

return system_message


def thinking_instructions_from_request(
task: Task,
data_strategy: FinetuneDataStrategy,
custom_thinking_instructions: str | None,
) -> str | None:
if data_strategy != FinetuneDataStrategy.final_and_intermediate:
# Not using COT/Thinking style
return None

if custom_thinking_instructions:
# prefer custom instructions
return custom_thinking_instructions

# default for this task
return chain_of_thought_prompt(task)
66 changes: 62 additions & 4 deletions app/desktop/studio_server/test_finetune_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest.mock
from pathlib import Path
from unittest.mock import AsyncMock, Mock
from unittest.mock import AsyncMock, Mock, patch

import pytest
from fastapi import FastAPI
Expand Down Expand Up @@ -28,6 +28,7 @@
DatasetFilterType,
DatasetSplitType,
connect_fine_tune_api,
thinking_instructions_from_request,
)


Expand Down Expand Up @@ -424,15 +425,24 @@ def mock_finetune_adapter():
base_model_id="base_model_1",
dataset_split_id="split1",
system_message="Test system message",
thinking_instructions=None,
),
)
)
return adapter


@pytest.mark.parametrize(
"data_strategy",
[FinetuneDataStrategy.final_only, FinetuneDataStrategy.final_and_intermediate],
"data_strategy,custom_thinking_instructions,expected_thinking_instructions",
[
(FinetuneDataStrategy.final_only, None, None),
(
FinetuneDataStrategy.final_and_intermediate,
None,
"Think step by step, explaining your reasoning.",
), # Our default
(FinetuneDataStrategy.final_and_intermediate, "CTI", "CTI"),
],
)
async def test_create_finetune(
client,
Expand All @@ -441,6 +451,8 @@ async def test_create_finetune(
mock_finetune_registry,
mock_finetune_adapter,
data_strategy,
custom_thinking_instructions,
expected_thinking_instructions,
):
mock_finetune_registry["test_provider"] = mock_finetune_adapter

Expand All @@ -454,6 +466,7 @@ async def test_create_finetune(
"provider": "test_provider",
"base_model_id": "base_model_1",
"custom_system_message": "Test system message",
"custom_thinking_instructions": custom_thinking_instructions,
"data_strategy": data_strategy.value,
}

Expand All @@ -477,6 +490,7 @@ async def test_create_finetune(
provider_base_model_id="base_model_1",
train_split_name="train",
system_message="Test system message",
thinking_instructions=expected_thinking_instructions,
parameters={"learning_rate": 0.001, "epochs": 10},
name="New Finetune",
description="Test description",
Expand Down Expand Up @@ -867,6 +881,7 @@ def test_download_dataset_jsonl_with_prompt_builder(
"split_name": "train",
"format_type": "openai_chat_jsonl",
"system_message_generator": "test_prompt_builder",
"custom_thinking_instructions": "custom thinking instructions",
"data_strategy": "final_only",
},
)
Expand All @@ -879,7 +894,11 @@ def test_download_dataset_jsonl_with_prompt_builder(

split1 = next(split for split in test_task.dataset_splits() if split.id == "split1")
# Verify formatter was created with generated system message
mock_formatter_class.assert_called_once_with(split1, "Generated system message")
mock_formatter_class.assert_called_once_with(
dataset=split1,
system_message="Generated system message",
thinking_instructions=None,
)


async def test_get_finetune(client, mock_task_from_id_disk_backed):
Expand Down Expand Up @@ -960,3 +979,42 @@ def __class_getitem__(cls, key):
# Verify that status was only checked for the pending finetune
mock_adapter_class.assert_called_once_with(tune1)
mock_adapter.status.assert_called_once()


def test_thinking_instructions_non_cot_strategy():
"""Test that non-COT strategies return None regardless of other parameters"""
task = Mock(spec=Task)
result = thinking_instructions_from_request(
task=task,
data_strategy=FinetuneDataStrategy.final_only,
custom_thinking_instructions="custom instructions",
)
assert result is None


def test_thinking_instructions_custom():
"""Test that custom instructions are returned when provided"""
task = Mock(spec=Task)
custom_instructions = "My custom thinking instructions"
result = thinking_instructions_from_request(
task=task,
data_strategy=FinetuneDataStrategy.final_and_intermediate,
custom_thinking_instructions=custom_instructions,
)
assert result == custom_instructions


@patch("app.desktop.studio_server.finetune_api.chain_of_thought_prompt")
def test_thinking_instructions_default(mock_cot):
"""Test that default chain of thought prompt is used when no custom instructions"""
task = Mock(spec=Task)
mock_cot.return_value = "Default COT instructions"

result = thinking_instructions_from_request(
task=task,
data_strategy=FinetuneDataStrategy.final_and_intermediate,
custom_thinking_instructions=None,
)

mock_cot.assert_called_once_with(task)
assert result == "Default COT instructions"
8 changes: 8 additions & 0 deletions app/web_ui/src/lib/api_schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,8 @@ export interface components {
system_message_generator?: string | null;
/** Custom System Message */
custom_system_message?: string | null;
/** Custom Thinking Instructions */
custom_thinking_instructions?: string | null;
data_strategy: components["schemas"]["FinetuneDataStrategy"];
};
/** DataGenCategoriesApiInput */
Expand Down Expand Up @@ -1037,6 +1039,11 @@ export interface components {
* @description The system message to use for this fine-tune.
*/
system_message: string;
/**
* Thinking Instructions
* @description The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate.
*/
thinking_instructions?: string | null;
/**
* @description The latest known status of this fine-tune. Not updated in real time.
* @default unknown
Expand Down Expand Up @@ -3181,6 +3188,7 @@ export interface operations {
data_strategy: string;
system_message_generator?: string | null;
custom_system_message?: string | null;
custom_thinking_instructions?: string | null;
};
header?: never;
path?: never;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
let automatic_validation = disabled_header
let data_strategy: FinetuneDataStrategy = "final_only"
let finetune_custom_system_prompt = ""
let finetune_custom_thinking_instructions =
"Think step by step, explaining your reasoning."
let system_prompt_method = "basic"
$: project_id = $page.params.project_id
Expand Down Expand Up @@ -336,6 +338,12 @@
? finetune_custom_system_prompt
: undefined
}
function get_custom_thinking_instructions_param(): string | undefined {
return system_prompt_method === "custom" &&
data_strategy === "final_and_intermediate"
? finetune_custom_thinking_instructions
: undefined
}
let create_finetune_error: KilnError | null = null
let create_finetune_loading = false
Expand Down Expand Up @@ -369,6 +377,8 @@
: undefined,
system_message_generator: get_system_prompt_method_param(),
custom_system_message: get_custom_system_prompt_param(),
custom_thinking_instructions:
get_custom_thinking_instructions_param(),
parameters: hyperparameter_values,
data_strategy: data_strategy,
validation_split_name:
Expand Down Expand Up @@ -451,6 +461,7 @@
format_type: download_model_select_options[model_provider],
system_message_generator: get_system_prompt_method_param(),
custom_system_message: get_custom_system_prompt_param(),
custom_thinking_instructions: get_custom_thinking_instructions_param(),
}
// Format params as query string, including escaping values and filtering undefined
Expand Down Expand Up @@ -613,14 +624,27 @@
custom_prompt_name="Custom Fine Tuning Prompt"
/>
{#if system_prompt_method === "custom"}
<FormElement
label="Custom System Message"
description="Enter a custom system message to use during fine-tuning."
info_description="There are tradeoffs to consider when choosing a system prompt for fine-tuning. Read more: https://platform.openai.com/docs/guides/fine-tuning/#crafting-prompts"
inputType="textarea"
id="finetune_custom_system_prompt"
bind:value={finetune_custom_system_prompt}
/>
<div class="p-4 border-l-4 border-gray-300">
<FormElement
label="Custom System Prompt"
description="Enter a custom system prompt to use during fine-tuning."
info_description="There are tradeoffs to consider when choosing a system prompt for fine-tuning. Read more: https://platform.openai.com/docs/guides/fine-tuning/#crafting-prompts"
inputType="textarea"
id="finetune_custom_system_prompt"
bind:value={finetune_custom_system_prompt}
/>
{#if data_strategy === "final_and_intermediate"}
<div class="mt-4"></div>
<FormElement
label="Custom Thinking Instructions"
description="Instructions for the model's 'thinking' stage, before returning the final response."
info_description="When training with intermediate results (reasoning, chain of thought, etc.), this prompt will be used to ask the model to 'think' before returning the final response."
inputType="textarea"
id="finetune_custom_thinking_instructions"
bind:value={finetune_custom_thinking_instructions}
/>
{/if}
</div>
{/if}
<FormElement
label="Training Strategy"
Expand Down
2 changes: 2 additions & 0 deletions libs/core/kiln_ai/adapters/fine_tune/base_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def create_and_start(
provider_base_model_id: str,
train_split_name: str,
system_message: str,
thinking_instructions: str | None,
data_strategy: FinetuneDataStrategy,
parameters: dict[str, str | int | float | bool] = {},
name: str | None = None,
Expand Down Expand Up @@ -101,6 +102,7 @@ async def create_and_start(
validation_split_name=validation_split_name,
parameters=parameters,
system_message=system_message,
thinking_instructions=thinking_instructions,
parent=parent_task,
data_strategy=data_strategy,
)
Expand Down
Loading

0 comments on commit dd26e20

Please sign in to comment.