Main Content

Entrenar una red con un bucle de entrenamiento personalizado

En este ejemplo se muestra cómo entrenar una red que clasifica dígitos manuscritos con una programación de tasa de aprendizaje personalizada.

Puede entrenar la mayoría de tipos de redes neuronales utilizando las funciones trainnet y trainingOptions. Si la función trainingOptions no proporciona las opciones que necesita (por ejemplo, una programación de tasa de aprendizaje personalizada), puede definir su propio bucle de entrenamiento personalizado mediante los objetos dlarray y dlnetwork para la diferenciación automática. Para ver un ejemplo de cómo volver a entrenar una red de deep learning preentrenada mediante la función trainnet, consulte Retrain Neural Network to Classify New Images.

El entrenamiento de una red neuronal profunda es una tarea de optimización. Considerando una red neuronal como una función f(X;θ), donde X es la entrada de la red y θ es el conjunto de parámetros que se pueden aprender, puede optimizar θ para que minimice parte del valor de pérdida en función de los datos de entrenamiento. Por ejemplo, optimice los parámetros que se pueden aprender θ de modo que, para unas entradas determinadas X con los objetivos correspondientes T, minimicen el error entre las predicciones Y=f(X;θ) y T.

La función de pérdida usada depende del tipo de tarea. Por ejemplo:

  • En tareas de clasificación, puede minimizar el error de entropía cruzada entre las predicciones y los objetivos.

  • En tareas de regresión, puede minimizar el error cuadrático medio entre las predicciones y los objetivos.

Puede optimizar el objetivo mediante el gradiente descendente: minimice la pérdida L actualizando iterativamente los parámetros que se pueden aprender θ dando pasos hacia el mínimo utilizando los gradientes de pérdida con respecto a los parámetros que se pueden aprender. Los algoritmos de gradiente descendente suelen actualizar los parámetros que se pueden aprender utilizando una variante de un paso de actualización de la forma θt+1=θt-ρL, donde t es el número de iteración, ρ es la tasa de aprendizaje y L denota los gradientes (las derivadas de la pérdida con respecto a los parámetros que se pueden aprender).

En este ejemplo se entrena una red para clasificar dígitos manuscritos con la programación de tasa de aprendizaje de decaimiento basado en el tiempo: para cada iteración, el solver utiliza la tasa de aprendizaje dada por ρt=ρ01+k t, donde t es el número de iteración, ρ0 es la tasa de aprendizaje inicial y k es el decaimiento.

Cargar los datos de entrenamiento

Cargue los datos de dígitos como un almacén de datos de imágenes mediante la función imageDatastore y especifique la carpeta que contiene los datos de imagen.

unzip("DigitsData.zip")

imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

Divida los datos en conjuntos de entrenamiento y de validación. Reserve el 10% de los datos para la validación mediante la función splitEachLabel.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,"randomize");

La red usada en este ejemplo requiere imágenes de entrada de un tamaño de 28 por 28 por 1. Para cambiar automáticamente el tamaño de las imágenes de entrenamiento, utilice un almacén de datos de imágenes aumentado. Especifique operaciones de aumento adicionales para realizar en las imágenes de entrenamiento: traslade aleatoriamente las imágenes hasta 5 píxeles a lo largo de los ejes vertical y horizontal. El aumento de datos ayuda a evitar que la red se sobreajuste y memorice los detalles exactos de las imágenes de entrenamiento.

inputSize = [28 28 1];
pixelRange = [-5 5];

imageAugmenter = imageDataAugmenter( ...
    RandXTranslation=pixelRange, ...
    RandYTranslation=pixelRange);

augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,DataAugmentation=imageAugmenter);

Para cambiar el tamaño de las imágenes de validación de forma automática sin realizar más aumentos de datos, utilice un almacén de datos de imágenes aumentadas sin especificar ninguna operación adicional de preprocesamiento.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Determine el número de clases de los datos de entrenamiento.

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

Definir la red

Defina la red para la clasificación de imágenes.

  • Para la entrada de imagen, especifique una capa de entrada de imagen con un tamaño de entrada que coincida con los datos de entrenamiento.

  • No normalice la entrada de la imagen; establezca la opción Normalization de la capa de entrada en "none".

  • Especifique tres bloques convolución-batchnorm-ReLU.

  • Rellene la entrada a las capas de convolución de modo que la salida tenga el mismo tamaño estableciendo la opción Padding como "same".

  • Para la primera capa de convolución, especifique 20 filtros de tamaño 5. Para las capas de convolución restantes, especifique 20 filtros de tamaño 3.

  • Para la clasificación, especifique una capa totalmente conectada con un tamaño que coincida con el número de clases.

  • Para asignar la salida a probabilidades, incluya una capa softmax.

