-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodels.py
99 lines (77 loc) · 2.98 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Any, Optional, Tuple
import flax.linen as nn
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
def default_conv_init(scale: Optional[float] = jnp.sqrt(2)):
return nn.initializers.xavier_uniform()
def default_mlp_init(scale: Optional[float] = 0.01):
return nn.initializers.orthogonal(scale)
def default_logits_init(scale: Optional[float] = 0.01):
return nn.initializers.orthogonal(scale)
class ResidualBlock(nn.Module):
"""Residual block."""
num_channels: int
prefix: str
@nn.compact
def __call__(self, x):
# Conv branch
y = nn.relu(x)
y = nn.Conv(self.num_channels,
kernel_size=[3, 3],
strides=(1, 1),
padding='SAME',
kernel_init=default_conv_init(),
name=self.prefix + '/conv2d_1')(y)
y = nn.relu(y)
y = nn.Conv(self.num_channels,
kernel_size=[3, 3],
strides=(1, 1),
padding='SAME',
kernel_init=default_conv_init(),
name=self.prefix + '/conv2d_2')(y)
return y + x
class Impala(nn.Module):
"""IMPALA architecture."""
prefix: str
@nn.compact
def __call__(self, x):
out = x
for i, (num_channels, num_blocks) in enumerate([(16, 2), (32, 2),
(32, 2)]):
conv = nn.Conv(num_channels,
kernel_size=[3, 3],
strides=(1, 1),
padding='SAME',
kernel_init=default_conv_init(),
name=self.prefix + '/conv2d_%d' % i)
out = conv(out)
out = nn.max_pool(out,
window_shape=(3, 3),
strides=(2, 2),
padding='SAME')
for j in range(num_blocks):
block = ResidualBlock(num_channels,
prefix='residual_{}_{}'.format(i, j))
out = block(out)
out = out.reshape(out.shape[0], -1)
out = nn.relu(out)
out = nn.Dense(256, kernel_init=default_mlp_init(), name=self.prefix + '/representation')(out)
out = nn.relu(out)
return out
class TwinHeadModel(nn.Module):
"""Critic+Actor for PPO."""
action_dim: int
prefix_critic: str = "critic"
prefix_actor: str = "policy"
@nn.compact
def __call__(self, x):
z = Impala(prefix='shared_encoder')(x)
# Linear critic
v = nn.Dense(1, kernel_init=default_mlp_init(), name=self.prefix_critic + '/fc_v')(z)
logits = nn.Dense(self.action_dim,
kernel_init=default_logits_init(),
name=self.prefix_actor + '/fc_pi')(z)
pi = tfd.Categorical(logits=logits)
return v, pi