Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA #29587

Merged
merged 27 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
78d6dcc
fsdp+qlora related changes
pacman100 Jan 19, 2024
2f51a20
fixes
pacman100 Jan 19, 2024
5df8a65
Update quantization_config.py
pacman100 Jan 19, 2024
595a16b
Merge branch 'main' into smangrul/fsdp-qlora-support
pacman100 Mar 6, 2024
00312cf
support fsdp+qlora and dsz3+qlora
pacman100 Mar 6, 2024
4511a2c
Update quantization_config.py
pacman100 Mar 6, 2024
574b371
Update modeling_utils.py
pacman100 Mar 6, 2024
b195b28
Update modeling_utils.py
pacman100 Mar 6, 2024
5081937
Update modeling_utils.py
pacman100 Mar 6, 2024
3da40ee
Update modeling_utils.py
pacman100 Mar 6, 2024
78e06e6
Update modeling_utils.py
pacman100 Mar 6, 2024
32f8c83
Update modeling_utils.py
pacman100 Mar 6, 2024
1401b73
handle fsdp+qlora and dsz3+qlora correctly while model loading
pacman100 Mar 11, 2024
3c56930
fix param count
pacman100 Mar 11, 2024
4840515
quality
pacman100 Mar 11, 2024
bef438b
fsdp related changes
pacman100 Mar 11, 2024
4a6596b
Merge branch 'main' into smangrul/fsdp-qlora-support
pacman100 Mar 11, 2024
ac6ddec
fsdp changes only when using LoRA/QLoRA
pacman100 Mar 11, 2024
e554934
add accelerate version check
pacman100 Mar 12, 2024
4c82852
refactor, update min accelerate version and add tests
pacman100 Mar 12, 2024
a43d49d
fix test
pacman100 Mar 12, 2024
f5fc519
Address comments
pacman100 Mar 13, 2024
6973569
fix the conditional flag
pacman100 Mar 13, 2024
7cde578
Merge branch 'main' into smangrul/fsdp-qlora-support
pacman100 Mar 13, 2024
c40c767
fix conditional flag
pacman100 Mar 13, 2024
73bda72
Merge branch 'main' into smangrul/fsdp-qlora-support
pacman100 Mar 13, 2024
6f1eb11
address comments
pacman100 Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib.metadata
import warnings
from copy import deepcopy
from inspect import signature

from packaging import version

Expand Down Expand Up @@ -179,13 +180,19 @@ def _replace_with_bnb_linear(
):
pass
else:
extra_kwargs = (
{"quant_storage": quantization_config.bnb_4bit_quant_storage}
if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
else {}
)
model._modules[name] = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
**extra_kwargs,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
Expand Down
60 changes: 38 additions & 22 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
prune_linear_layer,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion
from .utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
Expand Down Expand Up @@ -496,7 +497,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)


