Skip to content

Commit

Permalink
return load_result when load_adapter (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkqkxx authored Jun 1, 2023
1 parent 38f48dd commit 668f045
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs):
filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# load the weights into the model
set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
if (
(getattr(self, "hf_device_map", None) is not None)
and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
Expand Down Expand Up @@ -442,6 +442,7 @@ def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs):

# Set model in evaluation mode to deactivate Dropout modules by default
self.eval()
return load_result

def set_adapter(self, adapter_name):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul
else:
raise NotImplementedError

model.load_state_dict(peft_model_state_dict, strict=False)
load_result = model.load_state_dict(peft_model_state_dict, strict=False)
if isinstance(config, PromptLearningConfig):
model.prompt_encoder[adapter_name].embedding.load_state_dict(
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
)
return load_result

0 comments on commit 668f045

Please sign in to comment.