Unexpected loss reduction using custom training loop in Deep Learning Toolbox
29 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
MathWorks Support Team
el 19 de Jul. de 2023
Respondida: MathWorks Support Team
el 3 de Ag. de 2023
I have created a custom training loop following the documentation example: https://www.mathworks.com/help/releases/R2023a/deeplearning/ug/train-network-using-custom-training-loop.html
However, since I use the same loss function for training and validation, I have altered the "modelloss" function so the "forward" function is outside of the function. For example:
[Y, state] = forward(net, X)
[loss,gradient] = dlfeval(@modelLoss,net,Y,T);
function [loss,gradients] = modelLoss(net,Y,T)
% Calculate cross-entropy loss.
loss = crossentropy(Y,T);
% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss,net.Learnables);
end
Now the resulting loss during training is not reducing as expected. How can I resolve this issue?
Respuesta aceptada
MathWorks Support Team
el 19 de Jul. de 2023
When the "dlgradient" function is used inside a second function which is called by "dlfeval", automatic differentiation is used to calculate the gradients. The "dlfeval" function traces the operations when calculating the gradient and therefore, for the loss to be calculated correctly, the functions related to finding the gradient (e.g. "forward") must remain inside the "modelloss" function called by the "dlfeval" function.
Please refer to the following documentation page for more information on automatic differentiation in Deep Learning Toolbox:
Moving the "forward" function back inside the "modelLoss" function will resolve the issue. Additionally, since, the gradient is not required for validation, using the "dlfeval" function to calculate the validation loss introduces unnecessary overhead and decreases performance.
0 comentarios
Más respuestas (0)
Ver también
Categorías
Más información sobre Operations en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!