def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
"""
Expand All @@ -512,8 +513,9 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
return safe_load_file(checkpoint_file)
try:
if (
is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0
) or (is_fsdp_enabled() and not is_local_dist_rank_0()):
(is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0)
or (is_fsdp_enabled() and not is_local_dist_rank_0())
) and not is_quantized:
map_location = "meta"
else:
map_location = "cpu"
Expand Down Expand Up @@ -718,6 +720,7 @@ def _load_state_dict_into_meta_model(

old_keys = []
new_keys = []
is_quantized = hf_quantizer is not None
for key in state_dict.keys():
new_key = None
if "gamma" in key:
Expand Down Expand Up @@ -797,14 +800,22 @@ def _load_state_dict_into_meta_model(
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif (
hf_quantizer is None
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict))
):
# For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name)
value = type(value)(value.data.to("cpu"), **value.__dict__)
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return

return error_msgs, offload_index, state_dict_index
Expand Down Expand Up @@ -1070,7 +1081,9 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
# For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
# used for the 4bit quantization (uint8 tensors are stored)
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
total_numel.append(param.numel() * 2)
total_numel.append(
param.numel() * 2 * self.hf_quantizer.quantization_config.bnb_4bit_quant_storage.itemsize
)
else:
total_numel.append(param.numel())

Expand Down Expand Up @@ -1805,10 +1818,11 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
old_embeddings_requires_grad = old_embeddings.weight.requires_grad
new_embeddings.requires_grad_(old_embeddings_requires_grad)
self.set_input_embeddings(new_embeddings)
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None

# Update new_num_tokens with the actual size of new_embeddings
if pad_to_multiple_of is not None:
if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
Expand Down Expand Up @@ -1882,7 +1896,8 @@ def _get_resized_embeddings(
if new_num_tokens is None:
return old_embeddings

if is_deepspeed_zero3_enabled():
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
Expand Down Expand Up @@ -1921,7 +1936,7 @@ def _get_resized_embeddings(
# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)

if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

params = [old_embeddings.weight, new_embeddings.weight]
Expand Down Expand Up @@ -1958,7 +1973,8 @@ def _get_resized_lm_head(
if new_num_tokens is None:
return old_lm_head

if is_deepspeed_zero3_enabled():
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason we couldn't have a is_quantized property of the model, which is by default False? Having to constantly define is_quantized within the methods isn't ideal, as it requires updating in many different places if the criteria change

Copy link
Contributor Author

@pacman100 pacman100 Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have property hf_quantizer for the model, storing is_quantized would duplicate the same information. Earlier, I was directly using self.hf_quantizer is None in checks but there were suggestions above to improve readability using is_quantized, and as such, I made the changes accordingly.

if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
Expand Down Expand Up @@ -2000,7 +2016,7 @@ def _get_resized_lm_head(

num_tokens_to_copy = min(old_num_tokens, new_num_tokens)

if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
Expand Down Expand Up @@ -3036,6 +3052,7 @@ def from_pretrained(
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.")
is_quantized = hf_quantizer is not None

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
Expand Down Expand Up @@ -3365,7 +3382,7 @@ def from_pretrained(
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]

if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
Expand Down Expand Up @@ -3564,7 +3581,8 @@ def from_pretrained(
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **device_map_kwargs)
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
dispatch_model(model, **device_map_kwargs)

if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
Expand Down Expand Up @@ -3610,6 +3628,7 @@ def _load_pretrained_model(
keep_in_fp32_modules=None,
):
is_safetensors = False
is_quantized = hf_quantizer is not None

if device_map is not None and "disk" in device_map.values():
archive_file = (
Expand Down Expand Up @@ -3735,7 +3754,7 @@ def _fix_key(key):
if param.device == torch.device("meta"):
value = torch.empty(*param.size(), dtype=target_dtype)
if (
hf_quantizer is None
not is_quantized
or getattr(hf_quantizer, "requires_parameters_quantization", False)
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
Expand Down Expand Up @@ -3765,7 +3784,7 @@ def _fix_key(key):
else:
not_initialized_submodules = dict(model.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

not_initialized_parameters = list(
Expand Down Expand Up @@ -3909,7 +3928,7 @@ def _find_mismatched_keys(
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
continue
state_dict = load_state_dict(shard_file)
state_dict = load_state_dict(shard_file, is_quantized=is_quantized)

# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
Expand All @@ -3922,15 +3941,12 @@ def _find_mismatched_keys(
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
if is_fsdp_enabled() and not is_local_dist_rank_0():
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if hf_quantizer is None:
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
hf_quantizer.create_quantized_param(model, param, key, "cpu", state_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I've understood - we don't need this anymore as we move creating quantized params when loading in the state dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_quantizer is None, i.e., not quantized will always be False as the conditional logic if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized has this and as such this inner conditional is no longer required. The quantized parameters are initilaized in the else logic on line 3950

set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,7 @@ def _inner_training_loop(

if delay_optimizer_creation:
if use_accelerator_prepare:
self._fsdp_qlora_plugin_updates()
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

Expand Down Expand Up @@ -4156,3 +4157,20 @@ def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)

def _fsdp_qlora_plugin_updates(self):
if self.is_fsdp_enabled and _is_peft_model(self.model):
from peft import LoraConfig
from peft.utils.other import fsdp_auto_wrap_policy

if isinstance(self.model.active_peft_config, LoraConfig):
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
if (
getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
and version.parse(accelerate_version) > version.parse("0.27.0")
):
fsdp_plugin.set_mixed_precision(
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
)
6 changes: 4 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,8 +1721,10 @@ def __post_init__(self):
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
os.environ[f"{prefix}SHARDING_STRATEGY"] = (
str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
if is_accelerate_available("0.26.0")
else fsdp_option.upper()
)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
This flag is used for nested quantization where the quantization constants from the first quantization are
quantized again.
bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
This sets the storage type to pack the quanitzed 4-bit prarams.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
"""
Expand All @@ -240,6 +242,7 @@ def __init__(
bnb_4bit_compute_dtype=None,
bnb_4bit_quant_type="fp4",
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_storage=None,
**kwargs,
):
self.quant_method = QuantizationMethod.BITS_AND_BYTES
Expand All @@ -265,6 +268,15 @@ def __init__(
else:
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")

if bnb_4bit_quant_storage is None:
self.bnb_4bit_quant_storage = torch.uint8
elif isinstance(bnb_4bit_quant_storage, str):
self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
elif isinstance(bnb_4bit_quant_storage, torch.dtype):
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
else:
raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")

self.post_init()

@property
Expand Down Expand Up @@ -345,6 +357,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
output = copy.deepcopy(self.__dict__)
output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1]
output["load_in_4bit"] = self.load_in_4bit
output["load_in_8bit"] = self.load_in_8bit

