Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF implementation of RegNets #17554

Merged
merged 42 commits into from
Jun 29, 2022
Merged

TF implementation of RegNets #17554

merged 42 commits into from
Jun 29, 2022

Conversation

ariG23498
Copy link
Contributor

In this PR in which we (/w @sayakpaul) are proting the RegNets model into TensorFlow.

Copied the torch implementation of regnets and porting the code to tf step by step. Also introduced an output layer which was needed for regnets.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 4, 2022

The documentation is not available anymore as the PR was closed or merged.

did not change the documentation yet, yet to try the playground on the model
sayakpaul added 2 commits June 6, 2022 21:45
* fix: code structure in few cases.

* fix: code structure to align tf models.

* fix: layer naming, bn layer still remains.

* chore: change default epsilon and momentum in bn.
@sayakpaul
Copy link
Member

sayakpaul commented Jun 6, 2022

@amyeroberts

If we run the following:

from PIL import Image
import numpy as np

from src.transformers.models.regnet.modeling_tf_regnet import (
    TFRegNetForImageClassification
)
from transformers import AutoFeatureExtractor

def prepare_img():
    image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
    return image

feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/regnet-y-040")
model = TFRegNetForImageClassification.from_pretrained("facebook/regnet-y-040", from_pt=True)

image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf") 
outputs = model(**inputs, training=False)

print(outputs.logits.shape)

expected_slice = np.array([-0.4180, -1.5051, -3.4836])

np.testing.assert_allclose(outputs.logits[0, :3].numpy(), expected_slice, atol=1e-4)

First, it complains the moving_mean and moving_variance params are not loaded properly.

We tested your solution in #17571. With that, we're running into mismatches of num_batches_tracked and even moving_mean. It also complains about some of the mismatches stemming from the shortcut layer which wasn't the case for the earlier setup.

Do you have any thoughts?

@amyeroberts
Copy link
Collaborator

Hi @sayakpaul

Could you give a bit more information about the mismatches i.e. the printouts you're currently getting?

Regarding num_batches_tracked, I don't believe this parameter will ever be cross-loaded into a tf.keras.layers.BatchNormalization layer as there isn't an equivalent parameter. This is only important if the corresponding PyTorch batch norm layer doesn't have its momentum set c.f. param updates, which you'll need to verify for this model. I suggest looking at the implementations of both the TF and PyTorch layer to see when/if these differences are important. If the parameter is necessary, then I think one approach might be subclassing to build a new layer and include the parameter as a registered weight + any necessary logic to use it, but I'm not sure at the moment.

@sayakpaul
Copy link
Member

sayakpaul commented Jun 7, 2022

I tried debugging this today but no luck yet. But here's some information for all of us to navigate this through:

All these mismatches seem to be stemming from the layers.0 of RegNet stages. Mismatches stemming from other layers (layers.2 for example) are related to num_batches_tracked.

The test used to gather this information is the same one as mentioned in #17554 (comment).

@amyeroberts

@amyeroberts
Copy link
Collaborator

@sayakpaul Thanks for your detailed update. Comments below:

  1. OK - thanks for posting that it really helps!

  2. num_batches_tracked isn't trainable, but it is updated during training. As I mentioned above, if the layer has momentum set (it's not None) then you can ignore it. However, if momentum isn't set, then the layer uses num_batches_tracked to update the running_mean and running_var calculations, which are used during evaluation to normalize the batch. You can quickly check if the momentum is set for the batchnorm layers running something like all([x.momentum is not None for x in model.modules() if isinstance(x, nn.BatchNorm2d)]).

  3. Looking at the printout you pasted above, it says All the weights of TFRegNetForImageClassification were initialized from the PyTorch model.. If this is the case, and some of the PyTorch weights weren't used, it makes me think some layers might be missing in your implementation. I would look at the two architectures and see if they differ anywhere.

@sayakpaul
Copy link
Member

@amyeroberts a quick update:

  • momentum is actually not set. This is why we need to also retrieve num_batches_tracked too. We need to figure out a way to factor it in to use with layers.BatchNormalization in TensorFlow.
  • The TF model has a fewer number of params than the PT model so we'll look into why this is the case. One immediate reason would be the absence of num_batches_tracked. But that contributes a very small difference. We currently have 629440 fewer parameters in the TF model than the PT one.

@amyeroberts
Copy link
Collaborator

@sayakpaul Thanks for the update!

  • OK, this makes things a bit more difficult. Let me know if you want any help for this step. It's something that will likely need to be done in other PT -> TF ports so definitely valuable to the community if you added this!

  • It might be easier to print out the weight names instead of comparing number of parameters. The porting code works on the names, and so seeing where the two models differ can really help pinpoint what's happening. What I typically do is use the porting code to convert the tensorflow weight names and compare the two sets. For this model, it would look something like:

