Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Qwen2Audio] handle input ids expansion during processing #35534

Merged
merged 12 commits into from
Jan 7, 2025
36 changes: 33 additions & 3 deletions src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,9 +1197,39 @@ def forward(
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.multi_modal_projector(selected_audio_feature)

inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
)
# if we have consecutive audio tokens, then it means we expanded input_ids in processing
audio_tokens = input_ids == self.config.audio_token_index
legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0

if legacy_processing:
eustlb marked this conversation as resolved.
Show resolved Hide resolved
logger.warning_once(
"Expanding inputs for audio tokens in Qwen2Audio should be done in processing."
)
inputs_embeds, attention_mask, labels, position_ids, _ = self._merge_input_ids_with_audio_features(
audio_features, audio_output_lengths, inputs_embeds, input_ids, attention_mask, labels
)
else:
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
audio_output_lengths.device
) < audio_output_lengths.unsqueeze(1)
eustlb marked this conversation as resolved.
Show resolved Hide resolved
audio_features = audio_features[audio_features_mask]

n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
n_audio_features = audio_features.shape[0]

if n_audio_tokens != n_audio_features:
raise ValueError(
f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
)
special_audio_mask = (
(input_ids == self.config.audio_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
eustlb marked this conversation as resolved.
Show resolved Hide resolved
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
28 changes: 26 additions & 2 deletions src/transformers/models/qwen2_audio/processing_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,18 @@ class Qwen2AudioProcessor(ProcessorMixin):
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
audio_token (`str`, *optional*, defaults to `"<|AUDIO|>"`):
The token to use for audio tokens.
"""

attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = "AutoTokenizer"

def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None):
def __init__(self, feature_extractor=None, tokenizer=None, chat_template=None, audio_token="<|AUDIO|>"):
if chat_template is None:
chat_template = self.default_chat_template
self.audio_token = tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)

def __call__(
Expand Down Expand Up @@ -88,7 +91,8 @@ def __call__(

if text is None:
raise ValueError("You need to specify either a `text` input to process.")
inputs = self.tokenizer(text, padding=padding, **kwargs)
elif isinstance(text, str):
text = [text]

if audios is not None:
audio_inputs = self.feature_extractor(
Expand All @@ -97,6 +101,26 @@ def __call__(
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
) # rename attention_mask to prevent conflicts later on

expanded_text = []
eustlb marked this conversation as resolved.
Show resolved Hide resolved
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()
for sample in text:
replace_str = []
while self.audio_token in sample:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
replace_str.append(self.audio_token * num_audio_tokens)
sample = sample.replace(self.audio_token, "<placeholder>", 1)

while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text
eustlb marked this conversation as resolved.
Show resolved Hide resolved

inputs = self.tokenizer(text, padding=padding, **kwargs)

if audios is not None:
inputs.update(audio_inputs)

return BatchFeature(data={**inputs})
Expand Down
18 changes: 13 additions & 5 deletions tests/models/qwen2_audio/test_modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
parent,
ignore_index=-100,
audio_token_index=0,
seq_length=7,
seq_length=25,
feat_seq_length=60,
text_config={
"model_type": "qwen2",
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
self.is_training = is_training

self.batch_size = 3
self.encoder_seq_length = audio_config["max_source_positions"] // 2 + seq_length - 1
self.encoder_seq_length = seq_length

def get_config(self):
return Qwen2AudioConfig(
Expand All @@ -118,11 +118,13 @@ def prepare_config_and_inputs(self):
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_features_values, feature_attention_mask = config_and_inputs
input_length = (input_features_values.shape[-1] - 1) // 2 + 1
num_audio_tokens = (input_length - 2) // 2 + 1
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
attention_mask[:, :1] = 0
# we are giving 3 audios let's make sure we pass in 3 audios tokens
input_ids[:, 1] = config.audio_token_index
input_ids[:, 1 : 1 + num_audio_tokens] = config.audio_token_index
inputs_dict = {
"input_features": input_features_values,
"feature_attention_mask": feature_attention_mask,
Expand Down Expand Up @@ -262,7 +264,9 @@ def test_small_model_integration_test_single(self):
25,
220,
151647,
151646,
]
+ [151646] * 101
+ [
151648,
198,
3838,
Expand All @@ -280,7 +284,11 @@ def test_small_model_integration_test_single(self):
)
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
EXPECTED_DECODED_TEXT = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
+ "<|AUDIO|>" * 101
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
)

eustlb marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=False),
Expand Down
Loading