Entrenar una red residual para clasificar imágenes
En este ejemplo se muestra cómo crear una red neuronal de deep learning con conexiones residuales y entrenarla con datos CIFAR-10. Las conexiones residuales son un elemento popular en las arquitecturas de red neuronal convolucional. El uso de conexiones residuales mejora el flujo de gradiente a través de la red y permite el entrenamiento de redes más profundas.
Para muchas aplicaciones, basta con usar una red que conste de una secuencia simple de capas. Sin embargo, algunas aplicaciones requieren redes con una estructura de gráfica más compleja en la que las capas pueden tener entradas de diferentes capas y salidas a varias capas. Estos tipos de redes se suelen llamar redes gráficas acíclicas dirigidas (DAG). Una red residual (ResNet) es un tipo de red DAG que tiene conexiones residuales (o de atajo) que evitan las capas principales de la red. En MATLAB, las redes DAG están representadas por objetos dlnetwork
. Las conexiones residuales permiten que los gradientes de parámetros se propaguen con mayor facilidad desde la capa de salida a las primeras capas de la red, lo que posibilita entrenar redes más profundas. Esta mayor profundidad de la red puede dar lugar a mayores precisiones en tareas más difíciles.
Una arquitectura ResNet se compone de capas iniciales, seguidas de pilas que contienen bloques residuales y, por último, las capas finales. Hay tres tipos de bloques residuales:
Bloque residual inicial: este bloque aparece al comienzo de la primera pila. Este ejemplo usa componentes de cuello de botella; por lo tanto, este bloque contiene las mismas capas que el bloque de submuestreo, solo que con un tramo de
[1,1]
en la primera capa convolucional. Para obtener más información, consulteresnetNetwork
.Bloque residual estándar: este bloque aparece en cada pila, después del primer bloque residual de submuestreo. Este bloque aparece varias veces en cada pila y conserva los tamaños de activación.
Bloque residual de submuestreo: este bloque aparece al inicio de cada pila (excepto la primera) y solo aparece una vez en cada pila. La primera unidad convolucional del bloque de submuestreo reduce las dimensiones espaciales en un factor de dos.
La profundidad de cada pila puede variar; este ejemplo entrena una red residual con tres pilas de profundidad decreciente. La primera pila tiene una profundidad de cuatro, la segunda de tres y la última de dos.
Cada bloque residual contiene capas de deep learning. Para obtener más información sobre las capas de cada bloque, consulte resnetNetwork
.
Para crear y entrenar una red residual que sea adecuada para la clasificación de imágenes, siga estos pasos:
Cree una red residual con la función
resnetNetwork
.Entrene la red con la función
trainnet
. La red entrenada es un objetodlnetwork
.Realice la clasificación y predicción con datos nuevos.
También puede cargar redes residuales preentrenadas para clasificar imágenes. Para obtener más información, consulte Redes neuronales profundas preentrenadas.
Preparar los datos
Descargue el conjunto de datos CIFAR-10 [1]. El conjunto de datos contiene 60.000 imágenes. Cada imagen tiene un tamaño de 32 por 32 píxeles y tres canales de color (RGB). El tamaño del conjunto de datos es de 175 MB. Según la conexión a Internet, el proceso de descarga puede tardar un tiempo.
datadir = tempdir; downloadCIFARData(datadir);
Downloading CIFAR-10 dataset (175 MB). This can take a while...done.
Cargue las imágenes de entrenamiento y prueba CIFAR-10 como arreglos 4D. El conjunto de entrenamiento contiene 50.000 imágenes y el de prueba contiene 10.000. Use las imágenes de prueba CIFAR-10 para la validación de la red.
[XTrain,TTrain,XValidation,TValidation] = loadCIFARData(datadir);
Puede visualizar una muestra aleatoria de las imágenes de entrenamiento usando el código siguiente.
figure; idx = randperm(size(XTrain,4),20); im = imtile(XTrain(:,:,:,idx),ThumbnailSize=[96,96]); imshow(im)
Cree un objeto augmentedImageDatastore
para usarlo durante el entrenamiento de la red. Durante el entrenamiento, el almacén de datos voltea aleatoriamente las imágenes de entrenamiento a lo largo del eje vertical y, de manera aleatoria, las traslada hasta cuatro píxeles en horizontal y vertical. El aumento de datos ayuda a evitar que la red se sobreajuste y memorice los detalles exactos de las imágenes de entrenamiento.
imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... RandXReflection=true, ... RandXTranslation=pixelRange, ... RandYTranslation=pixelRange); augimdsTrain = augmentedImageDatastore(imageSize,XTrain,TTrain, ... DataAugmentation=imageAugmenter, ... OutputSizeMode="randcrop");
Definir la arquitectura de red
Use la función resnetNetwork
para crear una red residual adecuada para este conjunto de datos.
Las imágenes CIFAR-10 son de 32 por 32 píxeles; por lo tanto, use un tamaño de filtro inicial pequeño de 3 y un tramo inicial de 1. Establezca el número de filtros iniciales en 16.
La primera pila de la red comienza con un bloque residual inicial. Las siguientes pilas comienzan con un bloque residual de submuestreo. Las primeras unidades convolucionales de los bloques de submuestreo reducen las dimensiones espaciales en un factor de dos. Para que la cantidad de computación necesaria en cada capa convolucional sea aproximadamente la misma en toda la red, multiplique por dos el número de filtros cada vez que realice un submuestreo espacial. Establezca la profundidad de la pila en
[4 3 2]
y el número de filtros en[16 32 64]
.
initialFilterSize = 3; numInitialFilters = 16; initialStride = 1; numFilters = [16 32 64]; stackDepth = [4 3 2];
Cree una red residual 2D.
net = resnetNetwork(imageSize,10, ... InitialFilterSize=initialFilterSize, ... InitialNumFilters=numInitialFilters, ... InitialStride=initialStride, ... InitialPoolingLayer="none", ... StackDepth=[4 3 2], ... NumFilters=[16 32 64]);
Visualice la red.
plot(net);
Opciones de entrenamiento
Especificar las opciones de entrenamiento. Entrene la red durante 80 épocas. Seleccione una tasa de aprendizaje que sea proporcional al tamaño del minilote y redúzcala en un factor de 10 después de 60 épocas. Valide la red una vez por época usando los datos de validación.
miniBatchSize = 128; learnRate = 0.1*miniBatchSize/128; valFrequency = floor(size(XTrain,4)/miniBatchSize); options = trainingOptions("sgdm", ... InitialLearnRate=learnRate, ... MaxEpochs=80, ... MiniBatchSize=miniBatchSize, ... VerboseFrequency=valFrequency, ... Shuffle="every-epoch", ... Plots="training-progress", ... Verbose=false, ... ValidationData={XValidation,TValidation}, ... ValidationFrequency=valFrequency, ... LearnRateSchedule="piecewise", ... LearnRateDropFactor=0.1, ... LearnRateDropPeriod=60);
Entrenar la red
Para entrenar la red con trainnet
, establezca el indicador doTraining
en true
. Para la clasificación, utilice la pérdida de entropía cruzada. De forma predeterminada, la función trainnet
usa una GPU en caso de que esté disponible. Para entrenar en una GPU se requiere una licencia de Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, la función trainnet
usa la CPU. Para especificar el entorno de ejecución, utilice la opción de entrenamiento ExecutionEnvironment
.
En caso contrario, cargue una red preentrenada.
doTraining = false; if doTraining net = trainnet(augimdsTrain,net,'crossentropy',options); else load("trainedResidualNetwork.mat","net"); end
Evaluar la red entrenada
Calcule la precisión final de la red en el conjunto de entrenamiento (sin aumento de datos) y el conjunto de validación. Para hacer predicciones con varias observaciones, utilice la función minibatchpredict
. Para convertir las puntuaciones de predicción en etiquetas, utilice la función scores2label
. La función minibatchpredict
usa automáticamente una GPU en caso de que esté disponible. Para utilizar una GPU se requiere una licencia de Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, la función usa la CPU.
scores = minibatchpredict(net,XValidation); [YValPred,probs] = scores2label(scores,categories(TValidation)); validationError = mean(YValPred ~= TValidation); scores = minibatchpredict(net,XTrain); YTrainPred = scores2label(scores,categories(TTrain)); trainError = mean(YTrainPred ~= TTrain); disp("Training error: " + trainError*100 + "%")
Training error: 4.168%
disp("Validation error: " + validationError*100 + "%")
Validation error: 9.13%
Represente la matriz de confusión. Muestre la precisión y la recuperación de cada clase mediante el uso de resúmenes de columnas y filas. La red confunde de manera habitual gatos con perros.
figure(Units="normalized",Position=[0.2 0.2 0.4 0.4]); cm = confusionchart(TValidation,YValPred); cm.Title = "Confusion Matrix for Validation Data"; cm.ColumnSummary = "column-normalized"; cm.RowSummary = "row-normalized";
Puede visualizar una muestra aleatoria de nueve imágenes de prueba junto con sus clases predichas y las probabilidades de esas clases usando el código siguiente.
figure idx = randperm(size(XValidation,4),9); for i = 1:numel(idx) subplot(3,3,i) imshow(XValidation(:,:,:,idx(i))); prob = num2str(100*max(probs(idx(i),:)),3); predClass = char(YValPred(idx(i))); title([predClass + ", " + prob + "%"]) end
Referencias
[1] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf
[2] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.
Consulte también
trainnet
| trainingOptions
| dlnetwork
| resnetNetwork
| resnet3dNetwork
| analyzeNetwork