-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathPrioritizedReplayBuffer.py
130 lines (103 loc) · 4.84 KB
/
PrioritizedReplayBuffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
import gym
import ptan
import numpy as np
import argparse
import torch
import torch.optim as optim
from tensorboardX import SummaryWriter
from lib import dqn_model, common
PRIO_REPLAY_ALPHA = 0.6
BETA_START = 0.4
BETA_FRAMES = 100000
class PrioReplayBuffer:
def __init__(self, exp_source, buf_size, prob_alpha=0.6):
self.exp_source_iter = iter(exp_source)
self.prob_alpha = prob_alpha
self.capacity = buf_size
self.pos = 0
self.buffer = []
self.priorities = np.zeros((buf_size, ), dtype=np.float32)
def __len__(self):
return len(self.buffer)
def populate(self, count):
max_prio = self.priorities.max() if self.buffer else 1.0
for _ in range(count):
sample = next(self.exp_source_iter)
if len(self.buffer) < self.capacity:
self.buffer.append(sample)
else:
self.buffer[self.pos] = sample
self.priorities[self.pos] = max_prio
self.pos = (self.pos + 1) % self.capacity
def sample(self, batch_size, beta=0.4):
if len(self.buffer) == self.capacity:
prios = self.priorities
else:
prios = self.priorities[:self.pos]
probs = prios ** self.prob_alpha
probs /= probs.sum()
indices = np.random.choice(len(self.buffer), batch_size, p=probs)
samples = [self.buffer[idx] for idx in indices]
total = len(self.buffer)
weights = (total * probs[indices]) ** (-beta)
weights /= weights.max()
return samples, indices, np.array(weights, dtype=np.float32)
def update_priorities(self, batch_indices, batch_priorities):
for idx, prio in zip(batch_indices, batch_priorities):
self.priorities[idx] = prio
def calc_loss(batch, batch_weights, net, tgt_net, gamma, device="cpu"):
states, actions, rewards, dones, next_states = common.unpack_batch(batch)
states_v = torch.tensor(states).to(device)
next_states_v = torch.tensor(next_states).to(device)
actions_v = torch.tensor(actions).to(device)
rewards_v = torch.tensor(rewards).to(device)
done_mask = torch.ByteTensor(dones).to(device)
batch_weights_v = torch.tensor(batch_weights).to(device)
state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
next_state_values = tgt_net(next_states_v).max(1)[0]
next_state_values[done_mask] = 0.0
expected_state_action_values = next_state_values.detach() * gamma + rewards_v
losses_v = batch_weights_v * (state_action_values - expected_state_action_values) ** 2
return losses_v.mean(), losses_v + 1e-5
if __name__ == "__main__":
params = common.HYPERPARAMS['pong']
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda")
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
env = gym.make(params['env_name'])
env = ptan.common.wrappers.wrap_dqn(env)
writer = SummaryWriter(comment="-" + params['run_name'] + "-prio-replay")
net = dqn_model.DQN(env.observation_space.shape, env.action_space.n).to(device)
tgt_net = ptan.agent.TargetNet(net)
selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=params['epsilon_start'])
epsilon_tracker = common.EpsilonTracker(selector, params)
agent = ptan.agent.DQNAgent(net, selector, device=device)
exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=params['gamma'], steps_count=1)
buffer = PrioReplayBuffer(exp_source, params['replay_size'], PRIO_REPLAY_ALPHA)
optimizer = optim.Adam(net.parameters(), lr=params['learning_rate'])
frame_idx = 0
beta = BETA_START
with common.RewardTracker(writer, params['stop_reward']) as reward_tracker:
while True:
frame_idx += 1
buffer.populate(1)
epsilon_tracker.frame(frame_idx)
beta = min(1.0, BETA_START + frame_idx * (1.0 - BETA_START) / BETA_FRAMES)
new_rewards = exp_source.pop_total_rewards()
if new_rewards:
writer.add_scalar("beta", beta, frame_idx)
if reward_tracker.reward(new_rewards[0], frame_idx, selector.epsilon):
break
if len(buffer) < params['replay_initial']:
continue
optimizer.zero_grad()
batch, batch_indices, batch_weights = buffer.sample(params['batch_size'], beta)
loss_v, sample_prios_v = calc_loss(batch, batch_weights, net, tgt_net.target_model,
params['gamma'], device=device)
loss_v.backward()
optimizer.step()
buffer.update_priorities(batch_indices, sample_prios_v.data.cpu().numpy())
if frame_idx % params['target_net_sync'] == 0:
tgt_net.sync()