from transformers import RegNetForImageClassification
# import directly once __init__ files updated
from transformers.models.regnet.modeling_tf_regnet import TFRegNetForImageClassification 
from transformers.modeling_tf_pytorch_utils import convert_tf_weight_name_to_pt_weight_name

checkpoint = "facebook/regnet-y-040"
tf_model = TFRegNetForImageClassification.from_pretrained(checkpoint, from_pt=True)
pt_model = RegNetForImageClassification.from_pretrained(checkpoint)

tf_model_weights = set([convert_tf_weight_name_to_pt_weight_name(x.name)[0] for x in tf_model.trainable_variables])
pt_model_weights = set(pt_model.state_dict().keys())

print(tf_model_weights - pt_model_weights)
print(pt_model_weights - tf_model_weights)

@sayakpaul
Copy link
Member

Thanks for the suggestions. Will try them out and update.

@sayakpaul
Copy link
Member

@amyeroberts

I had to do a few minor modifications to your snippet in #17554 (comment):

tf_model_weights = set(
    [
        convert_tf_weight_name_to_pt_weight_name(x.name)[0]
        for x in tf_model.trainable_variables + tf_model.non_trainable_variables
    ]
)
pt_model_weights = set(pt_model.state_dict().keys())
tf_model_weights_new = set()

for name in tf_model_weights:
    if "moving_mean" in name:
        name = name.replace("moving_mean", "running_mean")
    elif "moving_variance" in name:
        name = name.replace("moving_variance", "running_var")
    tf_model_weights_new.add(name)


print(f"Differences in the TF model and PT model: {tf_model_weights_new - pt_model_weights}")
print(f"Differences in the PT model and TF model: {pt_model_weights - tf_model_weights_new}")
print(f"Total weights differing: {len(pt_model_weights - tf_model_weights_new)}")

convert_tf_weight_name_to_pt_weight_name() doesn't change the moving_mean and moving_variance to running_mean and running_var respectively. Instead, currently, it's handled here so that this query is successful.

With this change, the result of pt_model_weights - tf_model_weights_new is exactly matching with the complaint:

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFRegNetForImageClassification ...

(Full output here).

I have gone over the modeling_tf_regnet.py script a couple of times but I don't yet know what I can do here. Let me know what you usually do when you have these differences.

@sayakpaul
Copy link
Member

Also an oversight on my end in reporting momentum in #17554 (comment).

all([x.momentum is not None for x in model.modules() if isinstance(x, nn.BatchNorm2d)]) actually gives True which means it's okay to ignore num_batches_tracked.

@sayakpaul
Copy link
Member

@amyeroberts we were able to rectify the model implementation and make it work. The integration test (mentioned in #17554 (comment)) is passing now.

The tests, however, are failing for a weird reason:

Parameter config in `TFRegNetModel(config)` should be an instance of class `PretrainedConfig`. To create a model from a pretrained model use `model = TFRegNetModel.from_pretrained(PRETRAINED_MODEL_NAME)`

Weird because we tested a couple of things in isolation:

from transformers import RegNetConfig

config_class = RegNetConfig()

print(f"RegNet Config class type: {type(config_class)}.")
print(f"RegNet Config is an instance of PretrainedConfig: {isinstance(config_class, PretrainedConfig)}")

The final print statement gives True. But when we do the following:

from src.transformers.models.regnet.modeling_tf_regnet import TFRegNetForImageClassification, TFRegNetModel

class_from_config = TFRegNetModel(config_class)
print("Model class from config was initialized.")

it complains:

Parameter config in `TFRegNetModel(config)` should be an instance of class `PretrainedConfig`. To create a model from a pretrained model use `model = TFRegNetModel.from_pretrained(PRETRAINED_MODEL_NAME)`

Do you have any suggestions for this?

@ariG23498 ariG23498 marked this pull request as ready for review June 13, 2022 05:19
@sayakpaul
Copy link
Member

@sgugger @Rocketknight1 the PR is now ready for review.

This particular model actually has the largest vision model checkpoint available to date: https://huggingface.co/facebook/regnet-y-10b-seer. It's still in PyTorch and the corresponding model makes use of the low_cpu_usage argument.

I had a chat with @Rocketknight1 a few days back on the possibility of supporting this checkpoint in TensorFlow too. This will require tweaks and they will be contributed in a separate PR.

Comment on lines 103 to 110
def call(self, hidden_state):
num_channels = shape_list(hidden_state)[-1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state = self.embedder(hidden_state)
return hidden_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def call(self, hidden_state):
num_channels = shape_list(hidden_state)[-1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state = self.embedder(hidden_state)
return hidden_state
def call(self, pixel_values):
num_channels = shape_list(pixel_values)[-1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state = self.embedder(pixel_values)
return hidden_state

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 306 to 309
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with other models, it's better to do this in the embedding (stem) class rather than here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 432 to 433
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check doesn't seem useful as pixel_values is a required argument

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines 317 to 322
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(last_hidden_state)
pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2))

# Change to NCHW output format have uniformity in the modules
last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(last_hidden_state)
pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2))
# Change to NCHW output format have uniformity in the modules
last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(last_hidden_state)
# Change to NCHW output format have uniformity in the modules
last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored. Thanks.

