Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights of BlipModel are not initialized from the model checkpoint #25024

Closed
2 of 4 tasks
Vibhu04 opened this issue Jul 23, 2023 · 7 comments
Closed
2 of 4 tasks

Weights of BlipModel are not initialized from the model checkpoint #25024

Vibhu04 opened this issue Jul 23, 2023 · 7 comments

Comments

@Vibhu04
Copy link

Vibhu04 commented Jul 23, 2023

System Info

  • transformers version: 4.31.0.dev0
  • Platform: Linux-5.15.0-76-generic-x86_64-with-debian-bullseye-sid
  • Python version: 3.7.15
  • Huggingface_hub version: 0.15.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 1.13.1+cu117 (True)
  • Tensorflow version (GPU?): 2.11.0 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@younesbelkada @ArthurZucker @amyeroberts @ydshieh

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from PIL import Image
import requests
from transformers import AutoProcessor, BlipModel

model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(
    text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
)

outputs = model(**inputs)
logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1)

Expected behavior

The code snippet is an example from https://huggingface.co/docs/transformers/model_doc/blip#transformers.BlipProcessor.

The warning that I get is:

Some weights of BlipModel were not initialized from the model checkpoint at Salesforce/blip-image-captioning-base and are newly initialized: ['text_model.encoder.layer.10.crossattention.output.dense.weight', 'text_model.encoder.layer.4.attention.output.LayerNorm.bias', 'text_model.encoder.layer.2.intermediate.dense.bias', 'text_model.encoder.layer.1.attention.self.value.bias', 'text_model.encoder.layer.5.attention.output.LayerNorm.bias', 'text_model.encoder.layer.2.attention.output.dense.bias', 'text_model.encoder.layer.1.crossattention.self.key.weight', 'text_model.encoder.layer.5.crossattention.self.key.bias', 'text_model.encoder.layer.11.crossattention.output.LayerNorm.bias', 'text_model.encoder.layer.1.attention.self.value.weight', 'text_model.encoder.layer.8.attention.self.key.weight', 'text_model.encoder.layer.9.crossattention.output.dense.bias', 'text_model.encoder.layer.7.crossattention.self.key.bias', 'text_model.encoder.layer.1.attention.output.dense.bias', 'text_model.encoder.layer.8.output.LayerNorm.bias', ...

It seems that the model weights are being initialised anew as there's some error with loading the pre-trained weights. Please guide me in solving this issue.

@sgugger
Copy link
Collaborator

sgugger commented Jul 24, 2023

Also cc @ydshieh who was just discussing this internally :-)

@younesbelkada
Copy link
Contributor

Hi @Vibhu04
Thanks for the issue,
indeed there is a problem with BlipModel classes. Note that BlipModel would stand for the "pre-trained" versions of Blip to extract raw logits / hidden states from text and vision input. That class has been copied from CLIPModel class and needs a careful refactoring to be able to reproduce the correct pre-trained Blip models: https://github.com/salesforce/BLIP/blob/main/models/blip_pretrain.py#L112-L136 .
Even after the refactoring one would need to convert the pre-trained BLIP weights as they are different from existing weights on the Hub + they contain additional modules.
I can put that on my TODO but cannot give an accurate ETA, for now if you want to use Blip as a model to retrieve hidden states and logits, I would advise you to use BlipForConditionalGeneration

@Vibhu04
Copy link
Author

Vibhu04 commented Jul 25, 2023

Hi @younesbelkada, thanks a lot for your prompt reply. I actually want to compute the image-text similarity score given an input image and a text, and I was hoping I could use BlipModel for that. Would there be a way of achieving this using BlipForConditionalGeneration? If not, is there any other Blip model class that I could use for this purpose?
Thanks a lot.

@younesbelkada
Copy link
Contributor

Thanks for your reply @Vibhu04
For computing image and text similarity score, I would advise you to use the ITM (image text matching) models: https://huggingface.co/Salesforce/blip-itm-base-coco

import requests
from PIL import Image
from transformers import BlipProcessor, BlipForImageTextRetrieval

processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")

img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' 
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

question = "A woman and a dog sitting together in a beach."
inputs = processor(raw_image, question, return_tensors="pt")

itm_scores = model(**inputs)[0]
cosine_score = model(**inputs, use_itm_head=False)[0]

@Vibhu04
Copy link
Author

Vibhu04 commented Jul 25, 2023

Hi @younesbelkada, thank you so much. If I may, I just have one last question: is there a lighter variant (i.e. fewer parameters) of the model that you mentioned? Thanks a lot.

@younesbelkada
Copy link
Contributor

Hi @Vibhu04
Thanks a lot, hm, to the best of my knowledge the smallest model of that kind is: https://huggingface.co/Salesforce/blip-itm-base-coco - however you can run them in half-precision to reduce their memory footprint by 2:

import requests
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForImageTextRetrieval

processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", torch_dtype=torch.bfloat16)

img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' 
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

question = "A woman and a dog sitting together in a beach."
inputs = processor(raw_image, question, return_tensors="pt").to(torch.bfloat16)

itm_scores = model(**inputs)[0]
cosine_score = model(**inputs, use_itm_head=False)[0]

@Vibhu04
Copy link
Author

Vibhu04 commented Jul 25, 2023

Thank you so much for your help @younesbelkada!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants