Contenido principal

Entrenar una red con varias salidas

Este ejemplo muestra cómo entrenar una red de deep learning con varias salidas que predicen tanto etiquetas como ángulos de rotación de dígitos manuscritos.

Cargar los datos de entrenamiento

Cargue los datos de dígitos. Los datos contienen imágenes de dígitos, así como las etiquetas de los dígitos y sus ángulos de rotación respecto a la vertical.

load DigitsDataTrain

Cree un objeto arrayDatastore para las imágenes, las etiquetas y los ángulos y, después, utilice la función combine para crear un único almacén de datos que contenga todos los datos de entrenamiento.

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsT1Train = arrayDatastore(labelsTrain);
dsT2Train = arrayDatastore(anglesTrain);

dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);

classNames = categories(labelsTrain);
numClasses = numel(classNames);
numObservations = numel(labelsTrain);

Visualice algunas imágenes de los datos de entrenamiento.

idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

Definir el modelo de deep learning

Defina la red siguiente, que predice tanto etiquetas como ángulos de rotación.

  • Un bloque convolution-batchnorm-ReLU con 16 filtros de 5 por 5.

  • Dos bloques convolution-batchnorm-ReLU con 32 filtros de 3 por 3 cada uno.

  • Una conexión de omisión alrededor de los dos bloques anteriores, que contenga un bloque convolution-batchnorm-ReLU con 32 filtros de 1 por 1.

  • Combine la conexión de omisión mediante la adición.

  • Para la salida de clasificación, una rama con una operación totalmente conectada con un tamaño de 10 (el número de clases) y una operación softmax.

  • Para la salida de regresión, una rama con una operación totalmente conectada con un tamaño de 1 (el número de respuestas).

Defina el bloque principal de capas.

net = dlnetwork;

layers = [
    imageInputLayer([28 28 1],Normalization="none")

    convolution2dLayer(5,16,Padding="same")
    batchNormalizationLayer
    reluLayer(Name="relu_1")

    convolution2dLayer(3,32,Padding="same",Stride=2)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer

    additionLayer(2,Name="add")

    fullyConnectedLayer(numClasses)
    softmaxLayer(Name="softmax")];

net = addLayers(net,layers);

Añada la conexión de omisión.

layers = [
    convolution2dLayer(1,32,Stride=2,Name="conv_skip")
    batchNormalizationLayer
    reluLayer(Name="relu_skip")];

net = addLayers(net,layers);
net = connectLayers(net,"relu_1","conv_skip");
net = connectLayers(net,"relu_skip","add/in2");

Añada la capa totalmente conectada para la regresión.

layers = fullyConnectedLayer(1,Name="fc_2");
net = addLayers(net,layers);
net = connectLayers(net,"add","fc_2");

Visualice la gráfica de una capa.

figure
plot(net)

Especificar las opciones de entrenamiento

Especifique las opciones de entrenamiento. Para escoger entre las opciones se requiere un análisis empírico. Para explorar diferentes configuraciones de opciones de entrenamiento mediante la ejecución de experimentos, puede utilizar la app Experiment Manager.

options = trainingOptions("adam", ...
    Plots="training-progress", ...
    Verbose=false);

Entrenar una red neuronal

Entrene la red neuronal con la función trainnet. Para la clasificación, utilice una función de pérdida personalizada que sea la pérdida de entropía cruzada de las etiquetas predichas y objetivo más 0,1 veces la pérdida de error cuadrático medio de los ángulos predichos y objetivo. De forma predeterminada, la función trainnet usa una GPU en caso de que esté disponible. Para utilizar una GPU se requiere una licencia de Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, la función usa la CPU. Para especificar el entorno de ejecución, utilice la opción de entrenamiento ExecutionEnvironment.

Defina la función de pérdida personalizada como un identificador de función. Defina una pérdida que corresponda a la pérdida de entropía cruzada de las etiquetas predichas y objetivo más el error cuadrático medio de los ángulos predichos y objetivo, escalado por un factor de 0,1.

lossFcn = @(Y1,Y2,T1,T2) crossentropy(Y1,T1) + 0.1*mse(Y2,T2);

Entrenar la red neuronal.

net = trainnet(dsTrain,net,lossFcn,options);

Probar un modelo

Cargue los datos de dígitos. Los datos contienen imágenes de dígitos, así como las etiquetas de los dígitos y sus ángulos de rotación respecto a la vertical.

load DigitsDataTest

Realice predicciones con la función minibatchpredict y convierta las puntuaciones de clasificación en etiquetas con la función scores2label. De forma predeterminada, la función minibatchpredict usa una GPU en caso de que esté disponible. Para utilizar una GPU se requiere una licencia de Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, la función usa la CPU. Para seleccionar el entorno de ejecución manualmente, utilice el argumento ExecutionEnvironment de la función minibatchpredict.

[scores,Y2] = minibatchpredict(net,XTest);
Y1 = scores2label(scores,classNames);

Calcule la precisión de clasificación de las etiquetas.

accuracy = mean(Y1 == labelsTest)
accuracy = 0.9732

Calcule el error cuadrático medio entre los ángulos predichos y objetivo.

err = rmse(Y2,anglesTest)
err = single
    6.9265

Vea algunas de las imágenes con sus predicciones. Muestre los ángulos predichos en rojo y las etiquetas correctas en verde.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on

    sz = size(I,1);
    offset = sz/2;

    theta = Y2(idx(i));
    plot(offset*[1-tand(theta) 1+tand(theta)],[sz 0],"r--")

    thetaTest = anglesTest(idx(i));
    plot(offset*[1-tand(thetaTest) 1+tand(thetaTest)],[sz 0],"g--")

    hold off
    label = Y1(idx(i));
    title("Label: " + string(label))
end

Consulte también

| | | | | | | | | | |

Temas