Expand Down
39 changes: 39 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import itertools
import os
import unittest
from copy import deepcopy
from functools import partial

from parameterized import parameterized
Expand Down Expand Up @@ -171,6 +172,44 @@ def test_fsdp_config(self, sharding_strategy, dtype):
self.assertEqual(v, self.fsdp_config[k])
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")

@parameterized.expand(params, name_func=_parameterized_custom_name_func)
def test_fsdp_config_transformers_auto_wrap(self, sharding_strategy, dtype):
output_dir = self.get_auto_remove_tmp_dir()
fsdp_config = deepcopy(self.fsdp_config)
del fsdp_config["min_num_params"]
fsdp_config["transformer_layer_cls_to_wrap"] = "BertLayer"
kwargs = {
"output_dir": output_dir,
"train_len": 128,
"save_steps": 5,
"learning_rate": 0.1,
"fsdp": f"{sharding_strategy} offload auto_wrap",
"fsdp_config": fsdp_config,
}
kwargs[dtype] = True
prefix = "FSDP_"
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
self.assertEqual(trainer.args.fsdp[0], sharding_strategy)
self.assertEqual(trainer.args.fsdp[1], FSDPOption.OFFLOAD)
self.assertEqual(trainer.args.fsdp[2], FSDPOption.AUTO_WRAP)
fsdp_sharding_strategy = (
str(FSDP_SHARDING_STRATEGY.index(sharding_strategy.upper()) + 1)
if is_accelerate_available("0.26.0")
else sharding_strategy.upper()
)
self.assertEqual(os.environ[f"{prefix}SHARDING_STRATEGY"], fsdp_sharding_strategy)
self.assertEqual(os.environ[f"{prefix}OFFLOAD_PARAMS"], "true")
self.assertEqual(os.environ[f"{prefix}AUTO_WRAP_POLICY"], "TRANSFORMER_BASED_WRAP")
self.assertEqual(
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"], ",".join(fsdp_config["transformer_layer_cls_to_wrap"])
)
self.assertEqual(os.environ[f"{prefix}BACKWARD_PREFETCH"], fsdp_config["backward_prefetch"].upper())
self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"])
self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"])
self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"])
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")

@parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_accelerator
@slow
Expand Down
Loading