Skip to content

Commit

Permalink
BoolQ for Training and Eval (#30)
Browse files Browse the repository at this point in the history
* set default to include_context=True

* boolq extended dataset for training

* improved evals + boolq T/F eval + text-only
  • Loading branch information
farzadab authored Jun 21, 2024
1 parent 0e4ae3b commit 4202b56
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 35 deletions.
104 changes: 97 additions & 7 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class VoiceDatasetArgs:
"""If `prompt` is not set, the number of canned prompts to use."""
include_audio: bool = True
"""Whether to include audio in the samples."""
include_context: bool = False
include_context: bool = True
"""Whether to include additional textual context from the dataset to the prompt."""
shuffle: bool = False
"""Whether to shuffle the dataset."""
Expand Down Expand Up @@ -434,7 +434,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
audio_transcript = row["chat"][0]["message"]
return VoiceSample(
return self._make_sample(
self._get_transcribe_messages(idx, audio_transcript),
self._load_anyinstruct_audio(row["chat"][0]["speech"]),
audio_transcript=audio_transcript,
Expand All @@ -447,7 +447,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
audio_transcript = row["chat"][1]["message"]
return VoiceSample(
return self._make_sample(
self._get_transcribe_messages(idx, audio_transcript),
self._load_anyinstruct_audio(row["chat"][1]["speech"]),
audio_transcript=audio_transcript,
Expand All @@ -456,11 +456,10 @@ def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:

class BoolQDataset(VoiceDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
assert (
args.split == DatasetSplit.VALIDATION
), f"BoolQ is only for validation, but got split={args.split}"
super().__init__(args)
dataset = self._load_audio_dataset("fixie-ai/boolq-audio", split="train")
dataset = self._load_audio_dataset(
"fixie-ai/boolq-audio", split=args.split.value
)
self._init_dataset(dataset)

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
Expand All @@ -479,6 +478,96 @@ def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
return self._get_transcribe_sample(idx, row, tcol="question")


class BoolQWithExtendedAnswerDataset(BoolQDataset):
SEPARATORS = ["\n\n", "\n", "\n----\n"]
BOOLQ_PASSAGE_PROMPTS = [
"Provide a short explanation, then respond with True/False on the last line",
"Explain briefly, concluding with True/False on a new line."
"Write a quick explanation, and finish with True/False on the last line"
"Summarize in a few words, and end with True/False on a new line."
"Give a brief explanation first, then answer with True/False on the final line",
"Start with a concise explanation, and end with a True/False response on the last line.",
"Explain briefly and follow up with True/False at the end",
"Write a short explanation, then state True/False on a new line.",
"First, offer a brief explanation, and then reply with True/False at the end.",
"Present a concise explanation, ending with a True/False answer on the final line",
"Start with a brief explanation, and then answer with True/False at the end.",
]
QUERY_PROMPTS = ["Question: ", "Question:\n", "Q: ", "Q:\n", "Query: ", "Query:\n"]
CONTEXT_PROMPTS = [
"Passage: ",
"Passage:\n",
"Context: ",
"Context:\n",
"Background: ",
"Background:\n",
]
ANSWER_PROMPTS = [
"Answer: ",
"A: ",
"",
"The answer is: ",
"Result: ",
"Conclusion: ",
]

def _get_query_prompt(self, idx: int) -> str:
"""
Creates a random prompt for a BoolQ sample with a passage and question.
Example prompt:
Passage: {context}
Question: {question}
Provide a short explanation, then respond with True/False on the last line.
"""
if self._args.prompt:
return self._args.prompt
prompt_idx = idx % min(self._args.num_prompts, len(self.BOOLQ_PASSAGE_PROMPTS))
prompt = self.BOOLQ_PASSAGE_PROMPTS[prompt_idx]

# Separate either with 1 or 2 newlines, depending on idx
# 13, 17, 19 are prime numbers (to avoid a pattern)
separator = self.SEPARATORS[
idx % 13 % min(self._args.num_prompts, len(self.SEPARATORS))
]

query_prompt = self.QUERY_PROMPTS[
idx % 17 % min(self._args.num_prompts, len(self.QUERY_PROMPTS))
]
prompt = f"{query_prompt}{{question}}{separator}{prompt}"

if self._args.include_context:
context_prompt = self.CONTEXT_PROMPTS[
idx % 19 % min(self._args.num_prompts, len(self.CONTEXT_PROMPTS))
]
prompt = f"{context_prompt}{{context}}{separator}{prompt}"

return prompt

def _get_sample(self, idx: int, row: transformers.BatchFeature) -> VoiceSample:
answer = "True" if row["answer"] else "False"
answer_prompt = self.ANSWER_PROMPTS[
idx % 23 % min(self._args.num_prompts, len(self.ANSWER_PROMPTS))
]
query_prompt = self._get_query_prompt(idx)
user_content = query_prompt.format(
question="<|audio|>" if self._args.include_audio else row["question"],
context=row["passage"],
)
messages = [
{"role": "user", "content": user_content},
{
"role": "assistant",
"content": f"{row['explanation']}\n{answer_prompt}{answer}",
},
]

return self._make_sample(
messages, self._get_audio(row), audio_transcript=row["question"]
)


class LibriSpeechDataset(VoiceDataset):
"""
LibriSpeech is a corpus of approximately 1000 hours of 16kHz read
Expand Down Expand Up @@ -599,6 +688,7 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> data.IterableDataset:
"anyinstruct_out": AnyInstructOutputDataset,
"boolq": BoolQDataset,
"boolq_in": BoolQInputDataset,
"boolq_extended": BoolQWithExtendedAnswerDataset,
"gigaspeech": GigaSpeechDataset,
"librispeech": LibriSpeechDataset,
"voxpopuli": VoxPopuliDataset,
Expand Down
3 changes: 3 additions & 0 deletions ultravox/evaluation/eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ultravox.evaluation import eval_types
from ultravox.evaluation import gpt_eval
from ultravox.evaluation import string_based
from ultravox.evaluation import wer


Expand All @@ -10,5 +11,7 @@ def evaluate_answer(sample: eval_types.Sample, metric: str) -> eval_types.Result
return gpt_eval.evaluate_answer_boolq(sample)
elif metric == "instruct":
return gpt_eval.evaluate_answer_instruct(sample)
elif metric == "exact_match_last_word":
return string_based.match_last_word(sample)
else:
raise ValueError(f"Unknown metric: {metric}")
10 changes: 9 additions & 1 deletion ultravox/evaluation/eval_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,12 @@ class WerResult:
score: float


Result = Union[InstructResult, WerResult]
@dataclasses.dataclass
class ExactMatchResult:
"""Score is the 0-1 evaluation of the accuracy of the generated answer being equal to expected answer."""

score: float
reason: str


Result = Union[InstructResult, WerResult, ExactMatchResult]
23 changes: 23 additions & 0 deletions ultravox/evaluation/string_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import re

from ultravox.evaluation import eval_types


def match_last_word(sample: eval_types.Sample) -> eval_types.ExactMatchResult:
last_words = re.findall(r"\b\w+\b(?=\W*$)", sample.generated_answer.lower())
expected_tf = re.findall(r"\b\w+\b(?=\W*$)", sample.expected_answer.lower())[-1]

if not last_words:
return eval_types.ExactMatchResult(score=0, reason="No last word found")

last_word: str = last_words[-1]
if last_word in ["yes", "true"]:
last_word = "true"
elif last_word in ["no", "false"]:
last_word = "false"
else:
return eval_types.ExactMatchResult(score=0, reason="Last word not true/false")

return eval_types.ExactMatchResult(
score=last_word == expected_tf, reason="exact_match check"
)
11 changes: 9 additions & 2 deletions ultravox/training/configs/stage2_lora.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
exp_name: stage2_lora__gs_ai_bq

text_model_lora_config:
r: 64 # no/little change in the range [16, 64]
target_modules: ['mlp.gate_proj', 'mlp.up_proj', 'mlp.down_proj', 'v_proj', 'o_proj', 'k_proj', 'q_proj']

data_sets: ["commonvoice", "peoplespeech", "anyinstruct"]
data_sets: ["gigaspeech", "anyinstruct", "boolq_extended"]

# disable_layer_drop: True
# audio_model_lora_config:
# r: 64
# target_modules: ['k_proj', 'q_proj', 'v_proj', 'out_proj', 'intermediate_dense', 'output_dense']

num_prompts: 6
num_prompts: 11

lr: 1.e-4 # need a lower LR for LLM fine-tuning
lr_scheduler: constant_with_warmup
lr_warmup_steps: 250
max_steps: 5_000
save_steps: 0

batch_size: 2
grad_accum_steps: 2
74 changes: 49 additions & 25 deletions ultravox/training/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent.futures
import dataclasses
import functools
import os
from typing import List, Optional
Expand Down Expand Up @@ -44,14 +45,37 @@ def dataset_infer(
return ddp_utils.all_gather_list(eval_samples)


def get_metric_name(ds_name: str, metric: str) -> str:
if ds_name == "boolq_in" and metric == "asr":
return "boolq__wer"
if ds_name == "boolq" and metric == "boolq":
return "boolq__correctness"
if metric == "instruct":
return f"{ds_name}__instruct_follow"
return f"{ds_name}__{metric}"
@dataclasses.dataclass
class EvalScenario:
name: str
dataset: str
metric: str
include_audio: bool = True
include_context: bool = True
new_tokens: Optional[int] = None


EVAL_SCENARIOS = [
EvalScenario("anyinstruct__instruct_follow", "anyinstruct", "instruct"),
EvalScenario(
"boolq__binary", "boolq_extended", "exact_match_last_word", new_tokens=128
),
EvalScenario("boolq__wer", "boolq_in", "asr"),
# Text-only scenarios: tests for catastrophic forgetting.
EvalScenario(
"anyinstruct__instruct_follow__text_only",
"anyinstruct",
"instruct",
include_audio=False,
),
EvalScenario(
"boolq__binary__text_only",
"boolq_extended",
"exact_match_last_word",
new_tokens=128,
include_audio=False,
),
]


def evaluate(
Expand All @@ -68,21 +92,20 @@ def evaluate(
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))

ds_args = datasets.VoiceDatasetArgs(
data_dir=data_dir, split=datasets.DatasetSplit.VALIDATION
)
for task in EVAL_SCENARIOS:
ds_args = datasets.VoiceDatasetArgs(
data_dir=data_dir,
split=datasets.DatasetSplit.VALIDATION,
include_audio=task.include_audio,
include_context=task.include_context,
)

for ds_name, metric in [
("boolq_in", "asr"),
("boolq", "boolq"),
("anyinstruct", "instruct"),
]:
ds = datasets.Range(datasets.create_dataset(ds_name, ds_args), num_samples)
ds = datasets.Range(datasets.create_dataset(task.dataset, ds_args), num_samples)

output_samples = dataset_infer(
inference,
ds=ds,
max_new_tokens=max_new_tokens,
max_new_tokens=task.new_tokens or max_new_tokens,
temperature=temperature,
world_size=world_size,
local_rank=local_rank,
Expand All @@ -92,21 +115,21 @@ def evaluate(
# Only the master process should evaluate the samples.
continue

eval_per_sample = functools.partial(eval.evaluate_answer, metric=metric)
eval_per_sample = functools.partial(eval.evaluate_answer, metric=task.metric)

with concurrent.futures.ThreadPoolExecutor(max_workers=num_procs) as executor:
possibly_non_scores = [
x.score for x in executor.map(eval_per_sample, output_samples)
]

if None in possibly_non_scores:
print(f"Failed to evaluate {metric} for {ds_name}")
print(f"Failed to evaluate {task.metric} for {task.dataset}")
continue

scores = [x for x in possibly_non_scores if x is not None]

if verbose:
print(f"Eval for {ds_name}:")
print(f"Eval for {task.dataset}:")
for sample, score in zip(output_samples, scores):
print("-" * 20)
print(f"Q: {sample.question}")
Expand All @@ -115,10 +138,11 @@ def evaluate(

average = np.mean(scores)
std = np.std(scores) / np.sqrt(len(scores))
metric_name = get_metric_name(ds_name, metric)
metrics[f"eval_{metric_name}"] = average
metrics[f"eval_{metric_name}_std"] = std
metrics[f"eval_{task.name}"] = average
metrics[f"eval_{task.name}_std"] = std

print(f"Aggregate {metric} score for {ds_name}: {average:.2f} ± {std:.2f}")
print(
f"Aggregate {task.metric} score for {task.dataset}: {average:.2f} ± {std:.2f}"
)

return metrics

0 comments on commit 4202b56

Please sign in to comment.