@@ -78,29 +78,34 @@ def train_replay(self):
78
78
mini_batch = random .sample (self .memory , batch_size )
79
79
80
80
update_input = np .zeros ((batch_size , self .state_size ))
81
- update_target = np .zeros ((batch_size , self .action_size ))
81
+ update_target = np .zeros ((batch_size , self .state_size ))
82
+ action , reward , done = [], [], []
82
83
83
84
for i in range (batch_size ):
84
- state , action , reward , next_state , done = mini_batch [i ]
85
- target = self .model .predict (state )[0 ]
85
+ update_input [i ] = mini_batch [i ][0 ]
86
+ action .append (mini_batch [i ][1 ])
87
+ reward .append (mini_batch [i ][2 ])
88
+ update_target [i ] = mini_batch [i ][3 ]
89
+ done .append (mini_batch [i ][4 ])
86
90
91
+ target = self .model .predict (update_input )
92
+ target_val = self .target_model .predict (update_target )
93
+
94
+ for i in range (self .batch_size ):
87
95
# like Q Learning, get maximum Q value at s'
88
96
# But from target model
89
- if done :
90
- target [action ] = reward
97
+ if done [ i ] :
98
+ target [i ][ action [ i ]] = reward [ i ]
91
99
else :
92
100
# the key point of Double DQN
93
101
# selection of action is from model
94
102
# update is from target model
95
- a = np .argmax (self .model .predict (next_state )[0 ])
96
- target [action ] = reward + self .discount_factor * \
97
- (self .target_model .predict (next_state )[0 ][a ])
98
- update_input [i ] = state
99
- update_target [i ] = target
103
+ a = np .argmax (target_val [i ])
104
+ target [i ][action [i ]] = reward [i ] + self .discount_factor * (target_val [i ][a ])
100
105
101
106
# make minibatch which includes target q value and predicted q value
102
107
# and do the model fit!
103
- self .model .fit (update_input , update_target , batch_size = batch_size , epochs = 1 , verbose = 0 )
108
+ self .model .fit (update_input , target , batch_size = self . batch_size , epochs = 1 , verbose = 0 )
104
109
105
110
# load the saved model
106
111
def load_model (self , name ):
@@ -112,7 +117,7 @@ def save_model(self, name):
112
117
113
118
114
119
if __name__ == "__main__" :
115
- # in case of CartPole-v1, you can play until 500 time step
120
+ # In case of CartPole-v1, you can play until 500 time step
116
121
env = gym .make ('CartPole-v1' )
117
122
# get size of state and action from environment
118
123
state_size = env .observation_space .shape [0 ]
0 commit comments