Skip to content

Commit

Permalink
added extract_features.py to master branch
Browse files Browse the repository at this point in the history
  • Loading branch information
clemsgrs committed Mar 7, 2024
1 parent 6f17006 commit 48a2332
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 25 deletions.
23 changes: 23 additions & 0 deletions dino/config/features.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
data_dir: '/data/pathology/projects/pathology-vlfm/data/downstream_patches_panda_fold_0'

output_dir: 'output'
experiment_name: 'feature_extraction'

img_size: 256
patch_size: 16

num_workers: 4
batch_size: 1

pretrain_vit_patch: '/data/pathology/projects/ais-cap/dataset/panda/hipt/dino/5-fold/vit_256_small_dino_fold_0.pt'
img_size_pretrained:

wandb:
enable: false
project: 'dino'
username: 'vlfm'
exp_name: '${experiment_name}'
tags:
dir: '/home/user/'
group:
resume_id:
7 changes: 6 additions & 1 deletion dino/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .dataset import ImagePretrainingDataset, HierarchicalPretrainingDataset
from .augmentations import PatchDataAugmentationDINO, RegionDataAugmentationDINO
from .datasets import ImageFolderWithNameDataset
from .augmentations import (
PatchDataAugmentationDINO,
RegionDataAugmentationDINO,
make_classification_eval_transform,
)
68 changes: 62 additions & 6 deletions dino/data/augmentations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
import torch
import random

from torchvision import transforms
from PIL import ImageFilter, ImageOps
from typing import Sequence


# Use timm's names
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


