Skip to content

Commit

Permalink
Merge pull request #151 from Kiln-AI/thinking_sft_ui
Browse files Browse the repository at this point in the history
Add tuning COT/thinking to the UI
  • Loading branch information
scosman authored Feb 3, 2025
2 parents acbce3e + f1fddbc commit 27fbccd
Show file tree
Hide file tree
Showing 14 changed files with 326 additions and 55 deletions.
18 changes: 17 additions & 1 deletion app/desktop/studio_server/finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AllSplitDefinition,
DatasetSplit,
Finetune,
FinetuneDataStrategy,
FineTuneStatusType,
HighRatingDatasetFilter,
Task,
Expand Down Expand Up @@ -99,6 +100,7 @@ class CreateFinetuneRequest(BaseModel):
base_model_id: str
system_message_generator: str | None = None
custom_system_message: str | None = None
data_strategy: FinetuneDataStrategy


class FinetuneWithStatus(BaseModel):
Expand Down Expand Up @@ -254,6 +256,7 @@ async def create_finetune(
name=request.name,
description=request.description,
validation_split_name=request.validation_split_name,
data_strategy=request.data_strategy,
)

return finetune_model
Expand All @@ -265,6 +268,7 @@ async def download_dataset_jsonl(
dataset_id: str,
split_name: str,
format_type: str,
data_strategy: str,
system_message_generator: str | None = None,
custom_system_message: str | None = None,
) -> StreamingResponse:
Expand All @@ -273,6 +277,14 @@ async def download_dataset_jsonl(
status_code=400,
detail=f"Dataset format '{format_type}' not found",
)
format_type_typed = DatasetFormat(format_type)
if data_strategy not in [strategy.value for strategy in FinetuneDataStrategy]:
raise HTTPException(
status_code=400,
detail=f"Data strategy '{data_strategy}' not found",
)
data_strategy_typed = FinetuneDataStrategy(data_strategy)

task = task_from_id(project_id, task_id)
dataset = DatasetSplit.from_id_and_parent_path(dataset_id, task.path)
if dataset is None:
Expand All @@ -291,7 +303,11 @@ async def download_dataset_jsonl(
)

dataset_formatter = DatasetFormatter(dataset, system_message)
path = dataset_formatter.dump_to_file(split_name, format_type) # type: ignore
path = dataset_formatter.dump_to_file(
split_name,
format_type_typed,
data_strategy_typed,
)

