-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix onnx export of speech foundation models #34224
fix onnx export of speech foundation models #34224
Conversation
Maybe also cc @xenova for ONNX :) |
@LysandreJik nobody responded so far, should I cc someone else? |
The best suited to review is @ylacombe I believe; let's give Yoach a few days to answer, thanks for your PR @nikosanto13! |
ok nice, thanks for the help @LysandreJik |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM @nikosanto13 , thanks for opening the PR!
My only worry is that it seems like a costly operation, are we sure this is the best / most elegant way of expanding the mask ?
BTW, you should make sure tests pass. Could you rebase on main, and check that the whole CI is green ?
We also have to make sure this doesn't break our slow tests, let's do an empty commit with the message name: [run-slow] data2vec, hubert, sew, sew_d, unispeech, unispeech_sat, wav2vec2, wav2vec2, wav2vec2_bert, wav2vec2_conformer, wavlm
thanks for mentioning @ylacombe - I'll rebase and add an empty commit for the slow tests. For your comment about this being a costly operation, I admit I pretty much adopted the same patch (with the previous pr), without giving it much consideration. My initial thought is that this shouldn't be much of a concern (given that the masks are boolean for the majority of cases). Let me conduct a mini benchmark and I'll return with actual numbers. |
Great, looking forward to it |
@ylacombe I conducted a mini benchmark using WavLM ( System info: A10 GPU (24GB VRAM) Details on the benchmark: I run WavLM's forward pass for both train and eval modes. In the code below, you can see the details - a batch (batch_size = 16) with 10-sec audios is generated and fed through the model. I created the following script: from contextlib import nullcontext
import numpy as np
import torch
from torch.autograd import profiler
from transformers import AutoFeatureExtractor, WavLMModel
# taken from torch.autograd.profiler
def _format_time(time_us):
"""Define how to format time in FunctionEvent."""
US_IN_SECOND = 1000.0 * 1000.0
US_IN_MS = 1000.0
if time_us >= US_IN_SECOND:
return f"{time_us / US_IN_SECOND:.3f}s"
if time_us >= US_IN_MS:
return f"{time_us / US_IN_MS:.3f}ms"
return f"{time_us:.3f}us"
# taken from torch.autograd.profiler
def _format_memory(nbytes):
"""Return a formatted memory size string."""
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
if abs(nbytes) >= GB:
return f"{nbytes * 1.0 / GB:.2f} Gb"
elif abs(nbytes) >= MB:
return f"{nbytes * 1.0 / MB:.2f} Mb"
elif abs(nbytes) >= KB:
return f"{nbytes * 1.0 / KB:.2f} Kb"
else:
return str(nbytes) + " b"
def run_forward(model, feat_extractor, train=False):
if train:
model.train()
else:
model.eval()
ctx = torch.no_grad() if not train else nullcontext()
input_values = torch.randn(16, 10*feat.sampling_rate).to(device)
attention_mask = torch.ones_like(input_values, dtype=torch.bool)
with ctx, profiler.profile(use_cuda='cuda', profile_memory=True) as prof:
output = model(input_values, attention_mask=attention_mask)
events_list = prof.key_averages()
mask_event = [elem for elem in events_list if elem.key == 'MASK_HIDDEN_STATES'][0]
print(mask_event)
return mask_event.device_time, mask_event.device_memory_usage
if __name__ == "__main__":
device = 'cuda'
model = WavLMModel.from_pretrained("microsoft/wavlm-base").to(device)
feat = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base")
times, mems = [], []
for i in range(10):
time, mem = run_forward(model, feat)
# skipping first iteration due to additional overhead
if i > 0:
times.append(time)
mems.append(mem)
avg_time, std_time = np.mean(times), np.std(times)
avg_mem, std_mem = np.mean(mems), np.std(mems)
print(f"Average time: {_format_time(avg_time)} ± {_format_time(std_time)}")
print(f"Average memory: {_format_memory(avg_mem)} ± {_format_memory(std_mem)}") NOTE: In addition to that, you also have to wrap the masking operation with the if attention_mask is not None:
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0 with this: # don't forget to import torch.autograd.profiler
# from torch.autograd import profiler
if attention_mask is not None:
with profiler.record_function("MASK_HIDDEN_STATES"):
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0 The following measurements were obtained:
|
1b23d37
to
b42f06e
Compare
b42f06e
to
22b18e4
Compare
22b18e4
to
1966676
Compare
Hey @nikosanto13, thanks for conducting the benchmark! I'm not sure to understand what you benchmarked though. Did you compare the speed of this new operation as compared to how it was done previously? |
Hey @ylacombe - well, based on your message I figured I should benchmark both speed and memory consumption. As you can see in the attached code snippet I'm comparing those two measurements for a forward pass of the WavLM model, before (main) and after the changes (my branch). Do you think of a more appropriate way to benchmark this? I'd be happy to extend this. |
Thanks @nikosanto13, it's much clearer now. Looks like it's not a costly operation. This PR looks great to me. cc @ArthurZucker or @Rocketknight1 for a core maintainer review! |
1966676
to
468dd8a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good 🤗 let's merge!
What does this PR do?
This old issue #10004 described an error in the onnx export of wav2vec2-base-960h.
The issue was partially fixed for some of the speech foundation models in this PR: #16004.
This PR:
wavlm-base-plus
(modeling_wavlm.py) to onnx),Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Tagging @xenova, @ylacombe, @eustlb (let me know if I should also add someone else)