Al entrenar una red mediante un bucle de entrenamiento personalizado, no incluya una capa de salida.

layers = [
    imageInputLayer(inputSize,Normalization="none")
    convolution2dLayer(5,20,Padding="same")
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding="same")
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding="same")
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

Cree un objeto dlnetwork a partir del arreglo de capas.

net = dlnetwork(layers)
net = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

Definir la función de pérdida del modelo

El entrenamiento de una red neuronal profunda es una tarea de optimización. Considerando una red neuronal como una función f(X;θ), donde X es la entrada de la red y θ es el conjunto de parámetros que se pueden aprender, puede optimizar θ para que minimice parte del valor de pérdida en función de los datos de entrenamiento. Por ejemplo, optimice los parámetros que se pueden aprender θ de modo que, para unas entradas determinadas X con los objetivos correspondientes T, minimicen el error entre las predicciones Y=f(X;θ) y T.

Cree la función modelLoss, que aparece en la sección Función de pérdida del modelo del ejemplo, que toma como entrada el objeto dlnetwork, un minilote de datos de entrada con los objetivos correspondientes, y devuelve la pérdida, los gradientes de la pérdida con respecto a los parámetros que se pueden aprender y el estado de la red.

Especificar las opciones de entrenamiento

Entrene con un tamaño de minilote de 128 durante diez épocas.

numEpochs = 10;
miniBatchSize = 128;

Especifique las opciones para la optimización de SGDM. Especifique una tasa de aprendizaje inicial de 0,01 con un decaimiento de 0,01 y un momento de 0,9.

initialLearnRate = 0.01;
decay = 0.01;
momentum = 0.9;

Entrenar un modelo

Cree un objeto minibatchqueue que procese y gestione minilotes de imágenes durante el entrenamiento. Para cada minilote:

  • Utilice la función de preprocesamiento de minilotes personalizada preprocessMiniBatch (definida al final de este ejemplo) para convertir las etiquetas en variables codificadas one-hot.

  • Dé formato a los datos de imagen con las etiquetas de dimensión "SSCB" (espacial, espacial, canal, lote). De forma predeterminada, el objeto minibatchqueue convierte los datos en objetos dlarray con el tipo subyacente single. No dé formato a las etiquetas de clase.

  • Descarte los minilotes parciales.

  • Entrene en una GPU, si se dispone de ella. De forma predeterminada, el objeto minibatchqueue convierte cada salida en gpuArray si hay una GPU disponible. Utilizar una GPU requiere Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox).

mbq = minibatchqueue(augimdsTrain,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@preprocessMiniBatch,...
    MiniBatchFormat=["SSCB" ""], ...
    PartialMiniBatch="discard");

Inicialice el parámetro de velocidad para el solver SGDM.

velocity = [];

Calcule el número total de iteraciones para monitorizar el progreso del entrenamiento.

numObservationsTrain = numel(imdsTrain.Files);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;

Inicialice el objeto TrainingProgressMonitor. Dado que el cronómetro empieza cuando crea el objeto de monitorización, asegúrese de crear el objeto cerca del bucle de entrenamiento.

monitor = trainingProgressMonitor( ...
    Metrics="Loss", ...
    Info=["Epoch" "LearnRate"], ...
    XLabel="Iteration");

Entrene la red con un bucle de entrenamiento personalizado. Para cada época, cambie el orden de los datos y pase en bucle por minilotes de datos. Para cada minilote:

  • Evalúe la pérdida, los gradientes y el estado del modelo utilizando las funciones dlfeval y modelLoss y actualice el estado de la red.

  • Determine la tasa de aprendizaje para la programación de tasa de aprendizaje de decaimiento basado en el tiempo.

  • Actualice los parámetros de red con la función sgdmupdate.

  • Actualice la pérdida, la tasa de aprendizaje y los valores de época en la monitorización del progreso del entrenamiento.

  • Detenga el proceso si la propiedad Stop está establecida como verdadero. El valor de la propiedad Stop del objeto TrainingProgressMonitor cambia a verdadero cuando hace clic en el botón Stop.

