diff --git a/mmlearn/modules/encoders/vision.py b/mmlearn/modules/encoders/vision.py index 94bb8d5..669d7f0 100644 --- a/mmlearn/modules/encoders/vision.py +++ b/mmlearn/modules/encoders/vision.py @@ -13,7 +13,7 @@ from transformers.modeling_outputs import BaseModelOutput from mmlearn import hf_utils -from mmlearn.datasets.core.modalities import Modalities, Modality +from mmlearn.datasets.core.modalities import Modalities from mmlearn.datasets.processors.masking import apply_masks from mmlearn.datasets.processors.transforms import ( repeat_interleave_batch, @@ -137,13 +137,13 @@ def forward(self, inputs: Dict[str, Any]) -> BaseModelOutput: ) def get_intermediate_layers( - self, inputs: Dict[Union[str, Modality], Any], n: int = 1 + self, inputs: Dict[str, Any], n: int = 1 ) -> List[torch.Tensor]: """Get the output of the intermediate layers. Parameters ---------- - inputs : Dict[Union[str, Modality], Any] + inputs : Dict[str, Any] The input data. The `image` will be expected under the `Modalities.RGB` key. n : int, default=1 The number of intermediate layers to return. diff --git a/mmlearn/tasks/contrastive_pretraining.py b/mmlearn/tasks/contrastive_pretraining.py index b91296f..04146c7 100644 --- a/mmlearn/tasks/contrastive_pretraining.py +++ b/mmlearn/tasks/contrastive_pretraining.py @@ -497,7 +497,7 @@ def validation_step( Parameters ---------- - batch : Dict[Union[str, Modality], Any] + batch : Dict[str, torch.Tensor] The batch of data to process. batch_idx : int The index of the batch. @@ -524,7 +524,7 @@ def test_step( Parameters ---------- - batch : Dict[Union[str, Modality], Any] + batch : Dict[str, torch.Tensor] The batch of data to process. batch_idx : int The index of the batch. @@ -644,7 +644,7 @@ def _shared_eval_step( Parameters ---------- - batch : Dict[Union[str, Modality], Any] + batch : Dict[str, torch.Tensor] The batch of data to process. batch_idx : int The index of the batch.