diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index b74f6accadfc..6c2cbb859c8e 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -37,7 +37,7 @@ replace_return_docstrings, requires_backends, ) -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_conditional_detr import ConditionalDetrConfig @@ -363,7 +363,7 @@ def __init__(self, config): **kwargs, ) else: - backbone = AutoBackbone.from_config(config.backbone_config) + backbone = load_backbone(config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 3767eef0392f..9eed0b8ba45c 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -44,7 +44,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid from ...utils import is_ninja_available, logging -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_deformable_detr import DeformableDetrConfig from .load_custom import load_cuda_kernels @@ -409,7 +409,7 @@ def __init__(self, config): **kwargs, ) else: - backbone = AutoBackbone.from_config(config.backbone_config) + backbone = load_backbone(config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/deta/configuration_deta.py b/src/transformers/models/deta/configuration_deta.py index 1ade9465a9f3..633d6267ef3d 100644 --- a/src/transformers/models/deta/configuration_deta.py +++ b/src/transformers/models/deta/configuration_deta.py @@ -46,6 +46,9 @@ class DetaConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. num_queries (`int`, *optional*, defaults to 900): Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead. @@ -146,6 +149,7 @@ def __init__( backbone_config=None, backbone=None, use_pretrained_backbone=False, + use_timm_backbone=False, num_queries=900, max_position_embeddings=2048, encoder_layers=6, @@ -203,6 +207,7 @@ def __init__( self.backbone_config = backbone_config self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone self.num_queries = num_queries self.max_position_embeddings = max_position_embeddings self.d_model = d_model diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 330ccfe3f0c3..b6c65a3c301b 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -39,7 +39,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_deta import DetaConfig @@ -338,7 +338,7 @@ class DetaBackboneWithPositionalEncodings(nn.Module): def __init__(self, config): super().__init__() - backbone = AutoBackbone.from_config(config.backbone_config) + backbone = load_backbone(config) with torch.no_grad(): replace_batch_norm(backbone) self.model = backbone diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index a3078cd2d0ae..026100b24506 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -37,7 +37,7 @@ replace_return_docstrings, requires_backends, ) -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_detr import DetrConfig @@ -356,7 +356,7 @@ def __init__(self, config): **kwargs, ) else: - backbone = AutoBackbone.from_config(config.backbone_config) + backbone = load_backbone(config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/dpt/configuration_dpt.py b/src/transformers/models/dpt/configuration_dpt.py index 0b6366659bc1..5bb48ad9780a 100644 --- a/src/transformers/models/dpt/configuration_dpt.py +++ b/src/transformers/models/dpt/configuration_dpt.py @@ -117,6 +117,9 @@ class DPTConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, defaults to `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. Example: @@ -169,6 +172,7 @@ def __init__( backbone_config=None, backbone=None, use_pretrained_backbone=False, + use_timm_backbone=False, **kwargs, ): super().__init__(**kwargs) @@ -179,9 +183,6 @@ def __init__( if use_pretrained_backbone: raise ValueError("Pretrained backbones are not supported yet.") - if backbone_config is not None and backbone is not None: - raise ValueError("You can't specify both `backbone` and `backbone_config`.") - use_autobackbone = False if self.is_hybrid: if backbone_config is None and backbone is None: @@ -193,17 +194,17 @@ def __init__( "out_features": ["stage1", "stage2", "stage3"], "embedding_dynamic_padding": True, } - self.backbone_config = BitConfig(**backbone_config) + backbone_config = BitConfig(**backbone_config) elif isinstance(backbone_config, dict): logger.info("Initializing the config with a `BiT` backbone.") - self.backbone_config = BitConfig(**backbone_config) + backbone_config = BitConfig(**backbone_config) elif isinstance(backbone_config, PretrainedConfig): - self.backbone_config = backbone_config + backbone_config = backbone_config else: raise ValueError( f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}." ) - + self.backbone_config = backbone_config self.backbone_featmap_shape = backbone_featmap_shape self.neck_ignore_stages = neck_ignore_stages @@ -221,14 +222,17 @@ def __init__( self.backbone_config = backbone_config self.backbone_featmap_shape = None self.neck_ignore_stages = [] - else: self.backbone_config = backbone_config self.backbone_featmap_shape = None self.neck_ignore_stages = [] + if use_autobackbone and backbone_config is not None and backbone is not None: + raise ValueError("You can't specify both `backbone` and `backbone_config`.") + self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone self.num_hidden_layers = None if use_autobackbone else num_hidden_layers self.num_attention_heads = None if use_autobackbone else num_attention_heads self.intermediate_size = None if use_autobackbone else intermediate_size diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 09fc6406fd85..e986e71d4851 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -41,7 +41,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, logging -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_dpt import DPTConfig @@ -131,12 +131,10 @@ def __init__(self, config, feature_size=None): patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.backbone = AutoBackbone.from_config(config.backbone_config) + self.backbone = load_backbone(config) feature_dim = self.backbone.channels[-1] - if len(config.backbone_config.out_features) != 3: - raise ValueError( - f"Expected backbone to have 3 output features, got {len(config.backbone_config.out_features)}" - ) + if len(self.backbone.channels) != 3: + raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}") self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage if feature_size is None: @@ -1082,7 +1080,7 @@ def __init__(self, config): self.backbone = None if config.backbone_config is not None and config.is_hybrid is False: - self.backbone = AutoBackbone.from_config(config.backbone_config) + self.backbone = load_backbone(config) else: self.dpt = DPTModel(config, add_pooling_layer=False) diff --git a/src/transformers/models/mask2former/configuration_mask2former.py b/src/transformers/models/mask2former/configuration_mask2former.py index 7202e551a0cb..0d27ba39cbde 100644 --- a/src/transformers/models/mask2former/configuration_mask2former.py +++ b/src/transformers/models/mask2former/configuration_mask2former.py @@ -53,6 +53,9 @@ class Mask2FormerConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. feature_size (`int`, *optional*, defaults to 256): The features (channels) of the resulting feature maps. mask_feature_size (`int`, *optional*, defaults to 256): @@ -162,6 +165,7 @@ def __init__( output_auxiliary_logits: bool = None, backbone=None, use_pretrained_backbone=False, + use_timm_backbone=False, **kwargs, ): if use_pretrained_backbone: @@ -228,6 +232,7 @@ def __init__( self.num_hidden_layers = decoder_layers self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone super().__init__(**kwargs) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index eeee25967e4f..a88028a80717 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -23,7 +23,6 @@ import torch from torch import Tensor, nn -from ... import AutoBackbone from ...activations import ACT2FN from ...file_utils import ( ModelOutput, @@ -36,6 +35,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import logging +from ...utils.backbone_utils import load_backbone from .configuration_mask2former import Mask2FormerConfig @@ -1376,7 +1376,7 @@ def __init__(self, config: Mask2FormerConfig): """ super().__init__() - self.encoder = AutoBackbone.from_config(config.backbone_config) + self.encoder = load_backbone(config) self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: diff --git a/src/transformers/models/maskformer/configuration_maskformer.py b/src/transformers/models/maskformer/configuration_maskformer.py index 3d2814dbfdc1..e906ceb2b39f 100644 --- a/src/transformers/models/maskformer/configuration_maskformer.py +++ b/src/transformers/models/maskformer/configuration_maskformer.py @@ -63,6 +63,9 @@ class MaskFormerConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. decoder_config (`Dict`, *optional*): The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50` will be used. @@ -122,6 +125,7 @@ def __init__( output_auxiliary_logits: Optional[bool] = None, backbone: Optional[str] = None, use_pretrained_backbone: bool = False, + use_timm_backbone: bool = False, **kwargs, ): if use_pretrained_backbone: @@ -193,6 +197,7 @@ def __init__( self.num_hidden_layers = self.decoder_config.num_hidden_layers self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone super().__init__(**kwargs) @classmethod diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index dc46a6e87988..026ea15d4439 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -23,7 +23,6 @@ import torch from torch import Tensor, nn -from ... import AutoBackbone from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutputWithCrossAttentions @@ -37,6 +36,7 @@ replace_return_docstrings, requires_backends, ) +from ...utils.backbone_utils import load_backbone from ..detr import DetrConfig from .configuration_maskformer import MaskFormerConfig from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -1428,14 +1428,13 @@ def __init__(self, config: MaskFormerConfig): The configuration used to instantiate this model. """ super().__init__() - - # TODD: add method to load pretrained weights of backbone - backbone_config = config.backbone_config - if backbone_config.model_type == "swin": + if hasattr(config, "backbone_config") and config.backbone_config.model_type == "swin": # for backwards compatibility + backbone_config = config.backbone_config backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict()) backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"] - self.encoder = AutoBackbone.from_config(backbone_config) + config.backbone_config = backbone_config + self.encoder = load_backbone(config) feature_channels = self.encoder.channels self.decoder = MaskFormerPixelDecoder( diff --git a/src/transformers/models/oneformer/configuration_oneformer.py b/src/transformers/models/oneformer/configuration_oneformer.py index 6cf54947de6b..b88e2c559098 100644 --- a/src/transformers/models/oneformer/configuration_oneformer.py +++ b/src/transformers/models/oneformer/configuration_oneformer.py @@ -50,6 +50,9 @@ class OneFormerConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, defaults to `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. ignore_value (`int`, *optional*, defaults to 255): Values to be ignored in GT label while calculating loss. num_queries (`int`, *optional*, defaults to 150): @@ -152,6 +155,7 @@ def __init__( backbone_config: Optional[Dict] = None, backbone: Optional[str] = None, use_pretrained_backbone: bool = False, + use_timm_backbone: bool = False, ignore_value: int = 255, num_queries: int = 150, no_object_weight: int = 0.1, @@ -222,6 +226,7 @@ def __init__( self.backbone_config = backbone_config self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone self.ignore_value = ignore_value self.num_queries = num_queries self.no_object_weight = no_object_weight diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index d0c0d405502e..894dac10f7ea 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -24,7 +24,6 @@ from torch import Tensor, nn from torch.cuda.amp import autocast -from ... import AutoBackbone from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel @@ -37,6 +36,7 @@ replace_return_docstrings, requires_backends, ) +from ...utils.backbone_utils import load_backbone from .configuration_oneformer import OneFormerConfig @@ -1478,8 +1478,7 @@ def __init__(self, config: OneFormerConfig): The configuration used to instantiate this model. """ super().__init__() - backbone_config = config.backbone_config - self.encoder = AutoBackbone.from_config(backbone_config) + self.encoder = load_backbone(config) self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels) def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput: diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 81afcdc9c18f..19aa680ad038 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -37,7 +37,7 @@ replace_return_docstrings, requires_backends, ) -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_table_transformer import TableTransformerConfig @@ -290,7 +290,7 @@ def __init__(self, config): **kwargs, ) else: - backbone = AutoBackbone.from_config(config.backbone_config) + backbone = load_backbone(config) # replace batch norm by frozen batch norm with torch.no_grad(): diff --git a/src/transformers/models/tvp/configuration_tvp.py b/src/transformers/models/tvp/configuration_tvp.py index 954ee4e90cb1..7e985ab84e30 100644 --- a/src/transformers/models/tvp/configuration_tvp.py +++ b/src/transformers/models/tvp/configuration_tvp.py @@ -49,6 +49,9 @@ class TvpConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, defaults to `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. distance_loss_weight (`float`, *optional*, defaults to 1.0): The weight of distance loss. duration_loss_weight (`float`, *optional*, defaults to 0.1): @@ -103,6 +106,7 @@ def __init__( backbone_config=None, backbone=None, use_pretrained_backbone=False, + use_timm_backbone=False, distance_loss_weight=1.0, duration_loss_weight=0.1, visual_prompter_type="framepad", @@ -143,6 +147,7 @@ def __init__( self.backbone_config = backbone_config self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone self.distance_loss_weight = distance_loss_weight self.duration_loss_weight = duration_loss_weight self.visual_prompter_type = visual_prompter_type diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 04192630eebd..c80cc9df0b35 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -28,7 +28,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import prune_linear_layer from ...utils import logging -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_tvp import TvpConfig @@ -148,7 +148,7 @@ def forward(self, logits, labels): class TvpVisionModel(nn.Module): def __init__(self, config): super().__init__() - self.backbone = AutoBackbone.from_config(config.backbone_config) + self.backbone = load_backbone(config) self.grid_encoder_conv = nn.Conv2d( config.backbone_config.hidden_sizes[-1], config.hidden_size, diff --git a/src/transformers/models/upernet/configuration_upernet.py b/src/transformers/models/upernet/configuration_upernet.py index c4e6f8168f55..9288bd67b610 100644 --- a/src/transformers/models/upernet/configuration_upernet.py +++ b/src/transformers/models/upernet/configuration_upernet.py @@ -42,6 +42,9 @@ class UperNetConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. hidden_size (`int`, *optional*, defaults to 512): The number of hidden units in the convolutional layers. initializer_range (`float`, *optional*, defaults to 0.02): @@ -83,6 +86,7 @@ def __init__( backbone_config=None, backbone=None, use_pretrained_backbone=False, + use_timm_backbone=False, hidden_size=512, initializer_range=0.02, pool_scales=[1, 2, 3, 6], @@ -113,6 +117,7 @@ def __init__( self.backbone_config = backbone_config self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone self.hidden_size = hidden_size self.initializer_range = initializer_range self.pool_scales = pool_scales diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index 2ad8e8c372f1..b889ae4eb4ce 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -20,10 +20,10 @@ from torch import nn from torch.nn import CrossEntropyLoss -from ... import AutoBackbone from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...utils.backbone_utils import load_backbone from .configuration_upernet import UperNetConfig @@ -348,7 +348,7 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel): def __init__(self, config): super().__init__(config) - self.backbone = AutoBackbone.from_config(config.backbone_config) + self.backbone = load_backbone(config) # Semantic segmentation head(s) self.decode_head = UperNetHead(config, in_channels=self.backbone.channels) diff --git a/src/transformers/models/vit_hybrid/configuration_vit_hybrid.py b/src/transformers/models/vit_hybrid/configuration_vit_hybrid.py index b0a37617dc1e..30ebe4fba659 100644 --- a/src/transformers/models/vit_hybrid/configuration_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/configuration_vit_hybrid.py @@ -48,6 +48,9 @@ class ViTHybridConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, defaults to `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. num_hidden_layers (`int`, *optional*, defaults to 12): @@ -100,6 +103,7 @@ def __init__( backbone_config=None, backbone=None, use_pretrained_backbone=False, + use_timm_backbone=False, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, @@ -147,6 +151,7 @@ def __init__( self.backbone_config = backbone_config self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 24b133e27af0..3dc715af511c 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -29,7 +29,7 @@ from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging -from ..auto import AutoBackbone +from ...utils.backbone_utils import load_backbone from .configuration_vit_hybrid import ViTHybridConfig @@ -150,7 +150,7 @@ def __init__(self, config, feature_size=None): image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - self.backbone = AutoBackbone.from_config(config.backbone_config) + self.backbone = load_backbone(config) if self.backbone.config.model_type != "bit": raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.") feature_dim = self.backbone.channels[-1] diff --git a/src/transformers/models/vitmatte/configuration_vitmatte.py b/src/transformers/models/vitmatte/configuration_vitmatte.py index 608b606c9bcb..4d2bcc612fe9 100644 --- a/src/transformers/models/vitmatte/configuration_vitmatte.py +++ b/src/transformers/models/vitmatte/configuration_vitmatte.py @@ -48,6 +48,9 @@ class VitMatteConfig(PretrainedConfig): is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. use_pretrained_backbone (`bool`, *optional*, defaults to `False`): Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. hidden_size (`int`, *optional*, defaults to 384): The number of input channels of the decoder. batch_norm_eps (`float`, *optional*, defaults to 1e-05): @@ -81,6 +84,7 @@ def __init__( backbone_config: PretrainedConfig = None, backbone=None, use_pretrained_backbone=False, + use_timm_backbone=False, hidden_size: int = 384, batch_norm_eps: float = 1e-5, initializer_range: float = 0.02, @@ -107,6 +111,7 @@ def __init__( self.backbone_config = backbone_config self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone self.batch_norm_eps = batch_norm_eps self.hidden_size = hidden_size self.initializer_range = initializer_range diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 01e6ed5aa0a3..465f5da6adf5 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -20,7 +20,6 @@ import torch from torch import nn -from ... import AutoBackbone from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -28,6 +27,7 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) +from ...utils.backbone_utils import load_backbone from .configuration_vitmatte import VitMatteConfig @@ -259,7 +259,7 @@ def __init__(self, config): super().__init__(config) self.config = config - self.backbone = AutoBackbone.from_config(config.backbone_config) + self.backbone = load_backbone(config) self.decoder = VitMatteDetailCaptureModule(config) # Initialize weights and apply final processing diff --git a/tests/models/conditional_detr/test_modeling_conditional_detr.py b/tests/models/conditional_detr/test_modeling_conditional_detr.py index 0bb9388d593f..aa0318f241aa 100644 --- a/tests/models/conditional_detr/test_modeling_conditional_detr.py +++ b/tests/models/conditional_detr/test_modeling_conditional_detr.py @@ -443,6 +443,7 @@ def test_different_timm_backbone(self): # let's pick a random timm backbone config.backbone = "tf_mobilenetv3_small_075" + config.use_timm_backbone = True for model_class in self.all_model_classes: model = model_class(config) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index f1d8b741411f..10ba5d187206 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -219,7 +219,11 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "out_features", "out_indices", "sampling_rate", + # backbone related arguments passed to load_backbone "use_pretrained_backbone", + "backbone", + "backbone_config", + "use_timm_backbone", ] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]