Problem in automatic gradient calculation using Deep Learning Toolbox
    9 visualizaciones (últimos 30 días)
  
       Mostrar comentarios más antiguos
    
    GAURAV YADAV
 el 8 de Sept. de 2021
  
    
    
    
    
    Comentada: GAURAV YADAV
 el 9 de Sept. de 2021
            Hello,
Could anyone please tell me why the dlgradient function is throwing the error? It seems to me that somehow the loss is not getting traced to the input weights? Could anyone suggest why it is happening? 
x0 = dlarray([1,2,3],'BC');
y0 = dlarray(14,'BC');
fcnn_graph = layerGraph;
layers = [featureInputLayer(3,"Normalization","none","Name","InputLayer")
          fullyConnectedLayer(5,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc1")
          tanhLayer("Name","active_1")
          fullyConnectedLayer(5,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc2")
          tanhLayer("Name","active_2")
          fullyConnectedLayer(1,"WeightsInitializer","glorot","BiasInitializer","ones","Name","fc3")
          ];
fcnn_graph = addLayers(fcnn_graph,layers);
dlnet = dlnetwork(fcnn_graph);
[gradients,state,loss] = modelGradients(dlnet,x0,y0);
function [gradients,state,loss] = modelGradients(dlnet,X,Y)
    [YPred,state] = forward(dlnet,X);
    loss = crossentropy(YPred,Y);
    gradients = dlgradient(loss,dlnet.Learnables);
    loss = double(gather(extractdata(loss)));
end
0 comentarios
Respuesta aceptada
  Philip Brown
    
 el 9 de Sept. de 2021
        To ensure tracing is happening, you need to pass your modelGradients function to dlfeval - see this doc page for more details. Replace your line:
[gradients,state,loss] = modelGradients(dlnet,x0,y0);
with 
[gradients,state,loss] = dlfeval(@modelGradients, dlnet,x0,y0);
This will ensure that your modelGradients function has tracing between the input weights and the loss.
4 comentarios
  Philip Brown
    
 el 9 de Sept. de 2021
				You can use the Learnables property of the dlnetwork; this stores them as a table. For example, for the 'fc1' layer weights, you can use:
dlnet.Learnables{1,3}{1}
Más respuestas (0)
Ver también
Categorías
				Más información sobre Deep Learning Toolbox en Help Center y File Exchange.
			
	Productos
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

