Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal, open for discussion] Better way of extracting hidden states #27873

Closed
wants to merge 2 commits into from

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented Dec 6, 2023

What does this PR do?

Currently, our AutoBackbone classes allow to get specific feature maps out of a certain vision model. For example:

from transformers import ConvNextBackbone
import torch

model = ConvNextBackbone.from_pretrained("facebook/convnext-small-224", out_indices=[0,1,2,3])

pixel_values = torch.randn(1, 3, 224, 224)

feature_maps = model(pixel_values)
for i in feature_maps:
   print(i.shape)

However, they currently extract all intermediate hidden states, store them in memory, and return the ones required by the user. This is not efficient, we should store only activations required by the user in memory.

This current PR proposes to only return the hidden states specified by config.out_indices when the user sets output_hidden_states=True. However, this is not backwards compatible (as by default we do return all hidden states). So I'm open for suggestions on how we could improve this. Alternatively, we could make it backwards compatible by setting out_indices to all stages by default.

I think this could be an argument that is part of all configs, or at least vision encoders, which typically only require certain hidden states to be extracted.

Curious to hear opinions of @ArthurZucker @amyeroberts

@ArthurZucker
Copy link
Collaborator

Yep I like this optimisation, non breaking overall

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Dec 7, 2023

@ArthurZucker it is a breaking change in its current state, since out_indices currently defaults to the last stage index if the user doesn't specify them (think @amyeroberts added that here). So if we were to add this with backwards compatibility, we would have to update the default out_indices to all stages in case they are not specified.

@ArthurZucker
Copy link
Collaborator

We can set it to -1 to return everything maybe but I mean we can make it BC!

@NielsRogge
Copy link
Contributor Author

I'd like to have @amyeroberts's opinion on this one

@amyeroberts
Copy link
Collaborator

it is a breaking change in its current state, since out_indices currently defaults to the last stage index if the user doesn't specify them (think @amyeroberts added that here).

This was just matching the logic that was originally implemented for the out_features (selecting the last layer). As you added this @NielsRogge you'll know the motivation for this better than me :)

As it stands this, I'm not in favour of this as this requires adding in backbone API / logic into standard model APIs. This is essentially making things leaky: why do I need to know about out_indices to get my hidden states if I'm not loading a backbone?

Moreover, this is going to break a tonne of stuff, as users who have created checkpoints which are not backbones will still have out_indices set in the model config. This isn't easy to rectify: how would we know if the values in the config are what the user wanted e.g. just the last hidden state, or it just happened to be the default when the config was created?

It introduces inconsistencies in our models forward passes, which makes the code harder to understand and is tying non-backbone logic to an API which still isn't 100% stable at the moment.

An alternative approach would be to have a different argument in the config which defaults to all the layers but then can be overridden by the config's out_indices when loading a backbone.

@amyeroberts
Copy link
Collaborator

Actually, not another config parameter - because then the source of truth isn't clear and behaviour for the user can be unexpected.

Copy link

github-actions bot commented Jan 6, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jan 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants