Skip to content

Commit

Permalink
added features normalization in eval_knn.py that prevented normalized…
Browse files Browse the repository at this point in the history
… distance computation + changed how labels are fetched to ensure they match feature order
  • Loading branch information
clemsgrs committed Mar 12, 2024
1 parent 78cc76b commit e058a80
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
16 changes: 8 additions & 8 deletions dino/config/knn.yaml
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
data:
features_dir:
features_dir: ''
query_csv: ''
test_csv: ''

output_dir: 'output'
experiment_name: 'eval'

batch_size_per_gpu: 256
batch_size_per_gpu: 16

nb_knn: 20
nb_knn: [10,20,100,200]
temperature: 0.07
save_features: false
label_name: ''
label_name: 'label'

model:
arch: vit_small
patch_size: 16
pretrained_weights: ''
pretrained_weights:
checkpoint_key: 'teacher'

speed:
use_cuda: true
num_workers: 16
num_workers: 8

wandb:
enable: true
enable: false
project: 'vision'
username: 'vlfm'
exp_name: 'eval'
tags: ['${experiment_name}', 'knn', '${student.arch}']
dir: '/home/user'
group:
to_log:
resume_id:
resume_id:
15 changes: 9 additions & 6 deletions dino/eval_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,10 @@ def load_features_and_labels_from_disk(
df, features_dir, label_name: str = "label", header: str = "query"
):
all_feature_paths = [fp for fp in features_dir.glob("*.pt")]
feature_paths = [fp for fp in all_feature_paths if fp.stem in df.filename.values]
df["stem"] = df.filename.apply(lambda x: Path(x).stem)
feature_paths = [fp for fp in all_feature_paths if fp.stem in df.stem.values]

labels = df[label_name].values
labels = torch.tensor(labels).long()

features = []
features, labels = [], []
with tqdm.tqdm(
feature_paths,
desc=f"Loading {header} features from disk",
Expand All @@ -314,7 +312,12 @@ def load_features_and_labels_from_disk(
for fp in t:
f = torch.load(fp)
features.append(f)
label = df[df.stem == fp.stem][label_name].values[0]
labels.append(label)

features = torch.stack(features)
features = nn.functional.normalize(features, dim=1, p=2)
labels = torch.tensor(labels).long()

return features, labels

Expand Down Expand Up @@ -351,7 +354,7 @@ def main(cfg: DictConfig):
output_dir = Path(cfg.output_dir, cfg.experiment_name, run_id)
if is_main_process():
if output_dir.exists():
print(f"{output_dir} already exists! deleting it...")
print(f"{output_dir} already exists!")
output_dir.mkdir(parents=True, exist_ok=True)

query_df = pd.read_csv(cfg.data.query_csv)
Expand Down

0 comments on commit e058a80

Please sign in to comment.