epoch = 0;
iteration = 0;

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
    
    epoch = epoch + 1;

    % Shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq) && ~monitor.Stop

        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [X,T] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelLoss function and update the network state.
        [loss,gradients,state] = dlfeval(@modelLoss,net,X,T);
        net.State = state;
        
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        
        % Update the network parameters using the SGDM optimizer.
        [net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum);
        
        % Update the training progress monitor.
        recordMetrics(monitor,iteration,Loss=loss);
        updateInfo(monitor,Epoch=epoch,LearnRate=learnRate);
        monitor.Progress = 100 * iteration/numIterations;
    end
end

Probar un modelo

Pruebe la precisión de clasificación del modelo comparando las predicciones en un conjunto de validación con las etiquetas verdaderas.

Después del entrenamiento, para hacer predicciones sobre nuevos datos no se requieren etiquetas. Cree un objeto minibatchqueue que contenga solo los predictores de los datos de prueba:

  • Para ignorar las etiquetas para las pruebas, establezca el número de salidas de la cola de minilotes en 1.

  • Especifique el mismo tamaño de minilote utilizado para el entrenamiento.

  • Preprocese los predictores mediante la función preprocessMiniBatchPredictors, que se enumera al final del ejemplo.

  • Para la salida única del almacén de datos, especifique el formato de los minilotes "SSCB" (espacial, espacial, canal, lote).

numOutputs = 1;

mbqTest = minibatchqueue(augimdsValidation,numOutputs, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatchPredictors, ...
    MiniBatchFormat="SSCB");

Pase en bucle por los minilotes y clasifique las imágenes mediante la función modelPredictions, que se enumera al final del ejemplo.

YTest = modelPredictions(net,mbqTest,classes);

Evalúe la precisión de clasificación.

TTest = imdsValidation.Labels;
accuracy = mean(TTest == YTest)
accuracy = 0.9220

Visualice las predicciones en una gráfica de confusión.

figure
confusionchart(TTest,YTest)

Los valores grandes de la diagonal indican predicciones precisas para la clase correspondiente. Los valores grandes fuera de la diagonal indican una fuerte confusión entre las clases correspondientes.​

Funciones de apoyo

Función de pérdida del modelo

La función modelLoss toma un objeto net de dlnetwork, un minilote de datos de entrada X con objetivos correspondientes T y devuelve la pérdida, los gradientes de la pérdida con respecto a los parámetros que se pueden aprender en net y el estado de la red. Para calcular los gradientes automáticamente, utilice la función dlgradient.

function [loss,gradients,state] = modelLoss(net,X,T)

% Forward data through network.
[Y,state] = forward(net,X);

% Calculate cross-entropy loss.
loss = crossentropy(Y,T);

% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss,net.Learnables);

end

Función de predicciones del modelo

La función modelPredictions toma un objeto net de dlnetwork, un minibatchqueue de datos de entrada mbq y las clases de red, y calcula las predicciones del modelo iterando sobre todos los datos en el objeto minibatchqueue. La función utiliza la función onehotdecode para encontrar la clase predicha con la puntuación más alta.

function Y = modelPredictions(net,mbq,classes)

Y = [];

% Loop over mini-batches.
while hasdata(mbq)
    X = next(mbq);

    % Make prediction.
    scores = predict(net,X);

    % Decode labels and append to output.
    labels = onehotdecode(scores,classes,1)';
    Y = [Y; labels];
end

end

Función de preprocesamiento de minilotes

La función preprocessMiniBatch preprocesa un minilote de predictores y etiquetas dando los siguientes pasos:

  1. Preprocesa las imágenes usando la función preprocessMiniBatchPredictors.

  2. Extrae los datos de la etiqueta del arreglo de celdas entrante y los concatena en un arreglo categórico a lo largo de la segunda dimensión.

  3. Hace una codificación one-hot de las etiquetas categóricas en arreglos numéricos. La codificación en la primera dimensión produce un arreglo codificado que coincide con la forma de la salida de la red.

function [X,T] = preprocessMiniBatch(dataX,dataT)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(dataX);

% Extract label data from cell and concatenate.
T = cat(2,dataT{1:end});

% One-hot encode labels.
T = onehotencode(T,1);

end

Función de preprocesamiento de predictores de minilotes

La función preprocessMiniBatchPredictors preprocesa un minilote de predictores extrayendo los datos de imagen del arreglo de celdas de entrada y concatenándolos en un arreglo numérico. Para la entrada en escala de grises, la concatenación sobre la cuarta dimensión añade una tercera dimensión a cada imagen, para usarla como dimensión de canal única.

function X = preprocessMiniBatchPredictors(dataX)

% Concatenate.
X = cat(4,dataX{1:end});

end

Consulte también

| | | | | | | | | |

Temas relacionados