Borrar filtros
Borrar filtros

Error while creating custom critic function

3 visualizaciones (últimos 30 días)
Harsh
Harsh el 26 de Feb. de 2024
Respondida: Amal Raj el 14 de Mzo. de 2024
customCriticNetwork = [
imageInputLayer([1 1 1], 'Normalization', 'none', 'Name', 'observation')
fullyConnectedLayer(400, 'Name', 'CriticFC1', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)
reluLayer('Name', 'CriticRelu1')
fullyConnectedLayer(300, 'Name', 'CriticFC2', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)
reluLayer('Name', 'CriticRelu2')
fullyConnectedLayer(1, 'Name', 'output', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)]; % Change the output size to 1
% Create the custom critic dlnetwork
dlnet = dlnetwork(customCriticNetwork);
% Create the custom critic
customCritic = rlQValueFunction(dlnet, obsInfo, actInfo);
Error using rlQValueFunction
Number of input layers for state-action-value function deep neural network must equal the number of
observation and action specifications.
Error in RL_Agent (line 37)
customCritic = rlQValueFunction(dlnet, obsInfo, actInfo);

Respuestas (1)

Amal Raj
Amal Raj el 14 de Mzo. de 2024
Hey Harsh.
The error message suggests that the number of input layers in your custom critic network does not match the number of observation and action specifications. To resolve this issue, you need to ensure that the number of input layers in your custom critic network matches the number of observation and action specifications.
Here's an example of how you can modify your custom critic network to match the observation and action specifications:
% Define observation and action specifications
obsInfo = rlNumericSpec([1 1 1]);
actInfo = rlFiniteSetSpec([1 2 3]);
% Create custom critic network
customCriticNetwork = [
imageInputLayer([1 1 1], 'Normalization', 'none', 'Name', 'observation')
fullyConnectedLayer(400, 'Name', 'CriticFC1', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)
reluLayer('Name', 'CriticRelu1')
fullyConnectedLayer(300, 'Name', 'CriticFC2', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)
reluLayer('Name', 'CriticRelu2')
fullyConnectedLayer(1, 'Name', 'output', 'WeightLearnRateFactor', 1e-4, 'BiasLearnRateFactor', 1e-4)];
% Check the number of input layers
numInputLayers = numel(customCriticNetwork(1).InputSize);
% Adjust the number of input layers to match the observation and action specifications
customCriticNetwork(1).InputSize = obsInfo.Dimension;
customCriticNetwork(1).Name = obsInfo.Name;
customCriticNetwork(numInputLayers).InputSize = actInfo.Dimension;
customCriticNetwork(numInputLayers).Name = actInfo.Name;
% Create the custom critic dlnetwork
dlnet = dlnetwork(customCriticNetwork);
% Create the custom critic
customCritic = rlQValueFunction(dlnet, obsInfo, actInfo);
This code snippet ensures that the number of input layers in the custom critic network matches the observation and action specifications provided by obsInfo and actInfo.

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Etiquetas

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by