def make_normalize_transform(
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Normalize:
return transforms.Normalize(mean=mean, std=std)


class MaybeToTensor(transforms.ToTensor):
"""
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
"""

def __call__(self, pic):
"""
Args:
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, torch.Tensor):
return pic
return super().__call__(pic)


class GaussianBlur(object):
Expand Down Expand Up @@ -42,7 +73,15 @@ def __call__(self, img):


class PatchDataAugmentationDINO(object):
def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
def __init__(
self,
global_crops_scale,
local_crops_scale,
local_crops_number,
interpolation=transforms.InterpolationMode.BICUBIC,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
):
flip_and_color_jitter = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=0.5),
Expand All @@ -59,8 +98,8 @@ def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
)
normalize = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
)

Expand All @@ -73,7 +112,7 @@ def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
transforms.RandomResizedCrop(
global_crop_size,
scale=global_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC,
interpolation=interpolation,
),
flip_and_color_jitter,
GaussianBlur(1.0),
Expand All @@ -86,7 +125,7 @@ def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
transforms.RandomResizedCrop(
global_crop_size,
scale=global_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC,
interpolation=interpolation,
),
flip_and_color_jitter,
GaussianBlur(0.1),
Expand All @@ -101,7 +140,7 @@ def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
transforms.RandomResizedCrop(
local_crop_size,
scale=local_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC,
interpolation=interpolation,
),
flip_and_color_jitter,
GaussianBlur(p=0.5),
Expand Down Expand Up @@ -170,3 +209,20 @@ def __call__(self, x):
for _ in range(self.local_crops_number):
crops.append(self.local_transfo(x))
return crops


def make_classification_eval_transform(
*,
resize_size: int = 256,
interpolation=transforms.InterpolationMode.BICUBIC,
crop_size: int = 224,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Compose:
transforms_list = [
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
return transforms.Compose(transforms_list)
19 changes: 1 addition & 18 deletions dino/eval_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,7 @@
import dino.models.vision_transformer as vits
from dino.data import ImagePretrainingDataset
from dino.log import initialize_wandb


def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True


def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()


def is_main_process():
return get_rank() == 0
from dino.distributed import is_main_process


class ReturnIndexDataset(ImagePretrainingDataset):
Expand Down
155 changes: 155 additions & 0 deletions dino/extract_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import os
import tqdm
import wandb
import torch
import hydra
import datetime
import pandas as pd
import multiprocessing as mp

from pathlib import Path
from omegaconf import DictConfig

from dino.models import PatchEmbedder
from dino.log import initialize_wandb
from dino.distributed import is_main_process
from dino.data import ImageFolderWithNameDataset, make_classification_eval_transform


@hydra.main(
version_base="1.2.0", config_path="config/feature_extraction", config_name="patch"
)
def main(cfg: DictConfig):
run_distributed = torch.cuda.device_count() > 1
if run_distributed:
torch.distributed.init_process_group(backend="nccl")
gpu_id = int(os.environ["LOCAL_RANK"])
if gpu_id == 0:
print("Distributed session successfully initialized")
else:
gpu_id = -1

if is_main_process():
print(f"torch.cuda.device_count(): {torch.cuda.device_count()}")
run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
# set up wandb
if cfg.wandb.enable:
key = os.environ.get("WANDB_API_KEY")
wandb_run = initialize_wandb(cfg, key=key)
wandb_run.define_metric("processed", summary="max")
run_id = wandb_run.id
else:
run_id = ""

if run_distributed:
obj = [run_id]
torch.distributed.broadcast_object_list(
obj, 0, device=torch.device(f"cuda:{gpu_id}")
)
run_id = obj[0]

output_dir = Path(cfg.output_dir, cfg.experiment_name, run_id)
features_dir = Path(output_dir, "features")
if is_main_process():
if output_dir.exists():
print(f"{output_dir} already exists! deleting it...")
output_dir.mkdir(parents=True, exist_ok=True)
features_dir.mkdir(exist_ok=True)

model = PatchEmbedder(
img_size=cfg.img_size,
mini_patch_size=cfg.patch_size,
pretrain_vit_patch=cfg.pretrain_vit_patch,
verbose=(gpu_id in [-1, 0]),
img_size_pretrained=cfg.img_size_pretrained,
)

transform = make_classification_eval_transform()
dataset = ImageFolderWithNameDataset(cfg.data_dir, transform)

if run_distributed:
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
else:
sampler = torch.utils.data.RandomSampler(dataset)

num_workers = min(mp.cpu_count(), cfg.num_workers)
if "SLURM_JOB_CPUS_PER_NODE" in os.environ:
num_workers = min(num_workers, int(os.environ["SLURM_JOB_CPUS_PER_NODE"]))

loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=cfg.batch_size,
num_workers=num_workers,
shuffle=False,
drop_last=False,
)

if gpu_id == -1:
device = torch.device("cuda")
else:
device = torch.device(f"cuda:{gpu_id}")
model = model.to(device, non_blocking=True)

if is_main_process():
print()

filenames, feature_paths = [], []

with tqdm.tqdm(
loader,
desc="Feature Extraction",
unit=" img",
ncols=80,
unit_scale=cfg.batch_size,
position=0,
leave=True,
disable=not (gpu_id in [-1, 0]),
) as t1:
with torch.no_grad():
for i, batch in enumerate(t1):
imgs, fnames = batch
imgs = imgs.to(device, non_blocking=True)
features = model(imgs)
for k, f in enumerate(features):
fname = fnames[k]
feature_path = Path(features_dir, f"{fname}.pt")
torch.save(f, feature_path)
filenames.append(fname)
feature_paths.append(feature_path)
if cfg.wandb.enable and not run_distributed:
wandb.log({"processed": i + imgs.shape[0]})

features_df = pd.DataFrame.from_dict(
{
"filename": filenames,
"feature_path": feature_paths,
}
)

if run_distributed:
features_csv_path = Path(output_dir, f"features_{gpu_id}.csv")
else:
features_csv_path = Path(output_dir, "features.csv")
features_df.to_csv(features_csv_path, index=False)

if run_distributed:
torch.distributed.barrier()
if is_main_process():
dfs = []
for gpu_id in range(torch.cuda.device_count()):
fp = Path(output_dir, f"features_{gpu_id}.csv")
df = pd.read_csv(fp)
dfs.append(df)
os.remove(fp)
features_df = pd.concat(dfs, ignore_index=True)
features_df = features_df.drop_duplicates()
features_df.to_csv(Path(output_dir, "features.csv"), index=False)

if cfg.wandb.enable and is_main_process() and run_distributed:
wandb.log({"processed": len(features_df)})


if __name__ == "__main__":
# python3 -m torch.distributed.run --standalone --nproc_per_node=gpu extract_features_patch.py --config-name 'patch'
main()
Loading

0 comments on commit 48a2332

Please sign in to comment.