TransWikia.com

Why could my DDQN get significantly worse after beating the game repeatedly?

Data Science Asked by Danny Tuppeny on March 28, 2021

I’ve been trying to train a DDQN to play OpenAI Gym’s CartPole-v1, but found that although it starts off well and starts getting full score (500) repeatedly (at around 600 episodes in the pic below), it then seems to go off the rails and do worse the more it plays.

Graph of getting good, then going really bad

I’m pretty new to ML so I’m not really sure what could cause this so I’m not sure how to start debugging (I’ve tried tweaking some of the hyper-parameters, but nothing seems to stop this trend).

If it helps, here’s the (probably) relevant parts of my agent:

def _build_model(self):
    model = Sequential()
    model.add(Dense(24, input_dim=self.state_size, activation="relu"))
    model.add(Dense(24, activation="relu"))
    model.add(Dense(self.action_size, activation="linear"))
    model.compile(optimizer=Adam(lr=self.learning_rate), loss="mse")
    return model

def get_action(self, state):
    # Use random exploration for the current rate.
    if np.random.rand() < self.epsilon:
        return random.randrange(self.action_size)
    
    # Otherwise use the model to predict the rewards and select the max.
    q_values = self.model.predict(state)
    return np.argmax(q_values[0])

def replay(self, batch_size):
    if len(agent.memory) < minibatch_size:
        return
    
    # Decay the exploration rate.
    self.epsilon *= self.epsilon_decay
    self.epsilon = max(self.epsilon_min, self.epsilon)
        
    minibatch = random.sample(self.memory, minibatch_size)
    
    state_batch, q_values_batch = [], []
    for state, action, reward, next_state, done in minibatch:
        # Get predictions for all actions for the current state.
        q_values = self.model.predict(state)
        
        # If we're not done, add on the future predicted reward at the discounted rate.
        if done:
            q_values[0][action] = reward
        else:
            f = self.target_model.predict(next_state)
            future_reward = max(self.target_model.predict(next_state)[0])
            q_values[0][action] = reward + self.gamma * future_reward
        
        state_batch.append(state[0])
        q_values_batch.append(q_values[0])
        
    # Re-fit the model to move it closer to this newly calculated reward.
    self.model.fit(np.array(state_batch), np.array(q_values_batch), batch_size=batch_size, epochs=1, verbose=0)
       
    self.update_weights()
        
def update_weights(self):
    weights = self.model.get_weights()
    target_weights = self.target_model.get_weights()
    
    for i in range(len(target_weights)):
        target_weights[i] = weights[i] * self.tau + target_weights[i] * (1 - self.tau)
        
    self.target_model.set_weights(target_weights)

And the full notebook is here.

One Answer

This is called "catastrophic forgetting" and can be a serious problem in many RL scenarios.

If you trained a neural network to recognise cats and dogs and did the following:

  • Train it for many epochs on a full dataset until you got a high accuracy.

  • Continue to train it, but remove all the cat pictures.

Then in a relatively short space of time, the NN would start to lose accuracy. It would forget what a cat looks like. It would learn that its task was to switch the dog prediction as high as possible, just because on average everything in the training population was a dog.

Something very similar happens in your DQN experience replay memory. Once it gets good at a task, it may only experience success. Eventually, only successful examples are in its memory. The NN forgets what failure looks like (what the states are, and what it should predict for their values), and predicts high values for everything.

When something bad happens and the NNs high predicted value is completely wrong, the error can be high, and the NN may have incorrectly "linked" its state representations so that it cannot distinguish which parts of the feature space are the cause of this. This creates odd effects in terms of what it learns about values of all states. Often the NN will behave incorrectly for a few episodes but then re-learn optimal behaviour. But it is also possible that it completely breaks and never recovers.

There is lots of active research into catastrophic forgetting and I suggest you search that term to find out some of the many types of mitigation you could use.

For Cartpole, I found a very simple hack made the learning very stable. Simply keep aside some percentage of replay memory stocked with the initial poor performing random exploration. Reserving say 10% to this long term memory is enough to make learning in Cartpole rock solid, as the NN always has a few examples of what not to do. The idea unfortunately does not scale well to more complex environments, but it is a nice demonstration. For a more sophisticated look at similar solutions you could see the paper "The importance of experience replay database composition in deep reinforcement learning"

Correct answer by Neil Slater on March 28, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP