# Train Agent to Play Turn-Based Game

This example shows you how to train a deep Q-network (DQN) reinforcement learning agent to play a turn-based multi-player game.

### Overview

Environments in which agent execution is controlled by a turn-based logic are known as turn-based environments. Turns are analogous to environment steps, and one or more agents may take actions at a given turn. For example, in a two-player game such as chess, each turn requires one player (or agent) to take action.

In this example, you create a turn-based multi-agent environment using MATLAB® functions and train an agent to play against a policy that takes random actions.

### Environment The environment in this example is a simple turn-based two-player game. For this environment:

• The game contains a 3-by-3 grid with nine cells.

• Two adversarial players (agents) play the game. At each turn, a player marks a location in the grid. Player one has red squares and player two has blue circles.

• The observation of each agent is a vector obtained from the 3-by-3 grid by flattening it in column-major fashion. The vector contains zeros in unmarked locations, ones in locations marked by the agent, and negative ones in locations marked by the opponent agent.

• The action of an agent is the grid location to be marked (an integer between one and nine, inclusive).

• The game is won if an agent has marked three consecutive cells (a horizontal row, vertical column, or any diagonal).

• The reward for the active agent is 10 if the agent wins the game or one if the agent marks two consecutive grid cells.

• The opponent agent receives the negative reward of the active agent.

• If an agent takes an illegal move like marking an already occupied cell, it receives a reward of `-10`, and the simulation is terminated.

First, create the observation specification for each agent. The observation is of a vector of `9` elements.

`oinfo = rlNumericSpec([9,1]);`

Next, create the action specification for each agent, which is a set of discrete cell indices.

`ainfo = rlFiniteSetSpec({1;2;3;4;5;6;7;8;9});`

The environment's observation and action specifications are cell arrays. Create the arrays.

```obsInfo = {oinfo, oinfo}; actInfo = {ainfo, ainfo};```

Create a turn-based environment using `rlTurnBasedFunctionEnv`. The command takes as input arguments the observation and action specifications of the agent and function handles for the step and reset operations.

Create the environment object. The functions `stepGame` and `resetGame` are provided at the end of this example.

`env = rlTurnBasedFunctionEnv(obsInfo, actInfo, @stepGame, @resetGame)`
```env = rlTurnBasedFunctionEnv with properties: StepFcn: @stepGame ResetFcn: @resetGame Info: [1x1 struct] ```

Add a listener for changes to the `Info` property of the environment. When the `Info` property is updated during simulation, the function `hPlotGame` plots the updated environment. To turn off visualization, set `doPlot` to `false`.

```doPlot = true; if doPlot addlistener(env,"Info","PostSet",@(~,~) hPlotGame(env)); end```

### Agents

Fix the random seed for reproducibility.

`rng(0);`

Two adversarial agents play the game.

• An agent generating random allowed actions controls the blue circles in the game. This agent is not trained.

• A deep Q-network (DQN) agent controls the red squares in the game. You train this agent to win against the random policy.

Create a custom agent object using the `AgentWithRandomActions` class. For more information see Create Custom Reinforcement Learning Agents.

`randomAgent = AgentWithRandomActions(oinfo,ainfo);`

View the implementation of `AgentWithRandomActions`.

`type("AgentWithRandomActions.m")`
```classdef AgentWithRandomActions < rl.agent.CustomAgent % AgentWithRandomActions models a custom agent with random actions. % Copyright 2023 The MathWorks, Inc. properties (Access = private) SampleTime_ end methods function this = AgentWithRandomActions(obsinfo, actinfo) setObservationInfo_(this, obsinfo); setActionInfo_(this,actinfo); this.SampleTime = 1; end end methods (Access = protected) function learnImpl(~,~) % no op because the agent does not learn end function action = getActionWithExplorationImpl(this, obs) action = getActionImpl(this, obs); end function action = getActionImpl(~, obs) % Generate random actions. x = obs{1}; % legal moves are positions which are unmarked (value is 0) legalmoves = find(x==0); % choose a random legal move randidx = randperm(numel(legalmoves),1); action = legalmoves(randidx); end function ts = getSampleTime_(this) ts = this.SampleTime_; end function this = setSampleTime_(this,ts) this.SampleTime_ = ts; end end end ```

