-
Notifications
You must be signed in to change notification settings - Fork 23
/
core.py
107 lines (86 loc) · 3.39 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import random
import torch
import numpy as np
import utils
from policy import GreedyEpsilonPolicy
class Sample(object):
def __init__(self, state, action, reward, next_state, end):
utils.assert_eq(type(state), type(next_state))
self._state = (state * 255.0).astype(np.uint8)
self._next_state = (next_state * 255.0).astype(np.uint8)
self.action = action
self.reward = reward
self.end = end
@property
def state(self):
return self._state.astype(np.float32) / 255.0
@property
def next_state(self):
return self._next_state.astype(np.float32) / 255.0
def __repr__(self):
info = ('S(mean): %3.4f, A: %s, R: %s, NS(mean): %3.4f, End: %s'
% (self.state.mean(), self.action, self.reward,
self.next_state.mean(), self.end))
return info
class ReplayMemory(object):
def __init__(self, max_size):
self.max_size = max_size
self.samples = []
self.oldest_idx = 0
def __len__(self):
return len(self.samples)
def _evict(self):
"""Simplest FIFO eviction scheme."""
to_evict = self.oldest_idx
self.oldest_idx = (self.oldest_idx + 1) % self.max_size
return to_evict
def burn_in(self, env, agent, num_steps):
policy = GreedyEpsilonPolicy(1, agent) # uniform policy
i = 0
while i < num_steps or not env.end:
if env.end:
state = env.reset()
action = policy.get_action(None)
next_state, reward = env.step(action)
self.append(state, action, reward, next_state, env.end)
state = next_state
i += 1
if i % 10000 == 0:
print '%d frames burned in' % i
print '%d frames burned into the memory.' % i
def append(self, state, action, reward, next_state, end):
assert len(self.samples) <= self.max_size
new_sample = Sample(state, action, reward, next_state, end)
if len(self.samples) == self.max_size:
avail_slot = self._evict()
self.samples[avail_slot] = new_sample
else:
self.samples.append(new_sample)
def sample(self, batch_size):
"""Simpliest uniform sampling (w/o replacement) to produce a batch.
"""
assert batch_size < len(self.samples), 'no enough samples to sample from'
return random.sample(self.samples, batch_size)
def clear(self):
self.samples = []
self.oldest_idx = 0
def samples_to_tensors(samples):
num_samples = len(samples)
states_shape = (num_samples, ) + samples[0].state.shape
states = np.zeros(states_shape, dtype=np.float32)
next_states = np.zeros(states_shape, dtype=np.float32)
rewards = np.zeros(num_samples, dtype=np.float32)
actions = np.zeros(num_samples, dtype=np.int64)
non_ends = np.zeros(num_samples, dtype=np.float32)
for i, s in enumerate(samples):
states[i] = s.state
next_states[i] = s.next_state
rewards[i] = s.reward
actions[i] = s.action
non_ends[i] = 0.0 if s.end else 1.0
states = torch.from_numpy(states).cuda()
actions = torch.from_numpy(actions).cuda()
rewards = torch.from_numpy(rewards).cuda()
next_states = torch.from_numpy(next_states).cuda()
non_ends = torch.from_numpy(non_ends).cuda()
return states, actions, rewards, next_states, non_ends