Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BoolQ for Training and Eval #30

Merged
merged 7 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class VoiceDatasetArgs:
"""If `prompt` is not set, the number of canned prompts to use."""
include_audio: bool = True
"""Whether to include audio in the samples."""
include_context: bool = False
include_context: bool = True
farzadab marked this conversation as resolved.
Show resolved Hide resolved
"""Whether to include additional textual context from the dataset to the prompt."""
shuffle: bool = False
"""Whether to shuffle the dataset."""
Expand Down Expand Up @@ -434,7 +434,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
audio_transcript = row["chat"][0]["message"]
return VoiceSample(
return self._make_sample(
farzadab marked this conversation as resolved.
Show resolved Hide resolved
self._get_transcribe_messages(idx, audio_transcript),
self._load_anyinstruct_audio(row["chat"][0]["speech"]),
audio_transcript=audio_transcript,
Expand All @@ -447,7 +447,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
audio_transcript = row["chat"][1]["message"]
return VoiceSample(
return self._make_sample(
self._get_transcribe_messages(idx, audio_transcript),
self._load_anyinstruct_audio(row["chat"][1]["speech"]),
audio_transcript=audio_transcript,
Expand All @@ -456,11 +456,10 @@ def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:

class BoolQDataset(VoiceDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
assert (
args.split == DatasetSplit.VALIDATION
), f"BoolQ is only for validation, but got split={args.split}"
super().__init__(args)
dataset = self._load_audio_dataset("fixie-ai/boolq-audio", split="train")
dataset = self._load_audio_dataset(
"fixie-ai/boolq-audio", split=args.split.value
)
self._init_dataset(dataset)

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
Expand All @@ -479,6 +478,96 @@ def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
return self._get_transcribe_sample(idx, row, tcol="question")


class BoolQWithExtendedAnswerDataset(BoolQDataset):
SEPARATORS = ["\n\n", "\n", "\n----\n"]
BOOLQ_PASSAGE_PROMPTS = [
"Provide a short explanation, then respond with True/False on the last line",
"Explain briefly, concluding with True/False on a new line."
"Write a quick explanation, and finish with True/False on the last line"
"Summarize in a few words, and end with True/False on a new line."
"Give a brief explanation first, then answer with True/False on the final line",
"Start with a concise explanation, and end with a True/False response on the last line.",
"Explain briefly and follow up with True/False at the end",
"Write a short explanation, then state True/False on a new line.",
"First, offer a brief explanation, and then reply with True/False at the end.",
"Present a concise explanation, ending with a True/False answer on the final line",
"Start with a brief explanation, and then answer with True/False at the end.",
]
QUERY_PROMPTS = ["Question: ", "Question:\n", "Q: ", "Q:\n", "Query: ", "Query:\n"]
CONTEXT_PROMPTS = [
"Passage: ",
"Passage:\n",
"Context: ",
"Context:\n",
"Background: ",
"Background:\n",
]
ANSWER_PROMPTS = [
"Answer: ",
"A: ",
"",
"The answer is: ",
"Result: ",
"Conclusion: ",
]

def _get_query_prompt(self, idx: int) -> str:
farzadab marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates a random prompt for a BoolQ sample with a passage and question.
Example prompt:
Passage: {context}

Question: {question}

Provide a short explanation, then respond with True/False on the last line.
"""
if self._args.prompt:
return self._args.prompt
prompt_idx = idx % min(self._args.num_prompts, len(self.BOOLQ_PASSAGE_PROMPTS))
farzadab marked this conversation as resolved.
Show resolved Hide resolved
prompt = self.BOOLQ_PASSAGE_PROMPTS[prompt_idx]

# Separate either with 1 or 2 newlines, depending on idx
# 13, 17, 19 are prime numbers (to avoid a pattern)
separator = self.SEPARATORS[
idx % 13 % min(self._args.num_prompts, len(self.SEPARATORS))
farzadab marked this conversation as resolved.
Show resolved Hide resolved
]

query_prompt = self.QUERY_PROMPTS[
idx % 17 % min(self._args.num_prompts, len(self.QUERY_PROMPTS))
farzadab marked this conversation as resolved.
Show resolved Hide resolved
]
prompt = f"{query_prompt}{{question}}{separator}{prompt}"

if self._args.include_context:
context_prompt = self.CONTEXT_PROMPTS[
idx % 19 % min(self._args.num_prompts, len(self.CONTEXT_PROMPTS))
]
prompt = f"{context_prompt}{{context}}{separator}{prompt}"

return prompt

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
answer = "True" if row["answer"] else "False"
answer_prompt = self.ANSWER_PROMPTS[
idx % 23 % min(self._args.num_prompts, len(self.ANSWER_PROMPTS))
]
query_prompt = self._get_query_prompt(idx)
user_content = query_prompt.format(
question="<|audio|>" if self._args.include_audio else row["question"],
context=row["passage"],
)
messages = [
{"role": "user", "content": user_content},
{
"role": "assistant",
"content": f"{row['explanation']}\n{answer_prompt}{answer}",
},
]

return self._make_sample(
messages, self._get_audio(row), audio_transcript=row["question"]
)


class LibriSpeechDataset(VoiceDataset):
"""
LibriSpeech is a corpus of approximately 1000 hours of 16kHz read
Expand Down Expand Up @@ -599,6 +688,7 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> data.IterableDataset:
"anyinstruct_out": AnyInstructOutputDataset,
"boolq": BoolQDataset,
"boolq_in": BoolQInputDataset,
"boolq_extended": BoolQWithExtendedAnswerDataset,
"gigaspeech": GigaSpeechDataset,
"librispeech": LibriSpeechDataset,
"voxpopuli": VoxPopuliDataset,
Expand Down
3 changes: 3 additions & 0 deletions ultravox/evaluation/eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ultravox.evaluation import eval_types
from ultravox.evaluation import gpt_eval
from ultravox.evaluation import string_based
from ultravox.evaluation import wer


Expand All @@ -10,5 +11,7 @@ def evaluate_answer(sample: eval_types.Sample, metric: str) -> eval_types.Result
return gpt_eval.evaluate_answer_boolq(sample)
elif metric == "instruct":
return gpt_eval.evaluate_answer_instruct(sample)
elif metric == "exact_match_last_word":
return string_based.match_last_word(sample)
else:
raise ValueError(f"Unknown metric: {metric}")
10 changes: 9 additions & 1 deletion ultravox/evaluation/eval_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,12 @@ class WerResult:
score: float


Result = Union[InstructResult, WerResult]
@dataclasses.dataclass
class ExactMatchResult:
juberti marked this conversation as resolved.
Show resolved Hide resolved
"""Score is the 0-1 evaluation of the accuracy of the generated answer being equal to expected answer."""

score: float
reason: str


Result = Union[InstructResult, WerResult, ExactMatchResult]
23 changes: 23 additions & 0 deletions ultravox/evaluation/string_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import re

from ultravox.evaluation import eval_types


def match_last_word(sample: eval_types.Sample) -> eval_types.ExactMatchResult:
last_words = re.findall(r"\b\w+\b(?=\W*$)", sample.generated_answer.lower())
expected_tf = re.findall(r"\b\w+\b(?=\W*$)", sample.expected_answer.lower())[-1]

if not last_words:
return eval_types.ExactMatchResult(score=0, reason="No last word found")

last_word: str = last_words[-1]
if last_word in ["yes", "true"]:
last_word = "true"
elif last_word in ["no", "false"]:
last_word = "false"
else:
return eval_types.ExactMatchResult(score=0, reason="Last word not true/false")

return eval_types.ExactMatchResult(
score=last_word == expected_tf, reason="exact_match check"
)
11 changes: 9 additions & 2 deletions ultravox/training/configs/stage2_lora.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
exp_name: stage2_lora__gs_ai_bq

text_model_lora_config:
r: 64 # no/little change in the range [16, 64]
target_modules: ['mlp.gate_proj', 'mlp.up_proj', 'mlp.down_proj', 'v_proj', 'o_proj', 'k_proj', 'q_proj']

data_sets: ["commonvoice", "peoplespeech", "anyinstruct"]
data_sets: ["gigaspeech", "anyinstruct", "boolq_extended"]

# disable_layer_drop: True
# audio_model_lora_config:
# r: 64
# target_modules: ['k_proj', 'q_proj', 'v_proj', 'out_proj', 'intermediate_dense', 'output_dense']

num_prompts: 6
num_prompts: 11

lr: 1.e-4 # need a lower LR for LLM fine-tuning
lr_scheduler: constant_with_warmup
farzadab marked this conversation as resolved.
Show resolved Hide resolved
lr_warmup_steps: 250
max_steps: 5_000
save_steps: 0

batch_size: 2
grad_accum_steps: 2
farzadab marked this conversation as resolved.
Show resolved Hide resolved
74 changes: 49 additions & 25 deletions ultravox/training/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent.futures
import dataclasses
import functools
import os
from typing import List, Optional
Expand Down Expand Up @@ -44,14 +45,37 @@ def dataset_infer(
return ddp_utils.all_gather_list(eval_samples)


def get_metric_name(ds_name: str, metric: str) -> str:
if ds_name == "boolq_in" and metric == "asr":
return "boolq__wer"
if ds_name == "boolq" and metric == "boolq":
return "boolq__correctness"
if metric == "instruct":
return f"{ds_name}__instruct_follow"
return f"{ds_name}__{metric}"
@dataclasses.dataclass
class EvalScenario:
name: str
dataset: str
metric: str
include_audio: bool = True
include_context: bool = True
new_tokens: Optional[int] = None


EVAL_SCENARIOS = [
farzadab marked this conversation as resolved.
Show resolved Hide resolved
EvalScenario("anyinstruct__instruct_follow", "anyinstruct", "instruct"),
EvalScenario(
"boolq__binary", "boolq_extended", "exact_match_last_word", new_tokens=128
),
EvalScenario("boolq__wer", "boolq_in", "asr"),
# Text-only scenarios: tests for catastrophic forgetting.
EvalScenario(
"anyinstruct__instruct_follow__text_only",
"anyinstruct",
"instruct",
include_audio=False,
),
EvalScenario(
"boolq__binary__text_only",
"boolq_extended",
"exact_match_last_word",
new_tokens=128,
include_audio=False,
),
]


def evaluate(
Expand All @@ -68,21 +92,20 @@ def evaluate(
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))

ds_args = datasets.VoiceDatasetArgs(
data_dir=data_dir, split=datasets.DatasetSplit.VALIDATION
)
for task in EVAL_SCENARIOS:
ds_args = datasets.VoiceDatasetArgs(
data_dir=data_dir,
split=datasets.DatasetSplit.VALIDATION,
include_audio=task.include_audio,
include_context=task.include_context,
)

for ds_name, metric in [
("boolq_in", "asr"),
("boolq", "boolq"),
("anyinstruct", "instruct"),
]:
ds = datasets.Range(datasets.create_dataset(ds_name, ds_args), num_samples)
ds = datasets.Range(datasets.create_dataset(task.dataset, ds_args), num_samples)

output_samples = dataset_infer(
inference,
ds=ds,
max_new_tokens=max_new_tokens,
max_new_tokens=task.new_tokens or max_new_tokens,
temperature=temperature,
world_size=world_size,
local_rank=local_rank,
Expand All @@ -92,21 +115,21 @@ def evaluate(
# Only the master process should evaluate the samples.
continue

eval_per_sample = functools.partial(eval.evaluate_answer, metric=metric)
eval_per_sample = functools.partial(eval.evaluate_answer, metric=task.metric)

with concurrent.futures.ThreadPoolExecutor(max_workers=num_procs) as executor:
possibly_non_scores = [
x.score for x in executor.map(eval_per_sample, output_samples)
]

if None in possibly_non_scores:
print(f"Failed to evaluate {metric} for {ds_name}")
print(f"Failed to evaluate {task.metric} for {task.dataset}")
continue

scores = [x for x in possibly_non_scores if x is not None]

if verbose:
print(f"Eval for {ds_name}:")
print(f"Eval for {task.dataset}:")
for sample, score in zip(output_samples, scores):
print("-" * 20)
print(f"Q: {sample.question}")
Expand All @@ -115,10 +138,11 @@ def evaluate(

average = np.mean(scores)
std = np.std(scores) / np.sqrt(len(scores))
metric_name = get_metric_name(ds_name, metric)
metrics[f"eval_{metric_name}"] = average
metrics[f"eval_{metric_name}_std"] = std
metrics[f"eval_{task.name}"] = average
metrics[f"eval_{task.name}_std"] = std

print(f"Aggregate {metric} score for {ds_name}: {average:.2f} ± {std:.2f}")
print(
f"Aggregate {task.metric} score for {task.dataset}: {average:.2f} ± {std:.2f}"
)

return metrics
Loading