Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify resize_token_embeddings to ensure output type is same as input #31979

Merged
merged 4 commits into from
Jul 23, 2024

Conversation

bayllama
Copy link
Contributor

What does this PR do?

Modified resize_token_embeddings to return the same class that is passed as input to it. Today, even if a custom embedding class is passed, resize_token_embeddings converts it to a nn.Embedding but this commit makes sure that does not happen and the custom embedding class is returned.

Fixes # (31835)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @bayllama !

One thing to note is that the newly created Embedding will not have the same embed scale as the old one, because we don't pass embed_scale at creation and the default is 1.0.

What if we don't rely on what type of embedding class is being used and modify the weights of the old_embedding in-place and return it as new_embeddings. Something like this added at the end before returning

old_embeddings.weight.data = new_embeddings.weight.data
return old_embeddings

@amyeroberts WDYT of this idea? It doesn't break BC and makes ModelScaledEmbedding happy

@bayllama
Copy link
Contributor Author

@zucchini-nlp
That's a good idea. I have made the change you recommended.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and nice suggestion @zucchini-nlp!

@bayllama
Copy link
Contributor Author

@zucchini-nlp @amyeroberts Added the comment describing the change

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks for working on this! I'll make sure that tests are passing and it can be merged later today

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member

@bayllama the tests are failing because we pass device and dtype to the Embedding Class when resizing, but the custom classes don't accept any kwargs. Can you add **kwargs plz to init in the correct device and dtype?

You should be able to verify all tests are passing by this command :)

pytest -k test_resize_token tests/models/

@bayllama
Copy link
Contributor Author

@zucchini-nlp

When I run the pytest that you suggested I don't see the error that you have mentioned above. We don't set the device and dtype for the Custom Embedding Class at all right?

new_embeddings = nn.Embedding(
            new_num_tokens,
            old_embedding_dim,
            device=old_embeddings.weight.device,
            dtype=old_embeddings.weight.dtype,
        )

We only do it for the nn.Embedding and the replace the weight.data in the old embeddings which is already in the correct device and has the right dtype.

However I do see another issue which is in hte resize_token_embeddings, this attribute "text_config" is being modified and I haven't accounted for that in my previous commit. I will work on a fix for this. However can you please give me more details on the dtype and device issue, because I don't see it in my set of tests.

@zucchini-nlp
Copy link
Member

@bayllama my bad, I was running in the a different branch. You're right, some VLMs are failing, as I see it should be fixed with one line after we swapped weights! Let me know when it's fixed :)

@bayllama
Copy link
Contributor Author

@zucchini-nlp Changing the shape of the old_embeddings would fix this, however shape attribute is not writable in torch, hence we cannot do something like this,

old_embeddings.weight.shape = new_embeddings.weight.shape

Hence I am thinking of what would be the best way to do this. If nothing works out I may need to go to the method I was following in my first PR where I check the type of input embedding and create a new object and return it.

@zucchini-nlp
Copy link
Member

@bayllama I see, but we can also change the num_embeddings attribute which is writable.
Like old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]

@bayllama bayllama force-pushed the progress branch 3 times, most recently from 240bc5e to c023ebb Compare July 20, 2024 04:07
@bayllama
Copy link
Contributor Author

@zucchini-nlp In addition to the above I found a couple of more things,

  1. The padding_idx must be updated to None if the number of tokens in the new embeddings is smaller than the padding_idx.

  2. The model Lxmert has a bias term that has to be updated similar to the final_logits_bias in mbart. This if not updated would cause some failures in test cases. This problem has been there from before and I believe was not addressed.

In addition to what we already discussed about, I have made these changes in this commit as well. All of the test cases are passing for me now.

@bayllama bayllama force-pushed the progress branch 2 times, most recently from 799299e to 9ebcd03 Compare July 21, 2024 16:21
@bayllama
Copy link
Contributor Author

@zucchini-nlp I am not sure what this tests_hub failure is. Could you please help out

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Some final nits and we need to merge main as a final step, then we're good to go.

The failing tests are not related to this PR and should be resolved by rerunning them.

Comment on lines 703 to 711
def _resize_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.bias.shape[0]
if new_num_tokens <= old_num_tokens:
new_bias = self.bias[:new_num_tokens]
else:
extra_bias = torch.zeros(new_num_tokens - old_num_tokens, device=self.bias.device)
new_bias = torch.cat([self.bias, extra_bias])
self.bias = nn.Parameter(new_bias)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good. Imo this should be moved to the LxmertForPretraining, as this method is not going to be used by the head itself.

@bayllama
Copy link
Contributor Author

@zucchini-nlp Made the changes that you have recommended. Please take a look. Thanks!

@zucchini-nlp
Copy link
Member

@bayllama thanks, everything looks good! Rerunning the tests didn't help, so can you merge main in case it was already fixed there

@bayllama bayllama force-pushed the progress branch 2 times, most recently from 9330df2 to 98819cb Compare July 22, 2024 15:09
@zucchini-nlp
Copy link
Member

@amyeroberts can you merge this plz? Unrelated hub failures, from internal slack seems like we're not the only ones seeing it

@amyeroberts amyeroberts self-requested a review July 22, 2024 15:20
@amyeroberts
Copy link
Collaborator

@zucchini-nlp There's been a fix push to main. @bayllama could you try rebasing?

I've also re-requested review as there's been several commits since my approval

@bayllama
Copy link
Contributor Author

@zucchini-nlp @amyeroberts Seems like the tests went through after rebasing. Let me know if anything else is required here.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks @bayllama!

One last thing I forgot to ask for before merge is to add a test. There's already tests for resizing embeddings, so extending those to check the returned type should be enough

Could you also update the title so it's no longer truncated?

@bayllama bayllama changed the title Change resize_token_embeddings to make it return same Class that is p… Modify resize_token_embeddings to ensure output type is same as input Jul 23, 2024
@bayllama
Copy link
Contributor Author

@amyeroberts @zucchini-nlp Added the test to make sure the correct type is returned. Also made the other changes suggested.

@amyeroberts amyeroberts merged commit 5a4a76e into huggingface:main Jul 23, 2024
23 checks passed
@seokhyunan
Copy link

This PR causes model.resize_token_embeddings to set vocab_size to zero (check this thread). Reverting this PR resolved the issue. Could you help with this?

@bayllama
Copy link
Contributor Author

@zucchini-nlp @amyeroberts I believe zucchini-nlp already pushed in a fix for this. Let me know if I need to do something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants