rlDQNAgent
Description
The deep Q-network (DQN) algorithm is a model-free, online, off-policy, discrete action-space reinforcement learning method. A DQN agent is a value-based reinforcement learning agent that trains a critic to estimate the expected discounted cumulative long-term reward when following the optimal policy. DQN is a variant of Q-learning.
For more information, Deep Q-Network (DQN) Agents. For more information on the different types of reinforcement learning agents, see Reinforcement Learning Agents.
Creation
Syntax
Description
Create Agent from Observation and Action Specifications
creates a DQN agent for an environment with the given observation and action
specifications, using default initialization options. The critic in the agent uses a
default vector (that is, multi-output) Q-value deep neural network built from the
observation specification agent
= rlDQNAgent(observationInfo
,actionInfo
)observationInfo
and the action
specification actionInfo
. The ObservationInfo
and ActionInfo
properties of agent
are set to
the observationInfo
and actionInfo
input
arguments, respectively.
creates a DQN agent for an environment with the given observation and action
specifications. The agent uses a default network configured using options specified in
the agent
= rlDQNAgent(observationInfo
,actionInfo
,initOpts
)initOpts
object. For more information on the initialization
options, see rlAgentInitializationOptions
.
Create Agent from Critic
creates a DQN agent with the specified critic network using a default option set for a
DQN agent.agent
= rlDQNAgent(critic
)
Specify Agent Options
creates a DQN agent with the specified critic network and sets the agent
= rlDQNAgent(critic
,agentOptions
)AgentOptions
property to the agentOptions
input argument. Use this syntax after
any of the input arguments in the previous syntaxes..
Input Arguments
initOpts
— Agent initialization options
rlAgentInitializationOptions
object
Agent initialization options, specified as an rlAgentInitializationOptions
object.
critic
— Critic
rlQValueFunction
object | rlVectorQValueFunction
object
Critic, specified as an rlQValueFunction
or as the generally more efficient rlVectorQValueFunction
object. For more information on creating critics,
see Create Policies and Value Functions.
Your critic can use a recurrent neural network as its function approximator.
However, only rlVectorQValueFunction
supports recurrent neural
networks. For an example, see Create DQN Agent with Recurrent Neural Network.
Properties
ObservationInfo
— Observation specifications
specification object | array of specification objects
This property is read-only.
Observation specifications, specified as an rlFiniteSetSpec
or rlNumericSpec
object or an array containing a mix of such objects. Each element in the array defines
the properties of an environment observation channel, such as its dimensions, data type,
and name.
If you create the agent by specifying the critic, the value of
ObservationInfo
matches the corresponding value specified in
critic
.
You can extract observationInfo
from an existing environment or
agent using getObservationInfo
. You can also construct the specifications manually
using rlFiniteSetSpec
or rlNumericSpec
.
ActionInfo
— Action specification
rlFiniteSetSpec
object
Action specifications, specified as an rlFiniteSetSpec
object. This object defines the properties of the environment action channel, such as
its dimensions, data type, and name.
Note
Only one action channel is allowed.
If you create the agent by specifying a critic object, the value of
ActionInfo
matches the value specified in
critic
.
You can extract actionInfo
from an existing environment or agent
using getActionInfo
. You can also construct the specification manually using
rlFiniteSetSpec
.
AgentOptions
— Agent options
rlDQNAgentOptions
object
Agent options, specified as an rlDQNAgentOptions
object.
If you create a DQN agent with a default critic that uses a recurrent neural
network, the default value of AgentOptions.SequenceLength
is
32
.
ExperienceBuffer
— Experience buffer
rlReplayMemory
object | rlPrioritizedReplayMemory
object | rlHindsightReplayMemory
object | rlHindsightPrioritizedReplayMemory
object
Experience buffer, specified as one of the following replay memory objects.
Note
Agents with recursive neural networks only support rlReplayMemory
and rlHindsightReplayMemory
buffers.
During training the agent stores each of its experiences (S,A,R,S',D) in the buffer. Here:
S is the current observation of the environment.
A is the action taken by the agent.
R is the reward for taking action A.
S' is the next observation after taking action A.
D is the is-done signal after taking action A.
The agent then samples mini-batches of experiences from the buffer and uses these mini-batches to update its actor and critic function approximators.
UseExplorationPolicy
— Option to use exploration policy for simulation and deployment
true
(default) | false
Option to use exploration policy when selecting actions during simulation or after deployment, specified as a one of the following logical values.
true
— Use the base agent exploration policy when selecting actions insim
andgeneratePolicyFunction
. Specifically, in this case the agent uses therlEpsilonGreedyPolicy
. Since the action selection has a random component, the agent explores its action and observation spaces.false
— Force the agent to use the base agent greedy policy (the action with maximum likelihood) when selecting actions insim
andgeneratePolicyFunction
. Specifically, in this case the agent uses therlMaxQPolicy
policy. Since the action selection is greedy the policy behaves deterministically and the agent does not explore its action and observation spaces.
Note
This option affects only simulation and deployment; it does not affect training.
When you train an agent using train
,
the agent always uses its exploration policy independently of the value of this
property.
SampleTime
— Sample time of agent
1
(default) | positive scalar | -1
Sample time of agent, specified as a positive scalar or as -1
. Setting this
parameter to -1
allows for event-based simulations.
Within a Simulink® environment, the RL Agent block
in which the agent is specified to execute every SampleTime
seconds
of simulation time. If SampleTime
is -1
, the
block inherits the sample time from its parent subsystem.
Within a MATLAB® environment, the agent is executed every time the environment advances. In
this case, SampleTime
is the time interval between consecutive
elements in the output experience returned by sim
or
train
. If
SampleTime
is -1
, the time interval between
consecutive elements in the returned output experience reflects the timing of the event
that triggers the agent execution.
Example: SampleTime=-1
Object Functions
train | Train reinforcement learning agents within a specified environment |
sim | Simulate trained reinforcement learning agents within specified environment |
getAction | Obtain action from agent, actor, or policy object given environment observations |
getActor | Extract actor from reinforcement learning agent |
setActor | Set actor of reinforcement learning agent |
getCritic | Extract critic from reinforcement learning agent |
setCritic | Set critic of reinforcement learning agent |
generatePolicyFunction | Generate MATLAB function that evaluates policy of an agent or policy object |
Examples
Create DQN Agent from Observation and Action Specifications
Create an environment with a discrete action space, and obtain its observation and action specifications. For this example, load the environment used in the example Create DQN Agent Using Deep Network Designer and Train Using Image Observations. This environment has two observations: a 50-by-50 grayscale image and a scalar (the angular velocity of the pendulum). The action is a scalar with five possible elements (a torque of -2, -1, 0, 1, or 2 Nm applied to a swinging pole).
% Load predefined environment env = rlPredefinedEnv("SimplePendulumWithImage-Discrete"); % Obtain observation and action specifications obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
The agent creation function initializes the actor and critic networks randomly. You can ensure reproducibility by fixing the seed of the random generator.
rng(0)
Create a deep Q-network agent from the environment observation and action specifications.
agent = rlDQNAgent(obsInfo,actInfo);
To check your agent, use getAction
to return the action from a random observation.
getAction(agent,{rand(obsInfo(1).Dimension),rand(obsInfo(2).Dimension)})
ans = 1x1 cell array
{[1]}
You can now test and train the agent within the environment.
Create DQN Agent Using Initialization Options
Create an environment with a discrete action space, and obtain its observation and action specifications. For this example, load the environment used in the example Create DQN Agent Using Deep Network Designer and Train Using Image Observations. This environment has two observations: a 50-by-50 grayscale image and a scalar (the angular velocity of the pendulum). The action is a scalar with five possible elements (a torque of either -2, -1, 0, 1, or 2 Nm applied to a swinging pole).
% Load predefined environment env = rlPredefinedEnv("SimplePendulumWithImage-Discrete"); % Obtain observation and action specifications obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
Create an agent initialization option object, specifying that each hidden fully connected layer in the network must have 128 neurons (instead of the default number, 256).
initOpts = rlAgentInitializationOptions(NumHiddenUnit=128);
The agent creation function initializes the actor and critic networks randomly. Ensure reproducibility by fixing the seed of the random generator.
rng(0)
Create a policy gradient agent from the environment observation and action specifications.
agent = rlDQNAgent(obsInfo,actInfo,initOpts);
Extract the deep neural network from both the critic.
criticNet = getModel(getCritic(agent));
To verify that each hidden fully connected layer has 128 neurons, you can display the layers on the MATLAB® command window,
criticNet.Layers
or visualize the structure interactively using analyzeNetwork
.
analyzeNetwork(criticNet)
Plot the critic network
plot(layerGraph(criticNet))
To check your agent, use getAction
to return the action from random observations.
getAction(agent,{rand(obsInfo(1).Dimension),rand(obsInfo(2).Dimension)})
ans = 1x1 cell array
{[0]}
You can now test and train the agent within the environment.
Create DQN Agent Using a Multi-Output Critic
Create an environment interface and obtain its observation and action specifications. For this example load the predefined environment used for the Train DQN Agent to Balance Cart-Pole System example. This environment has a continuous four-dimensional observation space (the positions and velocities of both cart and pole) and a discrete one-dimensional action space consisting on the application of two possible forces, -10N or 10N.
Create the predefined environment.
env = rlPredefinedEnv("CartPole-Discrete");
Get the observation and action specification objects.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
A DQN agent approximates the long-term reward, given observations and actions, using a parametrized Q-value function critic.
Since DQN agents have a discrete action space, you have the option to create a vector (that is a multi-output) Q-value function critic, which is generally more efficient than a comparable single-output critic. A vector Q-value function is a mapping from an environment observation to a vector in which each element represents the expected discounted cumulative long-term reward when an agent starts from the state corresponding to the given observation and executes the action corresponding to the element number (and follows a given policy afterwards).
To model the Q-value function within the critic, use a deep neural network. The network must have one input layer (which receives the content of the observation channel, as specified by obsInfo
) and one output layer (which returns the vector of values for all the possible actions).
Define the network as an array of layer objects, and get the dimensions of the observation space (that is, prod(obsInfo.Dimension)
) and the number of possible actions (that is, numel(actInfo.Elements)
) directly from the environment specification objects.
dnn = [ featureInputLayer(prod(obsInfo.Dimension)) fullyConnectedLayer(24) reluLayer fullyConnectedLayer(24) reluLayer fullyConnectedLayer(numel(actInfo.Elements)) ];
Convert the network to a dlnetwork
object ad display the number of weights.
dnn = dlnetwork(dnn); summary(dnn)
Initialized: true Number of learnables: 770 Inputs: 1 'input' 4 features
Create the critic using rlVectorQValueFunction
, the network dnn
as well as the observation and action specifications.
critic = rlVectorQValueFunction(dnn,obsInfo,actInfo);
Check that the critic works with a random observation input.
getValue(critic,{rand(obsInfo.Dimension)})
ans = 2x1 single column vector
-0.0361
0.0913
Create the DQN agent using the critic.
agent = rlDQNAgent(critic)
agent = rlDQNAgent with properties: ExperienceBuffer: [1x1 rl.replay.rlReplayMemory] AgentOptions: [1x1 rl.option.rlDQNAgentOptions] UseExplorationPolicy: 0 ObservationInfo: [1x1 rl.util.rlNumericSpec] ActionInfo: [1x1 rl.util.rlFiniteSetSpec] SampleTime: 1
Specify agent options, including training options for the critic.
agent.AgentOptions.UseDoubleDQN=false;
agent.AgentOptions.TargetUpdateMethod="periodic";
agent.AgentOptions.TargetUpdateFrequency=4;
agent.AgentOptions.ExperienceBufferLength=100000;
agent.AgentOptions.DiscountFactor=0.99;
agent.AgentOptions.MiniBatchSize=256;
agent.AgentOptions.CriticOptimizerOptions.LearnRate=1e-2;
agent.AgentOptions.CriticOptimizerOptions.GradientThreshold=1;
To check your agent, use getAction
to return the action from a random observation.
getAction(agent,{rand(obsInfo.Dimension)})
ans = 1x1 cell array
{[10]}
You can now test and train the agent within the environment.
Create a DQN Agent Using a Single-Output Critic
Create an environment interface and obtain its observation and action specifications. For this example load the predefined environment used for the Train DQN Agent to Balance Cart-Pole System example. This environment has a continuous four-dimensional observation space (the positions and velocities of both cart and pole) and a discrete one-dimensional action space consisting on the application of two possible forces, -10 N or 10 N.
Create the predefined environment.
env = rlPredefinedEnv("CartPole-Discrete");
Get the observation and action specification objects.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
For DQN agents, you can use multi-output Q-value function critics, which are generally more efficient than a comparable single-output critics. However, for this example, create a single-output Q-value function critic instead.
A Q-value function critic takes the current observation and an action as inputs and returns a single scalar as output (the estimated discounted cumulative long-term reward for taking the action from the state corresponding to the current observation, and following the policy thereafter).
To model the parametrized Q-value function within the critic, use a neural network with two input layers (one for the observation channel, as specified by obsInfo
, and the other for the action channel, as specified by actInfo
) and one output layer (which returns the scalar value).
Note that prod(obsInfo.Dimension)
and prod(actInfo.Dimension)
return the number of dimensions of the observation and action spaces, respectively, regardless of whether they are arranged as row vectors, column vectors, or matrices.
Define each network path as an array of layer objects, and assign names to the input and output layers of each path, so you can connect the paths.
% Observation path obsPath = [ featureInputLayer(prod(obsInfo.Dimension),Name="netOin") fullyConnectedLayer(24) reluLayer fullyConnectedLayer(24,Name="fcObsPath") ]; % Action path actPath = [ featureInputLayer(prod(actInfo.Dimension),Name="netAin") fullyConnectedLayer(24,Name="fcActPath") ]; % Common path (concatenate inputs along dim #1) commonPath = [ concatenationLayer(1,2,Name="cat") reluLayer fullyConnectedLayer(1,Name="out") ]; % Add paths to network net = layerGraph; net = addLayers(net,obsPath); net = addLayers(net,actPath); net = addLayers(net,commonPath); % Connect layers net = connectLayers(net,"fcObsPath","cat/in1"); net = connectLayers(net,"fcActPath","cat/in2"); % Plot network plot(net)
% Convert to dlnetwork object net = dlnetwork(net); % Display the number of weights summary(net)
Initialized: true Number of learnables: 817 Inputs: 1 'netOin' 4 features 2 'netAin' 1 features
Create the critic approximator object using net
, the environment observation and action specifications, and the names of the network input layers to be connected with the environment observation and action channels. For more information, see rlQValueFunction
.
critic = rlQValueFunction(net, ... obsInfo, ... actInfo, ... ObservationInputNames="netOin", ... ActionInputNames="netAin");
Check the critic with a random observation and action input.
getValue(critic,{rand(obsInfo.Dimension)},{rand(actInfo.Dimension)})
ans = single
-0.0232
Create the DQN agent using the critic.
agent = rlDQNAgent(critic)
agent = rlDQNAgent with properties: ExperienceBuffer: [1x1 rl.replay.rlReplayMemory] AgentOptions: [1x1 rl.option.rlDQNAgentOptions] UseExplorationPolicy: 0 ObservationInfo: [1x1 rl.util.rlNumericSpec] ActionInfo: [1x1 rl.util.rlFiniteSetSpec] SampleTime: 1
Specify agent options, including training options for the critic.
agent.AgentOptions.UseDoubleDQN=false;
agent.AgentOptions.TargetUpdateMethod="periodic";
agent.AgentOptions.TargetUpdateFrequency=4;
agent.AgentOptions.ExperienceBufferLength=100000;
agent.AgentOptions.DiscountFactor=0.99;
agent.AgentOptions.MiniBatchSize=256;
agent.AgentOptions.CriticOptimizerOptions.LearnRate=1e-2;
agent.AgentOptions.CriticOptimizerOptions.GradientThreshold=1;
To check your agent, use getAction
to return the action from a random observation.
getAction(agent,{rand(obsInfo.Dimension)})
ans = 1x1 cell array
{[10]}
You can now test and train the agent within the environment.
Create DQN Agent with Recurrent Neural Network
For this example load the predefined environment used for the Train DQN Agent to Balance Cart-Pole System example. This environment has a continuous four-dimensional observation space (the positions and velocities of both cart and pole) and a discrete one-dimensional action space consisting on the application of two possible forces, -10N or 10N.
env = rlPredefinedEnv("CartPole-Discrete");
Get the observation and action specification objects.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
For DQN agents, only the vector function approximator, rlVectorQValueFunction
, supports recurrent neural networks models. The network must have one input layer (taking the content of the observation channel) and one output layer (returning the vector of values for all the possible actions).
Define the network as an array of layer objects. To create a recurrent neural network, use a sequenceInputLayer
as the input layer and include at least one lstmLayer
.
net = [
sequenceInputLayer(prod(obsInfo.Dimension))
fullyConnectedLayer(50)
reluLayer
lstmLayer(20,OutputMode="sequence");
fullyConnectedLayer(20)
reluLayer
fullyConnectedLayer(numel(actInfo.Elements))
];
Convert to a dlnetwork
object and display the number of weights.
net = dlnetwork(net); summary(net);
Initialized: true Number of learnables: 6.3k Inputs: 1 'sequenceinput' Sequence input with 4 dimensions
Create the critic approximator object using net
and the environment specifications.
critic = rlVectorQValueFunction(net,obsInfo,actInfo);
Check your critic with a random input observation.
getValue(critic,{rand(obsInfo.Dimension)})
ans = 2x1 single column vector
0.0136
0.0067
Define some training options for the critic.
criticOptions = rlOptimizerOptions( ... LearnRate=1e-3, ... GradientThreshold=1);
Specify options for creating the DQN agent. To use a recurrent neural network, you must specify a SequenceLength
greater than 1.
agentOptions = rlDQNAgentOptions(... UseDoubleDQN=false, ... TargetSmoothFactor=5e-3, ... ExperienceBufferLength=1e6, ... SequenceLength=32, ... CriticOptimizerOptions=criticOptions); agentOptions.EpsilonGreedyExploration.EpsilonDecay = 1e-4;
Create the agent. The actor and critic networks are initialized randomly.
agent = rlDQNAgent(critic,agentOptions)
agent = rlDQNAgent with properties: ExperienceBuffer: [1x1 rl.replay.rlReplayMemory] AgentOptions: [1x1 rl.option.rlDQNAgentOptions] UseExplorationPolicy: 0 ObservationInfo: [1x1 rl.util.rlNumericSpec] ActionInfo: [1x1 rl.util.rlFiniteSetSpec] SampleTime: 1
Check your agent using getAction
to return the action from a random observation.
getAction(agent,rand(obsInfo.Dimension))
ans = 1x1 cell array
{[-10]}
To evaluate the agent using sequential observations, use the sequence length (time) dimension. For example, obtain actions for a sequence of 9
observations.
[action,state] = getAction(agent, ...
{rand([obsInfo.Dimension 1 9])});
Display the action corresponding to the seventh element of the observation.
action = action{1}; action(1,1,1,7)
ans = -10
You can now test and train the agent within the environment.
Version History
Introduced in R2019a
See Also
Apps
Functions
getAction
|getActor
|getCritic
|getModel
|generatePolicyFunction
|generatePolicyBlock
|getActionInfo
|getObservationInfo
Objects
rlDQNAgentOptions
|rlAgentInitializationOptions
|rlVectorQValueFunction
|rlQValueFunction
|rlQAgent
|rlSARSAAgent
Blocks
Abrir ejemplo
Tiene una versión modificada de este ejemplo. ¿Desea abrir este ejemplo con sus modificaciones?
Comando de MATLAB
Ha hecho clic en un enlace que corresponde a este comando de MATLAB:
Ejecute el comando introduciéndolo en la ventana de comandos de MATLAB. Los navegadores web no admiten comandos de MATLAB.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)