Esta página es para la versión anterior. La página correspondiente en inglés ha sido eliminada en la versión actual.
Personalizar salidas durante un entrenamiento de red de deep learning
En este ejemplo se muestra cómo definir una función de salida que se ejecute con cada iteración durante el entrenamiento de redes neuronales de deep learning. Si especifica funciones de salida usando el argumento de par nombre-valor 'OutputFcn'
de trainingOptions
, trainNetwork
llama a estas funciones una vez antes del inicio del entrenamiento, después de cada iteración y una vez cuando el entrenamiento ha finalizado. Cada vez que se llama a una función de salida, trainNetwork
pasa una estructura que contiene información, como el número de iteración actual, pérdidas o precisión. Puede utilizar las funciones de salida para mostrar o representar información de progreso o para detener el entrenamiento. Para detener el entrenamiento antes de tiempo, haga que su función de salida devuelva true
. Si cualquier función de salida devuelve true
, el entrenamiento finaliza y trainNetwork
devuelve la red más reciente.
Para detener el entrenamiento cuando la pérdida del conjunto de validación pare de decrecer, especifique los datos de validación y una paciencia de validación con los argumentos de par nombre-valor 'ValidationData'
y 'ValidationPatience'
de trainingOptions
, respectivamente. La paciencia de validación es el número de veces que la pérdida, en el conjunto de validación, puede ser mayor o igual que la menor pérdida previa antes de que el entrenamiento de la red se detenga. Puede añadir criterios de detención adicionales mediante funciones de salida. En este ejemplo se muestra cómo crear una función de salida que detenga el entrenamiento cuando la precisión de clasificación de los datos de validación deje de mejorar. La función de salida se define al final del script.
Cargue los datos de entrenamiento, que contienen 5000 imágenes de dígitos. Reserve 1000 de las imágenes para la validación de la red.
[XTrain,YTrain] = digitTrain4DArrayData; idx = randperm(size(XTrain,4),1000); XValidation = XTrain(:,:,:,idx); XTrain(:,:,:,idx) = []; YValidation = YTrain(idx); YTrain(idx) = [];
Construya una red para clasificar los datos de las imágenes de los dígitos.
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer fullyConnectedLayer(10) softmaxLayer classificationLayer];
Especifique opciones para el entrenamiento de la red. Para validar la red a intervalos regulares durante el entrenamiento, especifique los datos de validación. Elija el valor 'ValidationFrequency'
para que la red se valide una vez por época.
Para detener el entrenamiento cuando la precisión de clasificación del conjunto de validación deje de mejorar, configure stopIfAccuracyNotImproving
como una función de salida. El segundo argumento de entrada de stopIfAccuracyNotImproving
es el número de veces que la precisión, en el conjunto de validación, puede ser menor o igual que la mayor precisión previa antes de que el entrenamiento de la red se detenga. Elija cualquier valor grande para el número máximo de épocas que desea entrenar. El entrenamiento no debería alcanzar la época final, ya que se detendría automáticamente.
miniBatchSize = 128; validationFrequency = floor(numel(YTrain)/miniBatchSize); options = trainingOptions('sgdm', ... 'InitialLearnRate',0.01, ... 'MaxEpochs',100, ... 'MiniBatchSize',miniBatchSize, ... 'VerboseFrequency',validationFrequency, ... 'ValidationData',{XValidation,YValidation}, ... 'ValidationFrequency',validationFrequency, ... 'Plots','training-progress', ... 'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));
Entrene la red. El entrenamiento se detiene cuando la precisión de validación deja de aumentar.
net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU. Initializing input data normalization. |======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================| | 1 | 1 | 00:00:02 | 7.81% | 12.70% | 2.7155 | 2.5169 | 0.0100 | | 1 | 31 | 00:00:07 | 71.88% | 74.70% | 0.8805 | 0.8124 | 0.0100 | | 2 | 62 | 00:00:12 | 87.50% | 88.00% | 0.3855 | 0.4426 | 0.0100 | | 3 | 93 | 00:00:19 | 94.53% | 94.00% | 0.2198 | 0.2544 | 0.0100 | | 4 | 124 | 00:00:25 | 96.09% | 96.40% | 0.1454 | 0.1754 | 0.0100 | | 5 | 155 | 00:00:31 | 98.44% | 97.70% | 0.0982 | 0.1298 | 0.0100 | | 6 | 186 | 00:00:36 | 99.22% | 97.90% | 0.0788 | 0.1132 | 0.0100 | | 7 | 217 | 00:00:40 | 100.00% | 98.00% | 0.0554 | 0.0937 | 0.0100 | | 8 | 248 | 00:00:45 | 100.00% | 97.90% | 0.0430 | 0.0863 | 0.0100 | | 9 | 279 | 00:00:50 | 100.00% | 98.10% | 0.0336 | 0.0787 | 0.0100 | | 10 | 310 | 00:00:54 | 100.00% | 98.40% | 0.0269 | 0.0685 | 0.0100 | | 11 | 341 | 00:01:00 | 100.00% | 98.40% | 0.0233 | 0.0621 | 0.0100 | | 12 | 372 | 00:01:04 | 100.00% | 98.70% | 0.0210 | 0.0572 | 0.0100 | | 13 | 403 | 00:01:08 | 100.00% | 98.80% | 0.0185 | 0.0538 | 0.0100 | | 14 | 434 | 00:01:12 | 100.00% | 98.90% | 0.0161 | 0.0512 | 0.0100 | | 15 | 465 | 00:01:17 | 100.00% | 98.80% | 0.0140 | 0.0491 | 0.0100 | | 16 | 496 | 00:01:20 | 100.00% | 98.70% | 0.0124 | 0.0466 | 0.0100 | | 17 | 527 | 00:01:24 | 100.00% | 98.90% | 0.0110 | 0.0444 | 0.0100 | |======================================================================================================================| Training finished: Stopped by OutputFcn.
Definir funciones de salida
Defina la función de salida stopIfAccuracyNotImproving(info,N)
, que detiene el entrenamiento de la red si la mejor precisión de clasificación de los datos de validación no mejora en N
validaciones de red seguidas. Este criterio es similar al criterio de detención integrado que usa la pérdida de validación, a excepción de que se aplica a la precisión de clasificación en vez de a la pérdida.
function stop = stopIfAccuracyNotImproving(info,N) stop = false; % Keep track of the best validation accuracy and the number of validations for which % there has not been an improvement of the accuracy. persistent bestValAccuracy persistent valLag % Clear the variables when training starts. if info.State == "start" bestValAccuracy = 0; valLag = 0; elseif ~isempty(info.ValidationAccuracy) % Compare the current validation accuracy to the best accuracy so far, % and either set the best accuracy to the current accuracy, or increase % the number of validations for which there has not been an improvement. if info.ValidationAccuracy > bestValAccuracy valLag = 0; bestValAccuracy = info.ValidationAccuracy; else valLag = valLag + 1; end % If the validation lag is at least N, that is, the validation accuracy % has not improved for at least N validations, then return true and % stop training. if valLag >= N stop = true; end end end
Consulte también
trainNetwork
| trainingOptions