Skip to content

Commit

Permalink
Merge pull request #155 from Kiln-AI/thinking_dataset_filter
Browse files Browse the repository at this point in the history
Add the ability to filter datasets to "has intermediate data"
  • Loading branch information
scosman authored Feb 4, 2025
2 parents f06c498 + 47ec76c commit 7eea98e
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 55 deletions.
24 changes: 7 additions & 17 deletions app/desktop/studio_server/finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
provider_name_from_id,
)
from kiln_ai.datamodel import (
AllDatasetFilter,
AllSplitDefinition,
DatasetFilterType,
DatasetSplit,
Finetune,
FinetuneDataStrategy,
FineTuneStatusType,
HighRatingDatasetFilter,
Task,
Train60Test20Val20SplitDefinition,
Train80Test10Val10SplitDefinition,
Train80Test20SplitDefinition,
dataset_filters,
)
from kiln_ai.utils.name_generator import generate_memorable_name
from kiln_server.task_api import task_from_id
Expand Down Expand Up @@ -68,19 +68,6 @@ class DatasetSplitType(Enum):
}


class DatasetFilterType(Enum):
"""Dataset filter types used in the API. Any filter style can be created in code."""

ALL = "all"
HIGH_RATING = "high_rating"


api_filter_types = {
DatasetFilterType.ALL: AllDatasetFilter,
DatasetFilterType.HIGH_RATING: HighRatingDatasetFilter,
}


class CreateDatasetSplitRequest(BaseModel):
"""Request to create a dataset split"""

Expand Down Expand Up @@ -209,14 +196,17 @@ async def create_dataset_split(
) -> DatasetSplit:
task = task_from_id(project_id, task_id)
split_definitions = api_split_types[request.dataset_split_type]
filter = api_filter_types[request.filter_type]

name = request.name
if not name:
name = generate_memorable_name()

dataset_split = DatasetSplit.from_task(
name, task, split_definitions, filter, request.description
name,
task,
split_definitions,
filter_type=request.filter_type,
description=request.description,
)
dataset_split.save_to_file()
return dataset_split
Expand Down
21 changes: 16 additions & 5 deletions app/desktop/studio_server/test_finetune_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
HighRatingDatasetFilter,
Project,
Task,
ThinkingModelDatasetFilter,
ThinkingModelHighRatedFilter,
Train60Test20Val20SplitDefinition,
Train80Test10Val10SplitDefinition,
Train80Test20SplitDefinition,
Expand Down Expand Up @@ -300,12 +302,19 @@ def test_api_split_types_mapping():


def test_api_filter_types_mapping():
from app.desktop.studio_server.finetune_api import api_filter_types
from kiln_ai.datamodel import dataset_filters

assert api_filter_types[DatasetFilterType.ALL] == AllDatasetFilter
assert api_filter_types[DatasetFilterType.HIGH_RATING] == HighRatingDatasetFilter
assert dataset_filters[DatasetFilterType.ALL] == AllDatasetFilter
assert dataset_filters[DatasetFilterType.HIGH_RATING] == HighRatingDatasetFilter
assert (
dataset_filters[DatasetFilterType.THINKING_MODEL] == ThinkingModelDatasetFilter
)
assert (
dataset_filters[DatasetFilterType.THINKING_MODEL_HIGH_RATED]
== ThinkingModelHighRatedFilter
)
for filter_type in DatasetFilterType:
assert filter_type in api_filter_types
assert filter_type in dataset_filters