# set headers to force download in a browser
headers = {
Expand Down
123 changes: 95 additions & 28 deletions app/desktop/studio_server/test_finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
from kiln_ai.adapters.fine_tune.base_finetune import FineTuneParameter
from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat
from kiln_ai.adapters.ml_model_list import KilnModel, KilnModelProvider
from kiln_ai.datamodel import (
AllDatasetFilter,
AllSplitDefinition,
DatasetSplit,
Finetune,
FinetuneDataStrategy,
HighRatingDatasetFilter,
Project,
Task,
Expand Down Expand Up @@ -428,12 +430,17 @@ def mock_finetune_adapter():
return adapter


@pytest.mark.parametrize(
"data_strategy",
[FinetuneDataStrategy.final_only, FinetuneDataStrategy.final_and_intermediate],
)
async def test_create_finetune(
client,
mock_task_from_id_disk_backed,
test_task,
mock_finetune_registry,
mock_finetune_adapter,
data_strategy,
):
mock_finetune_registry["test_provider"] = mock_finetune_adapter

Expand All @@ -447,6 +454,7 @@ async def test_create_finetune(
"provider": "test_provider",
"base_model_id": "base_model_1",
"custom_system_message": "Test system message",
"data_strategy": data_strategy.value,
}

response = client.post(
Expand All @@ -473,6 +481,7 @@ async def test_create_finetune(
name="New Finetune",
description="Test description",
validation_split_name="validation",
data_strategy=data_strategy,
)


Expand All @@ -484,6 +493,7 @@ def test_create_finetune_invalid_provider(client, mock_task_from_id_disk_backed)
"provider": "invalid_provider",
"base_model_id": "base_model_1",
"custom_system_message": "Test system message",
"data_strategy": "final_only",
}

response = client.post(
Expand Down Expand Up @@ -511,6 +521,7 @@ def test_create_finetune_invalid_dataset(
"provider": "test_provider",
"base_model_id": "base_model_1",
"custom_system_message": "Test system message",
"data_strategy": "final_only",
}

response = client.post(
Expand All @@ -536,6 +547,7 @@ def test_create_finetune_request_validation():
provider="test_provider",
base_model_id="base_model_1",
custom_system_message="Test system message",
data_strategy=FinetuneDataStrategy.final_only,
)
assert request.name == "Test Finetune"
assert request.description == "Test description"
Expand All @@ -550,6 +562,7 @@ def test_create_finetune_request_validation():
provider="test_provider",
base_model_id="base_model_1",
custom_system_message="Test system message",
data_strategy=FinetuneDataStrategy.final_only,
)
assert request.name is None
assert request.description is None
Expand All @@ -576,6 +589,7 @@ def test_create_finetune_no_system_message(
"parameters": {},
"provider": "test_provider",
"base_model_id": "base_model_1",
"data_strategy": "final_only",
}

response = client.post(
Expand All @@ -589,6 +603,30 @@ def test_create_finetune_no_system_message(
)


def test_create_finetune_no_data_strategy(
client,
mock_task_from_id_disk_backed,
mock_finetune_registry,
mock_finetune_adapter,
):
mock_finetune_registry["test_provider"] = mock_finetune_adapter

request_data = {
"dataset_id": "split1",
"train_split_name": "train",
"parameters": {},
"provider": "test_provider",
"base_model_id": "base_model_1",
"custom_system_message": "Test system message",
}

response = client.post(
"/api/projects/project1/tasks/task1/finetunes", json=request_data
)

assert response.status_code == 422


@pytest.fixture
def mock_prompt_builder():
builder = Mock()
Expand Down Expand Up @@ -618,6 +656,7 @@ async def test_create_finetune_with_prompt_builder(
"provider": "test_provider",
"base_model_id": "base_model_1",
"system_message_generator": "test_prompt_builder",
"data_strategy": "final_only",
}

response = client.post(
Expand Down Expand Up @@ -658,6 +697,7 @@ def test_create_finetune_prompt_builder_error(
"provider": "test_provider",
"base_model_id": "base_model_1",
"system_message_generator": "test_prompt_builder",
"data_strategy": "final_only",
}

response = client.post(
Expand All @@ -683,11 +723,16 @@ def mock_dataset_formatter():
yield mock_class, formatter


@pytest.mark.parametrize(
"data_strategy",
[FinetuneDataStrategy.final_only, FinetuneDataStrategy.final_and_intermediate],
)
def test_download_dataset_jsonl(
client,
mock_task_from_id_disk_backed,
mock_dataset_formatter,
tmp_path,
data_strategy,
):
mock_formatter_class, mock_formatter = mock_dataset_formatter

Expand All @@ -705,6 +750,7 @@ def test_download_dataset_jsonl(
"split_name": "train",
"format_type": "openai_chat_jsonl",
"custom_system_message": "Test system message",
"data_strategy": data_strategy.value,
},
)

Expand All @@ -718,37 +764,61 @@ def test_download_dataset_jsonl(

# Verify the formatter was created and used correctly
mock_formatter_class.assert_called_once()
mock_formatter.dump_to_file.assert_called_once_with("train", "openai_chat_jsonl")
mock_formatter.dump_to_file.assert_called_once_with(
"train",
DatasetFormat.OPENAI_CHAT_JSONL,
data_strategy,
)


def test_download_dataset_jsonl_invalid_format(client, mock_task_from_id_disk_backed):
@pytest.fixture
def valid_download_params():
return {
"project_id": "project1",
"task_id": "task1",
"dataset_id": "split1",
"split_name": "train",
"format_type": "openai_chat_jsonl",
"custom_system_message": "Test system message",
"data_strategy": "final_only",
}


def test_download_dataset_jsonl_invalid_format(
client, mock_task_from_id_disk_backed, valid_download_params
):
valid_download_params["format_type"] = "invalid_format"
response = client.get(
"/api/download_dataset_jsonl",
params={
"project_id": "project1",
"task_id": "task1",
"dataset_id": "split1",
"split_name": "train",
"format_type": "invalid_format",
"custom_system_message": "Test system message",
},
params=valid_download_params,
)

assert response.status_code == 400
assert response.json()["detail"] == "Dataset format 'invalid_format' not found"


def test_download_dataset_jsonl_invalid_dataset(client, mock_task_from_id_disk_backed):
def test_download_dataset_jsonl_data_strategy_invalid(
client, mock_task_from_id_disk_backed, valid_download_params
):
valid_download_params["data_strategy"] = "invalid_data_strategy"
response = client.get(
"/api/download_dataset_jsonl",
params={
"project_id": "project1",
"task_id": "task1",
"dataset_id": "invalid_split",
"split_name": "train",
"format_type": "openai_chat_jsonl",
"custom_system_message": "Test system message",
},
params=valid_download_params,
)

assert response.status_code == 400
assert (
response.json()["detail"] == "Data strategy 'invalid_data_strategy' not found"
)


def test_download_dataset_jsonl_invalid_dataset(
client, mock_task_from_id_disk_backed, valid_download_params
):
valid_download_params["dataset_id"] = "invalid_split"
response = client.get(
"/api/download_dataset_jsonl",
params=valid_download_params,
)

assert response.status_code == 404
Expand All @@ -757,17 +827,13 @@ def test_download_dataset_jsonl_invalid_dataset(client, mock_task_from_id_disk_b
)


def test_download_dataset_jsonl_invalid_split(client, mock_task_from_id_disk_backed):
def test_download_dataset_jsonl_invalid_split(
client, mock_task_from_id_disk_backed, valid_download_params
):
valid_download_params["split_name"] = "invalid_split"
response = client.get(
"/api/download_dataset_jsonl",
params={
"project_id": "project1",
"task_id": "task1",
"dataset_id": "split1",
"split_name": "invalid_split",
"format_type": "openai_chat_jsonl",
"custom_system_message": "Test system message",
},
params=valid_download_params,
)

assert response.status_code == 404
Expand Down Expand Up @@ -801,6 +867,7 @@ def test_download_dataset_jsonl_with_prompt_builder(
"split_name": "train",
"format_type": "openai_chat_jsonl",
"system_message_generator": "test_prompt_builder",
"data_strategy": "final_only",
},
)

Expand Down
Loading

0 comments on commit 27fbccd

Please sign in to comment.