-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathknn_postprocessor.py
64 lines (53 loc) · 2.16 KB
/
knn_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Any
import faiss
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from .base_postprocessor import BasePostprocessor
normalizer = lambda x: x / np.linalg.norm(x, axis=-1, keepdims=True) + 1e-10
class KNNPostprocessor(BasePostprocessor):
def __init__(self, config):
super(KNNPostprocessor, self).__init__(config)
self.args = self.config.postprocessor.postprocessor_args
self.K = self.args.K
self.activation_log = None
self.args_dict = self.config.postprocessor.postprocessor_sweep
self.setup_flag = False
self.has_data_based_setup = True
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict, id_loader_split="train"):
print(f"Setup on ID data - {id_loader_split} split")
if not self.setup_flag:
activation_log = []
net.eval()
with torch.no_grad():
for batch in tqdm(id_loader_dict[id_loader_split],
desc='Setup: ',
position=0,
leave=True):
data = batch['data'].cuda()
data = data.float()
_, feature = net(data, return_feature=True)
activation_log.append(
normalizer(feature.data.cpu().numpy()))
self.activation_log = np.concatenate(activation_log, axis=0)
self.index = faiss.IndexFlatL2(feature.shape[1])
self.index.add(self.activation_log)
self.setup_flag = True
else:
pass
@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
output, feature = net(data, return_feature=True)
feature_normed = normalizer(feature.data.cpu().numpy())
D, _ = self.index.search(
feature_normed,
self.K,
)
kth_dist = -D[:, -1]
_, pred = torch.max(torch.softmax(output, dim=1), dim=1)
return pred, torch.from_numpy(kth_dist)
def set_hyperparam(self, hyperparam: list):
self.K = hyperparam[0]
def get_hyperparam(self):
return self.K