-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Low CPU Memory Mode Issues for Quantized Peft (#90)
* address issue 2 in #83 Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * properly handle broadcast of adapters Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * handle param_init_fn_tied_param Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * trl version error Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * tied weights fix and meta fix for autogptq Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * update readme Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fmt + lint Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * upgrade granite benches Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> --------- Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
- Loading branch information
Showing
8 changed files
with
354 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Standard | ||
from collections import defaultdict | ||
|
||
# Third Party | ||
import torch | ||
|
||
# Copyright The IBM Tuning Team | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
# https://spdx.dev/learn/handling-license-info/ | ||
|
||
|
||
def ensure_weights_retied( | ||
param_init_fn, model: torch.nn.Module, device: torch.cuda.device | ||
): | ||
|
||
_tied_names = model._tied_weights_keys | ||
if not _tied_names: | ||
# if no tied names just passthrough | ||
return param_init_fn | ||
|
||
# get map of parameter instances to params. | ||
# - needed for replacement later | ||
_tied_params = {} | ||
for name in _tied_names: | ||
name = name.split(".") | ||
name, param_name = ".".join(name[:-1]), name[-1] | ||
mod = model.get_submodule(name) | ||
param = getattr(mod, param_name) | ||
|
||
_tied_params[id(param)] = None # placeholder for the param first | ||
|
||
# build param_init_fn for the case with tied params | ||
def param_init_fn_tied_param(module: torch.nn.Module): | ||
|
||
# track which params to tie | ||
# - usually only 1, but for completeness consider > 1 | ||
params_to_tie = defaultdict(list) | ||
for n, param in module.named_parameters(recurse=False): | ||
if id(param) in _tied_params: | ||
params_to_tie[id(param)].append(n) | ||
|
||
# call the param init fn, which potentially re-allocates the | ||
# parameters | ||
module = param_init_fn(module) | ||
|
||
# search the parameters again and tie them up again | ||
for id_key, _param_names in params_to_tie.items(): | ||
for param_name in _param_names: | ||
param = _tied_params[id_key] | ||
if param is None: | ||
# everything will be tied to the first time the | ||
# param is observed | ||
_tied_params[id_key] = getattr(module, param_name) | ||
else: | ||
setattr(module, param_name, param) # tie | ||
|
||
return module | ||
|
||
return param_init_fn_tied_param |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.