Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add projects package #3

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,5 @@ repos:
entry: python3 -m pytest -m "not integration_test"
pass_filenames: false
always_run: true

exclude: "projects"
10 changes: 0 additions & 10 deletions mmlearn/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,8 @@
from mmlearn.datasets.imagenet import ImageNet
from mmlearn.datasets.librispeech import LibriSpeech
from mmlearn.datasets.llvip import LLVIPDataset
from mmlearn.datasets.medvqa import MedVQA
from mmlearn.datasets.mimiciv_cxr import MIMICIVCXR
from mmlearn.datasets.nihcxr import NIHCXR
from mmlearn.datasets.nyuv2 import NYUv2Dataset
from mmlearn.datasets.pmcoa import PMCOA
from mmlearn.datasets.quilt import Quilt
from mmlearn.datasets.roco import ROCO
from mmlearn.datasets.sunrgbd import SUNRGBDDataset


Expand All @@ -21,12 +16,7 @@
"ImageNet",
"LibriSpeech",
"LLVIPDataset",
"MedVQA",
"MIMICIVCXR",
"NIHCXR",
"NYUv2Dataset",
"PMCOA",
"Quilt",
"ROCO",
"SUNRGBDDataset",
]
8 changes: 1 addition & 7 deletions mmlearn/datasets/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,12 @@
RandomMaskGenerator,
)
from mmlearn.datasets.processors.tokenizers import HFTokenizer
from mmlearn.datasets.processors.transforms import (
MedVQAProcessor,
TrimText,
med_clip_vision_transform,
)
from mmlearn.datasets.processors.transforms import TrimText


__all__ = [
"BlockwiseImagePatchMaskGenerator",
"HFTokenizer",
"MedVQAProcessor",
"RandomMaskGenerator",
"TrimText",
"med_clip_vision_transform",
]
73 changes: 1 addition & 72 deletions mmlearn/datasets/processors/transforms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Custom transforms for datasets."""

from typing import List, Literal, Union
from typing import List, Union

from hydra_zen import store
from timm.data.transforms import ResizeKeepRatio
from torchvision import transforms


@store(group="datasets/transforms", provider="mmlearn")
Expand All @@ -30,72 +28,3 @@ def __call__(self, sentence: Union[str, List[str]]) -> Union[str, List[str]]:
sentence[i] = s[: self.trim_size]

return sentence


class MedVQAProcessor:
"""Preprocessor for textual reports of MedVQA datasets."""

def __call__(self, sentence: Union[str, List[str]]) -> Union[str, List[str]]:
"""Process the textual captions."""
if not isinstance(sentence, (list, str)):
raise TypeError(
f"Expected sentence to be a string or list of strings, got {type(sentence)}"
)

def _preprocess_sentence(sentence: str) -> str:
sentence = sentence.lower()
if "? -yes/no" in sentence:
sentence = sentence.replace("? -yes/no", "")
if "? -open" in sentence:
sentence = sentence.replace("? -open", "")
if "? - open" in sentence:
sentence = sentence.replace("? - open", "")
return (
sentence.replace(",", "")
.replace("?", "")
.replace("'s", " 's")
.replace("...", "")
.replace("x ray", "x-ray")
.replace(".", "")
)

if isinstance(sentence, str):
return _preprocess_sentence(sentence)

for i, s in enumerate(sentence):
sentence[i] = _preprocess_sentence(s)

return sentence


@store(group="datasets/transforms", provider="mmlearn") # type: ignore[misc]
def med_clip_vision_transform(
image_crop_size: int = 224, job_type: Literal["train", "eval"] = "train"
) -> transforms.Compose:
"""Return transforms for training/evaluating CLIP with medical images.

Parameters
----------
image_crop_size : int, default=224
Size of the image crop.
job_type : {"train", "eval"}, default="train"
Type of the job (training or evaluation) for which the transforms are needed.

