Main Content

Entrenar una red neuronal convolucional para regresión

Este ejemplo muestra cómo ajustar un modelo de regresión mediante redes neuronales convolucionales para predecir los ángulos de rotación de dígitos manuscritos.

Las redes neuronales convolucionales (CNN o ConvNets) son herramientas fundamentales en deep learning y resultan especialmente adecuadas para analizar datos de imágenes. Por ejemplo, puede utilizar las CNN para clasificar imágenes. Para predecir datos continuos, como ángulos y distancias, puede incluir una capa de regresión al final de la red.

El ejemplo crea una arquitectura de red neuronal convolucional, entrena una red y utiliza la red entrenada para predecir ángulos de dígitos manuscritos rotados. Estas predicciones resultan útiles para el reconocimiento óptico de caracteres.

Opcionalmente, puede utilizar imrotate (Image Processing Toolbox™) para rotar las imágenes y boxplot (Statistics and Machine Learning Toolbox™) para crear una gráfica de cajas.

Cargar datos

El conjunto de datos contiene imágenes sintéticas de dígitos manuscritos junto con los ángulos correspondientes (en grados) que se rota cada imagen.

Cargue las imágenes de entrenamiento y validación como arreglos 4D mediante digitTrain4DArrayData y digitTest4DArrayData. Las salidas YTrain e YValidation son los ángulos de rotación en grados. Cada conjunto de datos de entrenamiento y validación contiene 5000 imágenes.

[XTrain,~,YTrain] = digitTrain4DArrayData;
[XValidation,~,YValidation] = digitTest4DArrayData;

Muestre 20 imágenes aleatorias de entrenamiento mediante imshow.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

Comprobar la normalización de datos

Cuando se entrenan redes neuronales, a menudo es útil asegurarse de que sus datos están normalizados en todas las etapas de la red. La normalización ayuda a estabilizar y acelerar el entrenamiento de la red mediante un gradiente descendente. Si sus datos se han escalado mal, la pérdida puede convertirse en NaN y los parámetros de red pueden divergir durante el entrenamiento. Entre las formas habituales de normalizar datos se incluye reescalar los datos de manera que su intervalo sea [0,1] o que tenga una media de cero y una desviación estándar de uno. Puede normalizar los siguientes datos:

  • Datos de entrada. Normalice los predictores antes de introducirlos en la red. En este ejemplo, las imágenes de entrada ya están normalizadas al intervalo [0,1].

  • Salidas de capas. Puede normalizar las salidas de cada capa convolucional y cada capa totalmente conectada mediante una capa de normalización de lotes.

  • Respuestas. Si utiliza capas de normalización de lotes para normalizar las salidas de capas al final de la red, las predicciones de la red se normalizan cuando comienza el entrenamiento. Si la respuesta tiene una escala muy diferente a esas predicciones, puede que el entrenamiento de la red no pueda converger. Si su respuesta se ha escalado mal, pruebe a normalizarla y vea si el entrenamiento de la red mejora. Si normaliza la respuesta antes del entrenamiento, debe transformar las predicciones de la red entrenada para obtener las predicciones de la respuesta original.

Represente la distribución de la respuesta. La respuesta (el ángulo de rotación en grados) se distribuye aproximadamente de forma uniforme entre -45 y 45, que es correcto y no requiere normalización. En los problemas de clasificación, las salidas son probabilidades de clases, que siempre están normalizadas.

figure
histogram(YTrain)
axis tight
ylabel('Counts')
xlabel('Rotation Angle')

Figure contains an axes object. The axes object contains an object of type histogram.

En general, los datos no tienen que estar normalizados exactamente. Sin embargo, si entrena la red de este ejemplo para predecir 100*YTrain o YTrain+500 en lugar de YTrain, la pérdida es NaN y los parámetros de la red divergen cuando el entrenamiento comienza. Este resultado se produce incluso aunque la única diferencia entre una red que predice aY + b y una red que predice Y es un reescalado sencillo de los pesos y los sesgos de la capa totalmente conectada final.

Si la distribución de la entrada o respuesta es muy irregular o está muy desviada, también puede realizar transformaciones no lineales (por ejemplo, tomando logaritmos) a los datos antes de entrenar la red.

Crear capas de red

Para resolver el problema de regresión, cree las capas de la red e incluya una capa de regresión al final de la red.

La primera capa define el tamaño y el tipo de los datos de entrada. Las imágenes de entrada son de 28 por 28 por 1. Cree una capa de entrada de imagen del mismo tamaño que las imágenes de entrenamiento.

Las capas intermedias de la red definen la arquitectura principal de la red, donde tiene lugar la mayoría de los cálculos y el aprendizaje.

Las capas finales definen el tamaño y el tipo de los datos de salida. Para problemas de regresión, una capa totalmente conectada debe preceder a la capa de regresión al final de la red. Cree una capa de salida totalmente conectada de tamaño 1 y una capa de regresión.

Combine todas las capas juntas en un arreglo Layer.

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    dropoutLayer(0.2)
    fullyConnectedLayer(1)
    regressionLayer];

Entrenar la red

Cree las opciones de entrenamiento de la red. Entrene durante 30 épocas. Establezca la tasa de aprendizaje inicial en 0.001 y disminuya la tasa de aprendizaje tras 20 épocas. Monitorice la precisión de la red durante el entrenamiento especificando datos de validación y la frecuencia de validación. El software entrena la red según los datos de entrenamiento y calcula la precisión de los datos de validación en intervalos regulares durante el entrenamiento. Los datos de validación no se utilizan para actualizar los pesos de la red. Active la gráfica de progreso del entrenamiento y desactive la salida de la ventana de comandos.

miniBatchSize  = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',30, ...
    'InitialLearnRate',1e-3, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',20, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'Verbose',false);

Cree la red mediante trainNetwork. Este comando utiliza una GPU compatible si está disponible. Utilizar una GPU requiere Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, trainNetwork utiliza la CPU.

net = trainNetwork(XTrain,YTrain,layers,options);

{"String":"Figure Training Progress (24-Jul-2022 21:28:52) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 10 objects of type patch, text, line. Axes object 2 contains 10 objects of type patch, text, line.","Tex":[],"LaTex":[]}

Examine los detalles de la arquitectura de red contenida en la propiedad Layers de net.

net.Layers
ans = 
  18x1 Layer array with layers:

     1   'imageinput'         Image Input           28x28x1 images with 'zerocenter' normalization
     2   'conv_1'             2-D Convolution       8 3x3x1 convolutions with stride [1  1] and padding 'same'
     3   'batchnorm_1'        Batch Normalization   Batch normalization with 8 channels
     4   'relu_1'             ReLU                  ReLU
     5   'avgpool2d_1'        2-D Average Pooling   2x2 average pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv_2'             2-D Convolution       16 3x3x8 convolutions with stride [1  1] and padding 'same'
     7   'batchnorm_2'        Batch Normalization   Batch normalization with 16 channels
     8   'relu_2'             ReLU                  ReLU
     9   'avgpool2d_2'        2-D Average Pooling   2x2 average pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv_3'             2-D Convolution       32 3x3x16 convolutions with stride [1  1] and padding 'same'
    11   'batchnorm_3'        Batch Normalization   Batch normalization with 32 channels
    12   'relu_3'             ReLU                  ReLU
    13   'conv_4'             2-D Convolution       32 3x3x32 convolutions with stride [1  1] and padding 'same'
    14   'batchnorm_4'        Batch Normalization   Batch normalization with 32 channels
    15   'relu_4'             ReLU                  ReLU
    16   'dropout'            Dropout               20% dropout
    17   'fc'                 Fully Connected       1 fully connected layer
    18   'regressionoutput'   Regression Output     mean-squared-error with response 'Response'

Probar la red

Pruebe el rendimiento de la red evaluando la precisión de los datos de validación.

Utilice predict para predecir los ángulos de rotación de las imágenes de validación.

YPredicted = predict(net,XValidation);

Evaluar el rendimiento

Evalúe el rendimiento del modelo calculando:

  1. El porcentaje de predicciones dentro de un margen de error aceptable

  2. El error cuadrático medio raíz (RMSE) del ángulo de rotación predicho y del real

Calcule el error de predicción entre el ángulo de rotación predicho y el real.

predictionError = YValidation - YPredicted;

Calcule el número de predicciones dentro de un margen de error aceptable a partir de los ángulos reales. Establezca el umbral en 10 grados. Calcule el porcentaje de predicciones dentro de este umbral.

thr = 10;
numCorrect = sum(abs(predictionError) < thr);
numValidationImages = numel(YValidation);

accuracy = numCorrect/numValidationImages
accuracy = 0.9644

Utilice el error cuadrático medio raíz (RMSE) para medir las diferencias entre el ángulo de rotación predicho y el real.

squares = predictionError.^2;
rmse = sqrt(mean(squares))
rmse = single
    4.6253

Visualizar predicciones

Visualice las predicciones en una gráfica de dispersión. Represente los valores predichos frente a los valores reales.

figure
scatter(YPredicted,YValidation,'+')
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],'r--')

Figure contains an axes object. The axes object contains 2 objects of type scatter, line.

Corregir rotaciones de dígitos

También puede utilizar funciones de Image Processing Toolbox para nivelar los dígitos y mostrarlos juntos. Rote 49 dígitos de muestra de acuerdo con sus ángulos de rotación predichos mediante imrotate (Image Processing Toolbox).

idx = randperm(numValidationImages,49);
for i = 1:numel(idx)
    image = XValidation(:,:,:,idx(i));
    predictedAngle = YPredicted(idx(i));  
    imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop');
end

Muestre los dígitos originales con sus rotaciones corregidas. Puede utilizar montage (Image Processing Toolbox) para mostrar los dígitos juntos en una única imagen.

figure
subplot(1,2,1)
montage(XValidation(:,:,:,idx))
title('Original')

subplot(1,2,2)
montage(imagesRotated)
title('Corrected')

Figure contains 2 axes objects. Axes object 1 with title Original contains an object of type image. Axes object 2 with title Corrected contains an object of type image.

Consulte también

|

Temas relacionados