diff --git a/2-cartpole/1-dqn/cartpole_dqn.py b/2-cartpole/1-dqn/cartpole_dqn.py index 8b2baaf0..1440c749 100644 --- a/2-cartpole/1-dqn/cartpole_dqn.py +++ b/2-cartpole/1-dqn/cartpole_dqn.py @@ -1,4 +1,3 @@ -import sys import gym import pylab import random @@ -121,6 +120,7 @@ def train_model(self): agent = DQNAgent(state_size, action_size) scores, episodes = [], [] + complete = False for e in range(EPISODES): done = False @@ -128,6 +128,8 @@ def train_model(self): state = env.reset() state = np.reshape(state, [1, state_size]) + if complete: break + while not done: if agent.render: env.render() @@ -162,7 +164,7 @@ def train_model(self): # if the mean of scores of last 10 episode is bigger than 490 # stop training if np.mean(scores[-min(10, len(scores)):]) > 490: - sys.exit() + complete = True # save the model if e % 50 == 0: