Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backbone] Use load_backbone instead of AutoBackbone.from_config #28661

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_conditional_detr import ConditionalDetrConfig


Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(self, config):
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)

# replace batch norm by frozen batch norm
with torch.no_grad():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels

Expand Down Expand Up @@ -409,7 +409,7 @@ def __init__(self, config):
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)

# replace batch norm by frozen batch norm
with torch.no_grad():
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/deta/configuration_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class DetaConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to have the use_timm_backbone included in all model configs that contain backbone information to differentiate between loading a transformers or timm pretrained backbone.

Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
Expand Down Expand Up @@ -146,6 +149,7 @@ def __init__(
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
num_queries=900,
max_position_embeddings=2048,
encoder_layers=6,
Expand Down Expand Up @@ -203,6 +207,7 @@ def __init__(
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_deta import DetaConfig


Expand Down Expand Up @@ -338,7 +338,7 @@ class DetaBackboneWithPositionalEncodings(nn.Module):
def __init__(self, config):
super().__init__()

backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_detr import DetrConfig


Expand Down Expand Up @@ -356,7 +356,7 @@ def __init__(self, config):
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)

# replace batch norm by frozen batch norm
with torch.no_grad():
Expand Down
20 changes: 12 additions & 8 deletions src/transformers/models/dpt/configuration_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class DPTConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.

Example:

Expand Down Expand Up @@ -169,6 +172,7 @@ def __init__(
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -179,9 +183,6 @@ def __init__(
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")

if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")

use_autobackbone = False
if self.is_hybrid:
if backbone_config is None and backbone is None:
Expand All @@ -193,17 +194,17 @@ def __init__(
"out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True,
}
self.backbone_config = BitConfig(**backbone_config)
backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, dict):
logger.info("Initializing the config with a `BiT` backbone.")
self.backbone_config = BitConfig(**backbone_config)
backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, PretrainedConfig):
self.backbone_config = backbone_config
backbone_config = backbone_config
else:
raise ValueError(
f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}."
)

self.backbone_config = backbone_config
self.backbone_featmap_shape = backbone_featmap_shape
self.neck_ignore_stages = neck_ignore_stages

Expand All @@ -221,14 +222,17 @@ def __init__(
self.backbone_config = backbone_config
self.backbone_featmap_shape = None
self.neck_ignore_stages = []

else:
self.backbone_config = backbone_config
self.backbone_featmap_shape = None
self.neck_ignore_stages = []

if use_autobackbone and backbone_config is not None and backbone is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to account for the recent DPT updates where backbone and backbone_config can be passed in, but aren't necessary.

raise ValueError("You can't specify both `backbone` and `backbone_config`.")

self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_dpt import DPTConfig


Expand Down Expand Up @@ -131,12 +131,10 @@ def __init__(self, config, feature_size=None):
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])

self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = load_backbone(config)
feature_dim = self.backbone.channels[-1]
if len(config.backbone_config.out_features) != 3:
raise ValueError(
f"Expected backbone to have 3 output features, got {len(config.backbone_config.out_features)}"
)
if len(self.backbone.channels) != 3:
raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage

if feature_size is None:
Expand Down Expand Up @@ -1082,7 +1080,7 @@ def __init__(self, config):

self.backbone = None
if config.backbone_config is not None and config.is_hybrid is False:
self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = load_backbone(config)
else:
self.dpt = DPTModel(config, add_pooling_layer=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class Mask2FormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256):
Expand Down Expand Up @@ -162,6 +165,7 @@ def __init__(
output_auxiliary_logits: bool = None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
**kwargs,
):
if use_pretrained_backbone:
Expand Down Expand Up @@ -228,6 +232,7 @@ def __init__(
self.num_hidden_layers = decoder_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone

super().__init__(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch
from torch import Tensor, nn

from ... import AutoBackbone
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
Expand All @@ -36,6 +35,7 @@
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig


Expand Down Expand Up @@ -1376,7 +1376,7 @@ def __init__(self, config: Mask2FormerConfig):
"""
super().__init__()

self.encoder = AutoBackbone.from_config(config.backbone_config)
self.encoder = load_backbone(config)
self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels)

def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class MaskFormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used.
Expand Down Expand Up @@ -122,6 +125,7 @@ def __init__(
output_auxiliary_logits: Optional[bool] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
**kwargs,
):
if use_pretrained_backbone:
Expand Down Expand Up @@ -193,6 +197,7 @@ def __init__(
self.num_hidden_layers = self.decoder_config.num_hidden_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
super().__init__(**kwargs)

@classmethod
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch
from torch import Tensor, nn

from ... import AutoBackbone
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
Expand All @@ -37,6 +36,7 @@
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import load_backbone
from ..detr import DetrConfig
from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
Expand Down Expand Up @@ -1428,14 +1428,13 @@ def __init__(self, config: MaskFormerConfig):
The configuration used to instantiate this model.
"""
super().__init__()

# TODD: add method to load pretrained weights of backbone
backbone_config = config.backbone_config
if backbone_config.model_type == "swin":
if hasattr(config, "backbone_config") and config.backbone_config.model_type == "swin":
# for backwards compatibility
backbone_config = config.backbone_config
backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
self.encoder = AutoBackbone.from_config(backbone_config)
config.backbone_config = backbone_config
self.encoder = load_backbone(config)

feature_channels = self.encoder.channels
self.decoder = MaskFormerPixelDecoder(
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/oneformer/configuration_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class OneFormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
ignore_value (`int`, *optional*, defaults to 255):
Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150):
Expand Down Expand Up @@ -152,6 +155,7 @@ def __init__(
backbone_config: Optional[Dict] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
ignore_value: int = 255,
num_queries: int = 150,
no_object_weight: int = 0.1,
Expand Down Expand Up @@ -222,6 +226,7 @@ def __init__(
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.ignore_value = ignore_value
self.num_queries = num_queries
self.no_object_weight = no_object_weight
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/oneformer/modeling_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from torch import Tensor, nn
from torch.cuda.amp import autocast

from ... import AutoBackbone
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
Expand All @@ -37,6 +36,7 @@
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import load_backbone
from .configuration_oneformer import OneFormerConfig


Expand Down Expand Up @@ -1478,8 +1478,7 @@ def __init__(self, config: OneFormerConfig):
The configuration used to instantiate this model.
"""
super().__init__()
backbone_config = config.backbone_config
self.encoder = AutoBackbone.from_config(backbone_config)
self.encoder = load_backbone(config)
self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels)

def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_table_transformer import TableTransformerConfig


Expand Down Expand Up @@ -290,7 +290,7 @@ def __init__(self, config):
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)

# replace batch norm by frozen batch norm
with torch.no_grad():
Expand Down
Loading
Loading