-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sdpa support for Albert #32092
Add sdpa support for Albert #32092
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding @OmarManzoor!
All looks good to me! Could you push a commit which contains the message [run_slow] albert
which will trigger a run of the slow integrations (and now sdpa) tests?
from transformers import AlbertModel | ||
model = AlbertModel.from_pretrained("albert/albert-base-v1", torch_dtype=torch.float16, attn_implementation="sdpa") | ||
... | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also run a few benchmarks for the model to show expected speed ups when using SDPA
Here are some benchmarks for training using
The code is present in the details import argparse
import random
from typing import Dict
import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForMaskedLM
import gc
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_training_steps",
type=int,
default=100,
help="",
)
parser.add_argument(
"--model-name",
type=str,
default="albert/albert-base-v2",
help="",
)
parser.add_argument(
"--use-half",
action="store_true",
)
parser.add_argument(
"--use-cuda",
action="store_true",
)
return parser
def seed_init_fn(x):
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
return
def benchmark_training(model, inputs: Dict, num_training_steps: int):
progress_bar = tqdm(range(num_training_steps))
model.train()
# warmup
for _ in range(10):
model.zero_grad()
outputs = model(**inputs)
loss = outputs.logits.sum()
loss.backward()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_event.record()
for _ in range(num_training_steps):
model.zero_grad()
outputs = model(**inputs)
loss = outputs.logits.sum()
loss.backward()
progress_bar.update(1)
end_event.record()
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated(device)
return (start_event.elapsed_time(end_event) * 1.0e-3) / num_training_steps, max_memory
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
BATCH_SIZES = [1, 2, 4]
SEQ_LEN = [256, 512]
device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")
output_file = open("log_{}_train.csv".format(args.model_name.replace("/", "-")), "w")
output_file.write(
"num_training_steps, batch_size, seq_len, is cuda, Time per batch (eager - s), Time per batch (sdpa - s), "
"Speedup (%), Eager peak mem (MB), sdpa peak mem (MB), Mem saving (%)\n"
)
all_eager_time_per_batch = {}
all_eager_max_mem = {}
all_sdpa_max_mem = {}
all_sdpa_time_per_batch = {}
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
with torch.device(device):
hf_model = AutoModelForMaskedLM.from_pretrained(
args.model_name, torch_dtype=torch.float16 if args.use_half else None, attn_implementation="sdpa"
)
hf_model = hf_model.to(device)
for batch_size in BATCH_SIZES:
for sequence_length in SEQ_LEN:
print(f"Benchmark sdpa on: bs={batch_size}, seq_len={sequence_length}")
vocab_size = hf_model.config.vocab_size
inputs = {
"input_ids": torch.randint(vocab_size - 1, (batch_size, sequence_length), dtype=torch.int64).to(
device
),
# "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device),
}
# raise error if no optimized kernel is available
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True):
sdpa_time_per_batch, sdpa_max_mem = benchmark_training(
hf_model, inputs=inputs, num_training_steps=args.num_training_steps
)
all_sdpa_max_mem[(batch_size, sequence_length)] = sdpa_max_mem
all_sdpa_time_per_batch[(batch_size, sequence_length)] = sdpa_time_per_batch
print(f"PT SDPA: {sdpa_time_per_batch:.3f} s, peak {sdpa_max_mem:.2f} MB")
del hf_model
gc.collect()
with torch.device(device):
hf_model = AutoModelForMaskedLM.from_pretrained(
args.model_name, torch_dtype=torch.float16 if args.use_half else None, attn_implementation="eager"
)
hf_model = hf_model.to(device)
for batch_size in BATCH_SIZES:
for sequence_length in SEQ_LEN:
print(f"Benchmark eager on: bs={batch_size}, seq_len={sequence_length}")
vocab_size = hf_model.config.vocab_size
inputs = {
"input_ids": torch.randint(vocab_size - 1, (batch_size, sequence_length), dtype=torch.int64).to(
device
),
# "attention_mask": torch.ones(batch_size, sequence_length, dtype=torch.int64).to(device),
}
eager_time_per_batch, eager_max_mem = benchmark_training(
hf_model, inputs=inputs, num_training_steps=args.num_training_steps
)
all_eager_time_per_batch[(batch_size, sequence_length)] = eager_time_per_batch
all_eager_max_mem[(batch_size, sequence_length)] = eager_max_mem
eager_max_mem = all_eager_max_mem[(batch_size, sequence_length)] * 1e-6
sdpa_max_mem = all_sdpa_max_mem[(batch_size, sequence_length)] * 1e-6
eager_time_per_batch = all_eager_time_per_batch[(batch_size, sequence_length)]
sdpa_time_per_batch = all_sdpa_time_per_batch[(batch_size, sequence_length)]
print(f"PT eager: {eager_time_per_batch:.3f} s, peak {eager_max_mem:.2f} MB")
print(f"PT SDPA: {sdpa_time_per_batch:.3f} s, peak {sdpa_max_mem:.2f} MB")
speedup = (eager_time_per_batch / sdpa_time_per_batch - 1) * 100
mem_saved = (eager_max_mem / sdpa_max_mem - 1) * 100
output_file.write(
"{},{},{},{},{},{},{},{},{},{}\n".format(
args.num_training_steps,
batch_size,
sequence_length,
args.use_cuda,
f"{eager_time_per_batch:.3f}",
f"{sdpa_time_per_batch:.3f}",
f"{speedup:.3f}",
f"{eager_max_mem:.3f}",
f"{sdpa_max_mem:.3f}",
f"{mem_saved:.3f}",
)
)
output_file.close()
|
For inference I used
Code: import argparse
import numpy as np
import pandas as pd
import torch
import gc
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-batches",
type=int,
default=50,
help="",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
help="",
)
parser.add_argument(
"--seqlen",
type=int,
default=256,
help="Input sequence length.",
)
parser.add_argument(
"--model-name",
type=str,
default="albert/albert-base-v2",
help="",
)
parser.add_argument(
"--use-cuda",
action="store_true",
)
parser.add_argument(
"--use-half",
action="store_true",
)
parser.add_argument(
"--use-mask",
action="store_true",
)
parser.add_argument(
"--sweep",
action="store_true",
)
parser.add_argument(
"--max_token",
type=int,
default=100,
help="Number of new tokens, for autoregressive models using generate.",
)
return parser
def get_batch(batch_size, sequence_length):
tokens = torch.randint(high=5, size=(batch_size, sequence_length))
mask = torch.ones((batch_size, sequence_length), )
mask[0, 0] = 0 # real world case where we may mask
return tokens, mask
def timing_cuda(model, num_batches, input_ids, masks):
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.synchronize()
# We need NOT call torch.cuda.empty_cache() here as it appears to negate the warmup.
latencies = []
for _ in tqdm(range(num_batches)):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
_ = model(input_ids, masks)
end_event.record()
torch.cuda.synchronize()
latency_ms = start_event.elapsed_time(end_event)
latencies.append(latency_ms)
max_memory = torch.cuda.max_memory_allocated(device)
return np.mean(latencies), max_memory
def benchmark(model, input_ids, masks, num_batches, max_token, pad_token_id):
_ = model(input_ids, masks)
torch.cuda.synchronize()
total_time, max_mem = timing_cuda(model, num_batches, input_ids, masks)
return total_time, max_mem
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
if args.sweep:
BATCH_SIZES = [1, 2, 4]
SEQ_LEN = [128, 265]
else:
BATCH_SIZES = [args.batch_size]
SEQ_LEN = [args.seqlen]
device = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
autoclass = AutoModelForQuestionAnswering
if args.use_cuda:
with torch.device("cuda:0"):
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None,
attn_implementation="eager")
hf_model = hf_model.to("cuda:0")
hf_model = hf_model.to(torch.float16)
else:
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None,
attn_implementation="eager")
output_name = "log_{}.csv".format(args.model_name.replace("/", "-"))
output_file = open(output_name, "w")
output_file.write(
"num_batches, batch_size, seq_len, is cuda, is half, use mask, Per token latency eager (ms), Per token latency SDPA (ms), Speedup (%), Mem eager (MB), Mem BT (MB), Mem saved (%)\n"
)
all_max_mem_eager = {}
total_eager_time = {}
for bs in tqdm(BATCH_SIZES):
for seq_len in tqdm(SEQ_LEN):
print(f"-- Running: bs={bs}, seq_len={seq_len}")
input_ids, masks = get_batch(bs, seq_len)
if args.use_cuda:
input_ids = input_ids.to(device)
masks = masks.to(device)
if args.use_mask is False and bs == 1:
masks = None
with torch.inference_mode():
eager_time, max_mem_eager = benchmark(
hf_model,
input_ids,
masks,
args.num_batches,
args.max_token,
tokenizer.pad_token_id,
)
total_eager_time[(bs, seq_len)] = eager_time
all_max_mem_eager[(bs, seq_len)] = max_mem_eager
del hf_model
gc.collect()
total_sdpa_time = {}
all_max_mem_sdpa = {}
if args.use_cuda:
with torch.device("cuda:0"):
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None,
attn_implementation="sdpa")
hf_model = hf_model.to("cuda:0")
hf_model = hf_model.to(torch.float16)
else:
hf_model = autoclass.from_pretrained(args.model_name, torch_dtype=torch.float16 if args.use_half else None,
attn_implementation="sdpa")
for bs in tqdm(BATCH_SIZES):
for seq_len in tqdm(SEQ_LEN):
print(f"-- Running: bs={bs}, seq_len={seq_len}")
input_ids, masks = get_batch(bs, seq_len)
if args.use_cuda:
input_ids = input_ids.to(device)
masks = masks.to(device)
if args.use_mask is False and bs == 1:
masks = None
with torch.inference_mode():
# raise error if no optimized kernel is available
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
sdpa_time, max_mem_sdpa = benchmark(
hf_model,
input_ids,
masks,
args.num_batches,
args.max_token,
tokenizer.pad_token_id,
)
total_sdpa_time[(bs, seq_len)] = sdpa_time
all_max_mem_sdpa[(bs, seq_len)] = max_mem_sdpa
per_token_latency_eager = total_eager_time[(bs, seq_len)] / args.max_token
per_token_latency_sdpa = total_sdpa_time[(bs, seq_len)] / args.max_token
max_mem_eager = all_max_mem_eager[(bs, seq_len)]
max_mem_sdpa = all_max_mem_sdpa[(bs, seq_len)]
speedup = (per_token_latency_eager / per_token_latency_sdpa - 1) * 100
mem_saved = (max_mem_eager / max_mem_sdpa - 1) * 100
max_mem_eager = max_mem_eager * 1e-6
max_mem_sdpa = max_mem_sdpa * 1e-6
print(f"PT eager: {per_token_latency_eager:.3f} ms, peak {max_mem_eager:.2f} MB")
print(f"PT SDPA: {per_token_latency_sdpa:.3f} ms, peak {max_mem_sdpa:.2f} MB")
output_file.write(
"{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
args.num_batches,
bs,
seq_len,
args.use_cuda,
args.use_half,
args.use_mask,
f"{per_token_latency_eager:.3f}",
f"{per_token_latency_sdpa:.3f}",
f"{speedup:.3f}",
f"{max_mem_eager:.3f}",
f"{max_mem_sdpa:.3f}",
f"{mem_saved:.3f}",
)
)
output_file.close()
print("RESULTS:")
df = pd.read_csv(output_name)
print(df.to_markdown(index=False))
|
@OmarManzoor, thanks for sharing! I'm surprised by these numbers - we typically would see speeds ups of ~30%, especially given the similarity to other models like BERT. Could you try running just on the |
With simple Inference
I don't think the simple model works for training. |
With AlbertForSequenceClassification python benchmark_sdpa_training.py --use-half --use-cuda Training
Inference 1
Inference 2
|
@amyeroberts Maybe the inference code I am using needs to be modified? Or maybe since this model does not have a decoder, the changes are not significant? |
Some more inference benchmarks with AlbertForSequenceClassification
|
@OmarManzoor, thanks for running some more numbers.
This shouldn't matter - BERT is encoder-only. Could you try running the script on BERT to see if you're able to replicate the same speedups reported in the docs? |
python benchmark_sdpa_inference.py --num-batches 100 --model-name bert-base-uncased --use-half --use-cuda --use-mask --sweep AutoModelForSequenceClassification
|
@OmarManzoor OK, this indicates to me there might be something wrong with the script or your setup, as you should be seeing the same speedup numbers as in the docs |
Inference benchmarks using GeForce RTX 2060 with 8GB RESULTS:
Probably we can't setup appropriately on kaggle. |
For training AutoModelForSequenceClassification
|
Hi @OmarManzoor thanks for working on this! I got the following result on your branch Env:
Inference speeedup:
|
@OmarManzoor Thanks for iterating on this! Could you rebase on main to include the recent upstream changes? This should solve the code quality checks |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks for all your work on this @OmarManzoor! Only thing remaining is a final slow model run ( |
@amyeroberts Can this be merged now? |
@OmarManzoor Yep! Thanks for all your work on this |
* Add sdpa support for Albert * [run_slow] albert * Add benchmarks and PR suggestion * Fix quality * Fix * [run_slow] albert
What does this PR do?
Adds SDPA for the Albert model
Towards #28005
Who can review?
@amyeroberts @fxmarty