Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paligemma- fix devices and dtype assignments #31008

Merged
merged 2 commits into from
May 24, 2024
Merged

Conversation

molbap
Copy link
Contributor

@molbap molbap commented May 24, 2024

What does this PR do?

Moves tensors to correct devices in case of multi-gpu training on accelerate and device_map = auto.
Additionally ensures bf16 training works as well.

Fixes #30997

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the quick fix!

@molbap
Copy link
Contributor Author

molbap commented May 24, 2024

cc @ArthurZucker wdyt?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, will ping offline our accelerate experts I want to understand a bit better what's going on + why our tests did not catch this!

Comment on lines 337 to 339
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :] == 0, 0
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one does not make sense to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the masked_fill, you need both tensors to be on the same device, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I read to fast.
So token type ids's device is not correctly inferred ? Or are we not creating the causal mask on the correct device? It should be created on the input or attention mask's device for consistency, since when it's used accelerate will transfer it accordingly I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right! will update to move to device at creation time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think we are setting the causal_mask to the correct device. It's the token_type_id device that is indeed not correctly inferred

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But from the comment of @SunMarc I would suspect both devices to be the same no

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, doesn't seem like it, tried on a multi-gpu env with device_map to auto and got
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc if you have an idea here - token_type_ids is created by the processor along with input_ids and passed to the forward normally

Copy link
Member

@SunMarc SunMarc May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, from the code and the image you shared, I see that token_type_ids is indeed on the same device as input_ids. However, since you created the causal_mask to be on the same device as inputs_embeds.device, token_type_ids and input_ids might not be on the same device.

            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )

where
dtype, device = inputs_embeds.dtype, inputs_embeds.device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, thanks! in that case, we can keep it as it is? The other way is to create the causal mask on the input_ids.device, I'm not sure if one is better than the other - inputs_embeds is much larger in general

@grahamannett
Copy link
Contributor

fwiw most of the lines in here are nearly identical to the changes i have done locally as well besides the final_embedding related one which i believe can be done with only 1 cast but didnt think too deeply about it

@molbap
Copy link
Contributor Author

molbap commented May 24, 2024

@grahamannett , good to know. For final_embedding it's also to fix the bf16 dtype mismatch.

@molbap molbap merged commit bdb9106 into main May 24, 2024
22 checks passed
@molbap molbap deleted the paligemma_fix_bf16_multigpu branch May 24, 2024 17:02
ArthurZucker pushed a commit that referenced this pull request May 30, 2024
* fix devices and dtype assignments

* [run-slow]paligemma
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

loss calculation for PaliGemmaForConditionalGeneration potentially not cast to correct device
6 participants