Comment on lines +44 to +46
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add TFRegNet to the doc tests, making sure all code examples are tested.

Details here: https://github.com/huggingface/transformers/tree/main/docs#testing-documentation-examples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
@sayakpaul
Copy link
Member

@sgugger @Rocketknight1 anything pending on my end to move this ahead?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good on my side!

@sayakpaul
Copy link
Member

Thank you!

The TF weights need to be uploaded to the model repository on Hub. After that, I will remove the from_pt argument here and here.

This part is pending then.

@sayakpaul
Copy link
Member

As soon as @ariG23498 performs a rebase the errors should go away (I can't since the forked repo's owner can only do that I guess).

@gante
Copy link
Member

gante commented Jun 28, 2022

Hey folks! The hub PRs for most models are open (the exception is the 10b models, which I'm having a look at), you should be able to remove the from_pt soon :)

@sayakpaul
Copy link
Member

Thanks, @gante! Keep us posted.

@sayakpaul
Copy link
Member

@gante thanks for your hard work on getting the TF parameters of this model to Hub (there are a total of 32 checkpoints in case anyone's curious). I really appreciate this help!

I have removed the occurrences of from_pt=True from TF test script. With this, I think the PR is ready to merge an exception on the 10B checkpoint of the model. It requires some functionalities that are apparently missing but will be likely added soon.

@sayakpaul
Copy link
Member

The failing tests seem to be unrelated to this PR?

@gante
Copy link
Member

gante commented Jun 29, 2022

Yes, pytorch v1.12 was released a few hours ago and broke a few things here. We have pinned pytorch to <1.12 -- rebasing with main should fix the problems :)

@sayakpaul
Copy link
Member

@gante we performed a rebase but the pipelines test for Torch seems to be still failing.

@Rocketknight1
Copy link
Member

That error seems unrelated to this PR, so I think you could probably merge anyway.

@sayakpaul
Copy link
Member

Over to you @gante then since I can't merge :D

@gante
Copy link
Member

gante commented Jun 29, 2022

The error is on main as well -- merging 🤞

@gante gante merged commit a7eba83 into huggingface:main Jun 29, 2022
@sayakpaul sayakpaul deleted the aritra-regnets branch June 29, 2022 13:06
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* chore: initial commit

Copied the torch implementation of regnets and porting the code to tf step by step. Also introduced an output layer which was needed for regnets.

* chore: porting the rest of the modules to tensorflow

did not change the documentation yet, yet to try the playground on the model

* Fix initilizations (huggingface#1)

* fix: code structure in few cases.

* fix: code structure to align tf models.

* fix: layer naming, bn layer still remains.

* chore: change default epsilon and momentum in bn.

* chore: styling nits.

* fix: cross-loading bn params.

* fix: regnet tf model, integration passing.

* add: tests for TF regnet.

* fix: code quality related issues.

* chore: added rest of the files.

* minor additions..

* fix: repo consistency.

* fix: regnet tf tests.

* chore: reorganize dummy_tf_objects for regnet.

* chore: remove checkpoint var.

* chore: remov unnecessary files.

* chore: run make style.

* Update docs/source/en/model_doc/regnet.mdx

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* chore: PR feedback I.

* fix: pt test. thanks to @ydshieh.

* New adaptive pooler (huggingface#3)

* feat: new adaptive pooler

Co-authored-by: @Rocketknight1

* chore: remove image_size argument.

Co-authored-by: matt <rocketknight1@gmail.com>

Co-authored-by: matt <rocketknight1@gmail.com>

* Empty-Commit

* chore: remove image_size comment.

* chore: remove playground_tf.py

* chore: minor changes related to spacing.

* chore: make style.

* Update src/transformers/models/regnet/modeling_tf_regnet.py

Co-authored-by: amyeroberts <aeroberts4444@gmail.com>

* Update src/transformers/models/regnet/modeling_tf_regnet.py

Co-authored-by: amyeroberts <aeroberts4444@gmail.com>

* chore: refactored __init__.

* chore: copied from -> taken from./g

* adaptive pool -> global avg pool, channel check.

* chore: move channel check to stem.

* pr comments - minor refactor and add regnets to doc tests.

* Update src/transformers/models/regnet/modeling_tf_regnet.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* minor fix in the xlayer.

* Empty-Commit

* chore: removed from_pt=True.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: matt <rocketknight1@gmail.com>
Co-authored-by: amyeroberts <aeroberts4444@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
@ariG23498 ariG23498 mentioned this pull request Dec 19, 2022
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants