離散cartpole環境が正常に学習しない

33 visualizaciones (últimos 30 días)
ryuuzi
ryuuzi el 25 de Oct. de 2024 a las 11:39
Respondida: Hiro Yoshino el 5 de Nov. de 2024 a las 7:17
「create custom environment from class template」を参考に離散cartpole環境を作成して、強化学習デザイナーにインポートさせてみました。
しかし、学習が安定に収束してくれませんでした。試行錯誤してみましたが、対処法が思いつきませんでした。
教えてください
classdef matlab < rl.env.MATLABEnvironment
properties
% Acceleration due to gravity in m/s^2
Gravity = 9.8
% Mass of the cart
MassCart = 1.0
% Mass of the pole
MassPole = 0.1
% Half the length of the pole
Length = 0.5
% Max Force the input can appy
MaxForce = 10
% Sample time
Ts = 0.02
% Angle at which to fail the episode
ThetaThresholdRadians = 12 * pi/180
% Distance at which to fail the episode
XThreshold = 2.4
% Reward each time step the cart-pole is balanced
RewardForNotFalling = 1
% Penalty when the cart-pole fails to balance
PenaltyForFalling = -5
end
properties
% system state [x,dx,theta,dtheta]'
State = zeros(4,1)
end
properties(Access = protected)
% Internal flag to store stale env that is finished
IsDone = false
end
properties (Transient,Access = private)
Visualizer = []
end
methods
function this = matlab()%ObservationInfo, ActionInfo
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
ActionInfo = rlFiniteSetSpec([-1 1]);
ActionInfo.Name = 'CartPole Action';
this = this@rl.env.MATLABEnvironment(ObservationInfo, ActionInfo);
updateActionInfo(this);
end
function set.State(this,state)
validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');
this.State = double(state(:));
notifyEnvUpdated(this);
end
function set.Length(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Length');
this.Length = val;
notifyEnvUpdated(this);
end
function set.Gravity(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');
this.Gravity = val;
end
function set.MassCart(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassCart');
this.MassCart = val;
end
function set.MassPole(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassPole');
this.MassPole = val;
end
function set.MaxForce(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');
this.MaxForce = val;
updateActionInfo(this);
end
function set.Ts(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');
this.Ts = val;
end
function set.ThetaThresholdRadians(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','ThetaThresholdRadians');
this.ThetaThresholdRadians = val;
notifyEnvUpdated(this);
end
function set.XThreshold(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','XThreshold');
this.XThreshold = val;
notifyEnvUpdated(this);
end
function set.RewardForNotFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');
this.RewardForNotFalling = val;
end
function set.PenaltyForFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');
this.PenaltyForFalling = val;
end
function [observation,reward,isdone,loggedSignals] = step(this,action)
loggedSignals = [];
% Get action
force = getForce(this,action);
% Unpack state vector
state = this.State;
%x = state(1);
x_dot = state(2);
theta = state(3);
theta_dot = state(4);
% Apply motion equations
costheta = cos(theta);
sintheta = sin(theta);
totalmass = this.MassCart + this.MassPole;
polemasslength = this.MassPole*this.Length;
temp = (force + polemasslength * theta_dot * theta_dot * sintheta) / totalmass;
thetaacc = (this.Gravity * sintheta - costheta* temp) / (this.Length * (4.0/3.0 - this.MassPole * costheta * costheta / totalmass));
xacc = temp - polemasslength * thetaacc * costheta / totalmass;
% Euler integration
observation = state + this.Ts.*[x_dot;xacc;theta_dot;thetaacc];
this.State = observation;
x = observation(1);
theta = observation(3);
isdone = abs(x) > this.XThreshold || abs(theta) > this.ThetaThresholdRadians;
this.IsDone = isdone;
% Get reward
reward = getReward(this,x,force);
end
function initialState = reset(this)
% Randomize the initial pendulum angle between (+- .05 rad)
% Theta (+- .05 rad)
T0 = 2*0.05*rand - 0.05;
% Thetadot
Td0 = 0;
% X
X0 = 0;
% Xdot
Xd0 = 0;
initialState= [X0;Xd0;T0;Td0];
this.State = initialState;
end
function varargout = plot(this)
% Visualizes the environment
if isempty(this.Visualizer) || ~isvalid(this.Visualizer)
this.Visualizer = rl.env.viz.CartPoleVisualizer(this);
else
bringToFront(this.Visualizer);
end
if nargout
varargout{1} = this.Visualizer;
end
end
end
methods (Access = protected)
function force = getForce(this,action)
if ~ismember(action,this.ActionInfo.Elements)
error(message('rl:env:CartPoleDiscreteInvalidAction',sprintf('%g',-this.MaxForce),sprintf('%g',this.MaxForce)));
end
force = action;
end
% update the action info based on max force
function updateActionInfo(this)
this.ActionInfo.Elements = this.MaxForce*[-1 10];
end
function Reward = getReward(this,~,~)
if ~this.IsDone
Reward = this.RewardForNotFalling;
else
Reward = this.PenaltyForFalling;
end
end
end
end

Respuesta aceptada

Hiro Yoshino
Hiro Yoshino el 5 de Nov. de 2024 a las 7:17
に離散 cartpole が有るので、動作するものを開いて中身を調べてみると参考になる (答えが有る) かもしれません

Más respuestas (0)

Categorías

Más información sobre ビッグ データの処理 en Help Center y File Exchange.

Etiquetas

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!