-
Notifications
You must be signed in to change notification settings - Fork 2
/
critics.py
28 lines (24 loc) · 829 Bytes
/
critics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Critic classes for 3 different flavors (K, K+1, K+1 plogp)
# Arguments (vphi, Sphi)
# Returns (f, fpos, fneg)
import torch
import torch.nn as nn
import torch.nn.functional as F
class K(nn.Module):
def forward(self, vphi, Sphi):
# For K return vphi only as f, there should be no constraints specified
# for fpos or fneg in this case (will fail).
return vphi, None, None
class Kp1(nn.Module):
def forward(self, vphi, S_phi):
p_y = F.softmax(S_phi)
fpos = (p_y * S_phi).sum(dim=1)
fneg = vphi
return fpos - fneg, fpos, fneg
class Kp1_plogp(nn.Module):
def forward(self, vphi, S_phi):
log_p_y = F.log_softmax(S_phi)
p_y = log_p_y.exp()
fpos = (p_y * log_p_y).sum(dim=1)
fneg = vphi
return fpos - fneg, fpos, fneg