Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix onnx export of speech foundation models #34224

Merged

Conversation

nikosanto13
Copy link
Contributor

@nikosanto13 nikosanto13 commented Oct 17, 2024

What does this PR do?

This old issue #10004 described an error in the onnx export of wav2vec2-base-960h.
The issue was partially fixed for some of the speech foundation models in this PR: #16004.

This PR:

  • applies the same change of expanding the attention mask in all modeling scripts that this occurs (e.g. I came across this when trying to export wavlm-base-plus (modeling_wavlm.py) to onnx),
  • applies the same change of expanding the downsampled padding mask in the <Wav2Vec2, Hubert, ...>ForSequenceClassification modules' forward implementation, because the onnx export of these fails similarly, due to the same broadcasting-related error

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Tagging @xenova, @ylacombe, @eustlb (let me know if I should also add someone else)

@LysandreJik
Copy link
Member

Maybe also cc @xenova for ONNX :)

@nikosanto13
Copy link
Contributor Author

@LysandreJik nobody responded so far, should I cc someone else?

@LysandreJik
Copy link
Member

The best suited to review is @ylacombe I believe; let's give Yoach a few days to answer, thanks for your PR @nikosanto13!

@nikosanto13
Copy link
Contributor Author

ok nice, thanks for the help @LysandreJik

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @nikosanto13 , thanks for opening the PR!

My only worry is that it seems like a costly operation, are we sure this is the best / most elegant way of expanding the mask ?

BTW, you should make sure tests pass. Could you rebase on main, and check that the whole CI is green ?

We also have to make sure this doesn't break our slow tests, let's do an empty commit with the message name: [run-slow] data2vec, hubert, sew, sew_d, unispeech, unispeech_sat, wav2vec2, wav2vec2, wav2vec2_bert, wav2vec2_conformer, wavlm

@nikosanto13
Copy link
Contributor Author

thanks for mentioning @ylacombe - I'll rebase and add an empty commit for the slow tests.

For your comment about this being a costly operation, I admit I pretty much adopted the same patch (with the previous pr), without giving it much consideration. My initial thought is that this shouldn't be much of a concern (given that the masks are boolean for the majority of cases).

Let me conduct a mini benchmark and I'll return with actual numbers.

@ylacombe
Copy link
Contributor

Let me conduct a mini benchmark and I'll return with actual numbers.

Great, looking forward to it

@nikosanto13
Copy link
Contributor Author

nikosanto13 commented Nov 4, 2024

@ylacombe I conducted a mini benchmark using WavLM (microsoft/wavlm-base).

System info: A10 GPU (24GB VRAM)

Details on the benchmark: I run WavLM's forward pass for both train and eval modes. In the code below, you can see the details - a batch (batch_size = 16) with 10-sec audios is generated and fed through the model.

I created the following script:

from contextlib import nullcontext

import numpy as np
import torch
from torch.autograd import profiler

from transformers import AutoFeatureExtractor, WavLMModel


# taken from torch.autograd.profiler
def _format_time(time_us):
    """Define how to format time in FunctionEvent."""
    US_IN_SECOND = 1000.0 * 1000.0
    US_IN_MS = 1000.0
    if time_us >= US_IN_SECOND:
        return f"{time_us / US_IN_SECOND:.3f}s"
    if time_us >= US_IN_MS:
        return f"{time_us / US_IN_MS:.3f}ms"
    return f"{time_us:.3f}us"

# taken from torch.autograd.profiler
def _format_memory(nbytes):
    """Return a formatted memory size string."""
    KB = 1024
    MB = 1024 * KB
    GB = 1024 * MB
    if abs(nbytes) >= GB:
        return f"{nbytes * 1.0 / GB:.2f} Gb"
    elif abs(nbytes) >= MB:
        return f"{nbytes * 1.0 / MB:.2f} Mb"
    elif abs(nbytes) >= KB:
        return f"{nbytes * 1.0 / KB:.2f} Kb"
    else:
        return str(nbytes) + " b"


