Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical inaccuracy in unpad_image (LlavaOnevison) #33531

Closed
2 of 4 tasks
dom-dziela opened this issue Sep 17, 2024 · 6 comments · Fixed by #33564
Closed
2 of 4 tasks

Numerical inaccuracy in unpad_image (LlavaOnevison) #33531

dom-dziela opened this issue Sep 17, 2024 · 6 comments · Fixed by #33564
Labels

Comments

@dom-dziela
Copy link
Contributor

dom-dziela commented Sep 17, 2024

System Info

System Info:

  • transformers version: 4.45.0.dev0
  • Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.35
  • Python version: 3.11.0
  • Huggingface_hub version: 0.24.7
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@amyeroberts, @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

In unpad_image we found a numerical inaccuracy, if original_aspect_ratio==current_aspect_ratio. Which occurs in DocVQA on training sample 32673. See for example the snippet below:

original_size = torch.tensor([2136,3212], device = "cuda:0", dtype = torch.bfloat16)
original_height, original_width = original_size
current_height, current_width = 108, 162

original_aspect_ratio = original_width / original_height #tensor(1.5000)
current_aspect_ratio = current_width / current_height #1.5

scale_factor = current_height / original_height
new_width = int(original_width * scale_factor) # 163

Testing showed, if orignal_height and original_width are integers, that this inaccuracy does not occur.

In die docstring the unpad function asks to be original_size to be a tuple (no type annotation tho), however it will always get a torch.tensor.

"""
Args:
            image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
                List of image feature tensor, each contains all the visual feature of all patches.
            image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
                Actual image size of each images (H, W)."""
.
.
.
image_feature = unpad_image(image_feature, image_sizes[image_idx])

Expected behavior

The new_width value shoud be 162. You can see that, if you write down the formula for the aspect ratios, equal them, and multiply by current_height, then you have original_width*scaling_factor=current_width(=new_width).

PS My first issue ever, have patience please.

@dom-dziela dom-dziela added the bug label Sep 17, 2024
@amyeroberts
Copy link
Collaborator

cc @zucchini-nlp

@zucchini-nlp
Copy link
Member

Hey @dom-dziela !

Thanks for reporting the issue. Yes, we already had to introduce some changes to account for that, because padding/patching in the preprocessing stage uses integers and that cause similar errors. For ex here we enforce list format before getting the best resolution.

I think we should do the same conversion in unpad_images. Would you like to open a PR to fix this? 🤗

# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
)
image_size = image_size.tolist()

[ I will later have to figure out a better way especially to satusfy compile, but that is a long-term plan ]

@hlky
Copy link
Member

hlky commented Sep 17, 2024

Are you explicitly casting inputs['image_sizes'] to bfloat16 when calling generate/forward? LlavaOnevisionProcessor produces int64 and calling to on BatchFeature only applies to floating point

inputs = processor(text=prompt, images=raw_image, return_tensors='pt').to(0, torch.bfloat16)
inputs['image_sizes'], inputs['image_sizes'].dtype
>>> (tensor([[2136, 3212]], device='cuda:0'), torch.int64)
original_size = torch.tensor([2136,3212], device = "cuda:0", dtype = torch.bfloat16)
original_height, original_width = original_size
current_height, current_width = 108, 162
original_width / original_height
>>> tensor(1.5000, device='cuda:0', dtype=torch.bfloat16)
original_size = torch.tensor([2136,3212], device = "cuda:0")
original_height, original_width = original_size
current_height, current_width = 108, 162
original_width / original_height
>>> tensor(1.5037, device='cuda:0')

@dom-dziela
Copy link
Contributor Author

Thanks for the replies.
@hlky My bad, idk why I forced original_size to be bfloat16 in my example, it is indeed int64 per default, confused tensors probably. If I test my example on my local machine, with dtype=torch.int64 everything works out fine, on the GPU-Server the error remains with the default datatype torch.int64. :

(original_height,original_height.dtype), (original_width, original_width.dtype)
>>> ((tensor(2136, device='cuda:0'), torch.int64), (tensor(3212, device='cuda:0'), torch.int64))
original_aspect_ratio = original_width / original_height
original_aspect_ratio
>>> tensor(1.5000, device='cuda:0')

I would guess this in then more likely a version thing? However the solution image_size.tolist() would work in this case.

@zucchini-nlp
Copy link
Member

@hlky Thanks for pointing to the issue!

@dom-dziela Am I correct that the suggested solution fixes the bug but only with certain hardware/torch-versions? I am okay with adding the tolist() as it is already present in similar function used to postprocess image embeddings. If you can confirm the bug exists and in which conditions, feel free to open a PR :)

@dom-dziela
Copy link
Contributor Author

I would assume that your suggested solution would fix that particular bug in any hardware/torch-version combination, since it detatches the calculation from torch and only uses base python.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants