Create Custom MATLAB Environment from Template

You can create simpler custom reinforcement learning environments using custom functions, as described in Create MATLAB Environment using Custom Functions.

Alternatively, you can define a custom reinforcement learning environment by creating and modifying a template environment class. You can use a custom template environment to:

  • Implement more complex environment dynamics.

  • Add custom visualizations to your environment.

  • Create an interface to third-party libraries defined in languages such as C++, Java®, or Python®. For more information, see External Language Interfaces (MATLAB).

For more information about creating MATLAB® classes, see User-Defined Classes (MATLAB).

Create Template Class

To define your custom environment, first create the template class file, specifying the name of the class. For this example, name the class MyEnvironment.

rlCreateEnvTemplate("MyEnvironment")

The software creates and opens the template class file. The template class is a subclass of the rl.env.MATLABEnvironment abstract class, as shown in the class definition at the start of the template file. This abstract class is the same one used by the other MATLAB reinforcement learning environment objects.

classdef MyEnvironment < rl.env.MATLABEnvironment

By default, the template class implements a simple cart-pole balancing model similar to the cart-pole predefined environments described in Load Predefined Control System Environments.

To define your environment dynamics modify the template class, specifying the following:

  • Environment properties

  • Required environment methods

  • Optional environment methods

Environment Properties

In the properties section of the template, specify any parameters necessary for creating and simulating the environment. These parameters can include:

  • Physical constants — The sample environment defines the acceleration due to gravity (Gravity).

  • Environment geometry — The sample environment defines the cart and pole masses (CartMass and PoleMass and the half-length of the pole (HalfPoleLength).)

  • Environment constraints — The sample environment defines the pole angle and cart distance thresholds (AngleThreshold, and DisplacementThreshold). The environment uses these values to detect when a training episode is finished.

  • Variables required for evaluating the environment — The sample environment defines the state vector (State) and a flag for indicating when an episode is finished (IsDone)

  • Constants for defining the actions or observation spaces — The sample environment defines the maximum force for the action space (MaxForce).

  • Constants for calculating the reward signal — The sample environment defines the constants RewardForNotFalling and PenaltyForFalling.

properties
    % Specify and initialize environment's necessary properties    
    % Acceleration due to gravity in m/s^2
    Gravity = 9.8
    
    % Mass of the cart
    CartMass = 1.0
    
    % Mass of the pole
    PoleMass = 0.1
    
    % Half the length of the pole
    HalfPoleLength = 0.5
    
    % Max Force the input can apply
    MaxForce = 10
           
    % Sample time
    Ts = 0.02
    
    % Angle at which to fail the episode (radians)
    AngleThreshold = 12 * pi/180
        
    % Distance at which to fail the episode
    DisplacementThreshold = 2.4
        
    % Reward each time step the cart-pole is balanced
    RewardForNotFalling = 1
    
    % Penalty when the cart-pole fails to balance
    PenaltyForFalling = -10 
end
    
properties
    % Initialize system state [x,dx,theta,dtheta]'
    State = zeros(4,1)
end

properties(Access = protected)
    % Initialize internal flag to indicate episode termination
    IsDone = false        
end

Required Functions

A reinforcement learning environment requires the following functions to be defined. The getObservationInfo, getActionInfo, sim, and validateEnvironment functions are already defined in the base abstract class. To create your environment, you must define the constructor, reset, and step functions.

FunctionDescription
getObservationInfoReturns information about the environment observations
getActionInfoReturns information about the environment actions
simSimulate the environment with an agent
validateEnvironmentValidate the environment by calling the reset function and simulating the environment for one time step using step
resetInitialize environment state and clean up any visualization
stepApplies an action, simulates the environment for one step, and outputs the observations and rewards. Also, sets a flag indicating whether the episode is complete
Constructor functionA function with the same name as the class that creates an instance of the class

Sample Constructor Function

The sample cart-pole constructor function creates the environment by:

  • Defining the action and observation specifications. For more information about creating these specifications, see rlNumericSpec and rlFiniteSetSpec.

  • Calling the constructor of the base abstract class.

function this = MyEnvironment()
    % Initialize Observation settings
    ObservationInfo = rlNumericSpec([4 1]);
    ObservationInfo.Name = 'CartPole States';
    ObservationInfo.Description = 'x, dx, theta, dtheta';

    % Initialize Action settings   
    ActionInfo = rlFiniteSetSpec([-1 1]);
    ActionInfo.Name = 'CartPole Action';

    % The following line implements built-in functions of RL env
    this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);

    % Initialize property values and pre-compute necessary values
    updateActionInfo(this);
end

This sample constructor function does not include any input arguments. However, you can add input arguments for your custom constructor.

Sample reset Function

The sample cart-pole reset function sets the initial condition of the model and returns the initial values of the observations. It also generates a notification that the environment has been updated by calling the envUpdatedCallback function, which is useful for updating the environment visualization.

% Reset environment to initial state and output initial observation
function InitialObservation = reset(this)
    % Theta (+- .05 rad)
    T0 = 2 * 0.05 * rand - 0.05;  
    % Thetadot
    Td0 = 0;
    % X 
    X0 = 0;
    % Xdot
    Xd0 = 0;

    InitialObservation = [T0;Td0;X0;Xd0];
    this.State = InitialObservation;

    % (optional) use notifyEnvUpdated to signal that the 
    % environment has been updated (e.g. to update visualization)
    notifyEnvUpdated(this);
end

Sample step Function

The sample cart-pole step function:

  • Processes the input action

  • Evaluates the environment dynamic equations for one time step

  • Computes and returns the updated observations

  • Computes and returns the reward signal

  • Checks if the episode has ended and returns the IsDone signal as appropriate

  • Generates a notification that the environment has been updated

function [Observation,Reward,IsDone,LoggedSignals] = step(this,Action)
    LoggedSignals = [];

    % Get action
    Force = getForce(this,Action);            

    % Unpack state vector
    XDot = this.State(2);
    Theta = this.State(3);
    ThetaDot = this.State(4);

    % Cache to avoid recomputation
    CosTheta = cos(Theta);
    SinTheta = sin(Theta);            
    SystemMass = this.CartMass + this.PoleMass;
    temp = (Force + this.PoleMass*this.HalfPoleLength*ThetaDot^2*SinTheta)...
        /SystemMass;

    % Apply motion equations            
    ThetaDotDot = (this.Gravity*SinTheta - CosTheta*temp)...
        / (this.HalfPoleLength*(4.0/3.0 - this.PoleMass*CosTheta*CosTheta/SystemMass));
    XDotDot  = temp - this.PoleMass*this.HalfPoleLength*ThetaDotDot*CosTheta/SystemMass;

    % Euler integration
    Observation = this.State + this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];

    % Update system states
    this.State = Observation;

    % Check terminal condition
    X = Observation(1);
    Theta = Observation(3);
    IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;
    this.IsDone = IsDone;

    % Get reward
    Reward = getReward(this);

    % (optional) use notifyEnvUpdated to signal that the 
    % environment has been updated (e.g. to update visualization)
    notifyEnvUpdated(this);
end

Optional Functions

You can define any other functions in your template class as required. For example, you can create helper functions that are called by either step or reset. The cart-pole template model implements a getReward function for computing the reward each time step.

function Reward = getReward(this)
    if ~this.IsDone
        Reward = this.RewardForNotFalling;
    else
        Reward = this.PenaltyForFalling;
    end          
end

Environment Visualization

You can add a visualization to your custom environment by implementing the plot function. In the plot function:

  • Create a figure or an instance of a visualizer class of your own implementation. For this example, create a figure and save the handle as a property of the environment.

  • Call the envUpdatedCallback function.

function plot(this)
    % Initiate the visualization
    this.h = figure;
          
    % Update the visualization
    envUpdatedCallback(this)
end

In the envUpdatedCallback, plot the visualization to the figure or using your custom visualizer object. For example, check if the figure handle has been set. If it has, then plot the visualization.

function envUpdatedCallback(this)
    % Set the visualization figure as the current figure
    figure(this.h)
    clf

    % Extract the cart position and pole angle
    X = this.State(2);
    theta = this.State(3);

    % Plot the cart
    cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);
    cartpoly = translate(cartpoly,[X 0]);
    plot(cartpoly,'FaceColor',[0.8500 0.3250 0.0980])
    hold on

    % Plot the pole
    L = this.HalfPoleLength*2;
    polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);
    polepoly = translate(polepoly,[X,0]);
    polepoly = rotate(polepoly,rad2deg(theta),[X,0]);
    plot(polepoly,'FaceColor',[0 0.4470 0.7410])

    hold off
    xlim([-3 3])
    ylim([-1 2])
end

The environment calls the envUpdatedCallback function, and therefore updates the visualization, whenever the environment is updated.

Create Custom Environment

Once you have defined your custom environment class, create an instance of it in the MATLAB workspace. At the command line, type:

env = MyEnvironment;

If your constructor has input arguments, specify them after the class name. For example, MyEnvironment(arg1,arg2).

After creating your environment, it is best practice to validate the environment dynamics. To do so, use the validateEnvironment function, which prints an error to the command window if there are any issues with your environment implementation.

validateEnvironment(env)

After validating the environment object, you can use it to train a reinforcement learning agent. For more information on training agents, see Train Reinforcement Learning Agents.

See Also

|

Related Topics