How to use Matlab trainnet to train a network without an explicit output layer (R2024a)
6 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
I've attempted to train a CNN with the goal of assigning N numeric values to different input images, depending on image characteristics. It looked like the network's output layer could be a fully-connected layer with N outputs (because I have not found a linear output layer in Deep Network Designer). I am not sure if I can use a non-linear output layer instead, because this is fundamentally a regression task.
However, when using a fully-connected layer in place of an output layer the trainnet gives repeating errors indicating that I must have an output layer.
So basically, I have two questions:
1) Is it possible to use trainnet in a network without an output layer? It is difficult to imagine that a built-in training function has an oversight like this. Do I really need to construct a custom training loop if my network?..
2) Are there any alternatives? In essence, all I am looking for is an output layer that is either a) linear or b) does not change the previous layer's output. Just anything that is compatible with a regression task.
If any clarification is needed on my issue or network construction, I would be happy to provide it.
Thank you so much for your help!
Deep Learning Toolbox Version 24.1 (R2024a) , trainnet function, Matlab 2024.
2 comentarios
Matt J
el 9 de Ag. de 2024
I can't reproduce that. Here is an example of a simple network training where the final layer is a fully connect layer. No error messages:
ds=combine( arrayDatastore(rand(3),IterationDim=3) , ...
arrayDatastore(rand(1),IterationDim=3) );
layers=[imageInputLayer([3,3,1]),fullyConnectedLayer(1)];
trainnet(ds,layers,'mse', trainingOptions('adam',TargetDataFormats="CB"))
Respuestas (2)
Aditya
el 8 de Ag. de 2024
Editada: Aditya
el 8 de Ag. de 2024
To Address your query:
1) From my knowledge we cannot use "trainnet" function without an explicit output layer. The trainnet function in MATLAB expects a complete network architecture, including an output layer, to properly define the loss function and perform backpropagation during training.
2) Use a fully connected layer with N outputs and set the loss function to "mse" since you are doing regression tasks. I am not sure why you are getting the mentioned error while doing this step. It might be helpful if you could provide the code that you are using (layers architecture & trainingOptions)
Also you could refer to this MATLAB documentation on "Train Convolutional Neural Network for Regression":
Hope this helps!
3 comentarios
Aditya
el 9 de Ag. de 2024
Editada: Aditya
el 9 de Ag. de 2024
Yes, so when we call trainnet with the "mse" loss function, MATLAB automatically understands that the network is intended for regression tasks. The "mse" loss function (mean squared error) is applied to the output of the fully connected layer during training.
You can also look into this MATLAB documentation on regressionLayer: https://in.mathworks.com/help/deeplearning/ref/regressionlayer.html
Here they have mentioned to use "trainnet" with "mse" instead of using regressionLayer.
Hope this clarifies the doubt!
Matt J
el 9 de Ag. de 2024
Editada: Matt J
el 9 de Ag. de 2024
1) Is it possible to use trainnet in a network without an output layer? It is difficult to imagine that a built-in training function has an oversight like this.
trainnet is always to be used without an output layer.. The loss function is specified using the lossFcn input argument,
2) Are there any alternatives? In essence, all I am looking for is an output layer that is either a) linear or b) does not change the previous layer's output. Just anything that is compatible with a regression task.
Function handle with the syntax loss = f(Y1,...,Yn,T1,...,Tm), where Y1,...,Yn are dlarray objects that correspond to the n network predictions and T1,...,Tm are dlarray objects that correspond to the m targets.
0 comentarios
Ver también
Categorías
Más información sobre Image Data Workflows en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!