-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup model init on CPU (by 10x+ for llama-3-8B as one example) #31771
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Note: we're seeing some failures with encoder/decoder models that don't have tied weights. Not fully sure what's up there but @SunMarc is investigating |
This can allegedly also increase throughput from Setup: import time
from accelerate.utils import set_seed
from transformers import LlamaForCausalLM, AutoTokenizer
set_seed(42)
file_size = 132 # Size in GB of the weights
factory_model = LlamaForCausalLM.from_pretrained("/mnt/superfast/llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("/mnt/superfast/llama-3-8B")
inputs = tokenizer("Blue is my favorite color. What is my favorite color?", return_tensors="pt")
start_time = time.time()
output = factory_model.generate(**inputs, max_new_tokens=20, num_return_sequences=1)
end_time = time.time()
time_taken = end_time - start_time
print(f"inference time={time_taken:.3f} seconds")
print(f"speed={file_size/time_taken:.3f} GB/second")
new_tokens = len(output[0]) - inputs.input_ids.shape[1]
print(f'tok/s={new_tokens/time_taken:.3f}') Current setup in HF:
New version:
|
Did some tests with
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow - what an impactful PR 🔥 !
Only question is about deepspeed compatibility and compatibility across pytorch versions.
cc @gante for reference for the generate speedups
src/transformers/modeling_utils.py
Outdated
if len(params_to_gather) > 0: | ||
# because zero3 puts placeholders in model params, this context | ||
# manager gathers (unpartitions) the params of the current layer, then loads from | ||
# the state dict and then re-partitions them again | ||
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): | ||
if torch.distributed.get_rank() == 0: | ||
module._load_from_state_dict(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we tested the new code with deepspeed too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet! To come next week :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this can work if the model has already been sharded under deepspeed ZeRO-3, because it hijacks the param tensors and the loading will ether fail (or worse silently remain random).
But I'd suggest to check in with the deepspeed team - perhaps they have some more recent tricks that will accomplish that faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But in general - under zero.Init
context the model is already spread out across the gpus, so you can't just overwrite its shards - w/o the machinery you propose to delete.
I think the just published Universal Checkpoint might be usable here to broadcast the updated tensor shards to each gpu, w/o needing to gather their content first.
@tjruwase, if possible could you please assist if there is a way to update the already sharded tensors in a faster way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
of course if this doesn't work, then you'd need to have 2 code branches. The users who use zero.Init
won't mind waiting, because the huge models they want to load won't load onto a single gpu, so cost of slower loading speed is a trade off here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjruwase, if possible could you please assist if there is a way to update the already sharded tensors in a faster way?
@stas00, I am just catching up on this, but my initial thought is wouldn't the following feature help for this case?
https://deepspeed.readthedocs.io/en/latest/zero3.html#modifying-partitioned-states
src/transformers/modeling_utils.py
Outdated
if is_deepspeed_zero3_enabled(): | ||
import deepspeed | ||
|
||
# In sharded models, each shard has only part of the full state_dict, so only gather | ||
# parameters that are in the current state_dict. | ||
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) | ||
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] | ||
if len(params_to_gather) > 0: | ||
# because zero3 puts placeholders in model params, this context | ||
# manager gathers (unpartitions) the params of the current layer, then loads from | ||
# the state dict and then re-partitions them again | ||
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): | ||
if torch.distributed.get_rank() == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should be as simple as migrating this chunk and having then the load_state_dict
occur under this context manager.
@msaroufim if you have a moment, could you give this a look to check that everything makes sense here per my understanding of how we should be loading in model weights, etc? Would be very appreciative of your eyes/take on this |
For transparency, here is the script I'm using: https://gist.github.com/muellerzr/7239668f61baff5726f556d30d2af5f5 |
💛 💛 💛 ( |
Got confirmation from Mark S (thanks Mark for looking this over) and this is indeed correct 🔥 |
A very important caveat @SunMarc and I discovered today: This works for now as a small fix to users who load everything on CPU instead first/don't do |
About the failing tests, we had the following ones:
it is a bit complicated but basically, these tests were not supposed to pass initially. However, they passed in the end because the weights were tied by default (even when For example, this is the architecture of def __init__(self, config, current_modality="text"):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.text_encoder = SeamlessM4Tv2Encoder(config, self.shared)
self.speech_encoder = SeamlessM4Tv2SpeechEncoder(config)
self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) With this PR, we've set assign (bool, optional): When ``False``, the properties of the tensors
in the current module are preserved while when ``True``, the
properties of the Tensors in the state dict are preserved. The only
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
for which the value from the module is preserved.
Default: ``False`` So, the conclusion is that there should be no problem with the modification you did, I just need to skip/modify the tests. cc @muellerzr In the future, we just need to make sure that when we have shared weight by default, we skip the tests or add the possibility to remove these shared weights. |
Okay ran some traces and I think this makes sense to me now. Compare the following first calls to Baseline:
Fix:
What this hints at here I believe is because we are using And because it's |
@muellerzr, I am curious about the I/O speeds in your OP. Can you please confirm that you are transferring weights from NVMe to HBM at 75-90GB/sec? Are you able to share PCIe and m.2 specs? Thanks |
@tjruwase the answer I've come to (as in my last post) is mmapping is covering the transition from m.2 when we first pass an input through the model (I think). Weights are only allocated in space, but not fully loaded in. Hence why I'm seeing far above what my actual M.2 drive can bring in, but at 8s of time to bring in said weights is reasonable if done quickly! (Because yes, I'd love to know what planet has a 75-90GB/s non-RAID M.2 as well!) My setup:
Let me know how much more specific I can get with this for you! |
@muellerzr, thanks for the clarification. As you may have guessed fast I/O is a passion, and I am also awaiting the above :). |
@tjruwase do let me know if you see anything else odd about what I’ve done here etc too/if you have insights. I’ll look into the DeepSpeed stuff in a few days! |
@muellerzr, your NVMe is blazingly fast, ~14GB/sec reads. May I request your contribution to the following? |
@muellerzr, nothing looks good. This is truly amazing work that you have done here, kudos! Do let me know if my suggestion for updating sharded DeepSpeed weights above is insufficient or problematic. |
Okay! After a ton of thorough testing I've proven that:
Users will not see much speedup if they do |
When I eventually ripped everything out to test, here's my full code: from transformers import LlamaForCausalLM, AutoConfig, AutoTokenizer
from accelerate.utils import set_seed
from accelerate.big_modeling import init_empty_weights
from safetensors.torch import load_file
from pathlib import Path
import json
from safetensors import safe_open
from accelerate.utils import retie_parameters
from transformers import GenerationConfig
from transformers.utils.hub import get_checkpoint_shard_files
import time
set_seed(42)
llama_path = Path("/mnt/superfast/llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained(llama_path)
inputs = tokenizer("Tell me about a girl that", return_tensors="pt")
config = AutoConfig.from_pretrained(llama_path)
use_keep_in_fp32_modules = False
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
llama_path,
llama_path/"model.safetensors.index.json"
)
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
config = LlamaForCausalLM._autoset_attn_implementation(
config, use_flash_attention_2=False, torch_dtype=None, device_map=None
)
with init_empty_weights():
factory_model = LlamaForCausalLM(config)
index_filename = llama_path / "model.safetensors.index.json"
with open(index_filename, "r") as f:
index = json.load(f)
if "weight_map" in index:
index = index["weight_map"]
checkpoint_files = sorted(list(set(index.values())))
checkpoint_files = [llama_path / f for f in checkpoint_files]
model_keys = set(factory_model.state_dict().keys())
new_state_dict = {}
for checkpoint_file in checkpoint_files:
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
weight_names = f.keys()
file_state = load_file(checkpoint_file)
new_state_dict.update(file_state)
factory_model.load_state_dict(new_state_dict, strict=True, assign=True)
retie_parameters(factory_model, [["lm_head.weight"]])
factory_model.eval()
factory_model.generation_config = GenerationConfig.from_pretrained(
llama_path
)
start_time = time.time()
output = factory_model.generate(**inputs, max_new_tokens=20, num_return_sequences=1)
end_time = time.time()
time_taken = end_time - start_time
new_tokens = len(output[0]) - inputs.input_ids.shape[1]
print(f"{time_taken:.3f}s | {new_tokens/time_taken:.3f} tokens/second | {tokenizer.batch_decode(output, skip_special_tokens=True)} | ") |
@SunMarc @LysandreJik @ArthurZucker I've adjusted this title to what is really happening here. See the new updated table, basically we "borrow" a little time later on during the first pass to load the weights in, rather than doing so immediately which can load models in much faster and after the 1st pass will still be quick. On CUDA I saw nearly no time changes either, aside from loading the model in 0.185s rather than 2s for llama-3-8B, so that's safe too :) |
So that we can merge this, for now I've kept the old deepspeed behavior in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simple and efficient! I think we need a good torch version (do we still support the 2 year old one?) and can you rebase to make sure failing tests are unrelated?
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
…ingface/transformers into muellerzr-speedup-inference
@amyeroberts reworked the PR description, let me know if everything makes sense now 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful ❤️
@@ -894,32 +895,42 @@ def test_from_pretrained_low_cpu_mem_usage_functional(self): | |||
@require_usr_bin_time | |||
@require_accelerate | |||
@mark.accelerate_tests | |||
def test_from_pretrained_low_cpu_mem_usage_measured(self): | |||
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default | |||
def test_from_pretrained_low_cpu_mem_usage_slower(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be faster?
def test_from_pretrained_low_cpu_mem_usage_slower(self): | |
def test_from_pretrained_low_cpu_mem_usage_faster(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was running the tests, it's a bit slower due to the added hooks IMO (which is fine, as low_cpu_mem_usage=True
is still needed 99% of the time, aka when weights != precision)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @muellerzr,
Thanks for your work! I'm really interested in understanding the changes made on the code level. Can you tell me what is "model precision" and "weights precision". I don't get the difference. I am aware of weights being FP32, BF16 etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model precision relates to the precision the model is initialized in. When using AutoModel
that is the torch_dtype
param.
Weight precision is what the actual weights of the model are stored in.
For instance, with transformers
when we do fine-tuning one should always train with float32
architecture ideally, then use autocast to train in bfloat16
so that users can make the best use of your weights and have stable fine-tuning.
However, models like llama-3-8B
only release bfloat16
weights.
So given this, if we choose to create the architecture in bfloat16
as well (again, in this case ideally just for inference), then we can see this fast speedup as the architecture weights are the same as the pretrained weights.
Does this make sense?
There's a bug on M1 macs with transformer >= 4.43.0 and torch >= 2.1.0, where if a model has tied embeddings, then the fast loading from huggingface#31771 causes a bus error when the model is actually run. This can be solved by disabling `_supports_param_buffer_assignment` for these models. More info in comments in huggingface#33357
There's a bug on M1 macs with transformer >= 4.43.0 and torch >= 2.1.0, where if a model has tied embeddings, then the fast loading from #31771 causes a bus error when the model is actually run. This can be solved by disabling `_supports_param_buffer_assignment` for these models. More info in comments in #33357
There's a bug on M1 macs with transformer >= 4.43.0 and torch >= 2.1.0, where if a model has tied embeddings, then the fast loading from huggingface#31771 causes a bus error when the model is actually run. This can be solved by disabling `_supports_param_buffer_assignment` for these models. More info in comments in huggingface#33357
What does this PR do?
This PR introduces utilizing
_assign_to_params_buffers
as a way to speed up weight loading if the dtypes of models are the same. By doing so, we can lazily load in model weights on the fly when an input is passed in, decreasing the TTL of both training and inference wrt the speed of your disk.The benefit of this is now
low_cpu_mem_usage
and this have ~ the same memory usage if and only if the model weights precision == the loaded model precision.For example, this will only work if you load in
llama-3-8B
inbfloat16
, since the weights and architecture are both inbfloat16
.Unsupported models
Some models also do not support buffer param assignments. I've added a new
_supports_param_buffer_assignment
attr to the specific models that do not, while eventually it'd be good to investigate this if any models fail thetest_from_pretrained_no_checkpoint
tests, they need to set this attribute in their model config (similar to howVisionEncoderDecoderModel
hassupports_gradient_checkpointing=False
).Example model init time:
Example model throughput:
First pass
Afterwards both lazy-loaded and non-lazy loaded inference times are the same (since we no longer need to read the weights in)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@LysandreJik @amyeroberts @SunMarc