Skip to content

Commit

Permalink
FIX / PEFT: Pass device correctly to peft (#30397)
Browse files Browse the repository at this point in the history
pass device correctly to peft
  • Loading branch information
younesbelkada authored Apr 22, 2024
1 parent 13b3b90 commit 367a0db
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import inspect
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from ..utils import (
check_peft_version,
Expand All @@ -25,17 +25,16 @@
)


if is_torch_available():
import torch

if is_accelerate_available():
from accelerate import dispatch_model
from accelerate.utils import get_balanced_memory, infer_auto_device_map

# Minimum PEFT version supported for the integration
MIN_PEFT_VERSION = "0.5.0"

if TYPE_CHECKING:
if is_torch_available():
import torch


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -151,6 +150,15 @@ def load_adapter(
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
)

if "device" not in adapter_kwargs:
device = self.device if not hasattr(self, "hf_device_map") else list(self.hf_device_map.values())[0]
else:
device = adapter_kwargs.pop("device")

# To avoid PEFT errors later on with safetensors.
if isinstance(device, torch.device):
device = str(device)

# We keep `revision` in the signature for backward compatibility
if revision is not None and "revision" not in adapter_kwargs:
adapter_kwargs["revision"] = revision
Expand Down Expand Up @@ -190,7 +198,7 @@ def load_adapter(
self._hf_peft_config_loaded = True

if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, token=token, **adapter_kwargs)
adapter_state_dict = load_peft_weights(peft_model_id, token=token, device=device, **adapter_kwargs)

# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {}
Expand Down

0 comments on commit 367a0db

Please sign in to comment.