-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathS_model.py
114 lines (87 loc) · 4.07 KB
/
S_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
from models.GCN import GCN
import numpy as np
import scipy.sparse as sp
from torch_geometric.utils import from_scipy_sparse_matrix
import utils
class NoiseAda(nn.Module):
def __init__(self, class_size):
super(NoiseAda, self).__init__()
P = torch.FloatTensor(utils.build_uniform_P(class_size,0.1))
self.B = torch.nn.parameter.Parameter(torch.log(P))
def forward(self, pred):
P = F.softmax(self.B, dim=1)
return pred @ P
class S_model(GCN):
def __init__(self, nfeat, nhid, nclass, dropout=0.5,device=None):
super(S_model, self).__init__(nfeat, nhid, nclass, device=device)
self.noise_ada = NoiseAda(nclass)
def fit(self, features, adj, labels, idx_train, idx_val=None,train_iters=200, verbose=False):
self.device = self.gc1.weight.device
self.initialize()
self.edge_index, self.edge_weight = from_scipy_sparse_matrix(adj)
self.edge_index, self.edge_weight = self.edge_index.to(self.device), self.edge_weight.float().to(self.device)
if sp.issparse(features):
features = utils.sparse_mx_to_torch_sparse_tensor(features).to_dense().float()
else:
features = torch.FloatTensor(np.array(features))
self.features = features.to(self.device)
self.labels = torch.LongTensor(np.array(labels)).to(self.device)
if idx_val is None:
self._train_without_val(self.labels, idx_train, train_iters, verbose)
else:
self._train_with_val(self.labels, idx_train, idx_val, train_iters, verbose)
def _train_without_val(self, labels, idx_train, train_iters, verbose):
self.train()
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
for i in range(train_iters):
optimizer.zero_grad()
output = self.forward(self.features, self.edge_index, self.edge_weight)
pred = F.softmax(output,dim=1)
eps = 1e-8
score = self.noise_ada(pred).clamp(eps,1-eps)
loss_train = F.cross_entropy(torch.log(score[idx_train]), self.labels[idx_train])
loss_train.backward()
optimizer.step()
if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
self.eval()
output = self.forward(self.features, self.edge_index, self.edge_weight)
self.output = output
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
if verbose:
print('=== training gcn model ===')
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
best_loss_val = 100
best_acc_val = 0
for i in range(train_iters):
self.train()
optimizer.zero_grad()
output = self.forward(self.features, self.edge_index, self.edge_weight)
pred = F.softmax(output,dim=1)
eps = 1e-8
score = self.noise_ada(pred).clamp(eps,1-eps)
loss_train = F.cross_entropy(torch.log(score[idx_train]), self.labels[idx_train])
loss_train.backward()
optimizer.step()
if verbose and i % 10 == 0:
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
self.eval()
output = self.forward(self.features, self.edge_index, self.edge_weight)
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
if best_acc_val < acc_val:
best_acc_val = acc_val
self.output = output
weights = deepcopy(self.state_dict())
if verbose:
print('=========save weights=========')
print("Epoch {}, val acc: {:.4f}".format(i,acc_val))
if verbose:
print('=== picking the best model according to the performance on validation ===')
self.load_state_dict(weights)
# %%