Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFlow implementation of ConvNeXTv2 #25558

Merged
merged 4 commits into from
Nov 1, 2023

Conversation

neggles
Copy link
Contributor

@neggles neggles commented Aug 17, 2023

What does this PR do?

This adds TensorFlow support for ConvNeXTV2, following the pattern of the existing PyTorch ConvNeXTV2 implementation and the existing ConvNeXT(V1) TensorFlow implementation.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Tests are in, make fixup and make quality are happy, and NVIDIA_TF32_OVERRIDE=0 RUN_SLOW=1 RUN_PT_TF_CROSS_TESTS=1 py.test -vv tests/models/convnextv2/test_modeling_tf_convnextv2.py passes everything except for from_pretrained (unsurprisingly, "facebook/convnextv2-tiny-1k-224" lacks TensorFlow weights, but from_pt=True works swimmingly)

Getting this one to pass tests was a little tricky, the outputs from the model were quite variable run-to-run. Still not entirely sure exactly what I did to fix it, but it looks like TensorFlow doesn't like it when you do this:

x = input + self.drop_path(x, training=training)

and wants you to do this instead:

x = self.drop_path(x, training=training)
x = input + x

🤷 who knows what cursed machinations the XLA compiler gets up to while nobody's looking.

There was a prior (seemingly abandoned) port attempt in #23155 which I referenced a little while building this; just to address one of the review comments on that PR, config.layer_norm_eps only applies to the TFConvNextV2MainLayer.layernorm layer, not the other norm layers or the GRN layer, which are fixed at 1e-6 epsilon (see the existing PyTorch implementation & original code). Using the (typically 1e-12) value from config.layer_norm_eps in those layers will produce aggressively incorrect outputs 😭

Based on the PR template, it looks like I should tag @amyeroberts and maybe @alaradirik (majority holder of git blame for the PyTorch implementation)?

@ArthurZucker
Copy link
Collaborator

cc @amyeroberts

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice PR - thanks for adding this model!

Just a few small nits and comments. Once those are resolved we're good to merge! 🤗

@neggles neggles force-pushed the tf-convnextv2 branch 2 times, most recently from 0468796 to cedf3ac Compare August 20, 2023 11:02
@neggles
Copy link
Contributor Author

neggles commented Aug 20, 2023

OK, pushed some updates! I think I've addressed everything.

To get the nice docstring fixes and type annotation changes I had to modify the ConvNeXT v1 model files (since make fix-copies overwrites the comment-annotated classes otherwise), so I broke those changes out into a separate commit just for clarity (happy to squash them back together if you prefer).

Of note, I'm getting some weird results when I run the transformers-cli pt-to-tf command. initially it was just caused by the PyTorch model running on CPU and TensorFlow "helpfully" automatically running on GPU, but even forcibly hiding my CUDA devices the outputs seem to be a little out of wack:

$ CUDA_VISIBLE_DEVICES="" NVIDIA_TF32_OVERRIDE=0 transformers-cli pt-to-tf --no-pr \
    --model-name facebook/convnextv2-tiny-1k-224 --local-dir temp/convnextv2-tiny-1k-224
# [...]
ValueError: The cross-loaded TensorFlow model has different outputs, something went wrong!

List of maximum output differences above the threshold (5e-05):


List of maximum hidden layer differences above the threshold (5e-05):
hidden_states[2]: 6.256e-04
hidden_states[3]: 2.838e-03
hidden_states[4]: 1.793e-04

So only hidden states are out of spec, not actual logits.

The strange thing is, if I run it against ConvNeXT v1, I get a similar result:

$ CUDA_VISIBLE_DEVICES="" NVIDIA_TF32_OVERRIDE=0 transformers-cli pt-to-tf --no-pr \
    --model-name facebook/convnext-tiny-224 --local-dir temp/convnext-tiny-224
# [...]
ValueError: The cross-loaded TensorFlow model has different outputs, something went wrong!

List of maximum output differences above the threshold (5e-05):


List of maximum hidden layer differences above the threshold (5e-05):
hidden_states[1]: 1.583e-04
hidden_states[2]: 1.480e-03
hidden_states[3]: 2.380e-03
hidden_states[4]: 1.595e-04

The hidden states are actually more out of range for the v1 model. Not sure what's going on there; I tried loading the model with fewer stages (as suggested in the earlier PR) but it runs into layer input/output shape issues and fails to load, and attempting to inspect layers via breakpoint() makes tensorflow throw a hissyfit and crash the entire debugger server 😅 pain.

The atol for logits in the PyTorch test script is only 1e-4, so maybe this is just Expected:tm: for these models?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neggles Awesome work - thanks again for adding this model and for the detailed comments!

