Borrar filtros
Borrar filtros

PPO and LSTM agent creation

16 visualizaciones (últimos 30 días)
Sourabh
Sourabh el 12 de Dic. de 2023
Comentada: Sourabh el 23 de Dic. de 2023
I am trying to implement PPO and LSTM and I am getting the error as
"To train an agent that has states, all actor and critic representations for that agent must have states."
but i am giving the states as input then why am i getting the error.
-----------------------------------------------
obsInfo = rlNumericSpec([3 1],...
'LowerLimit',[ -inf -inf -inf ]',...
'UpperLimit',[ inf inf inf ]');
numObservations = obsInfo.Dimension(1);
actInfo = rlNumericSpec([2 1],...
'LowerLimit',[ -inf -inf ]',...
'UpperLimit',[ inf inf ]');
numActions = actInfo.Dimension(1);
env = rlSimulinkEnv('simmodelppo','simmodelppo/RL Agent',...
obsInfo,actInfo);
env.ResetFcn = @(in)localResetFcn(in);
rng(0)
criticNetwork = [
sequenceInputLayer(prod(obsInfo.Dimension));
fullyConnectedLayer(64);
tanhLayer;
fullyConnectedLayer(64);
tanhLayer;
lstmLayer(5,'Name','lstm1');
fullyConnectedLayer(1)];
criticNetwork = dlnetwork(criticNetwork);
critic = rlValueFunction(criticNetwork,obsInfo);
commonPath = [
featureInputLayer(prod(obsInfo.Dimension),Name="comPathIn")
fullyConnectedLayer(150)
tanhLayer
fullyConnectedLayer(1,Name="comPathOut")
];
% Define mean value path
meanPath = [
fullyConnectedLayer(50,Name="meanPathIn")
tanhLayer
fullyConnectedLayer(50,Name="fc_2")
tanhLayer
lstmLayer(5,'Name','lstm1');
fullyConnectedLayer(prod(actInfo.Dimension))
leakyReluLayer(0.01,Name="meanPathOut")
];
% Define standard deviation path
sdevPath = [
fullyConnectedLayer(50,"Name","stdPathIn")
tanhLayer
lstmLayer(5,'Name','lstm2');
fullyConnectedLayer(prod(actInfo.Dimension));
reluLayer
scalingLayer(Scale=0.9,Name="stdPathOut")
];
% Add layers to layerGraph object
actorNet = layerGraph(commonPath);
actorNet = addLayers(actorNet,meanPath);
actorNet = addLayers(actorNet,sdevPath);
% Connect paths
actorNet = connectLayers(actorNet,"comPathOut","meanPathIn/in");
actorNet = connectLayers(actorNet,"comPathOut","stdPathIn/in");
actorNetwork = dlnetwork(actorNet);
actor = rlContinuousGaussianActor(actorNetwork, obsInfo, actInfo, ...
"ActionMeanOutputNames","meanPathOut",...
"ActionStandardDeviationOutputNames","stdPathOut",...
ObservationInputNames="comPathIn");
%%
agentOpts = rlPPOAgentOptions(...
'SampleTime',800,...
'ClipFactor',0.2,...
'NumEpoch',3,...
'EntropyLossWeight',0.025,...
'AdvantageEstimateMethod','finite-horizon',...
'DiscountFactor',0.99, ...
'MiniBatchSize',64, ...
'ExperienceHorizon',128);
agent = rlPPOAgent(actor,critic,agentOpts);
agent.AgentOptions.CriticOptimizerOptions.LearnRate = 0.0001;
agent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;
agent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;
agent.AgentOptions.ActorOptimizerOptions.LearnRate = 0.0003;
%%
maxepisodes = 6000;
maxsteps = 1;
trainingOpts = rlTrainingOptions(...
'MaxEpisodes',maxepisodes,...
'MaxStepsPerEpisode',1,...
'ScoreAveragingWindowLength',5, ...
'Verbose',false,...
'Plots','training-progress',...
'StopTrainingCriteria','AverageReward',...
'StopTrainingValue',-20);
% TO TRAIN
doTraining = true;
if doTraining
trainingStats = train(agent,env,trainingOpts);
% save('agent_new.mat','agent') %%% to save agent ###
else
% Load pretrained agent for the example.
load('agent_old.mat','agent')
end
%%
function in = localResetFcn(in) %%%%%%% RANDOM INPUT GENERATOR %%%%%%%
% randomize setpoints -- ensure feasible
set_point = 1 + 0.0*rand; % Set-point [0,1]
in = setBlockParameter(in,'simmodelppo/Set','Value',num2str(set_point));
end

Respuestas (2)

Venu
Venu el 19 de Dic. de 2023
Editada: Venu el 20 de Dic. de 2023
The error is likely occurring because the LSTM layers require explicit handling of their states, which is not just about feeding in the external states (observations) but also about managing the internal states of the LSTM layers. In the code you provided, there is a local function "localResetFcn" that is used as the environment's reset function. However, this function currently only sets the parameter 'Value' for the block 'simmodelppo/Set' in your Simulink model.
In the code snippet you've provided, there is no explicit handling of the LSTM states, so unless the "rlContinuousGaussianActor" and "rlValueFunction" objects handle this internally in a way that is not shown, you would need to add this functionality.
Here's an example of how you might modify your "localResetFcn" to include resetting the LSTM states:
function in = localResetFcn(in) %%%%%%% RANDOM INPUT GENERATOR %%%%%%%
% randomize setpoints -- ensure feasible
set_point = 1 + 0.0*rand; % Set-point [0,1]
in = setBlockParameter(in, 'simmodelppo/Set', 'Value', num2str(set_point));
% Reset the LSTM states of the actor and critic networks
% Note: The following is an example and may need adjustment
% to match the specifics of your MATLAB version and network setup.
% Reset actor LSTM states
actor = getActor(agent);
actor = resetState(actor);
agent = setActor(agent, actor);
% Reset critic LSTM states
critic = getCritic(agent);
critic = resetState(critic);
agent = setCritic(agent, critic);
end
Since the "localResetFcn" currently does not have access to the 'agent' variable, you would need to modify your training setup to ensure that the agent's states can be reset at the start of each episode. This might involve changes to how the 'agent' variable is passed around or stored.
  4 comentarios
Venu
Venu el 20 de Dic. de 2023
Editada: Venu el 23 de Dic. de 2023
does the issue still persist?
Sourabh
Sourabh el 23 de Dic. de 2023
its fixed now I had to change feature inputlayer to sequence i/p layer in critic and actor

Iniciar sesión para comentar.


Emmanouil Tzorakoleftherakis
Emmanouil Tzorakoleftherakis el 21 de Dic. de 2023
Hi,
With lstm policies, BOTH the actor and the critic should have lstm layers. That's why you are getting this error.
LSTM policies tend to be harder to architect, so I would siggest using the default agent feature to get an initial architecture. See for example here. Don't forget to indicate you want an rnn policy in the agent initialization options.
Hope this helps

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by