diff --git a/src/transformers/models/rt_detr/configuration_rt_detr.py b/src/transformers/models/rt_detr/configuration_rt_detr.py index 0e34d0376f9f..ca20cc584dfd 100644 --- a/src/transformers/models/rt_detr/configuration_rt_detr.py +++ b/src/transformers/models/rt_detr/configuration_rt_detr.py @@ -55,6 +55,8 @@ class RTDetrConfig(PretrainedConfig): use_timm_backbone (`bool`, *optional*, defaults to `False`): Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers library. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. backbone_kwargs (`dict`, *optional*): Keyword arguments to be passed to AutoBackbone when loading from a checkpoint e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. @@ -190,6 +192,7 @@ def __init__( backbone=None, use_pretrained_backbone=False, use_timm_backbone=False, + freeze_backbone_batch_norms=True, backbone_kwargs=None, # encoder HybridEncoder encoder_hidden_dim=256, @@ -280,6 +283,7 @@ def __init__( self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone self.use_timm_backbone = use_timm_backbone + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms self.backbone_kwargs = backbone_kwargs # encoder self.encoder_hidden_dim = encoder_hidden_dim diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 3f476725941e..ab83a81f5067 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -559,9 +559,10 @@ def __init__(self, config): backbone = load_backbone(config) - # replace batch norm by frozen batch norm - with torch.no_grad(): - replace_batch_norm(backbone) + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) self.model = backbone self.intermediate_channel_sizes = self.model.channels