Programmatically determine which Deep Learning layer properties contain learnables
15 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Matt J
el 6 de Dic. de 2023
Respondida: Vishnu Keyen
el 8 de Nov. de 2024 a las 19:26
In the Deep Learning Toolbox, there are a variety of layer object types. The display methods of these objects indicate which object properties contain learnable parameter data. For example, for LSTMLayer objects, the learnable parameters are stored in the properties "InputWeights", "RecurrentWeights", and "Bias" as shown below. My question is, is there a way, given a layer object, to programmatically determine the subset of its properties that are learnable?
layer = lstmLayer(100,'Name','lstm1')
0 comentarios
Respuesta aceptada
Pratyush Swain
el 14 de Dic. de 2023
Hi Matt,
I understand you want to access the learnable properties of a deep learning layer object. There is no direct way fetch the trainable properties of a layer but a workaround can be to initialize a deep learning layer from the "lstm" layer object and then fetch its learnable parameters.
Please refer to the below example implementation:
% Define a lstm layer %
layer = lstmLayer(100,'Name','lstm1')
% Define a deep learning network %
net = dlnetwork(layer,Initialize=false);
% Retreive the learnable properties %
properties = net.Learnables;
% Display the learnable properties %
disp(properties)
As we can observe, the "dlnetwork" forms a single layer network named as "lstm1" and we have successfully retreived its learnable properties as a table.
For more information , please refer to:
Hope this helps.
2 comentarios
Más respuestas (1)
Vishnu Keyen
el 8 de Nov. de 2024 a las 19:26
Let's define the a network
layers = [sequenceInputLayer(32, 'Name', 'input')
lstmLayer(128, 'OutputMode', 'sequence', 'Name', 'lstm')]
since each layer is a nnet.cnn.layer type object, one can determine
you can determine the object type using class() command.
For example,
layerType = class(layers(2))
And then you can search for a match in that class for the type you are looking for
For example
contains(class(layers(2)),'lstm','IgnoreCase',true)
you can check for specific parameters like RecurrentWeights once you identify its an LSTM, or a GRU layer..
0 comentarios
Ver también
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!