how to define custom loss function like tf.train.AdamOptimizer().minimize(loss) ?
4 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
I put X=[4,100] into LSTM and received the last hidden state[4,1]. I use this state vector to estimate Q. I want to use Adam Optimizer to minimize a custom loss function. But options offered by official tool cannot do it. So I want to define a loss function and use Adam Optimizer to min the loss. How can I do it like python:train = tf.train.AdamOptimizer().minimize(loss)?
z=[0.01,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.99];
z_normal=norminv(z);
numFeatures = 4;
A=4;
X = reshape(X_test.',4,100,[]);
Y = reshape(Y_test.',1,[]);
dlX = dlarray(X,'CBT');
dlY = dlarray(Y,'BT');
numHiddenUnits = 16;
outputdim=4;
H0 = zeros(numHiddenUnits,1);
C0 = zeros(numHiddenUnits,1);
weights = dlarray(randn(4*numHiddenUnits,numFeatures),'CU');
recurrentWeights = dlarray(randn(4*numHiddenUnits,numHiddenUnits),'CU');
bias = dlarray(randn(4*numHiddenUnits,1),'C');
[outputs,hiddenState,cellState] = lstm(dlX,H0,C0,weights,recurrentWeights,bias);
state_h=hiddenState(:,end);
w_out=normrnd(0,1,[outputdim,numHiddenUnits]);
b_out=normrnd(0,1,[outputdim,1]);
out=w_out*state_h+b_out;
params=out+[0;1;1;1];
mu=params(1,:);
sig=params(2,:);
utail=params(3,:);
vtail=params(4,:);
factor1=exp(z_normal.'*utail)/A+1;
factor2=exp(z_normal.'*utail)/A+1;
factor=factor1.*factor2;
Q=factor.*z_normal.'.*sig+mu;
error=dlY-Q;
error1=z.'.*error;
error2=(z.'-1).*error;
loss=mean(max(error1,error2));
1 comentario
Respuestas (1)
Sai Bhargav Avula
el 18 de Feb. de 2020
Hi,
You can create custom loss function by creating a function of the form loss = myLoss(Y,T), where Y is the network predictions, T are the targets. The loss can be used to update the gradients in the modelGradient function.
This link explains how to create custom layers and loss functions in matlab
Hope this helps!
0 comentarios
Ver también
Categorías
Más información sobre Deep Learning Toolbox en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!