Train Agent to Play Turn-Based Game
This example shows you how to train a deep Q-network (DQN) reinforcement learning agent to play a turn-based multi-player game.
Overview
Environments in which agent execution is controlled by a turn-based logic are known as turn-based environments. Turns are analogous to environment steps, and one or more agents may take actions at a given turn. For example, in a two-player game such as chess, each turn requires one player (or agent) to take action.
In this example, you create a turn-based multi-agent environment using MATLAB® functions and train an agent to play against a policy that takes random actions.
Environment
The environment in this example is a simple turn-based two-player game. For this environment:
The game contains a 3-by-3 grid with nine cells.
Two adversarial players (agents) play the game. At each turn, a player marks a location in the grid. Player one has red squares and player two has blue circles.
The observation of each agent is a vector obtained from the 3-by-3 grid by flattening it in column-major fashion. The vector contains zeros in unmarked locations, ones in locations marked by the agent, and negative ones in locations marked by the opponent agent.
The action of an agent is the grid location to be marked (an integer between one and nine, inclusive).
The game is won if an agent has marked three consecutive cells (a horizontal row, vertical column, or any diagonal).
The reward for the active agent is 10 if the agent wins the game or one if the agent marks two consecutive grid cells.
The opponent agent receives the negative reward of the active agent.
If an agent takes an illegal move like marking an already occupied cell, it receives a reward of
-10
, and the simulation is terminated.
First, create the observation specification for each agent. The observation is of a vector of 9
elements.
oinfo = rlNumericSpec([9,1]);
Next, create the action specification for each agent, which is a set of discrete cell indices.
ainfo = rlFiniteSetSpec({1;2;3;4;5;6;7;8;9});
The environment's observation and action specifications are cell arrays. Create the arrays.
obsInfo = {oinfo, oinfo}; actInfo = {ainfo, ainfo};
Create a turn-based environment using rlTurnBasedFunctionEnv
. The command takes as input arguments the observation and action specifications of the agent and function handles for the step and reset operations.
Create the environment object. The functions stepGame
and resetGame
are provided at the end of this example.
env = rlTurnBasedFunctionEnv(obsInfo, actInfo, @stepGame, @resetGame)
env = rlTurnBasedFunctionEnv with properties: StepFcn: @stepGame ResetFcn: @resetGame Info: [1x1 struct]
Add a listener for changes to the Info
property of the environment. When the Info
property is updated during simulation, the function hPlotGame
plots the updated environment. To turn off visualization, set doPlot
to false
.
doPlot = true; if doPlot addlistener(env,"Info","PostSet",@(~,~) hPlotGame(env)); end
Agents
Fix the random seed for reproducibility.
rng(0);
Two adversarial agents play the game.
An agent generating random allowed actions controls the blue circles in the game. This agent is not trained.
A deep Q-network (DQN) agent controls the red squares in the game. You train this agent to win against the random policy.
Create a custom agent object using the AgentWithRandomActions
class. For more information see Create Custom Reinforcement Learning Agents.
randomAgent = AgentWithRandomActions(oinfo,ainfo);
View the implementation of AgentWithRandomActions
.
type("AgentWithRandomActions.m")
classdef AgentWithRandomActions < rl.agent.CustomAgent % AgentWithRandomActions models a custom agent with random actions. % Copyright 2023 The MathWorks, Inc. properties (Access = private) SampleTime_ end methods function this = AgentWithRandomActions(obsinfo, actinfo) setObservationInfo_(this, obsinfo); setActionInfo_(this,actinfo); this.SampleTime = 1; end end methods (Access = protected) function learnImpl(~,~) % no op because the agent does not learn end function action = getActionWithExplorationImpl(this, obs) action = getActionImpl(this, obs); end function action = getActionImpl(~, obs) % Generate random actions. x = obs{1}; % legal moves are positions which are unmarked (value is 0) legalmoves = find(x==0); % choose a random legal move randidx = randperm(numel(legalmoves),1); action = legalmoves(randidx); end function ts = getSampleTime_(this) ts = this.SampleTime_; end function this = setSampleTime_(this,ts) this.SampleTime_ = ts; end end end
Next, specify options for training the DQN agent. The agent learns using the standard deep Q-network algorithm (by setting UseDoubleDQN
to false
) with a mini-batch size of 64
and a discount factor of 0.99
, which favors long-term rewards. For more information see rlDQNAgentOptions
.
agentOpts = rlDQNAgentOptions( ... UseDoubleDQN=false, ... MiniBatchSize=64, ... DiscountFactor=0.99);
Specify options for the critic optimization algorithm. The critic learns using the adam
algorithm with a learning rate of 1e-4
.
agentOpts.CriticOptimizerOptions.Algorithm = "adam";
agentOpts.CriticOptimizerOptions.LearnRate = 1e-4;
Specify exploration options for training. The agent uses an epsilon-greedy exploration strategy with an initial Epsilon
value of 0.9
, which exponentially decays at the rate of 1e-4
until it reaches the value of 0.01
.
agentOpts.EpsilonGreedyExploration.Epsilon = 0.9; agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-4; agentOpts.EpsilonGreedyExploration.EpsilonMin = 0.01;
Set the length of the experience buffer as 1e6
. This ensures that the agent learns from a large set of experiences.
agentOpts.ExperienceBufferLength = 1e6;
Create initialization options for the agent. The agent's critic uses a neural network model with a hidden layer size of 256
.
initOpts = rlAgentInitializationOptions(NumHiddenUnit=256);
Create a default DQN agent object. For more information see rlDQNAgent
.
agent = rlDQNAgent(oinfo, ainfo, initOpts, agentOpts);
View the critic network of the agent.
critic = getCritic(agent); criticNet = getModel(critic); plot(layerGraph(criticNet));
View a summary of the critic network.
summary(criticNet);
Initialized: true Number of learnables: 70.6k Inputs: 1 'input_1' 9 features
Training
In this example, only the agent controlling the red squares is trained.
To evaluate the simulation behavior of the agent during training, configure an evaluator object that periodically computes an evaluation score. In this example the evaluator object runs 25
evaluation simulations every 100
training episodes with different random seeds and computes the minimum episode reward as the evaluation score. The evaluation score may be used as a criterion to save agents during training. For more information on evaluation objects, see rlEvaluator
.
evaluator = rlEvaluator( ... EvaluationFrequency=100, ... EvaluationStatisticType="MinEpisodeReward", ... NumEpisodes=25, ... RandomSeeds=0:24);
For this training session:
Run the training for
5000
episodes.Specify
AgentGroups
as a cell array of agent indices. Omit the index of the agent that is not trained. Note that the index must correspond to the order of agents specified in thetrain
function.Compute the average episodic rewards using a moving window of size
20
.Save the agents from the episodes where the evaluation score is
9
or higher. This score is close to the reward received (10
) when the game is won and may indicate that the policy has learned sufficiently.
trainOpts = rlMultiAgentTrainingOptions( ... AgentGroups={1}, ... MaxEpisodes=5000, ... ScoreAveragingWindowLength=20, ... StopTrainingCriteria="none", ... SaveAgentCriteria="EvaluationStatistic", ... SaveAgentValue=9);
Train the agents using the train
function. Training can take several hours to complete depending on the available computational power. To save time, load the MAT-file twoPlayerGameAgent.mat
, which contains a set of pretrained agents. To train the agents yourself, set doTraining
to true
.
doTraining = false; if doTraining trainResults = train([agent,randomAgent], env, ... trainOpts, Evaluator=evaluator); else load("twoPlayerGameAgent.mat"); end
A snapshot of the training progress is shown in the following figure. You may see different results from your training process.
Simulation
Reset the random seed generator for reproducibility.
rng(0);
Simulate the trained agent with the environment. For more information on agent simulation, see rlSimulationOptions
and sim
.
simOptions = rlSimulationOptions(); experience = sim(env,[agent,agent],simOptions);
The visualization shows the sequence of turns for the game. The trained agent has learned the actions to win the game.
Local Functions
Custom reset function for the two-player game. The reset function initializes the environment, randomly selects an agent for the first turn, and returns the initial observations of the agents along with the info
variable that passes information between simulation steps.
function [initialObs, info] = resetGame() % Reset the two-player game. % The state is a 3x3 matrix with 0s in unmarked cells, % -1s in cells marked with squares and % 1s in cells marked with circles. info.State = zeros(3,3); % The initial turn is randomly selected. info.ActiveAgentIndex = randperm(2,1); % The current environment step count. info.StepCount = 0; % Flag to keep track of invalid action. info.IsInvalidAction = false; % The initial observation is the state % of the environment as seen by each agent. initialObs = {info.State(:), -info.State(:)}; end
Custom step function for the two-player game. This function steps the environment dynamics to the next state.
function [nextobs, reward, isdone, info] = stepGame(action, info) % Get the active agent for the current turn. agentIdx = info.ActiveAgentIndex; % Player cell values: % Player 1 (red square) is identified by the value -1. % Player 2 (blue circle) is identified by the value 1. playerVals = [-1,1]; % Get current player value. currentVal = playerVals(agentIdx); % Index of cell for the current player. % action{1} is the move of the current player, % that is, the index (between 1 and 9) % of new grid location to be marked. playerCell = action{1}; % Advance environment to the next state. state = info.State; if state(playerCell) == 0 % Move to next state. state(playerCell) = currentVal; % Compute reward and terminal condition. [rwd, isdone] = hComputeRewardAndIsDone( ... state, playerCell, currentVal); else % For an illegal action, % terminate with a large penalty. rwd = -10; isdone = true; info.IsInvalidAction = true; end % Compute the next agent index. nextIdx = mod(agentIdx,2) + 1; % The active agent receives the reward rwd, % and the other agent receives the reward -rwd. reward = [0,0]; reward(agentIdx) = rwd; reward(nextIdx) = -rwd; % Next observation. nextobs = {state(:), -state(:)}; % Update the info structure with: % 1. The state for the next step. % 2. The agent turn for the next step. % 3. The step count. info.State = state; info.ActiveAgentIndex = nextIdx; info.StepCount = info.StepCount + 1; end
Helper function to compute reward and terminal conditions.
function [reward, isover] = hComputeRewardAndIsDone( ... state, index, player) % Advance to the next state of the game by marking a cell. % % - If the agent marks three adjacent cells, reward is +10. % - If the agent marks two adjacent cells with the third % cell in line unmarked then reward is +1. % - Otherwise, reward is 0. sz = size(state); reward = 0; isover = false; % Current row and column of active agent. [r,c] = ind2sub(sz, index); function nestedComputeReward_(arr) % arr is a row or column or a diagonal. if all(arr==player) % Row/column/diagonal complete, game over. reward = reward + 10; isover = true; elseif sum(arr==player)==2 && any(arr==0) % Player marks adjacent cells, % possible win next turn. reward = reward + 1; end end % Horizontal cells. hcells = state(r,:); nestedComputeReward_(hcells); % Vertical cells. vcells = state(:,c); nestedComputeReward_(vcells); % Check the two diagonals: % r==c is the main diagonal, % r+c==4 is the other diagonal. if r==c || r+c==4 % Main diagonal cells. mdcells = [state(1,1) state(2,2) state(3,3)]; nestedComputeReward_(mdcells); % Other diagonal cells. odcells = [state(3,1) state(2,2) state(1,3)]; nestedComputeReward_(odcells); end % If all cells are marked the game is over. if ~any(state==0) isover = true; end end