Main Content

Import Policy and Value Function Representations

To create function approximators for reinforcement learning, you can import pretrained deep neural networks or deep neural network layer architectures using the Deep Learning Toolbox™ network import functionality. You can import:

  • Open Neural Network Exchange (ONNX™) models, which require the Deep Learning Toolbox Converter for ONNX Model Format support package software. For more information, importONNXLayers.

  • TensorFlow™-Keras networks, which require Deep Learning Toolbox Importer for TensorFlow-Keras Models support package software. For more information, see importKerasLayers.

  • Caffe convolutional networks, which require Deep Learning Toolbox Importer for Caffe Models support package software. For more information, see importCaffeLayers.

After you import a deep neural network, you can create a policy or value function representation object using a representation object, such as rlValueRepresentation.

When you import deep neural network architectures, consider the following.

  • Imported architectures must have a single input layer and a single output layer. Therefore, importing entire critic networks with observation and action input layers is not supported.

  • The dimensions of the imported network architecture input and output layers must match the dimensions of the corresponding action, observation, or reward dimensions for your environment.

  • After importing the network architecture, you must set the names of the input and output layers to match the names of the corresponding action and observation specifications.

For more information on the deep neural network architectures supported for reinforcement learning, see Create Policy and Value Function Representations.

Import Actor and Critic for Image Observation Application

As an example, assume that you have an environment with a 50-by-50 grayscale image observation signal and a continuous action space. To train a policy gradient agent, you require the following function approximators, both of which must have a single 50-by-50 image input observation layer and a single scalar output value.

  • Actor — Selects an action value based on the current observation

  • Critic — Estimates the expected long-term reward based on the current observation

Also, assume that you have the following network architectures to import:

  • A deep neural network architecture for the actor with a 50-by-50 image input layer and a scalar output layer, which is saved in the ONNX format (criticNetwork.onnx).

  • A deep neural network architecture for the critic with a 50-by-50 image input layer and a scalar output layer, which is saved in the ONNX format (actorNetwork.onnx).

To import the critic and actor networks, use the importONNXLayers function without specifying an output layer.

criticNetwork = importONNXLayers('criticNetwork.onnx');
actorNetwork = importONNXLayers('actorNetwork.onnx');

These commands generate a warning, which states that the network is trainable until an output layer is added. When you use an imported network to create an actor or critic representation, Reinforcement Learning Toolbox™ software automatically adds an output layer for you.

After you import the networks, create the actor and critic function approximator representations. To do so, first obtain the observation and action specifications from the environment.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Create the critic representation, specifying the name of the input layer of the critic network as the observation name. Since the critic network has a single observation input and a single action output, use a value-function representation.

critic = rlValueRepresentation(criticNetwork,obsInfo,...
    'Observation',{criticNetwork.Layers(1).Name});

Create the actor representation, specifying the name of the input layer of the actor network as the observation name and the output layer of the actor network as the observation name. Since the actor network has a single scalar output, use a deterministic actor representation.

actor = rlDeterministicActorRepresentation(actorNetwork,obsInfo,actInfo,...
    'Observation',{actorNetwork.Layers(1).Name},...
    'Action',{actorNetwork.Layers(end).Name});

You can then:

Related Topics