Train DQN Agent to Balance Cart-Pole System

This example shows how to train a deep Q-learning network (DQN) agent to balance a cart-pole system modeled in MATLAB®.

For more information on DQN agents, see Deep Q-Network Agents. For an example that trains a DQN agent in Simulink®, see Train DQN Agent to Swing Up and Balance Pendulum.

Cart-Pole MATLAB Environment

The reinforcement learning environment for this example is a pole attached to an unactuated joint on a cart, which moves along a frictionless track. The training goal is to make the pole stand upright without falling over.

For this environment:

  • The upward balanced pole position is 0 radians, and the downward hanging position is pi radians.

  • The pole starts upright with an initial angle between –0.05 and 0.05 radians.

  • The force action signal from the agent to the environment is from –10 to 10 N.

  • The observations from the environment are the position and velocity of the cart, the pole angle, and the pole angle derivative.

  • The episode terminates if the pole is more than 12 degrees from vertical or if the cart moves more than 2.4 m from the original position.

  • A reward of +1 is provided for every time step that the pole remains upright. A penalty of –5 is applied when the pole falls.

For more information on this model, see Load Predefined Control System Environments.

Create Environment Interface

Create a predefined environment interface for the system.

env = rlPredefinedEnv("CartPole-Discrete")
env = 
  CartPoleDiscreteAction with properties:

                  Gravity: 9.8000
                 MassCart: 1
                 MassPole: 0.1000
                   Length: 0.5000
                 MaxForce: 10
                       Ts: 0.0200
    ThetaThresholdRadians: 0.2094
               XThreshold: 2.4000
      RewardForNotFalling: 1
        PenaltyForFalling: -5
                    State: [4x1 double]

The interface has a discrete action space where the agent can apply one of two possible force values to the cart, –10 or 10 N.

Get the observation and action specification information.

obsInfo = getObservationInfo(env)
obsInfo = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "CartPole States"
    Description: "x, dx, theta, dtheta"
      Dimension: [4 1]
       DataType: "double"

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-10 10]
           Name: "CartPole Action"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

Fix the random generator seed for reproducibility.


Create DQN Agent

A DQN agent approximates the long-term reward, given observations and actions, using a value-function critic.

DQN agents can use multi-output Q-value critic approximators, which are generally more efficient. A multi-output approximator has observations as inputs and state-action values as outputs. Each output element represents the expected cumulative long-term reward for taking the corresponding discrete action from the state indicated by the observation inputs.

To create the critic, first create a deep neural network with one input (the 4-dimensional observed state) and one output vector with two elements (one for the 10 N action, another for the –10 N action). For more information on creating value-function representations based on a neural network, see Create Policy and Value Function Representations.

dnn = [
    fullyConnectedLayer(24, 'Name','CriticStateFC2')

View the network configuration.


Specify some training options for the critic representation using rlRepresentationOptions.

criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);

Create the critic representation using the specified neural network and options. For more information, see rlQValueRepresentation.

critic = rlQValueRepresentation(dnn,obsInfo,actInfo,'Observation',{'state'},criticOpts);

To create the DQN agent, first specify the DQN agent options using rlDQNAgentOptions.

agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...    
    'TargetSmoothFactor',1, ...
    'TargetUpdateFrequency',4, ...   
    'ExperienceBufferLength',100000, ...
    'DiscountFactor',0.99, ...

Then, create the DQN agent using the specified critic representation and agent options. For more information, see rlDQNAgent.

agent = rlDQNAgent(critic,agentOpts);

Train Agent

To train the agent, first specify the training options. For this example, use the following options:

  • Run one training session containing at most 1000 episodes, with each episode lasting at most 500 time steps.

  • Display the training progress in the Episode Manager dialog box (set the Plots option) and disable the command line display (set the Verbose option to false).

  • Stop training when the agent receives an moving average cumulative reward greater than 480. At this point, the agent can balance the cart-pole system in the upright position.

For more information, see rlTrainingOptions.

trainOpts = rlTrainingOptions(...
    'MaxEpisodes',1000, ...
    'MaxStepsPerEpisode',500, ...
    'Verbose',false, ...

You can visualize the cart-pole system can be visualized by using the plot function during training or simulation.


Train the agent using the train function. Training this agent is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining    
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
    % Load the pretrained agent for the example.

Simulate DQN Agent

To validate the performance of the trained agent, simulate it within the cart-pole environment. For more information on agent simulation, see rlSimulationOptions and sim. The agent can balance the cart-pole even when the simulation time increases to 500 steps.

simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);

totalReward = sum(experience.Reward)
totalReward = 500

See Also

Related Topics