Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2 Conformer] Fix inference float16 #25985

Merged
merged 4 commits into from
Sep 5, 2023

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Sep 5, 2023

What does this PR do?

Fixes #25964 - the Wav2Vec2 conformer model with rotary embeddings now works when we load it from_pretrained with float16. The issue was originating in the rotary embedding layer, which was returning the positional embeddings in float32 always

@@ -901,6 +901,26 @@ def test_speech_to_text_leveraged(self):
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})

@slow
@require_torch_gpu
def test_wav2vec2_conformer_float16(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the error repro that was failing before @Vaibhavs10 - added a slow integration test to make sure this works after the fix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Thanks <3

@@ -406,13 +406,15 @@ def forward(self, hidden_states):
return self.cached_rotary_positional_embedding

self.cached_sequence_length = sequence_length
# Embeddings are computed in the dtype of the inv_freq constant
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now looks a lot like:

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

Wondering if we can add copied from and use this / wondering if the dynamic scaling could also work for audio models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't use # Copied from on the whole module since the Wav2Vec2ConformerRotaryPositionalEmbedding accepts the config as an argument, but LlamaRotaryEmbedding uses various ad-hoc arguments. But we could do a similar dynamic slicing - will add this in a follow-up PR so as not to block @Vaibhavs10

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 5, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks for taking care of this!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Left a nit, but I think we can use the LlamaRotary class now 😄

@@ -406,13 +406,15 @@ def forward(self, hidden_states):
return self.cached_rotary_positional_embedding

self.cached_sequence_length = sequence_length
# Embeddings are computed in the dtype of the inv_freq constant
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now looks a lot like:

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

Wondering if we can add copied from and use this / wondering if the dynamic scaling could also work for audio models?

@sanchit-gandhi sanchit-gandhi merged commit 8d51801 into huggingface:main Sep 5, 2023
@sanchit-gandhi sanchit-gandhi deleted the w2v2-conformer branch September 5, 2023 17:26
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
* [Wav2Vec2 Conformer] Fix inference float16

* fix test

* fix test more

* clean pipe test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[bug] facebook/wav2vec2-conformer-rope-large-960h-ft refuses to work in fp16
6 participants