Problem in automatic gradient calculation using Deep Learning Toolbox
1 visualización (últimos 30 días)
Mostrar comentarios más antiguos
GAURAV YADAV
el 8 de Sept. de 2021
Comentada: GAURAV YADAV
el 9 de Sept. de 2021
Hello,
Could anyone please tell me why the dlgradient function is throwing the error? It seems to me that somehow the loss is not getting traced to the input weights? Could anyone suggest why it is happening?
x0 = dlarray([1,2,3],'BC');
y0 = dlarray(14,'BC');
fcnn_graph = layerGraph;
layers = [featureInputLayer(3,"Normalization","none","Name","InputLayer")
fullyConnectedLayer(5,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc1")
tanhLayer("Name","active_1")
fullyConnectedLayer(5,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc2")
tanhLayer("Name","active_2")
fullyConnectedLayer(1,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc3")
];
fcnn_graph = addLayers(fcnn_graph,layers);
dlnet = dlnetwork(fcnn_graph);
[gradients,state,loss] = modelGradients(dlnet,x0,y0);
function [gradients,state,loss] = modelGradients(dlnet,X,Y)
[YPred,state] = forward(dlnet,X);
loss = crossentropy(YPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);
loss = double(gather(extractdata(loss)));
end
0 comentarios
Respuesta aceptada
Philip Brown
el 9 de Sept. de 2021
To ensure tracing is happening, you need to pass your modelGradients function to dlfeval - see this doc page for more details. Replace your line:
[gradients,state,loss] = modelGradients(dlnet,x0,y0);
with
[gradients,state,loss] = dlfeval(@modelGradients, dlnet,x0,y0);
This will ensure that your modelGradients function has tracing between the input weights and the loss.
4 comentarios
Philip Brown
el 9 de Sept. de 2021
You can use the Learnables property of the dlnetwork; this stores them as a table. For example, for the 'fc1' layer weights, you can use:
dlnet.Learnables{1,3}{1}
Más respuestas (0)
Ver también
Productos
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!