@pytest.fixture
Expand All @@ -331,7 +340,7 @@ def test_create_dataset_split(
with mock_from_task as from_task_mock, mock_save as save_mock:
request_data = {
"dataset_split_type": "train_test",
"filter_type": "all",
"filter_type": "high_rating",
"name": "Test Split",
"description": "Test description",
}
Expand All @@ -348,6 +357,8 @@ def test_create_dataset_split(
# Verify the mocks were called correctly
mock_task_from_id_disk_backed.assert_called_once_with("project1", "task1")
from_task_mock.assert_called_once()
args, kwargs = from_task_mock.call_args
assert kwargs["filter_type"] == DatasetFilterType.HIGH_RATING
save_mock.assert_called_once()


Expand Down
6 changes: 4 additions & 2 deletions app/web_ui/src/lib/api_schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -837,10 +837,10 @@ export interface components {
DataSourceType: "human" | "synthetic";
/**
* DatasetFilterType
* @description Dataset filter types used in the API. Any filter style can be created in code.
* @description Dataset filter names.
* @enum {string}
*/
DatasetFilterType: "all" | "high_rating";
DatasetFilterType: "all" | "high_rating" | "thinking_model" | "thinking_model_high_rated";
/**
* DatasetSplit
* @description A collection of task runs, with optional splits (train, test, validation).
Expand Down Expand Up @@ -888,6 +888,8 @@ export interface components {
split_contents: {
[key: string]: string[];
};
/** @description The filter used to build the dataset. */
filter?: components["schemas"]["DatasetFilterType"] | null;
/** Model Type */
readonly model_type: string;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import { onMount } from "svelte"
import { formatDate } from "$lib/utils/formatters"
import type { FinetuneDataStrategy } from "$lib/types"
import Warning from "$lib/ui/warning.svelte"
import PromptTypeSelector from "../../../../run/prompt_type_selector.svelte"
Expand Down Expand Up @@ -41,6 +42,9 @@
let available_models_loading = true
$: selected_dataset = datasets?.find((d) => d.id === dataset_id)
$: selecting_thinking_dataset =
selected_dataset?.filter === "thinking_model" ||
selected_dataset?.filter === "thinking_model_high_rated"
$: selected_dataset_has_val = selected_dataset?.splits?.find(
(s) => s.name === "val",
)
Expand Down Expand Up @@ -646,21 +650,28 @@
{/if}
</div>
{/if}
<FormElement
label="Training Strategy"
description="Should the model be trained on the final response only, or also include intermediate thinking?"
info_description="If you select 'Final Response and Intermediate Thinking', the model will also be trained on the intermediate thinking such as reasoning or chain of thought. Use this if you want to call the tuned model with a chain-of-thought prompt for additional inference time compute."
inputType="select"
id="data_strategy"
select_options={[
["final_only", "Final Response Only"],
[
"final_and_intermediate",
"Final Response and Intermediate Thinking",
],
]}
bind:value={data_strategy}
/>
<div>
<FormElement
label="Training Strategy"
description="Should the model be trained on the final response only, or also include intermediate thinking?"
info_description="If you select 'Final Response and Intermediate Thinking', the model will also be trained on the intermediate thinking such as reasoning or chain of thought. Use this if you want to call the tuned model with a chain-of-thought prompt for additional inference time compute."
inputType="select"
id="data_strategy"
select_options={[
["final_only", "Final Response Only"],
[
"final_and_intermediate",
"Final Response and Intermediate Thinking",
],
]}
bind:value={data_strategy}
/>
{#if data_strategy === "final_and_intermediate" && !selecting_thinking_dataset}
<Warning
warning_message="You are training a model for inference-time thinking, but are not using a dataset filtered to samples with reasoning or chain-of-thought training data. This is not recommended, as it may lead to poor performance. We suggest creating a new dataset with a thinking filter."
/>
{/if}
</div>
{#if !is_download}
<div class="collapse collapse-arrow bg-base-200">
<input type="checkbox" class="peer" />
Expand Down Expand Up @@ -774,14 +785,22 @@
<FormElement
label="Dataset Filter"
description="Select a filter for your dataset. Typically you want to filter out examples that are not rated 4+ stars."
info_description="A 'High Rating' filter will include only examples that are rated 4+ stars. The 'All' filter will include all examples."
info_description="A 'High Rating' filter will include only examples that are rated 4+ stars. The 'All' filter will include all examples. Thinking filters will also check the sample has reasoning or chain-of-thought data for training thinking models."
inputType="select"
optional={false}
id="dataset_filter"
select_options={[
[disabled_header, "Select a dataset filter"],
["high_rating", "High Rating (4+ stars)"],
["all", "All (no filter)"],
[
"thinking_model",
"Thinking (items with reasoning/chain-of-thought)",
],
[
"thinking_model_high_rated",
"Thinking + High Rated (4+ stars and thinking)",
],
]}
bind:value={new_dataset_filter}
/>
Expand Down
17 changes: 5 additions & 12 deletions libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from uuid import uuid4

from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
from kiln_ai.adapters.prompt_builders import chain_of_thought_prompt
from kiln_ai.datamodel import DatasetSplit, FinetuneDataStrategy, TaskRun


Expand Down Expand Up @@ -83,23 +82,17 @@ def build_training_data(
thinking_final_answer_prompt = None
parent_task = task_run.parent_task()

if (
include_cot
and task_run.intermediate_outputs is not None
and (
"reasoning" in task_run.intermediate_outputs
or "chain_of_thought" in task_run.intermediate_outputs
)
):
if include_cot and task_run.has_thinking_training_data():
if not parent_task:
raise ValueError(
"TaskRuns for training required a parent Task for building a chain of thought prompts. Train without COT, or save this TaskRun to a parent Task."
)

# Prefer reasoning to cot if both are present
thinking = task_run.intermediate_outputs.get(
"reasoning"
) or task_run.intermediate_outputs.get("chain_of_thought")
intermediate_outputs = task_run.intermediate_outputs or {}
thinking = intermediate_outputs.get("reasoning") or intermediate_outputs.get(
"chain_of_thought"
)

thinking_final_answer_prompt = COT_FINAL_ANSWER_PROMPT

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def test_build_training_data_with_repaired_output(mock_task):
),
)

training_data_output = build_training_data(mock_task_run, "system message", True)
training_data_output = build_training_data(mock_task_run, "system message", False)
assert training_data_output.final_output == '{"test": "repaired output"}'
assert training_data_output.thinking is None
assert training_data_output.thinking_instructions is None
Expand Down
50 changes: 49 additions & 1 deletion libs/core/kiln_ai/datamodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,17 @@ class TaskRun(KilnParentedModel):
description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
)

def has_thinking_training_data(self) -> bool:
"""
Does this run have thinking data that we can use to train a thinking model?
"""
if self.intermediate_outputs is None:
return False
return (
"chain_of_thought" in self.intermediate_outputs
or "reasoning" in self.intermediate_outputs
)

def parent_task(self) -> Task | None:
if not isinstance(self.parent, Task):
return None
Expand Down Expand Up @@ -663,6 +674,37 @@ def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
return task_run.output.rating.is_high_quality()


def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool:
"""
A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)
"""
return task_run.has_thinking_training_data()


def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool:
"""
A filter that returns True if the task has thinking data and the output is high quality
"""
return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run)


class DatasetFilterType(str, Enum):
"""Dataset filter names."""

ALL = "all"
HIGH_RATING = "high_rating"
THINKING_MODEL = "thinking_model"
THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated"


dataset_filters = {
DatasetFilterType.ALL: AllDatasetFilter,
DatasetFilterType.HIGH_RATING: HighRatingDatasetFilter,
DatasetFilterType.THINKING_MODEL: ThinkingModelDatasetFilter,
DatasetFilterType.THINKING_MODEL_HIGH_RATED: ThinkingModelHighRatedFilter,
}


class DatasetSplitDefinition(BaseModel):
"""
A definition of a split in a dataset.
Expand Down Expand Up @@ -722,6 +764,10 @@ class DatasetSplit(KilnParentedModel):
split_contents: dict[str, list[str]] = Field(
description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
)
filter: DatasetFilterType | None = Field(
default=None,
description="The filter used to build the dataset.",
)

@model_validator(mode="after")
def validate_split_percentages(self) -> "DatasetSplit":
Expand All @@ -736,19 +782,21 @@ def from_task(
name: str,
task: "Task",
splits: list[DatasetSplitDefinition],
filter: DatasetFilter = AllDatasetFilter,
filter_type: DatasetFilterType = DatasetFilterType.ALL,
description: str | None = None,
):
"""
Build a dataset split from a task.
"""
filter = dataset_filters[filter_type]
split_contents = cls.build_split_contents(task, splits, filter)
return cls(
parent=task,
name=name,
description=description,
splits=splits,
split_contents=split_contents,
filter=filter_type,
)

@classmethod
Expand Down
Loading

0 comments on commit 7eea98e

Please sign in to comment.