Borrar filtros
Borrar filtros

Problem in automatic gradient calculation using Deep Learning Toolbox

1 visualización (últimos 30 días)
Hello,
Could anyone please tell me why the dlgradient function is throwing the error? It seems to me that somehow the loss is not getting traced to the input weights? Could anyone suggest why it is happening?
x0 = dlarray([1,2,3],'BC');
y0 = dlarray(14,'BC');
fcnn_graph = layerGraph;
layers = [featureInputLayer(3,"Normalization","none","Name","InputLayer")
fullyConnectedLayer(5,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc1")
tanhLayer("Name","active_1")
fullyConnectedLayer(5,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc2")
tanhLayer("Name","active_2")
fullyConnectedLayer(1,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc3")
];
fcnn_graph = addLayers(fcnn_graph,layers);
dlnet = dlnetwork(fcnn_graph);
[gradients,state,loss] = modelGradients(dlnet,x0,y0);
function [gradients,state,loss] = modelGradients(dlnet,X,Y)
[YPred,state] = forward(dlnet,X);
loss = crossentropy(YPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);
loss = double(gather(extractdata(loss)));
end

Respuesta aceptada

Philip Brown
Philip Brown el 9 de Sept. de 2021
To ensure tracing is happening, you need to pass your modelGradients function to dlfeval - see this doc page for more details. Replace your line:
[gradients,state,loss] = modelGradients(dlnet,x0,y0);
with
[gradients,state,loss] = dlfeval(@modelGradients, dlnet,x0,y0);
This will ensure that your modelGradients function has tracing between the input weights and the loss.
  4 comentarios
Philip Brown
Philip Brown el 9 de Sept. de 2021
You can use the Learnables property of the dlnetwork; this stores them as a table. For example, for the 'fc1' layer weights, you can use:
dlnet.Learnables{1,3}{1}

Iniciar sesión para comentar.

Más respuestas (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by