def run_forward(model, feat_extractor, train=False):

    if train:
        model.train()
    else:
        model.eval()

    ctx = torch.no_grad() if not train else nullcontext()

    input_values = torch.randn(16, 10*feat.sampling_rate).to(device)
    attention_mask = torch.ones_like(input_values, dtype=torch.bool)

    with ctx, profiler.profile(use_cuda='cuda', profile_memory=True) as prof:
        output = model(input_values, attention_mask=attention_mask)

    events_list = prof.key_averages()
    mask_event = [elem for elem in events_list if elem.key == 'MASK_HIDDEN_STATES'][0]

    print(mask_event)
    return mask_event.device_time, mask_event.device_memory_usage


if __name__ == "__main__":

    device = 'cuda'
    model = WavLMModel.from_pretrained("microsoft/wavlm-base").to(device)
    feat = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base")

    times, mems = [], []
    for i in range(10):
        time, mem = run_forward(model, feat)
        # skipping first iteration due to additional overhead
        if i > 0:
            times.append(time)
            mems.append(mem)

    avg_time, std_time = np.mean(times), np.std(times)
    avg_mem, std_mem = np.mean(mems), np.std(mems)

    print(f"Average time: {_format_time(avg_time)} ± {_format_time(std_time)}")
    print(f"Average memory: {_format_memory(avg_mem)} ± {_format_memory(std_mem)}")

NOTE: In addition to that, you also have to wrap the masking operation with the profiler.record_function context manager. E.g. in modeling_wavlm.py and the WavLMEncoder forward, I replaced this

if attention_mask is not None:
       expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
       hidden_states[~expand_attention_mask] = 0

with this:

# don't forget to import torch.autograd.profiler
# from torch.autograd import profiler
if attention_mask is not None:
            with profiler.record_function("MASK_HIDDEN_STATES"):
                expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
                hidden_states[~expand_attention_mask] = 0

The following measurements were obtained:

branch/mode Time CUDA Memory Consumption
main/train 278.000us ± 7.333us 8.00 Kb ± 0.0 b
onnx-export-bugfix-where-nodes/train 435.222us ± 12.417us 11.70 Mb ± 0.0 b
main/eval 145.444us ± 1.066us 0.0 b ± 0.0 b
onnx-export-bugfix-where-nodes/eval 251.333us ± 3.127us 5.85 Mb ± 0.0 b

@nikosanto13 nikosanto13 force-pushed the onnx-export-bugfix-where-nodes branch 5 times, most recently from 1b23d37 to b42f06e Compare November 12, 2024 17:14
@nikosanto13 nikosanto13 force-pushed the onnx-export-bugfix-where-nodes branch from b42f06e to 22b18e4 Compare November 18, 2024 09:27
@nikosanto13 nikosanto13 force-pushed the onnx-export-bugfix-where-nodes branch from 22b18e4 to 1966676 Compare November 25, 2024 15:07
@ylacombe
Copy link
Contributor

Hey @nikosanto13, thanks for conducting the benchmark! I'm not sure to understand what you benchmarked though. Did you compare the speed of this new operation as compared to how it was done previously?

@nikosanto13
Copy link
Contributor Author

Hey @ylacombe - well, based on your message I figured I should benchmark both speed and memory consumption.

As you can see in the attached code snippet I'm comparing those two measurements for a forward pass of the WavLM model, before (main) and after the changes (my branch).

Do you think of a more appropriate way to benchmark this? I'd be happy to extend this.

@ylacombe
Copy link
Contributor

ylacombe commented Dec 2, 2024

Thanks @nikosanto13, it's much clearer now. Looks like it's not a costly operation. This PR looks great to me.

cc @ArthurZucker or @Rocketknight1 for a core maintainer review!

@nikosanto13 nikosanto13 force-pushed the onnx-export-bugfix-where-nodes branch from 1966676 to 468dd8a Compare December 19, 2024 08:20
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good 🤗 let's merge!

@ArthurZucker ArthurZucker merged commit ff9141b into huggingface:main Dec 20, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants