Entrenar una red con varias salidas
Este ejemplo muestra cómo entrenar una red de deep learning con varias salidas que predicen tanto etiquetas como ángulos de rotación de dígitos manuscritos.
Cargar los datos de entrenamiento
Cargue los datos de dígitos. Los datos contienen imágenes de dígitos, así como las etiquetas de los dígitos y sus ángulos de rotación respecto a la vertical.
load DigitsDataTrain
Cree un objeto arrayDatastore
para las imágenes, las etiquetas y los ángulos y, después, utilice la función combine
para crear un único almacén de datos que contenga todos los datos de entrenamiento.
dsXTrain = arrayDatastore(XTrain,IterationDimension=4); dsT1Train = arrayDatastore(labelsTrain); dsT2Train = arrayDatastore(anglesTrain); dsTrain = combine(dsXTrain,dsT1Train,dsT2Train); classNames = categories(labelsTrain); numClasses = numel(classNames); numObservations = numel(labelsTrain);
Visualice algunas imágenes de los datos de entrenamiento.
idx = randperm(numObservations,64); I = imtile(XTrain(:,:,:,idx)); figure imshow(I)
Definir el modelo de deep learning
Defina la red siguiente, que predice tanto etiquetas como ángulos de rotación.
Un bloque convolution-batchnorm-ReLU con 16 filtros de 5 por 5.
Dos bloques convolution-batchnorm-ReLU con 32 filtros de 3 por 3 cada uno.
Una conexión de omisión alrededor de los dos bloques anteriores, que contenga un bloque convolution-batchnorm-ReLU con 32 filtros de 1 por 1.
Combine la conexión de omisión mediante la adición.
Para la salida de clasificación, una rama con una operación totalmente conectada con un tamaño de 10 (el número de clases) y una operación softmax.
Para la salida de regresión, una rama con una operación totalmente conectada con un tamaño de 1 (el número de respuestas).
Defina el bloque principal de capas.
net = dlnetwork; layers = [ imageInputLayer([28 28 1],Normalization="none") convolution2dLayer(5,16,Padding="same") batchNormalizationLayer reluLayer(Name="relu_1") convolution2dLayer(3,32,Padding="same",Stride=2) batchNormalizationLayer reluLayer convolution2dLayer(3,32,Padding="same") batchNormalizationLayer reluLayer additionLayer(2,Name="add") fullyConnectedLayer(numClasses) softmaxLayer(Name="softmax")]; net = addLayers(net,layers);
Añada la conexión de omisión.
layers = [ convolution2dLayer(1,32,Stride=2,Name="conv_skip") batchNormalizationLayer reluLayer(Name="relu_skip")]; net = addLayers(net,layers); net = connectLayers(net,"relu_1","conv_skip"); net = connectLayers(net,"relu_skip","add/in2");
Añada la capa totalmente conectada para la regresión.
layers = fullyConnectedLayer(1,Name="fc_2"); net = addLayers(net,layers); net = connectLayers(net,"add","fc_2");
Visualice la gráfica de una capa.
figure plot(net)
Especificar las opciones de entrenamiento
Especifique las opciones de entrenamiento. Para escoger entre las opciones se requiere un análisis empírico. Para explorar diferentes configuraciones de opciones de entrenamiento mediante la ejecución de experimentos, puede utilizar la app Experiment Manager.
options = trainingOptions("adam", ... Plots="training-progress", ... Verbose=false);
Entrenar una red neuronal
Entrene la red neuronal con la función trainnet
. Para la clasificación, utilice una función de pérdida personalizada que sea la pérdida de entropía cruzada de las etiquetas predichas y objetivo más 0,1 veces la pérdida de error cuadrático medio de los ángulos predichos y objetivo. De forma predeterminada, la función trainnet
usa una GPU en caso de que esté disponible. Para utilizar una GPU se requiere una licencia de Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, la función usa la CPU. Para especificar el entorno de ejecución, utilice la opción de entrenamiento ExecutionEnvironment
.
Defina la función de pérdida personalizada como un identificador de función. Defina una pérdida que corresponda a la pérdida de entropía cruzada de las etiquetas predichas y objetivo más el error cuadrático medio de los ángulos predichos y objetivo, escalado por un factor de 0,1.
lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + 0.1*mse(Y2,T2);
Entrenar la red neuronal.
net = trainnet(dsTrain,net,lossFcn,options);
Probar un modelo
Cargue los datos de dígitos. Los datos contienen imágenes de dígitos, así como las etiquetas de los dígitos y sus ángulos de rotación respecto a la vertical.
load DigitsDataTest
Realice predicciones con la función minibatchpredict
y convierta las puntuaciones de clasificación en etiquetas con la función scores2label
. De forma predeterminada, la función minibatchpredict
usa una GPU en caso de que esté disponible. Para utilizar una GPU se requiere una licencia de Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, la función usa la CPU. Para seleccionar el entorno de ejecución manualmente, utilice el argumento ExecutionEnvironment
de la función minibatchpredict
.
[scores,Y2] = minibatchpredict(net,XTest); Y1 = scores2label(scores,classNames);
Calcule la precisión de clasificación de las etiquetas.
accuracy = mean(Y1 == labelsTest)
accuracy = 0.9732
Calcule el error cuadrático medio entre los ángulos predichos y objetivo.
err = rmse(Y2,anglesTest)
err = single
6.9265
Vea algunas de las imágenes con sus predicciones. Muestre los ángulos predichos en rojo y las etiquetas correctas en verde.
idx = randperm(size(XTest,4),9); figure for i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) hold on sz = size(I,1); offset = sz/2; theta = Y2(idx(i)); plot(offset*[1-tand(theta) 1+tand(theta)],[sz 0],"r--") thetaTest = anglesTest(idx(i)); plot(offset*[1-tand(thetaTest) 1+tand(thetaTest)],[sz 0],"g--") hold off label = Y1(idx(i)); title("Label: " + string(label)) end
Consulte también
dlarray
| dlgradient
| dlfeval
| sgdmupdate
| batchNormalizationLayer
| convolution2dLayer
| reluLayer
| fullyConnectedLayer
| softmaxLayer
| minibatchqueue
| onehotencode
| onehotdecode