-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added extract_features.py to master branch
- Loading branch information
Showing
7 changed files
with
312 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
data_dir: '/data/pathology/projects/pathology-vlfm/data/downstream_patches_panda_fold_0' | ||
|
||
output_dir: 'output' | ||
experiment_name: 'feature_extraction' | ||
|
||
img_size: 256 | ||
patch_size: 16 | ||
|
||
num_workers: 4 | ||
batch_size: 1 | ||
|
||
pretrain_vit_patch: '/data/pathology/projects/ais-cap/dataset/panda/hipt/dino/5-fold/vit_256_small_dino_fold_0.pt' | ||
img_size_pretrained: | ||
|
||
wandb: | ||
enable: false | ||
project: 'dino' | ||
username: 'vlfm' | ||
exp_name: '${experiment_name}' | ||
tags: | ||
dir: '/home/user/' | ||
group: | ||
resume_id: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,7 @@ | ||
from .dataset import ImagePretrainingDataset, HierarchicalPretrainingDataset | ||
from .augmentations import PatchDataAugmentationDINO, RegionDataAugmentationDINO | ||
from .datasets import ImageFolderWithNameDataset | ||
from .augmentations import ( | ||
PatchDataAugmentationDINO, | ||
RegionDataAugmentationDINO, | ||
make_classification_eval_transform, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import os | ||
import tqdm | ||
import wandb | ||
import torch | ||
import hydra | ||
import datetime | ||
import pandas as pd | ||
import multiprocessing as mp | ||
|
||
from pathlib import Path | ||
from omegaconf import DictConfig | ||
|
||
from dino.models import PatchEmbedder | ||
from dino.log import initialize_wandb | ||
from dino.distributed import is_main_process | ||
from dino.data import ImageFolderWithNameDataset, make_classification_eval_transform | ||
|
||
|
||
@hydra.main( | ||
version_base="1.2.0", config_path="config/feature_extraction", config_name="patch" | ||
) | ||
def main(cfg: DictConfig): | ||
run_distributed = torch.cuda.device_count() > 1 | ||
if run_distributed: | ||
torch.distributed.init_process_group(backend="nccl") | ||
gpu_id = int(os.environ["LOCAL_RANK"]) | ||
if gpu_id == 0: | ||
print("Distributed session successfully initialized") | ||
else: | ||
gpu_id = -1 | ||
|
||
if is_main_process(): | ||
print(f"torch.cuda.device_count(): {torch.cuda.device_count()}") | ||
run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M") | ||
# set up wandb | ||
if cfg.wandb.enable: | ||
key = os.environ.get("WANDB_API_KEY") | ||
wandb_run = initialize_wandb(cfg, key=key) | ||
wandb_run.define_metric("processed", summary="max") | ||
run_id = wandb_run.id | ||
else: | ||
run_id = "" | ||
|
||
if run_distributed: | ||
obj = [run_id] | ||
torch.distributed.broadcast_object_list( | ||
obj, 0, device=torch.device(f"cuda:{gpu_id}") | ||
) | ||
run_id = obj[0] | ||
|
||
output_dir = Path(cfg.output_dir, cfg.experiment_name, run_id) | ||
features_dir = Path(output_dir, "features") | ||
if is_main_process(): | ||
if output_dir.exists(): | ||
print(f"{output_dir} already exists! deleting it...") | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
features_dir.mkdir(exist_ok=True) | ||
|
||
model = PatchEmbedder( | ||
img_size=cfg.img_size, | ||
mini_patch_size=cfg.patch_size, | ||
pretrain_vit_patch=cfg.pretrain_vit_patch, | ||
verbose=(gpu_id in [-1, 0]), | ||
img_size_pretrained=cfg.img_size_pretrained, | ||
) | ||
|
||
transform = make_classification_eval_transform() | ||
dataset = ImageFolderWithNameDataset(cfg.data_dir, transform) | ||
|
||
if run_distributed: | ||
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) | ||
else: | ||
sampler = torch.utils.data.RandomSampler(dataset) | ||
|
||
num_workers = min(mp.cpu_count(), cfg.num_workers) | ||
if "SLURM_JOB_CPUS_PER_NODE" in os.environ: | ||
num_workers = min(num_workers, int(os.environ["SLURM_JOB_CPUS_PER_NODE"])) | ||
|
||
loader = torch.utils.data.DataLoader( | ||
dataset, | ||
sampler=sampler, | ||
batch_size=cfg.batch_size, | ||
num_workers=num_workers, | ||
shuffle=False, | ||
drop_last=False, | ||
) | ||
|
||
if gpu_id == -1: | ||
device = torch.device("cuda") | ||
else: | ||
device = torch.device(f"cuda:{gpu_id}") | ||
model = model.to(device, non_blocking=True) | ||
|
||
if is_main_process(): | ||
print() | ||
|
||
filenames, feature_paths = [], [] | ||
|
||
with tqdm.tqdm( | ||
loader, | ||
desc="Feature Extraction", | ||
unit=" img", | ||
ncols=80, | ||
unit_scale=cfg.batch_size, | ||
position=0, | ||
leave=True, | ||
disable=not (gpu_id in [-1, 0]), | ||
) as t1: | ||
with torch.no_grad(): | ||
for i, batch in enumerate(t1): | ||
imgs, fnames = batch | ||
imgs = imgs.to(device, non_blocking=True) | ||
features = model(imgs) | ||
for k, f in enumerate(features): | ||
fname = fnames[k] | ||
feature_path = Path(features_dir, f"{fname}.pt") | ||
torch.save(f, feature_path) | ||
filenames.append(fname) | ||
feature_paths.append(feature_path) | ||
if cfg.wandb.enable and not run_distributed: | ||
wandb.log({"processed": i + imgs.shape[0]}) | ||
|
||
features_df = pd.DataFrame.from_dict( | ||
{ | ||
"filename": filenames, | ||
"feature_path": feature_paths, | ||
} | ||
) | ||
|
||
if run_distributed: | ||
features_csv_path = Path(output_dir, f"features_{gpu_id}.csv") | ||
else: | ||
features_csv_path = Path(output_dir, "features.csv") | ||
features_df.to_csv(features_csv_path, index=False) | ||
|
||
if run_distributed: | ||
torch.distributed.barrier() | ||
if is_main_process(): | ||
dfs = [] | ||
for gpu_id in range(torch.cuda.device_count()): | ||
fp = Path(output_dir, f"features_{gpu_id}.csv") | ||
df = pd.read_csv(fp) | ||
dfs.append(df) | ||
os.remove(fp) | ||
features_df = pd.concat(dfs, ignore_index=True) | ||
features_df = features_df.drop_duplicates() | ||
features_df.to_csv(Path(output_dir, "features.csv"), index=False) | ||
|
||
if cfg.wandb.enable and is_main_process() and run_distributed: | ||
wandb.log({"processed": len(features_df)}) | ||
|
||
|
||
if __name__ == "__main__": | ||
# python3 -m torch.distributed.run --standalone --nproc_per_node=gpu extract_features_patch.py --config-name 'patch' | ||
main() |
Oops, something went wrong.