Main Content

Esta página es para la versión anterior. La página correspondiente en inglés ha sido eliminada en la versión actual.

Personalizar salidas durante un entrenamiento de red de deep learning

En este ejemplo se muestra cómo definir una función de salida que se ejecute con cada iteración durante el entrenamiento de redes neuronales de deep learning. Si especifica funciones de salida usando el argumento de par nombre-valor 'OutputFcn' de trainingOptions, trainNetwork llama a estas funciones una vez antes del inicio del entrenamiento, después de cada iteración y una vez cuando el entrenamiento ha finalizado. Cada vez que se llama a una función de salida, trainNetwork pasa una estructura que contiene información, como el número de iteración actual, pérdidas o precisión. Puede utilizar las funciones de salida para mostrar o representar información de progreso o para detener el entrenamiento. Para detener el entrenamiento antes de tiempo, haga que su función de salida devuelva true. Si cualquier función de salida devuelve true, el entrenamiento finaliza y trainNetwork devuelve la red más reciente.

Para detener el entrenamiento cuando la pérdida del conjunto de validación pare de decrecer, especifique los datos de validación y una paciencia de validación con los argumentos de par nombre-valor 'ValidationData' y 'ValidationPatience' de trainingOptions, respectivamente. La paciencia de validación es el número de veces que la pérdida, en el conjunto de validación, puede ser mayor o igual que la menor pérdida previa antes de que el entrenamiento de la red se detenga. Puede añadir criterios de detención adicionales mediante funciones de salida. En este ejemplo se muestra cómo crear una función de salida que detenga el entrenamiento cuando la precisión de clasificación de los datos de validación deje de mejorar. La función de salida se define al final del script.

Cargue los datos de entrenamiento, que contienen 5000 imágenes de dígitos. Reserve 1000 de las imágenes para la validación de la red.

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

Construya una red para clasificar los datos de las imágenes de los dígitos.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Especifique opciones para el entrenamiento de la red. Para validar la red a intervalos regulares durante el entrenamiento, especifique los datos de validación. Elija el valor 'ValidationFrequency' para que la red se valide una vez por época.

Para detener el entrenamiento cuando la precisión de clasificación del conjunto de validación deje de mejorar, configure stopIfAccuracyNotImproving como una función de salida. El segundo argumento de entrada de stopIfAccuracyNotImproving es el número de veces que la precisión, en el conjunto de validación, puede ser menor o igual que la mayor precisión previa antes de que el entrenamiento de la red se detenga. Elija cualquier valor grande para el número máximo de épocas que desea entrenar. El entrenamiento no debería alcanzar la época final, ya que se detendría automáticamente.

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

Entrene la red. El entrenamiento se detiene cuando la precisión de validación deja de aumentar.

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |        7.81% |       12.70% |       2.7155 |       2.5169 |          0.0100 |
|       1 |          31 |       00:00:07 |       71.88% |       74.70% |       0.8805 |       0.8124 |          0.0100 |
|       2 |          62 |       00:00:12 |       87.50% |       88.00% |       0.3855 |       0.4426 |          0.0100 |
|       3 |          93 |       00:00:19 |       94.53% |       94.00% |       0.2198 |       0.2544 |          0.0100 |
|       4 |         124 |       00:00:25 |       96.09% |       96.40% |       0.1454 |       0.1754 |          0.0100 |
|       5 |         155 |       00:00:31 |       98.44% |       97.70% |       0.0982 |       0.1298 |          0.0100 |
|       6 |         186 |       00:00:36 |       99.22% |       97.90% |       0.0788 |       0.1132 |          0.0100 |
|       7 |         217 |       00:00:40 |      100.00% |       98.00% |       0.0554 |       0.0937 |          0.0100 |
|       8 |         248 |       00:00:45 |      100.00% |       97.90% |       0.0430 |       0.0863 |          0.0100 |
|       9 |         279 |       00:00:50 |      100.00% |       98.10% |       0.0336 |       0.0787 |          0.0100 |
|      10 |         310 |       00:00:54 |      100.00% |       98.40% |       0.0269 |       0.0685 |          0.0100 |
|      11 |         341 |       00:01:00 |      100.00% |       98.40% |       0.0233 |       0.0621 |          0.0100 |
|      12 |         372 |       00:01:04 |      100.00% |       98.70% |       0.0210 |       0.0572 |          0.0100 |
|      13 |         403 |       00:01:08 |      100.00% |       98.80% |       0.0185 |       0.0538 |          0.0100 |
|      14 |         434 |       00:01:12 |      100.00% |       98.90% |       0.0161 |       0.0512 |          0.0100 |
|      15 |         465 |       00:01:17 |      100.00% |       98.80% |       0.0140 |       0.0491 |          0.0100 |
|      16 |         496 |       00:01:20 |      100.00% |       98.70% |       0.0124 |       0.0466 |          0.0100 |
|      17 |         527 |       00:01:24 |      100.00% |       98.90% |       0.0110 |       0.0444 |          0.0100 |
|======================================================================================================================|
Training finished: Stopped by OutputFcn.

Figure Training Progress (29-Aug-2023 21:19:35) contains 2 axes objects and another object of type uigridlayout. Axes object 1 with xlabel Iteration, ylabel Loss contains 8 objects of type patch, text, line. Axes object 2 with xlabel Iteration, ylabel Accuracy (%) contains 8 objects of type patch, text, line.

Definir funciones de salida

Defina la función de salida stopIfAccuracyNotImproving(info,N), que detiene el entrenamiento de la red si la mejor precisión de clasificación de los datos de validación no mejora en N validaciones de red seguidas. Este criterio es similar al criterio de detención integrado que usa la pérdida de validación, a excepción de que se aplica a la precisión de clasificación en vez de a la pérdida.

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;

elseif ~isempty(info.ValidationAccuracy)

    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end

    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end

end

end

Consulte también

|

Temas relacionados