Entrenar una red neuronal convolucional para regresión
Este ejemplo muestra cómo ajustar un modelo de regresión mediante redes neuronales convolucionales para predecir los ángulos de rotación de dígitos manuscritos.
Las redes neuronales convolucionales (CNN o ConvNets) son herramientas fundamentales en deep learning y resultan especialmente adecuadas para analizar datos de imágenes. Por ejemplo, puede utilizar las CNN para clasificar imágenes. Para predecir datos continuos, como ángulos y distancias, puede incluir una capa de regresión al final de la red.
El ejemplo crea una arquitectura de red neuronal convolucional, entrena una red y utiliza la red entrenada para predecir ángulos de dígitos manuscritos rotados. Estas predicciones resultan útiles para el reconocimiento óptico de caracteres.
Opcionalmente, puede utilizar imrotate
(Image Processing Toolbox™) para rotar las imágenes y boxplot
(Statistics and Machine Learning Toolbox™) para crear una gráfica de cajas.
Cargar datos
El conjunto de datos contiene imágenes sintéticas de dígitos manuscritos junto con los ángulos correspondientes (en grados) que se rota cada imagen.
Cargue las imágenes de entrenamiento y validación como arreglos 4D mediante digitTrain4DArrayData
y digitTest4DArrayData
. Las salidas YTrain
e YValidation
son los ángulos de rotación en grados. Cada conjunto de datos de entrenamiento y validación contiene 5000 imágenes.
[XTrain,~,YTrain] = digitTrain4DArrayData; [XValidation,~,YValidation] = digitTest4DArrayData;
Muestre 20 imágenes aleatorias de entrenamiento mediante imshow
.
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) end
Comprobar la normalización de datos
Cuando se entrenan redes neuronales, a menudo es útil asegurarse de que sus datos están normalizados en todas las etapas de la red. La normalización ayuda a estabilizar y acelerar el entrenamiento de la red mediante un gradiente descendente. Si sus datos se han escalado mal, la pérdida puede convertirse en NaN
y los parámetros de red pueden divergir durante el entrenamiento. Entre las formas habituales de normalizar datos se incluye reescalar los datos de manera que su intervalo sea [0,1] o que tenga una media de cero y una desviación estándar de uno. Puede normalizar los siguientes datos:
Datos de entrada. Normalice los predictores antes de introducirlos en la red. En este ejemplo, las imágenes de entrada ya están normalizadas al intervalo [0,1].
Salidas de capas. Puede normalizar las salidas de cada capa convolucional y cada capa totalmente conectada mediante una capa de normalización de lotes.
Respuestas. Si utiliza capas de normalización de lotes para normalizar las salidas de capas al final de la red, las predicciones de la red se normalizan cuando comienza el entrenamiento. Si la respuesta tiene una escala muy diferente a esas predicciones, puede que el entrenamiento de la red no pueda converger. Si su respuesta se ha escalado mal, pruebe a normalizarla y vea si el entrenamiento de la red mejora. Si normaliza la respuesta antes del entrenamiento, debe transformar las predicciones de la red entrenada para obtener las predicciones de la respuesta original.
Represente la distribución de la respuesta. La respuesta (el ángulo de rotación en grados) se distribuye aproximadamente de forma uniforme entre -45 y 45, que es correcto y no requiere normalización. En los problemas de clasificación, las salidas son probabilidades de clases, que siempre están normalizadas.
figure histogram(YTrain) axis tight ylabel('Counts') xlabel('Rotation Angle')
En general, los datos no tienen que estar normalizados exactamente. Sin embargo, si entrena la red de este ejemplo para predecir 100*YTrain
o YTrain+500
en lugar de YTrain
, la pérdida es NaN
y los parámetros de la red divergen cuando el entrenamiento comienza. Este resultado se produce incluso aunque la única diferencia entre una red que predice aY + b y una red que predice Y es un reescalado sencillo de los pesos y los sesgos de la capa totalmente conectada final.
Si la distribución de la entrada o respuesta es muy irregular o está muy desviada, también puede realizar transformaciones no lineales (por ejemplo, tomando logaritmos) a los datos antes de entrenar la red.
Crear capas de red
Para resolver el problema de regresión, cree las capas de la red e incluya una capa de regresión al final de la red.
La primera capa define el tamaño y el tipo de los datos de entrada. Las imágenes de entrada son de 28 por 28 por 1. Cree una capa de entrada de imagen del mismo tamaño que las imágenes de entrenamiento.
Las capas intermedias de la red definen la arquitectura principal de la red, donde tiene lugar la mayoría de los cálculos y el aprendizaje.
Las capas finales definen el tamaño y el tipo de los datos de salida. Para problemas de regresión, una capa totalmente conectada debe preceder a la capa de regresión al final de la red. Cree una capa de salida totalmente conectada de tamaño 1 y una capa de regresión.
Combine todas las capas juntas en un arreglo Layer
.
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer dropoutLayer(0.2) fullyConnectedLayer(1) regressionLayer];
Entrenar la red
Cree las opciones de entrenamiento de la red. Entrene durante 30 épocas. Establezca la tasa de aprendizaje inicial en 0.001 y disminuya la tasa de aprendizaje tras 20 épocas. Monitorice la precisión de la red durante el entrenamiento especificando datos de validación y la frecuencia de validación. El software entrena la red según los datos de entrenamiento y calcula la precisión de los datos de validación en intervalos regulares durante el entrenamiento. Los datos de validación no se utilizan para actualizar los pesos de la red. Active la gráfica de progreso del entrenamiento y desactive la salida de la ventana de comandos.
miniBatchSize = 128; validationFrequency = floor(numel(YTrain)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... 'MaxEpochs',30, ... 'InitialLearnRate',1e-3, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropFactor',0.1, ... 'LearnRateDropPeriod',20, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'ValidationFrequency',validationFrequency, ... 'Plots','training-progress', ... 'Verbose',false);
Cree la red mediante trainNetwork
. Este comando utiliza una GPU compatible si está disponible. Utilizar una GPU requiere Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, trainNetwork
utiliza la CPU.
net = trainNetwork(XTrain,YTrain,layers,options);
Examine los detalles de la arquitectura de red contenida en la propiedad Layers
de net
.
net.Layers
ans = 18x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv_1' 2-D Convolution 8 3x3x1 convolutions with stride [1 1] and padding 'same' 3 'batchnorm_1' Batch Normalization Batch normalization with 8 channels 4 'relu_1' ReLU ReLU 5 'avgpool2d_1' 2-D Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 6 'conv_2' 2-D Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'avgpool2d_2' 2-D Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' 2-D Convolution 32 3x3x16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'conv_4' 2-D Convolution 32 3x3x32 convolutions with stride [1 1] and padding 'same' 14 'batchnorm_4' Batch Normalization Batch normalization with 32 channels 15 'relu_4' ReLU ReLU 16 'dropout' Dropout 20% dropout 17 'fc' Fully Connected 1 fully connected layer 18 'regressionoutput' Regression Output mean-squared-error with response 'Response'
Probar la red
Pruebe el rendimiento de la red evaluando la precisión de los datos de validación.
Utilice predict
para predecir los ángulos de rotación de las imágenes de validación.
YPredicted = predict(net,XValidation);
Evaluar el rendimiento
Evalúe el rendimiento del modelo calculando:
El porcentaje de predicciones dentro de un margen de error aceptable
El error cuadrático medio raíz (RMSE) del ángulo de rotación predicho y del real
Calcule el error de predicción entre el ángulo de rotación predicho y el real.
predictionError = YValidation - YPredicted;
Calcule el número de predicciones dentro de un margen de error aceptable a partir de los ángulos reales. Establezca el umbral en 10 grados. Calcule el porcentaje de predicciones dentro de este umbral.
thr = 10; numCorrect = sum(abs(predictionError) < thr); numValidationImages = numel(YValidation); accuracy = numCorrect/numValidationImages
accuracy = 0.9626
Utilice el error cuadrático medio raíz (RMSE) para medir las diferencias entre el ángulo de rotación predicho y el real.
squares = predictionError.^2; rmse = sqrt(mean(squares))
rmse = single
4.6600
Visualizar predicciones
Visualice las predicciones en una gráfica de dispersión. Represente los valores predichos frente a los valores reales.
figure scatter(YPredicted,YValidation,'+') xlabel("Predicted Value") ylabel("True Value") hold on plot([-60 60], [-60 60],'r--')
Corregir rotaciones de dígitos
También puede utilizar funciones de Image Processing Toolbox para nivelar los dígitos y mostrarlos juntos. Rote 49 dígitos de muestra de acuerdo con sus ángulos de rotación predichos mediante imrotate
(Image Processing Toolbox).
idx = randperm(numValidationImages,49); for i = 1:numel(idx) image = XValidation(:,:,:,idx(i)); predictedAngle = YPredicted(idx(i)); imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop'); end
Muestre los dígitos originales con sus rotaciones corregidas. Puede utilizar montage
(Image Processing Toolbox) para mostrar los dígitos juntos en una única imagen.
figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title('Original') subplot(1,2,2) montage(imagesRotated) title('Corrected')
Consulte también
regressionLayer
| classificationLayer