Returns
-------
transforms.Compose
Composed transforms for training CLIP with medical images.
"""
return transforms.Compose(
[
ResizeKeepRatio(512, interpolation="bicubic"),
transforms.RandomCrop(image_crop_size)
if job_type == "train"
else transforms.CenterCrop(image_crop_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
Empty file added projects/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions projects/med_benchmarking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## Benchmarking CLIP-style Methods on Medical Data
Prior to running any experiments under this project, please install the required dependencies by running the following command:
```bash
pip install -r requirements.txt
```
**NOTE**: It is assumed that the requirements for the `mmlearn` package have already been installed in a virtual environment.
If not, please refer to the README file in the `mmlearn` package for installation instructions.

Also, please make sure to set the following environment variables:
```bash
export MIMICIVCXR_ROOT_DIR=/path/to/mimic-cxr/data
export PMCOA_ROOT_DIR=/path/to/pmc_oa/data
export QUILT_ROOT_DIR=/path/to/quilt/data
export ROCO_ROOT_DIR=/path/to/roco/data
```

If you are running an experiment with the MedVQA dataset, please also set the following environment variables:
```bash
export PATHVQA_ROOT_DIR=/path/to/pathvqa/data
export VQARAD_ROOT_DIR=/path/to/vqarad/data
```

To run an experiment, use the following command:

**To Run Locally**:
```bash
mmlearn_run 'hydra.searchpath=[pkg://projects.med_benchmarking.configs]' +experiment=baseline experiment_name=test
```

**To Run on a SLURM Cluster**:
```bash
mmlearn_run --multirun hydra.launcher.mem_gb=32 hydra.launcher.qos=your_qos hydra.launcher.partition=your_partition hydra.launcher.gres=gpu:4 hydra.launcher.cpus_per_task=8 hydra.launcher.tasks_per_node=4 hydra.launcher.nodes=1 hydra.launcher.stderr_to_stdout=true hydra.launcher.timeout_min=60 '+hydra.launcher.additional_parameters={export: ALL}' 'hydra.searchpath=[pkg://projects.med_benchmarking.configs]' +experiment=baseline experiment_name=test
```
82 changes: 82 additions & 0 deletions projects/med_benchmarking/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
from typing import Literal

from hydra_zen import builds, store
from omegaconf import MISSING
from timm.data.transforms import ResizeKeepRatio
from torchvision import transforms

from mmlearn.conf import external_store
from projects.med_benchmarking.datasets.medvqa import MedVQA, MedVQAProcessor
from projects.med_benchmarking.datasets.mimiciv_cxr import MIMICIVCXR
from projects.med_benchmarking.datasets.pmcoa import PMCOA
from projects.med_benchmarking.datasets.quilt import Quilt
from projects.med_benchmarking.datasets.roco import ROCO


_MedVQAConf = builds(
MedVQA,
split="train",
encoder={"image_size": 224, "feat_dim": 512, "images_filename": "images_clip.pkl"},
autoencoder={
"available": True,
"image_size": 128,
"feat_dim": 64,
"images_filename": "images128x128.pkl",
},
num_ans_candidates=MISSING,
)
_PathVQAConf = builds(
MedVQA,
root_dir=os.getenv("PATHVQA_ROOT_DIR", MISSING),
num_ans_candidates=3974,
autoencoder={"available": False},
builds_bases=(_MedVQAConf,),
)
_VQARADConf = builds(
MedVQA,
root_dir=os.getenv("VQARAD_ROOT_DIR", MISSING),
num_ans_candidates=458,
autoencoder={"available": False},
builds_bases=(_MedVQAConf,),
)
external_store(_MedVQAConf, name="MedVQA", group="datasets")
external_store(_PathVQAConf, name="PathVQA", group="datasets")
external_store(_VQARADConf, name="VQARAD", group="datasets")

external_store(MedVQAProcessor, name="MedVQAProcessor", group="datasets/transforms")


@external_store(group="datasets/transforms")
def med_clip_vision_transform(
image_crop_size: int = 224, job_type: Literal["train", "eval"] = "train"
) -> transforms.Compose:
"""Return transforms for training/evaluating CLIP with medical images.

Parameters
----------
image_crop_size : int, default=224
Size of the image crop.
job_type : {"train", "eval"}, default="train"
Type of the job (training or evaluation) for which the transforms are needed.

