From 120d5df661b682f916337a730a1c0597ce7cc318 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 20 Nov 2024 16:01:49 +0100 Subject: [PATCH 1/2] fix device map --- src/transformers/models/blip/modeling_blip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index b623d2a8adb1..88835f71a3e0 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -464,6 +464,7 @@ class BlipPreTrainedModel(PreTrainedModel): config_class = BlipConfig base_model_prefix = "blip" supports_gradient_checkpointing = True + _no_split_modules = ["BlipEncoderLayer"] def _init_weights(self, module): """Initialize the weights""" From ed9f05e3a6dc6e45b66264ee7172a49d0b1ba40f Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 12 Dec 2024 16:05:48 +0100 Subject: [PATCH 2/2] fix offloading + model parallel test --- src/transformers/models/blip/modeling_blip.py | 6 ++++-- src/transformers/models/blip/modeling_blip_text.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 88835f71a3e0..599d35bd59ad 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -464,7 +464,8 @@ class BlipPreTrainedModel(PreTrainedModel): config_class = BlipConfig base_model_prefix = "blip" supports_gradient_checkpointing = True - _no_split_modules = ["BlipEncoderLayer"] + _no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"] + _skip_keys_device_placement = ["past_key_value"] def _init_weights(self, module): """Initialize the weights""" @@ -1010,7 +1011,8 @@ def forward( text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits - logit_scale = self.logit_scale.exp() + logit_scale = self.logit_scale.exp().to(device=text_embeds.device) + image_embeds = image_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype) logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_per_image = logits_per_text.t() diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 97a4f523380b..db8ad939725a 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -82,7 +82,6 @@ def forward( position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] if inputs_embeds is None: - input_ids = input_ids.to(self.word_embeddings.weight.device) inputs_embeds = self.word_embeddings(input_ids) embeddings = inputs_embeds