Regarding the difference in hidden states, this is indeed Expected™️ 😅. Often we see small differences creep in for certain layers e.g. batch norm which we're not able to reduce any further (without a lot more work). These then get amplified during the forward pass as the inputs get passed through the layers. Some models even have differences on the order of 1e-2 for the hidden layers! As long as the difference between the PT and TF model output logits are small ~1e-5 then it's OK.

@amyeroberts
Copy link
Collaborator

@neggles For uploading the checkpoints, as long as the output logits have a small difference, you can override the checks. Let us know when you've pushed the TF checkpoints to the hub, or if you need help/permission to do so and then we can merge.

@neggles
Copy link
Contributor Author

neggles commented Aug 30, 2023

@amyeroberts thanks for the help/feedback! TensorFlow may be a little old hat these days but, well, Cloud TPUs 🤷

OK, cool, I figured that was probably the case. I'm seeing about 1.75e-5 difference in output logits which seems reasonable enough to me; I suspect most of the difference comes down to LayerNorm epsilon settings, there's a bit of variation there depending on who created the model (e.g. the WD Tagger ConvNeXTv2 model uses the default TF layernorm eps of 1e-3 for everything), but Transformers sets all of them to 1e-6 except for the final one (set by config). That's a tiny bit out of scope for this PR, though 😆 [edit: see below 😅 ]

Anyway! Have opened PRs for conversion on the smaller models:

facebook/convnextv2-atto-1k-224
facebook/convnextv2-femto-1k-224
facebook/convnextv2-pico-1k-224
facebook/convnextv2-nano-1k-224
facebook/convnextv2-nano-22k-224
facebook/convnextv2-nano-22k-384
facebook/convnextv2-tiny-1k-224
facebook/convnextv2-tiny-22k-224
facebook/convnextv2-tiny-22k-384

