how to define custom loss function like tf.train.A​​damOptimi​z​er().min​im​ize(los​s) ?

4 visualizaciones (últimos 30 días)
I put X=[4,100] into LSTM and received the last hidden state[4,1]. I use this state vector to estimate Q. I want to use Adam Optimizer to minimize a custom loss function. But options offered by official tool cannot do it. So I want to define a loss function and use Adam Optimizer to min the loss. How can I do it like python:train = tf.train.AdamOptimizer().minimize(loss)?
z=[0.01,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,0.99];
z_normal=norminv(z);
numFeatures = 4;
A=4;
X = reshape(X_test.',4,100,[]);
Y = reshape(Y_test.',1,[]);
dlX = dlarray(X,'CBT');
dlY = dlarray(Y,'BT');
numHiddenUnits = 16;
outputdim=4;
H0 = zeros(numHiddenUnits,1);
C0 = zeros(numHiddenUnits,1);
weights = dlarray(randn(4*numHiddenUnits,numFeatures),'CU');
recurrentWeights = dlarray(randn(4*numHiddenUnits,numHiddenUnits),'CU');
bias = dlarray(randn(4*numHiddenUnits,1),'C');
[outputs,hiddenState,cellState] = lstm(dlX,H0,C0,weights,recurrentWeights,bias);
state_h=hiddenState(:,end);
w_out=normrnd(0,1,[outputdim,numHiddenUnits]);
b_out=normrnd(0,1,[outputdim,1]);
out=w_out*state_h+b_out;
params=out+[0;1;1;1];
mu=params(1,:);
sig=params(2,:);
utail=params(3,:);
vtail=params(4,:);
factor1=exp(z_normal.'*utail)/A+1;
factor2=exp(z_normal.'*utail)/A+1;
factor=factor1.*factor2;
Q=factor.*z_normal.'.*sig+mu;
error=dlY-Q;
error1=z.'.*error;
error2=(z.'-1).*error;
loss=mean(max(error1,error2));

Respuestas (1)

Sai Bhargav Avula
Sai Bhargav Avula el 18 de Feb. de 2020
Hi,
You can create custom loss function by creating a function of the form loss = myLoss(Y,T), where Y is the network predictions, T are the targets. The loss can be used to update the gradients in the modelGradient function.
This link explains how to create custom layers and loss functions in matlab
Hope this helps!

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Productos


Versión

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by