-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify resize_token_embeddings to ensure output type is same as input #31979
Conversation
d8682a7
to
441ed3b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @bayllama !
One thing to note is that the newly created Embedding will not have the same embed scale as the old one, because we don't pass embed_scale
at creation and the default is 1.0.
What if we don't rely on what type of embedding class is being used and modify the weights of the old_embedding
in-place and return it as new_embeddings
. Something like this added at the end before returning
old_embeddings.weight.data = new_embeddings.weight.data
return old_embeddings
@amyeroberts WDYT of this idea? It doesn't break BC and makes ModelScaledEmbedding
happy
@zucchini-nlp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing and nice suggestion @zucchini-nlp!
@zucchini-nlp @amyeroberts Added the comment describing the change |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks for working on this! I'll make sure that tests are passing and it can be merged later today
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@bayllama the tests are failing because we pass You should be able to verify all tests are passing by this command :) pytest -k test_resize_token tests/models/ |
When I run the pytest that you suggested I don't see the error that you have mentioned above. We don't set the device and dtype for the Custom Embedding Class at all right?
We only do it for the nn.Embedding and the replace the weight.data in the old embeddings which is already in the correct device and has the right dtype. However I do see another issue which is in hte resize_token_embeddings, this attribute "text_config" is being modified and I haven't accounted for that in my previous commit. I will work on a fix for this. However can you please give me more details on the dtype and device issue, because I don't see it in my set of tests. |
@bayllama my bad, I was running in the a different branch. You're right, some VLMs are failing, as I see it should be fixed with one line after we swapped weights! Let me know when it's fixed :) |
@zucchini-nlp Changing the shape of the old_embeddings would fix this, however shape attribute is not writable in torch, hence we cannot do something like this,
Hence I am thinking of what would be the best way to do this. If nothing works out I may need to go to the method I was following in my first PR where I check the type of input embedding and create a new object and return it. |
@bayllama I see, but we can also change the |
240bc5e
to
c023ebb
Compare
@zucchini-nlp In addition to the above I found a couple of more things,
In addition to what we already discussed about, I have made these changes in this commit as well. All of the test cases are passing for me now. |
799299e
to
9ebcd03
Compare
@zucchini-nlp I am not sure what this tests_hub failure is. Could you please help out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Some final nits and we need to merge main
as a final step, then we're good to go.
The failing tests are not related to this PR and should be resolved by rerunning them.
def _resize_bias(self, new_num_tokens: int) -> None: | ||
old_num_tokens = self.bias.shape[0] | ||
if new_num_tokens <= old_num_tokens: | ||
new_bias = self.bias[:new_num_tokens] | ||
else: | ||
extra_bias = torch.zeros(new_num_tokens - old_num_tokens, device=self.bias.device) | ||
new_bias = torch.cat([self.bias, extra_bias]) | ||
self.bias = nn.Parameter(new_bias) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good. Imo this should be moved to the LxmertForPretraining, as this method is not going to be used by the head itself.
@zucchini-nlp Made the changes that you have recommended. Please take a look. Thanks! |
@bayllama thanks, everything looks good! Rerunning the tests didn't help, so can you merge |
9330df2
to
98819cb
Compare
@amyeroberts can you merge this plz? Unrelated hub failures, from internal slack seems like we're not the only ones seeing it |
@zucchini-nlp There's been a fix push to main. @bayllama could you try rebasing? I've also re-requested review as there's been several commits since my approval |
@zucchini-nlp @amyeroberts Seems like the tests went through after rebasing. Let me know if anything else is required here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks @bayllama!
One last thing I forgot to ask for before merge is to add a test. There's already tests for resizing embeddings, so extending those to check the returned type should be enough
Could you also update the title so it's no longer truncated?
@amyeroberts @zucchini-nlp Added the test to make sure the correct type is returned. Also made the other changes suggested. |
This PR causes |
@zucchini-nlp @amyeroberts I believe zucchini-nlp already pushed in a fix for this. Let me know if I need to do something. |
What does this PR do?
Modified resize_token_embeddings to return the same class that is passed as input to it. Today, even if a custom embedding class is passed, resize_token_embeddings converts it to a nn.Embedding but this commit makes sure that does not happen and the custom embedding class is returned.
Fixes # (31835)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@zucchini-nlp