Skip to content

Commit a187250

Browse files
brianwooridle
brian
authored andcommitted
double DQN batch form modified
1 parent 3c24aa7 commit a187250

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

Code 2. Cartpole/2. Double DQN/Cartpole_DoubleDQN.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,29 +78,34 @@ def train_replay(self):
7878
mini_batch = random.sample(self.memory, batch_size)
7979

8080
update_input = np.zeros((batch_size, self.state_size))
81-
update_target = np.zeros((batch_size, self.action_size))
81+
update_target = np.zeros((batch_size, self.state_size))
82+
action, reward, done = [], [], []
8283

8384
for i in range(batch_size):
84-
state, action, reward, next_state, done = mini_batch[i]
85-
target = self.model.predict(state)[0]
85+
update_input[i] = mini_batch[i][0]
86+
action.append(mini_batch[i][1])
87+
reward.append(mini_batch[i][2])
88+
update_target[i] = mini_batch[i][3]
89+
done.append(mini_batch[i][4])
8690

91+
target = self.model.predict(update_input)
92+
target_val = self.target_model.predict(update_target)
93+
94+
for i in range(self.batch_size):
8795
# like Q Learning, get maximum Q value at s'
8896
# But from target model
89-
if done:
90-
target[action] = reward
97+
if done[i]:
98+
target[i][action[i]] = reward[i]
9199
else:
92100
# the key point of Double DQN
93101
# selection of action is from model
94102
# update is from target model
95-
a = np.argmax(self.model.predict(next_state)[0])
96-
target[action] = reward + self.discount_factor * \
97-
(self.target_model.predict(next_state)[0][a])
98-
update_input[i] = state
99-
update_target[i] = target
103+
a = np.argmax(target_val[i])
104+
target[i][action[i]] = reward[i] + self.discount_factor * (target_val[i][a])
100105

101106
# make minibatch which includes target q value and predicted q value
102107
# and do the model fit!
103-
self.model.fit(update_input, update_target, batch_size=batch_size, epochs=1, verbose=0)
108+
self.model.fit(update_input, target, batch_size=self.batch_size, epochs=1, verbose=0)
104109

105110
# load the saved model
106111
def load_model(self, name):
@@ -112,7 +117,7 @@ def save_model(self, name):
112117

113118

114119
if __name__ == "__main__":
115-
# in case of CartPole-v1, you can play until 500 time step
120+
# In case of CartPole-v1, you can play until 500 time step
116121
env = gym.make('CartPole-v1')
117122
# get size of state and action from environment
118123
state_size = env.observation_space.shape[0]

0 commit comments

Comments
 (0)