-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
68 lines (56 loc) · 1.84 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import List
import numpy as np
from torch import nn
from torch.nn import functional as F
import torch
def build_mlp(layers_dims: List[int]):
layers = []
for i in range(len(layers_dims) - 2):
layers.append(nn.Linear(layers_dims[i], layers_dims[i + 1]))
layers.append(nn.BatchNorm1d(layers_dims[i + 1]))
layers.append(nn.ReLU(True))
layers.append(nn.Linear(layers_dims[-2], layers_dims[-1]))
return nn.Sequential(*layers)
class MockModel(torch.nn.Module):
"""
Does nothing. Just for testing.
"""
def __init__(self, device="cuda", output_dim=256):
super().__init__()
self.device = device
self.repr_dim = output_dim
def forward(self, states, actions):
"""
Args:
During training:
states: [B, T, Ch, H, W]
During inference:
states: [B, 1, Ch, H, W]
actions: [B, T-1, 2]
Output:
predictions: [B, T, D]
"""
B, T, _ = actions.shape
return torch.randn((B, T + 1, self.repr_dim)).to(self.device)
class Prober(torch.nn.Module):
def __init__(
self,
embedding: int,
arch: str,
output_shape: List[int],
):
super().__init__()
self.output_dim = np.prod(output_shape)
self.output_shape = output_shape
self.arch = arch
arch_list = list(map(int, arch.split("-"))) if arch != "" else []
f = [embedding] + arch_list + [self.output_dim]
layers = []
for i in range(len(f) - 2):
layers.append(torch.nn.Linear(f[i], f[i + 1]))
layers.append(torch.nn.ReLU(True))
layers.append(torch.nn.Linear(f[-2], f[-1]))
self.prober = torch.nn.Sequential(*layers)
def forward(self, e):
output = self.prober(e)
return output