Something looks a bit screwy with the smaller models? The output differences are pretty big, e.g. 3.719e-05 for atto-1k-224, and nano-1k-224 is at 2.9e-5; on that one we have a comparison point from the last conversion attempt, where output difference is 1.371e-06 which is pretty major. Hmm. [edit: Fixed, everything's within ~1.5e5 (usually less) now]

@neggles
Copy link
Contributor Author

neggles commented Aug 30, 2023

Found it 🤦 missed a pair of () in the GRN calc. With that fixed, atto is down to 1.001e-05 😅 will go add PR comments with corrected values.

@neggles
Copy link
Contributor Author

neggles commented Sep 21, 2023

I see most of the weight PRs have been merged, yay! I'm not sure what happened with the CI pipelines, though - looks like an unrelated error, but I don't have permissions to hit the rerun button 😢

@amyeroberts should I rebase this and push to make CI re-run?

@ArthurZucker
Copy link
Collaborator

Sure rebasing is always a good thing’ Amy is out for a while, ping me whenever

@ArthurZucker
Copy link
Collaborator

Ok, the tf weights are missing on the HUB I asked for a merge of your PRs! 😉

@Rocketknight1
Copy link
Member

TF weights are in!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@neggles
Copy link
Contributor Author

neggles commented Oct 21, 2023

Oops, forgot about this! Yes, TF weights are all in now - thanks much! - lemme resolve that conflict...

@neggles
Copy link
Contributor Author

neggles commented Oct 21, 2023

OK! Rebased one more time and did one other thing; I've dropped the add_pooling_layer argument from TFConvNeXTV2Model and TFConvNeXTV2MainStage since utils/check_docstrings.py didn't like that it's present-but-undocumented, and it's not present in the PyTorch version of the code either.

Should be good to merge now, I think? Sorry for the delay this end! Been a busy few weeks.

(ping @ArthurZucker 😄)

@ArthurZucker
Copy link
Collaborator

Sure! I'll let @Rocketknight1 handle this!

@Rocketknight1
Copy link
Member

Rocketknight1 commented Oct 23, 2023

Hey @neggles, this looks good, but needs make fix-copies to get the tests passing. Try running that in the root of the repo and then committing/pushing!

@neggles
Copy link
Contributor Author

neggles commented Oct 25, 2023

@Rocketknight1 So this is... weird. When i run make repo-consistency on this end, I get it complaining about TFRegNetModel / TFRegNetForImageClassification:

Traceback (most recent call last):
  File "/tank/ml/huggingface/transformers/utils/check_docstrings.py", line 1232, in <module>
    check_docstrings(overwrite=args.fix_and_overwrite)
  File "/tank/ml/huggingface/transformers/utils/check_docstrings.py", line 1224, in check_docstrings
    raise ValueError(error_message)
ValueError: There was at least one problem when checking docstrings of public objects.
The following objects docstrings do not match their signature. Run `make fix-copies` to fix this.
- TFRegNetForImageClassification
- TFRegNetModel
make: *** [Makefile:46: repo-consistency] Error 1

While CircleCI shows it complaining about these new models despite the fact that - as far as I can tell - the docstrings do match for both TFConvNextV2Model and TFRegNetModel 🤔. Running make fix-copies results in no action:

aholmes@hyperion:/tank/ml/huggingface/transformers ❯ make fix-copies
python utils/check_copies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
python utils/check_doctest_list.py --fix_and_overwrite
python utils/check_task_guides.py --fix_and_overwrite
python utils/check_docstrings.py --fix_and_overwrite
Using /home/aholmes/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/aholmes/.cache/torch_extensions/py310_cu118/cuda_kernel/build.ninja...
Building extension module cuda_kernel...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cuda_kernel...
aholmes@hyperion:/tank/ml/huggingface/transformers ❯ 

Maybe something screwy with TensorFlow model init signature parsing? If i edit utils/check_docstrings.py like so, to make it print what it thinks the docstrings should be:

--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -1195,7 +1195,7 @@ def check_docstrings(overwrite: bool = False):
             if overwrite:
                 fix_docstring(obj, old_doc, new_doc)
             else:
-                failures.append(name)
+                failures.append(name + "\n  Corrected docstring:\n  " + new_doc)
         elif not overwrite and new_doc is not None and ("<fill_type>" in new_doc or "<fill_docstring>" in new_doc):
             to_clean.append(name)

I get this output:

ValueError: There was at least one problem when checking docstrings of public objects.
The following objects docstrings do not match their signature. Run `make fix-copies` to fix this.
- TFRegNetForImageClassification
  Corrected docstring:
          config (`RegNetConfig`): <fill_docstring>
- TFRegNetModel
  Corrected docstring:
          config (`RegNetConfig`): <fill_docstring>

Which is the same as what's currently in there, but without the [] to turn it into a link. Is that not necessary anymore? From recent commit history I suspect not, but I'm not sure.

I also note that both TFConvNextModel and TFConvNextForImageClassification are listed as exceptions in check_docstrings.py (as is ConvNextV2Model)... Not entirely sure what to do here. Nevertheless, I've rebased and pushed again so who knows, maybe it'll pass this time!

@Rocketknight1
Copy link
Member

Ah, ugh, this might indeed indicate an issue with the docstring parsing for that file, I didn't realize ConvNext was one of the exceptions! If this current run fails a repo-consistency check, then I would just add the TFConvNextV2 model classes to the exceptions list in check_docstrings.py, and maybe leave a comment that they're equivalent to the PT docstrings, so whenever we get around to properly fixing those we should be able to fix the TF ones too.

@neggles
Copy link
Contributor Author

neggles commented Oct 30, 2023

@Rocketknight1 OK, no problem - have added exclusions and comment in check_docstrings.py (let me know if the comment should be in the modeling_tf_convnextv2.py file instead), rebased again, etc.

C'mon, tests! please? 🤞

@neggles
Copy link
Contributor Author

neggles commented Oct 31, 2023

Yay, tests passed!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reviewed - overall this is super-clean and it's basically ready to merge now! I left two nits, but neither is blocking.

TFConvNextV2Model and TFConvNextV2ForImageClassification have docstrings
which are equivalent to their PyTorch cousins, but a parsing issue prevents them
from passing the test.

Adding exclusions for these two classes as discussed in huggingface#25558.
@neggles
Copy link
Contributor Author

neggles commented Nov 1, 2023

Done and done! I've elected not to rebase again just to reduce the likelihood of tests getting angry 😅

@Rocketknight1
Copy link
Member

This looks good now and I'm happy to merge! cc @amyeroberts - you don't need to do another review, but let me know if there's anything you think is unresolved before we merge this.

@amyeroberts
Copy link
Collaborator

Did a quick scan over that changes. All looks good to me. Thanks again @neggles for adding this model and for such a clean PR!

@amyeroberts amyeroberts merged commit f8afb2b into huggingface:main Nov 1, 2023
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* Add type annotations to TFConvNextDropPath

* Use tf.debugging.assert_equal for TFConvNextEmbeddings shape check

* Add TensorFlow implementation of ConvNeXTV2

* check_docstrings: add TFConvNextV2Model to exclusions

TFConvNextV2Model and TFConvNextV2ForImageClassification have docstrings
which are equivalent to their PyTorch cousins, but a parsing issue prevents them
from passing the test.

Adding exclusions for these two classes as discussed in huggingface#25558.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants