diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index d0277560..1b9e2e6b 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -192,7 +192,7 @@ class VoiceDatasetArgs: """If `prompt` is not set, the number of canned prompts to use.""" include_audio: bool = True """Whether to include audio in the samples.""" - include_context: bool = False + include_context: bool = True """Whether to include additional textual context from the dataset to the prompt.""" shuffle: bool = False """Whether to shuffle the dataset.""" @@ -434,7 +434,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample: audio_transcript = row["chat"][0]["message"] - return VoiceSample( + return self._make_sample( self._get_transcribe_messages(idx, audio_transcript), self._load_anyinstruct_audio(row["chat"][0]["speech"]), audio_transcript=audio_transcript, @@ -447,7 +447,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample: audio_transcript = row["chat"][1]["message"] - return VoiceSample( + return self._make_sample( self._get_transcribe_messages(idx, audio_transcript), self._load_anyinstruct_audio(row["chat"][1]["speech"]), audio_transcript=audio_transcript, @@ -456,11 +456,10 @@ def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample: class BoolQDataset(VoiceDataset): def __init__(self, args: VoiceDatasetArgs) -> None: - assert ( - args.split == DatasetSplit.VALIDATION - ), f"BoolQ is only for validation, but got split={args.split}" super().__init__(args) - dataset = self._load_audio_dataset("fixie-ai/boolq-audio", split="train") + dataset = self._load_audio_dataset( + "fixie-ai/boolq-audio", split=args.split.value + ) self._init_dataset(dataset) def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample: @@ -479,6 +478,96 @@ def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample: return self._get_transcribe_sample(idx, row, tcol="question") +class BoolQWithExtendedAnswerDataset(BoolQDataset): + SEPARATORS = ["\n\n", "\n", "\n----\n"] + BOOLQ_PASSAGE_PROMPTS = [ + "Provide a short explanation, then respond with True/False on the last line", + "Explain briefly, concluding with True/False on a new line." + "Write a quick explanation, and finish with True/False on the last line" + "Summarize in a few words, and end with True/False on a new line." + "Give a brief explanation first, then answer with True/False on the final line", + "Start with a concise explanation, and end with a True/False response on the last line.", + "Explain briefly and follow up with True/False at the end", + "Write a short explanation, then state True/False on a new line.", + "First, offer a brief explanation, and then reply with True/False at the end.", + "Present a concise explanation, ending with a True/False answer on the final line", + "Start with a brief explanation, and then answer with True/False at the end.", + ] + QUERY_PROMPTS = ["Question: ", "Question:\n", "Q: ", "Q:\n", "Query: ", "Query:\n"] + CONTEXT_PROMPTS = [ + "Passage: ", + "Passage:\n", + "Context: ", + "Context:\n", + "Background: ", + "Background:\n", + ] + ANSWER_PROMPTS = [ + "Answer: ", + "A: ", + "", + "The answer is: ", + "Result: ", + "Conclusion: ", + ] + + def _get_query_prompt(self, idx: int) -> str: + """ + Creates a random prompt for a BoolQ sample with a passage and question. + Example prompt: + Passage: {context} + + Question: {question} + + Provide a short explanation, then respond with True/False on the last line. + """ + if self._args.prompt: + return self._args.prompt + prompt_idx = idx % min(self._args.num_prompts, len(self.BOOLQ_PASSAGE_PROMPTS)) + prompt = self.BOOLQ_PASSAGE_PROMPTS[prompt_idx] + + # Separate either with 1 or 2 newlines, depending on idx + # 13, 17, 19 are prime numbers (to avoid a pattern) + separator = self.SEPARATORS[ + idx % 13 % min(self._args.num_prompts, len(self.SEPARATORS)) + ] + + query_prompt = self.QUERY_PROMPTS[ + idx % 17 % min(self._args.num_prompts, len(self.QUERY_PROMPTS)) + ] + prompt = f"{query_prompt}{{question}}{separator}{prompt}" + + if self._args.include_context: + context_prompt = self.CONTEXT_PROMPTS[ + idx % 19 % min(self._args.num_prompts, len(self.CONTEXT_PROMPTS)) + ] + prompt = f"{context_prompt}{{context}}{separator}{prompt}" + + return prompt + + def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample: + answer = "True" if row["answer"] else "False" + answer_prompt = self.ANSWER_PROMPTS[ + idx % 23 % min(self._args.num_prompts, len(self.ANSWER_PROMPTS)) + ] + query_prompt = self._get_query_prompt(idx) + user_content = query_prompt.format( + question="<|audio|>" if self._args.include_audio else row["question"], + context=row["passage"], + ) + messages = [ + {"role": "user", "content": user_content}, + { + "role": "assistant", + "content": f"{row['explanation']}\n{answer_prompt}{answer}", + }, + ] + + return self._make_sample( + messages, self._get_audio(row), audio_transcript=row["question"] + ) + + class LibriSpeechDataset(VoiceDataset): """ LibriSpeech is a corpus of approximately 1000 hours of 16kHz read @@ -599,6 +688,7 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> data.IterableDataset: "anyinstruct_out": AnyInstructOutputDataset, "boolq": BoolQDataset, "boolq_in": BoolQInputDataset, + "boolq_extended": BoolQWithExtendedAnswerDataset, "gigaspeech": GigaSpeechDataset, "librispeech": LibriSpeechDataset, "voxpopuli": VoxPopuliDataset, diff --git a/ultravox/evaluation/eval.py b/ultravox/evaluation/eval.py index 7e6f2eda..f574454d 100644 --- a/ultravox/evaluation/eval.py +++ b/ultravox/evaluation/eval.py @@ -1,5 +1,6 @@ from ultravox.evaluation import eval_types from ultravox.evaluation import gpt_eval +from ultravox.evaluation import string_based from ultravox.evaluation import wer @@ -10,5 +11,7 @@ def evaluate_answer(sample: eval_types.Sample, metric: str) -> eval_types.Result return gpt_eval.evaluate_answer_boolq(sample) elif metric == "instruct": return gpt_eval.evaluate_answer_instruct(sample) + elif metric == "exact_match_last_word": + return string_based.match_last_word(sample) else: raise ValueError(f"Unknown metric: {metric}") diff --git a/ultravox/evaluation/eval_types.py b/ultravox/evaluation/eval_types.py index 63788ce0..c473e55a 100644 --- a/ultravox/evaluation/eval_types.py +++ b/ultravox/evaluation/eval_types.py @@ -26,4 +26,12 @@ class WerResult: score: float -Result = Union[InstructResult, WerResult] +@dataclasses.dataclass +class ExactMatchResult: + """Score is the 0-1 evaluation of the accuracy of the generated answer being equal to expected answer.""" + + score: float + reason: str + + +Result = Union[InstructResult, WerResult, ExactMatchResult] diff --git a/ultravox/evaluation/string_based.py b/ultravox/evaluation/string_based.py new file mode 100644 index 00000000..c2d1ab03 --- /dev/null +++ b/ultravox/evaluation/string_based.py @@ -0,0 +1,23 @@ +import re + +from ultravox.evaluation import eval_types + + +def match_last_word(sample: eval_types.Sample) -> eval_types.ExactMatchResult: + last_words = re.findall(r"\b\w+\b(?=\W*$)", sample.generated_answer.lower()) + expected_tf = re.findall(r"\b\w+\b(?=\W*$)", sample.expected_answer.lower())[-1] + + if not last_words: + return eval_types.ExactMatchResult(score=0, reason="No last word found") + + last_word: str = last_words[-1] + if last_word in ["yes", "true"]: + last_word = "true" + elif last_word in ["no", "false"]: + last_word = "false" + else: + return eval_types.ExactMatchResult(score=0, reason="Last word not true/false") + + return eval_types.ExactMatchResult( + score=last_word == expected_tf, reason="exact_match check" + ) diff --git a/ultravox/training/configs/stage2_lora.yaml b/ultravox/training/configs/stage2_lora.yaml index 538b18a7..c8f957ba 100644 --- a/ultravox/training/configs/stage2_lora.yaml +++ b/ultravox/training/configs/stage2_lora.yaml @@ -1,16 +1,23 @@ +exp_name: stage2_lora__gs_ai_bq + text_model_lora_config: r: 64 # no/little change in the range [16, 64] target_modules: ['mlp.gate_proj', 'mlp.up_proj', 'mlp.down_proj', 'v_proj', 'o_proj', 'k_proj', 'q_proj'] -data_sets: ["commonvoice", "peoplespeech", "anyinstruct"] +data_sets: ["gigaspeech", "anyinstruct", "boolq_extended"] # disable_layer_drop: True # audio_model_lora_config: # r: 64 # target_modules: ['k_proj', 'q_proj', 'v_proj', 'out_proj', 'intermediate_dense', 'output_dense'] -num_prompts: 6 +num_prompts: 11 lr: 1.e-4 # need a lower LR for LLM fine-tuning +lr_scheduler: constant_with_warmup lr_warmup_steps: 250 max_steps: 5_000 +save_steps: 0 + +batch_size: 2 +grad_accum_steps: 2 diff --git a/ultravox/training/evaluation.py b/ultravox/training/evaluation.py index 7c4468b0..9785ac92 100644 --- a/ultravox/training/evaluation.py +++ b/ultravox/training/evaluation.py @@ -1,4 +1,5 @@ import concurrent.futures +import dataclasses import functools import os from typing import List, Optional @@ -44,14 +45,37 @@ def dataset_infer( return ddp_utils.all_gather_list(eval_samples) -def get_metric_name(ds_name: str, metric: str) -> str: - if ds_name == "boolq_in" and metric == "asr": - return "boolq__wer" - if ds_name == "boolq" and metric == "boolq": - return "boolq__correctness" - if metric == "instruct": - return f"{ds_name}__instruct_follow" - return f"{ds_name}__{metric}" +@dataclasses.dataclass +class EvalScenario: + name: str + dataset: str + metric: str + include_audio: bool = True + include_context: bool = True + new_tokens: Optional[int] = None + + +EVAL_SCENARIOS = [ + EvalScenario("anyinstruct__instruct_follow", "anyinstruct", "instruct"), + EvalScenario( + "boolq__binary", "boolq_extended", "exact_match_last_word", new_tokens=128 + ), + EvalScenario("boolq__wer", "boolq_in", "asr"), + # Text-only scenarios: tests for catastrophic forgetting. + EvalScenario( + "anyinstruct__instruct_follow__text_only", + "anyinstruct", + "instruct", + include_audio=False, + ), + EvalScenario( + "boolq__binary__text_only", + "boolq_extended", + "exact_match_last_word", + new_tokens=128, + include_audio=False, + ), +] def evaluate( @@ -68,21 +92,20 @@ def evaluate( world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - ds_args = datasets.VoiceDatasetArgs( - data_dir=data_dir, split=datasets.DatasetSplit.VALIDATION - ) + for task in EVAL_SCENARIOS: + ds_args = datasets.VoiceDatasetArgs( + data_dir=data_dir, + split=datasets.DatasetSplit.VALIDATION, + include_audio=task.include_audio, + include_context=task.include_context, + ) - for ds_name, metric in [ - ("boolq_in", "asr"), - ("boolq", "boolq"), - ("anyinstruct", "instruct"), - ]: - ds = datasets.Range(datasets.create_dataset(ds_name, ds_args), num_samples) + ds = datasets.Range(datasets.create_dataset(task.dataset, ds_args), num_samples) output_samples = dataset_infer( inference, ds=ds, - max_new_tokens=max_new_tokens, + max_new_tokens=task.new_tokens or max_new_tokens, temperature=temperature, world_size=world_size, local_rank=local_rank, @@ -92,7 +115,7 @@ def evaluate( # Only the master process should evaluate the samples. continue - eval_per_sample = functools.partial(eval.evaluate_answer, metric=metric) + eval_per_sample = functools.partial(eval.evaluate_answer, metric=task.metric) with concurrent.futures.ThreadPoolExecutor(max_workers=num_procs) as executor: possibly_non_scores = [ @@ -100,13 +123,13 @@ def evaluate( ] if None in possibly_non_scores: - print(f"Failed to evaluate {metric} for {ds_name}") + print(f"Failed to evaluate {task.metric} for {task.dataset}") continue scores = [x for x in possibly_non_scores if x is not None] if verbose: - print(f"Eval for {ds_name}:") + print(f"Eval for {task.dataset}:") for sample, score in zip(output_samples, scores): print("-" * 20) print(f"Q: {sample.question}") @@ -115,10 +138,11 @@ def evaluate( average = np.mean(scores) std = np.std(scores) / np.sqrt(len(scores)) - metric_name = get_metric_name(ds_name, metric) - metrics[f"eval_{metric_name}"] = average - metrics[f"eval_{metric_name}_std"] = std + metrics[f"eval_{task.name}"] = average + metrics[f"eval_{task.name}_std"] = std - print(f"Aggregate {metric} score for {ds_name}: {average:.2f} ± {std:.2f}") + print( + f"Aggregate {task.metric} score for {task.dataset}: {average:.2f} ± {std:.2f}" + ) return metrics