Next, specify options for training the DQN agent. The agent learns using the standard deep Q-network algorithm (by setting `UseDoubleDQN` to `false`) with a mini-batch size of `64` and a discount factor of `0.99`, which favors long-term rewards. For more information see `rlDQNAgentOptions`.

```agentOpts = rlDQNAgentOptions( ... UseDoubleDQN=false, ... MiniBatchSize=64, ... DiscountFactor=0.99);```

Specify options for the critic optimization algorithm. The critic learns using the `adam` algorithm with a learning rate of `1e-4`.

```agentOpts.CriticOptimizerOptions.Algorithm = "adam"; agentOpts.CriticOptimizerOptions.LearnRate = 1e-4;```

Specify exploration options for training. The agent uses an epsilon-greedy exploration strategy with an initial `Epsilon` value of `0.9`, which exponentially decays at the rate of `1e-4` until it reaches the value of `0.01`.

```agentOpts.EpsilonGreedyExploration.Epsilon = 0.9; agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-4; agentOpts.EpsilonGreedyExploration.EpsilonMin = 0.01;```

Set the length of the experience buffer as `1e6`. This ensures that the agent learns from a large set of experiences.

`agentOpts.ExperienceBufferLength = 1e6;`

Create initialization options for the agent. The agent's critic uses a neural network model with a hidden layer size of `256`.

`initOpts = rlAgentInitializationOptions(NumHiddenUnit=256);`

Create a default DQN agent object. For more information see `rlDQNAgent`.

`agent = rlDQNAgent(oinfo, ainfo, initOpts, agentOpts);`

View the critic network of the agent.

```critic = getCritic(agent); criticNet = getModel(critic); plot(layerGraph(criticNet));``` View a summary of the critic network.

`summary(criticNet);`
``` Initialized: true Number of learnables: 70.6k Inputs: 1 'input_1' 9 features ```

### Training

In this example, only the agent controlling the red squares is trained.

To evaluate the simulation behavior of the agent during training, configure an evaluator object that periodically computes an evaluation score. In this example the evaluator object runs `25` evaluation simulations every `100` training episodes with different random seeds and computes the minimum episode reward as the evaluation score. The evaluation score may be used as a criterion to save agents during training. For more information on evaluation objects, see `rlEvaluator`.

```evaluator = rlEvaluator( ... EvaluationFrequency=100, ... EvaluationStatisticType="MinEpisodeReward", ... NumEpisodes=25, ... RandomSeeds=0:24);```

For this training session:

• Run the training for `5000` episodes.

• Specify `AgentGroups` as a cell array of agent indices. Omit the index of the agent that is not trained. Note that the index must correspond to the order of agents specified in the `train` function.

• Compute the average episodic rewards using a moving window of size `20`.

• Save the agents from the episodes where the evaluation score is `9` or higher. This score is close to the reward received (`10`) when the game is won and may indicate that the policy has learned sufficiently.

```trainOpts = rlMultiAgentTrainingOptions( ... AgentGroups={1}, ... MaxEpisodes=5000, ... ScoreAveragingWindowLength=20, ... StopTrainingCriteria="none", ... SaveAgentCriteria="EvaluationStatistic", ... SaveAgentValue=9);```

Train the agents using the `train` function. Training can take several hours to complete depending on the available computational power. To save time, load the MAT-file `twoPlayerGameAgent.mat`, which contains a set of pretrained agents. To train the agents yourself, set `doTraining` to `true`.

```doTraining = false; if doTraining trainResults = train([agent,randomAgent], env, ... trainOpts, Evaluator=evaluator); else load("twoPlayerGameAgent.mat"); end```

A snapshot of the training progress is shown in the following figure. You may see different results from your training process. ### Simulation

Reset the random seed generator for reproducibility.

`rng(0);`

Simulate the trained agent with the environment. For more information on agent simulation, see `rlSimulationOptions` and `sim`.

```simOptions = rlSimulationOptions(); experience = sim(env,[agent,agent],simOptions);``` The visualization shows the sequence of turns for the game. The trained agent has learned the actions to win the game.

### Local Functions

Custom reset function for the two-player game. The reset function initializes the environment, randomly selects an agent for the first turn, and returns the initial observations of the agents along with the `info` variable that passes information between simulation steps.

```function [initialObs, info] = resetGame() % Reset the two-player game. % The state is a 3x3 matrix with 0s in unmarked cells, % -1s in cells marked with squares and % 1s in cells marked with circles. info.State = zeros(3,3); % The initial turn is randomly selected. info.ActiveAgentIndex = randperm(2,1); % The current environment step count. info.StepCount = 0; % Flag to keep track of invalid action. info.IsInvalidAction = false; % The initial observation is the state % of the environment as seen by each agent. initialObs = {info.State(:), -info.State(:)}; end```

Custom step function for the two-player game. This function steps the environment dynamics to the next state.

```function [nextobs, reward, isdone, info] = stepGame(action, info) % Get the active agent for the current turn. agentIdx = info.ActiveAgentIndex; % Player cell values: % Player 1 (red square) is identified by the value -1. % Player 2 (blue circle) is identified by the value 1. playerVals = [-1,1]; % Get current player value. currentVal = playerVals(agentIdx); % Index of cell for the current player. % action{1} is the move of the current player, % that is, the index (between 1 and 9) % of new grid location to be marked. playerCell = action{1}; % Advance environment to the next state. state = info.State; if state(playerCell) == 0 % Move to next state. state(playerCell) = currentVal; % Compute reward and terminal condition. [rwd, isdone] = hComputeRewardAndIsDone( ... state, playerCell, currentVal); else % For an illegal action, % terminate with a large penalty. rwd = -10; isdone = true; info.IsInvalidAction = true; end % Compute the next agent index. nextIdx = mod(agentIdx,2) + 1; % The active agent receives the reward rwd, % and the other agent receives the reward -rwd. reward = [0,0]; reward(agentIdx) = rwd; reward(nextIdx) = -rwd; % Next observation. nextobs = {state(:), -state(:)}; % Update the info structure with: % 1. The state for the next step. % 2. The agent turn for the next step. % 3. The step count. info.State = state; info.ActiveAgentIndex = nextIdx; info.StepCount = info.StepCount + 1; end```

Helper function to compute reward and terminal conditions.

```function [reward, isover] = hComputeRewardAndIsDone( ... state, index, player) % Advance to the next state of the game by marking a cell. % % - If the agent marks three adjacent cells, reward is +10. % - If the agent marks two adjacent cells with the third % cell in line unmarked then reward is +1. % - Otherwise, reward is 0. sz = size(state); reward = 0; isover = false; % Current row and column of active agent. [r,c] = ind2sub(sz, index); function nestedComputeReward_(arr) % arr is a row or column or a diagonal. if all(arr==player) % Row/column/diagonal complete, game over. reward = reward + 10; isover = true; elseif sum(arr==player)==2 && any(arr==0) % Player marks adjacent cells, % possible win next turn. reward = reward + 1; end end % Horizontal cells. hcells = state(r,:); nestedComputeReward_(hcells); % Vertical cells. vcells = state(:,c); nestedComputeReward_(vcells); % Check the two diagonals: % r==c is the main diagonal, % r+c==4 is the other diagonal. if r==c || r+c==4 % Main diagonal cells. mdcells = [state(1,1) state(2,2) state(3,3)]; nestedComputeReward_(mdcells); % Other diagonal cells. odcells = [state(3,1) state(2,2) state(1,3)]; nestedComputeReward_(odcells); end % If all cells are marked the game is over. if ~any(state==0) isover = true; end end```