Returns
-------
transforms.Compose
Composed transforms for training CLIP with medical images.
"""
return transforms.Compose(
[
ResizeKeepRatio(
512 if job_type == "train" else image_crop_size, interpolation="bicubic"
),
transforms.RandomCrop(image_crop_size)
if job_type == "train"
else transforms.CenterCrop(image_crop_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ task:
eps: 1.0e-6
lr_scheduler:
scheduler:
T_max: 537_775
T_max: 107_555 # make sure to change this if max_epochs or accumulate_grad_batches is changed
extras:
interval: step
loss:
Expand All @@ -80,15 +80,15 @@ task:
task_specs:
- query_modality: text
target_modality: rgb
top_k: [200]
top_k: [10, 200]
- query_modality: rgb
target_modality: text
top_k: [200]
top_k: [10, 200]
run_on_validation: false
run_on_test: true

trainer:
max_epochs: 100
max_epochs: 20
precision: 16-mixed
deterministic: False
benchmark: True
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import json
import os
import warnings
from typing import Any, Callable, Dict, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional, Union

import numpy as np
import torch
from hydra_zen import MISSING, builds, store
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import CenterCrop, Compose, Grayscale, Resize, ToTensor
Expand Down Expand Up @@ -205,32 +204,37 @@ def __len__(self) -> int:
return len(self.entries)


_MedVQAConf = builds(
MedVQA,
split="train",
encoder={"image_size": 224, "feat_dim": 512, "images_filename": "images_clip.pkl"},
autoencoder={
"available": True,
"image_size": 128,
"feat_dim": 64,
"images_filename": "images128x128.pkl",
},
num_ans_candidates=MISSING,
)
_PathVQAConf = builds(
MedVQA,
root_dir=os.getenv("PATHVQA_ROOT_DIR", MISSING),
num_ans_candidates=3974,
autoencoder={"available": False},
builds_bases=(_MedVQAConf,),
)
_VQARADConf = builds(
MedVQA,
root_dir=os.getenv("VQARAD_ROOT_DIR", MISSING),
num_ans_candidates=458,
autoencoder={"available": False},
builds_bases=(_MedVQAConf,),
)
store(_MedVQAConf, name="MedVQA", group="datasets", provider="mmlearn")
store(_PathVQAConf, name="PathVQA", group="datasets", provider="mmlearn")
store(_VQARADConf, name="VQARAD", group="datasets", provider="mmlearn")
class MedVQAProcessor:
"""Preprocessor for textual reports of MedVQA datasets."""

def __call__(self, sentence: Union[str, List[str]]) -> Union[str, List[str]]:
"""Process the textual captions."""
if not isinstance(sentence, (list, str)):
raise TypeError(
f"Expected sentence to be a string or list of strings, got {type(sentence)}"
)

def _preprocess_sentence(sentence: str) -> str:
sentence = sentence.lower()
if "? -yes/no" in sentence:
sentence = sentence.replace("? -yes/no", "")
if "? -open" in sentence:
sentence = sentence.replace("? -open", "")
if "? - open" in sentence:
sentence = sentence.replace("? - open", "")
return (
sentence.replace(",", "")
.replace("?", "")
.replace("'s", " 's")
.replace("...", "")
.replace("x ray", "x-ray")
.replace(".", "")
)

if isinstance(sentence, str):
return _preprocess_sentence(sentence)

for i, s in enumerate(sentence):
sentence[i] = _preprocess_sentence(s)

return sentence
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import numpy as np
import pandas as pd
import torch
from hydra_zen import MISSING, store
from omegaconf import MISSING
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor
from tqdm import tqdm

from mmlearn.conf import external_store
from mmlearn.constants import EXAMPLE_INDEX_KEY
from mmlearn.datasets.core import Modalities
from mmlearn.datasets.core.example import Example
Expand All @@ -22,9 +23,8 @@
logger = logging.getLogger(__name__)


@store(
@external_store(
group="datasets",
provider="mmlearn",
root_dir=os.getenv("MIMICIVCXR_ROOT_DIR", MISSING),
split="train",
labeler="double_image",
Expand Down
Loading