-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsparse_model.py
60 lines (51 loc) · 1.74 KB
/
sparse_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
################################################
# sparse_model.py
#
# Author: Aude Forcione-Lambert
#
# Date: june 16 2019
#
# Description:
################################################
from pathlib import Path
import requests
from itertools import chain
import pickle
import gzip
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import math
class sparse_model(nn.Module):
def __init__(self, n_inputs, n_neurons, n_outputs, loss_func, opt_func, lr):
super().__init__()
self.n_inputs = n_inputs
self.n_neurons = n_neurons
self.n_outputs = n_outputs
self.w = torch.randn(n_neurons, n_inputs).to_sparse().requires_grad_(True)
self.wout = torch.randn(n_outputs, n_neurons).requires_grad_(True)
self.opt = opt_func(chain(self.w.parameters(),self.wout.parameters()), lr=lr)
self.loss_func = loss_func
def forward(self, x):
s = torch.sparse.mm(w,x.t())
return torch.mm(wout,s)
def fit(self, epochs, train_dl):
loss_array=np.zeros(epochs)
for epoch in range(epochs):
for xb,yb in train_dl:
loss = self.loss_func(self.forward(xb), yb)
loss_array[epoch] += loss
loss.backward()
self.opt.step()
self.opt.zero_grad()
print(epoch)
return loss_array
def accuracy(self, valid_dl):
x_valid, y_valid = next(iter(valid_dl))
preds = torch.argmax(self.forward(x_valid), dim=0)
return (preds == y_valid).float().mean()