Skip to content

Commit

Permalink
Add TensorFlow implementation of ConvNeXTv2 (huggingface#25558)
Browse files Browse the repository at this point in the history
* Add type annotations to TFConvNextDropPath

* Use tf.debugging.assert_equal for TFConvNextEmbeddings shape check

* Add TensorFlow implementation of ConvNeXTV2

* check_docstrings: add TFConvNextV2Model to exclusions

TFConvNextV2Model and TFConvNextV2ForImageClassification have docstrings
which are equivalent to their PyTorch cousins, but a parsing issue prevents them
from passing the test.

Adding exclusions for these two classes as discussed in huggingface#25558.
  • Loading branch information
neggles authored Nov 1, 2023
1 parent 391d14e commit f8afb2b
Show file tree
Hide file tree
Showing 12 changed files with 1,012 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Conditional DETR](model_doc/conditional_detr) ||||
| [ConvBERT](model_doc/convbert) ||||
| [ConvNeXT](model_doc/convnext) ||||
| [ConvNeXTV2](model_doc/convnextv2) || ||
| [ConvNeXTV2](model_doc/convnextv2) || ||
| [CPM](model_doc/cpm) ||||
| [CPM-Ant](model_doc/cpmant) ||||
| [CTRL](model_doc/ctrl) ||||
Expand Down
13 changes: 12 additions & 1 deletion docs/source/en/model_doc/convnextv2.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,15 @@ If you're interested in submitting a resource to be included here, please feel f
## ConvNextV2ForImageClassification

[[autodoc]] ConvNextV2ForImageClassification
- forward
- forward

## TFConvNextV2Model

[[autodoc]] TFConvNextV2Model
- call


## TFConvNextV2ForImageClassification

[[autodoc]] TFConvNextV2ForImageClassification
- call
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3415,6 +3415,13 @@
"TFConvNextPreTrainedModel",
]
)
_import_structure["models.convnextv2"].extend(
[
"TFConvNextV2ForImageClassification",
"TFConvNextV2Model",
"TFConvNextV2PreTrainedModel",
]
)
_import_structure["models.ctrl"].extend(
[
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -7127,6 +7134,11 @@
TFConvBertPreTrainedModel,
)
from .models.convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
from .models.convnextv2 import (
TFConvNextV2ForImageClassification,
TFConvNextV2Model,
TFConvNextV2PreTrainedModel,
)
from .models.ctrl import (
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCTRLForSequenceClassification,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("convnextv2", "TFConvNextV2Model"),
("ctrl", "TFCTRLModel"),
("cvt", "TFCvtModel"),
("data2vec-vision", "TFData2VecVisionModel"),
Expand Down Expand Up @@ -200,6 +201,7 @@
[
# Model for Image-classsification
("convnext", "TFConvNextForImageClassification"),
("convnextv2", "TFConvNextV2ForImageClassification"),
("cvt", "TFCvtForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
Expand Down
49 changes: 31 additions & 18 deletions src/transformers/models/convnext/modeling_tf_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -50,11 +50,11 @@ class TFConvNextDropPath(tf.keras.layers.Layer):
(1) github.com:rwightman/pytorch-image-models
"""

def __init__(self, drop_path, **kwargs):
def __init__(self, drop_path: float, **kwargs):
super().__init__(**kwargs)
self.drop_path = drop_path

def call(self, x, training=None):
def call(self, x: tf.Tensor, training=None):
if training:
keep_prob = 1 - self.drop_path
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
Expand All @@ -69,15 +69,15 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer):
found in src/transformers/models/swin/modeling_swin.py.
"""

def __init__(self, config, **kwargs):
def __init__(self, config: ConvNextConfig, **kwargs):
super().__init__(**kwargs)
self.patch_embeddings = tf.keras.layers.Conv2D(
filters=config.hidden_sizes[0],
kernel_size=config.patch_size,
strides=config.patch_size,
name="patch_embeddings",
kernel_initializer=get_initializer(config.initializer_range),
bias_initializer="zeros",
bias_initializer=tf.keras.initializers.Zeros(),
)
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
self.num_channels = config.num_channels
Expand All @@ -86,15 +86,15 @@ def call(self, pixel_values):
if isinstance(pixel_values, dict):
pixel_values = pixel_values["pixel_values"]

num_channels = shape_list(pixel_values)[1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
self.num_channels,
message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
)

# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
# shape = (batch_size, in_height, in_width, in_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))

embeddings = self.patch_embeddings(pixel_values)
Expand Down Expand Up @@ -188,15 +188,28 @@ class TFConvNextStage(tf.keras.layers.Layer):
"""ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
Args:
config ([`ConvNextConfig`]): Model configuration class.
in_channels (`int`): Number of input channels.
out_channels (`int`): Number of output channels.
depth (`int`): Number of residual blocks.
drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
config (`ConvNextV2Config`):
Model configuration class.
in_channels (`int`):
Number of input channels.
out_channels (`int`):
Number of output channels.
depth (`int`):
Number of residual blocks.
drop_path_rates(`List[float]`):
Stochastic depth rates for each layer.
"""

def __init__(
self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None, **kwargs
self,
config: ConvNextConfig,
in_channels: int,
out_channels: int,
kernel_size: int = 2,
stride: int = 2,
depth: int = 2,
drop_path_rates: Optional[List[float]] = None,
**kwargs,
):
super().__init__(**kwargs)
if in_channels != out_channels or stride > 1:
Expand All @@ -215,7 +228,7 @@ def __init__(
kernel_size=kernel_size,
strides=stride,
kernel_initializer=get_initializer(config.initializer_range),
bias_initializer="zeros",
bias_initializer=tf.keras.initializers.Zeros(),
name="downsampling_layer.1",
),
]
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/models/convnextv2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_tf_available,
)


Expand All @@ -46,6 +47,17 @@
"ConvNextV2Backbone",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_convnextv2"] = [
"TFConvNextV2ForImageClassification",
"TFConvNextV2Model",
"TFConvNextV2PreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_convnextv2 import (
Expand All @@ -67,6 +79,18 @@
ConvNextV2PreTrainedModel,
)

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_convnextv2 import (
TFConvNextV2ForImageClassification,
TFConvNextV2Model,
TFConvNextV2PreTrainedModel,
)

else:
import sys

Expand Down
Loading

0 comments on commit f8afb2b

Please sign in to comment.