-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathreplay_buffer.py
156 lines (118 loc) · 4.3 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
''' https://github.com/pytorch/tutorials/blob/master/intermediate_source/reinforcement_q_learning.py
https://gist.github.com/Pocuston/13f1a7786648e1e2ff95bfad02a51521
'''
######################################################################
# Replay Memory
# -------------
#
# We'll be using experience replay memory for training our DQN. It stores
# the transitions that the agent observes, allowing us to reuse this data
# later. By sampling from it randomly, the transitions that build up a
# batch are decorrelated. It has been shown that this greatly stabilizes
# and improves the DQN training procedure.
#
# For this, we're going to need two classses:
#
# - ``Transition`` - a named tuple representing a single transition in
# our environment. It essentially maps (state, action) pairs
# to their (next_state, reward) result, with the state being the
# screen difference image as described later on.
# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the
# transitions observed recently. It also implements a ``.sample()``
# method for selecting a random batch of transitions for training.
#
import random
from collections import namedtuple
import numpy as np
import random
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
class SumTree:
write = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype=object)
self.n_entries = 0
# update to the root node
def _propagate(self, idx, change):
parent = (idx - 1) // 2
self.tree[parent] += change
if parent != 0:
self._propagate(parent, change)
# find sample on leaf node
def _retrieve(self, idx, s):
left = 2 * idx + 1
right = left + 1
if left >= len(self.tree):
return idx
if s <= self.tree[left]:
return self._retrieve(left, s)
else:
return self._retrieve(right, s - self.tree[left])
def total(self):
return self.tree[0]
# store priority and sample
def add(self, p, data):
idx = self.write + self.capacity - 1
self.data[self.write] = data
self.update(idx, p)
self.write += 1
if self.write >= self.capacity:
self.write = 0
if self.n_entries < self.capacity:
self.n_entries += 1
# update priority
def update(self, idx, p):
change = p - self.tree[idx]
self.tree[idx] = p
self._propagate(idx, change)
# get priority and sample
def get(self, s):
idx = self._retrieve(0, s)
dataIdx = idx - self.capacity + 1
return (idx, self.tree[idx], self.data[dataIdx])
class ReplayMemory_Per(object):
# stored as ( s, a, r, s_ ) in SumTree
def __init__(self, capacity, a=0.6, e=0.01):
self.tree = SumTree(capacity)
self.memory_size = capacity
self.prio_max = 0.1
self.a = a
self.e = e
def push(self, batch):
data = batch
p = (np.abs(self.prio_max) + self.e) ** self.a # proportional priority
self.tree.add(p, data)
def sample(self, batch_size):
idxs = []
segment = self.tree.total() / batch_size
sample_datas = []
for i in range(batch_size):
a = segment * i
b = segment * (i + 1)
s = random.uniform(a, b)
idx, p, data = self.tree.get(s)
print(p)
sample_datas.append(data)
idxs.append(idx)
return idxs, sample_datas
def update(self, idxs, errors):
self.prio_max = max(self.prio_max, max(np.abs(errors)))
for i, idx in enumerate(idxs):
p = (np.abs(errors[i]) + self.e) ** self.a
self.tree.update(idx, p)
def size(self):
return self.tree.n_entries
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, batch):
self.memory.append(batch)
if len(self.memory) > self.capacity:
del self.memory[0]
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)