Skip to content

Commit

Permalink
Update imports and type annotations in vision.py and contrastive_pret…
Browse files Browse the repository at this point in the history
…raining.py
  • Loading branch information
fcogidi committed Nov 4, 2024
1 parent c8e6afb commit f1abf92
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions mmlearn/modules/encoders/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers.modeling_outputs import BaseModelOutput

from mmlearn import hf_utils
from mmlearn.datasets.core.modalities import Modalities, Modality
from mmlearn.datasets.core.modalities import Modalities
from mmlearn.datasets.processors.masking import apply_masks
from mmlearn.datasets.processors.transforms import (
repeat_interleave_batch,
Expand Down Expand Up @@ -137,13 +137,13 @@ def forward(self, inputs: Dict[str, Any]) -> BaseModelOutput:
)

def get_intermediate_layers(
self, inputs: Dict[Union[str, Modality], Any], n: int = 1
self, inputs: Dict[str, Any], n: int = 1
) -> List[torch.Tensor]:
"""Get the output of the intermediate layers.
Parameters
----------
inputs : Dict[Union[str, Modality], Any]
inputs : Dict[str, Any]
The input data. The `image` will be expected under the `Modalities.RGB` key.
n : int, default=1
The number of intermediate layers to return.
Expand Down
6 changes: 3 additions & 3 deletions mmlearn/tasks/contrastive_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def validation_step(
Parameters
----------
batch : Dict[Union[str, Modality], Any]
batch : Dict[str, torch.Tensor]
The batch of data to process.
batch_idx : int
The index of the batch.
Expand All @@ -524,7 +524,7 @@ def test_step(
Parameters
----------
batch : Dict[Union[str, Modality], Any]
batch : Dict[str, torch.Tensor]
The batch of data to process.
batch_idx : int
The index of the batch.
Expand Down Expand Up @@ -644,7 +644,7 @@ def _shared_eval_step(
Parameters
----------
batch : Dict[Union[str, Modality], Any]
batch : Dict[str, torch.Tensor]
The batch of data to process.
batch_idx : int
The index of the batch.
Expand Down

0 comments on commit f1abf92

Please sign in to comment.