diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 1a190df9..1b0a5753 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -190,7 +190,8 @@ def forward( # B x A/3200 x D audio_tower_output = self.audio_tower.forward( - audio_values.to(self.audio_tower.dtype), audio_len=audio_len + audio_values.to(self.audio_tower.dtype), + audio_len=audio_len, ).last_hidden_state audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) @@ -286,14 +287,28 @@ def _create_audio_tower( audio_tower = ModifiedWhisperEncoder.from_pretrained( config.audio_model_id, torch_dtype=config.torch_dtype ) + audio_tower.init_latency_mask( + config.audio_latency_block_size, dtype=config.torch_dtype + ) else: + assert config.audio_latency_block_size not in ( + None, + 0, + ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" audio_tower = transformers.AutoModel.from_pretrained( config.audio_model_id, torch_dtype=config.torch_dtype ) else: if "whisper" in config.audio_config._name_or_path: audio_tower = ModifiedWhisperEncoder(config.audio_config) + audio_tower.init_latency_mask( + config.audio_latency_block_size, dtype=config.torch_dtype + ) else: + assert config.audio_latency_block_size not in ( + None, + 0, + ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'" with transformers.modeling_utils.no_init_weights(): # we only ever use from_config if the weights are retrained, hence initializing is not # required. This makes the model quite creation faster since init on CPU is quite slow. @@ -529,6 +544,39 @@ class ModifiedWhisperEncoder( base_model_prefix = "model.encoder" _no_split_modules = ["WhisperEncoderLayer"] + def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype): + if audio_latency_block_size is None: + self.audio_streaming_mask = None + return + + # maximum sequence length + max_seqlen = ( + self.config.max_source_positions + * self.conv1.stride[0] + * self.conv2.stride[0] + ) + assert ( + max_seqlen > 0 + ), f"maximum sequence length must be positive, got {max_seqlen}" + assert ( + max_seqlen % audio_latency_block_size == 0 + ), f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly." + # Given the block size, we calculate number of blocks. + audio_latency_nblocks = max_seqlen // audio_latency_block_size + audio_streaming_mask = ( + torch.tril( + torch.ones(audio_latency_nblocks, audio_latency_nblocks), + diagonal=0, + ) + .repeat_interleave(audio_latency_block_size, dim=0) + .repeat_interleave(audio_latency_block_size, dim=1) + ) + audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min + audio_streaming_mask = audio_streaming_mask[None, None, :, :] + self.register_buffer( + "audio_streaming_mask", audio_streaming_mask, persistent=False + ) + def forward( self, input_features, @@ -586,13 +634,10 @@ def forward( attention_mask = None if audio_len != None: audio_feature_len = self._get_feat_extract_output_lengths(audio_len) - batch_size = hidden_states.shape[0] max_seq_len = hidden_states.shape[1] - attention_mask = ( - torch.arange(max_seq_len, device=hidden_states.device)[None, :] - .expand(batch_size, -1) - .lt(audio_feature_len.view(batch_size, 1)) - ) + attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[ + None, : + ].lt(audio_feature_len.view(-1, 1)) attention_mask = self.get_extended_attention_mask( attention_mask, None, @@ -600,6 +645,16 @@ def forward( dtype=hidden_states.dtype, ) + if self.audio_streaming_mask is not None: + seqlen = hidden_states.size(-2) + if attention_mask is not None: + attention_mask = torch.minimum( + self.audio_streaming_mask[:, :, :seqlen, :seqlen], attention_mask + ) # merge + else: + attention_mask = self.audio_streaming_mask[:, :, :seqlen, :seqlen] + attention_mask = attention_mask.to(hidden_states.dtype) + # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( diff --git a/ultravox/model/ultravox_model_test.py b/ultravox/model/ultravox_model_test.py new file mode 100644 index 00000000..a0b95b20 --- /dev/null +++ b/ultravox/model/ultravox_model_test.py @@ -0,0 +1,62 @@ +import pytest +import torch +from transformers import WhisperConfig + +from ultravox.model import ultravox_model + + +@pytest.fixture +def encoder(): + config = WhisperConfig( + max_source_positions=1500, + d_model=256, + encoder_attention_heads=4, + encoder_layers=4, + ) + return ultravox_model.ModifiedWhisperEncoder(config) + + +def test_init_latency_mask_none(encoder): + encoder.init_latency_mask(None, torch.float32) + assert encoder.audio_streaming_mask is None + + +def test_init_latency_mask_valid(encoder): + block_size = 100 + encoder.init_latency_mask(block_size, torch.float32) + assert encoder.audio_streaming_mask is not None + + assert len(encoder.audio_streaming_mask.shape) == 4 + assert encoder.audio_streaming_mask.shape[0] == 1 + assert encoder.audio_streaming_mask.shape[1] == 1 + + mask = encoder.audio_streaming_mask[0, 0] + # 100*30=3000 + source_mask = ( + torch.tril(torch.ones(30, 30), diagonal=0) + .repeat_interleave(block_size, dim=0) + .repeat_interleave(block_size, dim=1) + ) + source_mask = (1.0 - source_mask) * torch.finfo(torch.float32).min + print(mask.shape) + assert torch.allclose(mask, source_mask) + + +def test_init_latency_mask_invalid_block_size(encoder): + invalid_block_size = 13 + + with pytest.raises(AssertionError, match="must divide .* evenly"): + encoder.init_latency_mask(invalid_block_size, torch.float32) + + +def test_init_latency_mask_different_dtypes(encoder): + block_size = 50 + for dtype in (torch.float32, torch.float16): + encoder.init_latency_mask(block_size, dtype) + assert encoder.audio_streaming_mask.min() == torch.finfo(dtype).min + + +def test_init_latency_mask_persistence(encoder): + block_size = 50 + encoder.init_latency_mask(block_size, torch.float32) + assert "audio_streaming_mask" in encoder._buffers diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 22d84f95..6a66b40f 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -106,6 +106,9 @@ def get_val_sets(self) -> List[DatasetOptions]: # loss function to use loss_config: Optional[ultravox_config.LossConfig] = None + # To simulate audio streaming with masking. None for non-causal, 100 for 1s, 200 for 2s, and so on. + audio_latency_block_size: Optional[int] = None + def __post_init__(self): assert self.data_type in ["bfloat16", "float16", "float32"] if self.device == "cuda" and not torch.cuda.is_available(): diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index bfe3626b..e7807bd4 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -31,3 +31,5 @@ batch_size: 4 data_type: "bfloat16" report_logs_to: ["tensorboard", "wandb"] + +audio_latency_block_size: null # null for non-causal, 100 for 1s, 200 for 2s, and so on. diff --git a/ultravox/training/configs/streaming_tinyllama.yaml b/ultravox/training/configs/streaming_tinyllama.yaml new file mode 100644 index 00000000..46d36dbc --- /dev/null +++ b/ultravox/training/configs/streaming_tinyllama.yaml @@ -0,0 +1,26 @@ + +exp_name: "ultravox-streaming-experiments-1s" +# Make sure to accept the license agreement on huggingface hub +text_model: "meta-llama/Llama-3.2-1B-Instruct" +audio_model: "openai/whisper-small" +loss_config: + # Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence" + loss_function: "KL_Divergence" +train_sets: + - name: librispeech-clean-continuation + - name: librispeech-other-continuation + - name: peoplespeech-clean-continuation + weight: 4 + - name: commonvoice-en-continuation + weight: 4 + - name: librispeech-clean-transcription + weight: 4 + - name: librispeech-other-transcription + - name: peoplespeech-clean-transcription + - name: commonvoice-en-transcription +# Temporarily remove heysquad_human from val_sets as it causes the training to fail. +val_sets: + - name: peoplespeech +batch_size: 24 +max_steps: 10000 # x8x24 = 2,764,800 +audio_latency_block_size: 100 # null for non-causal, 100 for 1s, 200 for 2s, and so on. \ No newline at end of file diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5281d80d..9bc33390 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -123,6 +123,7 @@ def train(args: config_base.TrainConfig): audio_model_lora_config=args.audio_model_lora_config, torch_dtype=args.data_type, pad_token_id=text_tokenizer.eos_token_id, + audio_latency_block_size=args.audio_latency_block_size, ) logging.info("Instantiating model...")