-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdice_postprocessor.py
69 lines (57 loc) · 2.5 KB
/
dice_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from .base_postprocessor import BasePostprocessor
normalizer = lambda x: x / np.linalg.norm(x, axis=-1, keepdims=True) + 1e-10
class DICEPostprocessor(BasePostprocessor):
def __init__(self, config):
super(DICEPostprocessor, self).__init__(config)
self.args = self.config.postprocessor.postprocessor_args
self.p = self.args.p
self.mean_act = None
self.masked_w = None
self.args_dict = self.config.postprocessor.postprocessor_sweep
self.setup_flag = False
self.has_data_based_setup = True
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict, id_loader_split="train"):
print(f"Setup on ID data - {id_loader_split} split")
if not self.setup_flag:
activation_log = []
net.eval()
with torch.no_grad():
for batch in tqdm(id_loader_dict[id_loader_split],
desc='Setup: ',
position=0,
leave=True):
data = batch['data'].cuda()
data = data.float()
_, feature = net(data, return_feature=True)
activation_log.append(feature.data.cpu().numpy())
activation_log = np.concatenate(activation_log, axis=0)
self.mean_act = activation_log.mean(0)
self.setup_flag = True
else:
pass
def calculate_mask(self, w):
contrib = self.mean_act[None, :] * w.data.squeeze().cpu().numpy()
self.thresh = np.percentile(contrib, self.p)
mask = torch.Tensor((contrib > self.thresh)).cuda()
self.masked_w = w * mask
@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
fc_weight, fc_bias = net.get_fc()
#if self.masked_w is None:
# self.calculate_mask(torch.from_numpy(fc_weight).cuda())
self.calculate_mask(torch.from_numpy(fc_weight).cuda())
_, feature = net(data, return_feature=True)
vote = feature[:, None, :] * self.masked_w
output = vote.sum(2) + torch.from_numpy(fc_bias).cuda()
_, pred = torch.max(torch.softmax(output, dim=1), dim=1)
energyconf = torch.logsumexp(output.data.cpu(), dim=1)
return pred, energyconf
def set_hyperparam(self, hyperparam: list):
self.p = hyperparam[0]
def get_hyperparam(self):
return self.p