From be0696425acb903013892a1665417f38d19b4f35 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Fri, 13 Sep 2024 17:09:31 -0400 Subject: [PATCH] Move timm vit encoder from bioscan project into modules (#18) Co-authored-by: GitHub Actions --- mmlearn/modules/encoders/vision.py | 158 ++++++++++++++++++++++ projects/bioscan_clip/configs/__init__.py | 3 +- projects/bioscan_clip/encoders.py | 60 -------- 3 files changed, 160 insertions(+), 61 deletions(-) create mode 100644 mmlearn/modules/encoders/vision.py diff --git a/mmlearn/modules/encoders/vision.py b/mmlearn/modules/encoders/vision.py new file mode 100644 index 0000000..6b06a60 --- /dev/null +++ b/mmlearn/modules/encoders/vision.py @@ -0,0 +1,158 @@ +"""Vision encoder implementations.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import timm +import torch +from hydra_zen import store +from peft import PeftConfig +from timm.models.vision_transformer import VisionTransformer +from torch import nn +from transformers.modeling_outputs import BaseModelOutput + +from mmlearn import hf_utils +from mmlearn.datasets.core.modalities import Modalities, Modality + + +@store( + group="modules/encoders", + provider="mmlearn", + model_name_or_path="vit_base_patch16_224", + hydra_convert="object", +) +class TimmViT(nn.Module): + """Vision Transformer model from timm. + + Parameters + ---------- + model_name : str + The name of the model to use. + projection_dim : int, default=768 + The dimension of the projection head. + pretrained : bool, default=True + Whether to use the pretrained weights. + freeze_layers : Union[int, float, List[int], bool], default=False + Whether to freeze the layers. + freeze_layer_norm : bool, default=True + Whether to freeze the layer norm. + peft_config : Optional[PeftConfig], default=None + The PEFT configuration. + model_kwargs : Optional[Dict[str, Any]], default=None + Additional keyword arguments for the model. + """ + + def __init__( + self, + model_name: str, + projection_dim: int = 768, + pretrained: bool = True, + freeze_layers: Union[int, float, List[int], bool] = False, + freeze_layer_norm: bool = True, + peft_config: Optional[PeftConfig] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the Vision Transformer model.""" + super().__init__() + if model_kwargs is None: + model_kwargs = {} + + self.model: VisionTransformer = timm.create_model( + model_name, + pretrained=pretrained, + num_classes=projection_dim, + **model_kwargs, + ) + assert isinstance(self.model, VisionTransformer), ( + f"Model {model_name} is not a Vision Transformer. " + "Please provide a model name that corresponds to a Vision Transformer." + ) + + self._freeze_layers(freeze_layers, freeze_layer_norm) + + if peft_config is not None: + self.model = hf_utils._wrap_peft_model(self.model, peft_config) + + def _freeze_layers( + self, freeze_layers: Union[int, float, List[int], bool], freeze_layer_norm: bool + ) -> None: + """Freeze the layers of the model. + + Parameters + ---------- + freeze_layers : Union[int, float, List[int], bool] + Whether to freeze the layers. + freeze_layer_norm : bool + Whether to freeze the layer norm. + """ + if isinstance(freeze_layers, bool) and freeze_layers: + for name, param in self.model.named_parameters(): + param.requires_grad = ( + (not freeze_layer_norm) if "norm" in name else False + ) + + modules = [self.model.patch_embed, *self.model.blocks, self.model.norm] + if isinstance(freeze_layers, float): + freeze_layers = int(freeze_layers * len(modules)) + if isinstance(freeze_layers, int): + freeze_layers = list(range(freeze_layers)) + + if isinstance(freeze_layers, list): + for idx, module in enumerate(modules): + if idx in freeze_layers: + for name, param in module.named_parameters(): + param.requires_grad = ( + (not freeze_layer_norm) if "norm" in name else False + ) + + def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput: + """Run the forward pass. + + Parameters + ---------- + inputs : Dict[str | Modality, Any] + The input data. The `image` will be expected under the `Modalities.RGB` key. + + Returns + ------- + BaseModelOutput + The output of the model. + """ + x = inputs[Modalities.RGB] + _, intermediates = self.model.forward_intermediates(x) + + return BaseModelOutput( + last_hidden_state=intermediates[-1], + hidden_states=intermediates, + attentions=None, + ) + + def get_intermediate_layers( + self, inputs: Dict[Union[str, Modality], Any], n: int = 1 + ) -> List[torch.Tensor]: + """Get the output of the intermediate layers. + + Parameters + ---------- + inputs : Dict[Union[str, Modality], Any] + The input data. The `image` will be expected under the `Modalities.RGB` key. + n : int, default=1 + The number of intermediate layers to return. + + Returns + ------- + List[torch.Tensor] + The outputs of the last n intermediate layers. + """ + return self.model.get_intermediate_layers(inputs[Modalities.RGB], n) # type: ignore + + def get_patch_info(self) -> Tuple[int, int]: + """Get patch size and number of patches. + + Returns + ------- + Tuple[int, int] + Patch size and number of patches. + """ + patch_size = self.model.patch_embed.patch_size[0] + num_patches = self.model.patch_embed.num_patches + return patch_size, num_patches diff --git a/projects/bioscan_clip/configs/__init__.py b/projects/bioscan_clip/configs/__init__.py index 2162684..4324caf 100644 --- a/projects/bioscan_clip/configs/__init__.py +++ b/projects/bioscan_clip/configs/__init__.py @@ -7,8 +7,9 @@ from mmlearn.conf import external_store from mmlearn.modules.encoders.hf_text_encoders import HFTextEncoder +from mmlearn.modules.encoders.vision import TimmViT -from projects.bioscan_clip.encoders import TimmViT, BarcodeBERT +from projects.bioscan_clip.encoders import BarcodeBERT from projects.bioscan_clip.dataset import BIOSCANInsectDataset from projects.bioscan_clip.eval_task import TaxonomicClassification from projects.bioscan_clip.dna_tokenizer import DNAProcessor diff --git a/projects/bioscan_clip/encoders.py b/projects/bioscan_clip/encoders.py index 366463b..6b83fee 100644 --- a/projects/bioscan_clip/encoders.py +++ b/projects/bioscan_clip/encoders.py @@ -2,8 +2,6 @@ import warnings from peft import PeftConfig -import timm -from timm.models.vision_transformer import VisionTransformer import torch from torch import nn from transformers.modeling_outputs import BaseModelOutput @@ -13,64 +11,6 @@ from mmlearn.datasets.core.modalities import Modalities, Modality -class TimmViT(nn.Module): - def __init__( - self, - model_name: str, - projection_dim: int = 768, - pretrained: bool = True, - freeze_layers: Union[int, float, List[int], bool] = False, - freeze_layer_norm: bool = True, - peft_config: Optional[PeftConfig] = None, - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - super().__init__() - if model_kwargs is None: - model_kwargs = {} - - model: nn.Module = timm.create_model( - model_name, - pretrained=pretrained, - num_classes=projection_dim, - **model_kwargs, - ) - assert isinstance(model, VisionTransformer), ( - f"Model {model_name} is not a Vision Transformer. " - "Please provide a model name that corresponds to a Vision Transformer." - ) - - if isinstance(freeze_layers, bool) and freeze_layers: - for name, param in model.named_parameters(): - param.requires_grad = ( - (not freeze_layer_norm) if "norm" in name else False - ) - - modules = [model.patch_embed, *model.blocks, model.norm] - if isinstance(freeze_layers, float): - freeze_layers = int(freeze_layers * len(modules)) - if isinstance(freeze_layers, int): - freeze_layers = list(range(freeze_layers)) - - if isinstance(freeze_layers, list): - for idx, module in enumerate(modules): - if idx in freeze_layers: - for name, param in module.named_parameters(): - param.requires_grad = ( - (not freeze_layer_norm) if "norm" in name else False - ) - - if peft_config is not None: - model = hf_utils._wrap_peft_model(model, peft_config) - - self.model = model - - def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput: - """Run the forward pass.""" - outputs = self.model(inputs[Modalities.RGB]) - - return BaseModelOutput(last_hidden_state=outputs) - - class BarcodeBERT(nn.Module): def __init__( self,