How can I validate CNN after training?

42 visualizaciones (últimos 30 días)
Kwasi
Kwasi el 14 de Mzo. de 2022
Comentada: Kwasi el 21 de Mzo. de 2022
I have 4 samples, each sample contains about 51,000 images. I train the network but each training ends with a suden fall of the validation accuracy. I have tried to increase the minibatch size, it helped a little but now I cannot increase further because I'm running out of memory
After reading some few helps online I think I can train the network without validation data but how do I validate after training?
Thank you in advance.
% reset(gpuDevice(1));
cz1 = fullfile('v1');
imds = imageDatastore(cz1,'LabelSource','none','IncludeSubfolders',true,'FileExtensions','.mat','ReadFcn',@(filename)customreader(filename));
tbl = countEachLabel(imds)
[trainingSet,validationSet, testSet] = splitEachLabel(imds,0.8, 0.1, 0.1 ...
...%399,49,49 ...
...%100,10,10 ...
,'randomized');
layers = [
imageInputLayer([480 640 1],'Normalization', ...
'none','Name','input')
%layer 1
convolution2dLayer([3 3],64,'Stride',[1 1],'Padding',1)
batchNormalizationLayer('Name','BN1')
reluLayer('Name','relu1')
maxPooling2dLayer([3 3],'Stride',2,'Name','MP1')
%layer 4
convolution2dLayer([3 3],64,'Stride',[1 1],'Padding',1,'Name','conv4')
batchNormalizationLayer('Name','BN4')
reluLayer('Name','relu4')
%layer 5
convolution2dLayer([3 3],64,'Stride',[2 2],'Padding',1,'Name','conv5')
additionLayer(2,'Name','add2')
batchNormalizationLayer('Name','BN5')
fullyConnectedLayer(512,'Name','fc1')
batchNormalizationLayer('Name','BN7')
reluLayer('Name','relu7')
%warstwa 7
fullyConnectedLayer(256,'Name','fc2')
reluLayer('Name','relu8')
%warstwa 8
fullyConnectedLayer(4,'Name','fc3')
softmaxLayer('Name','softmax')
classificationLayer('Name','classif')
]
lgraph = layerGraph(layers);
skipConv1 = convolution2dLayer(1,64,'Stride',2,'Name','skipConv1');
lgraph = addLayers(lgraph,skipConv1);
lgraph = connectLayers(lgraph,'MP1','skipConv1');
lgraph = connectLayers(lgraph,'skipConv1','add1/in2')
skipConv2 = convolution2dLayer(1,64,'Stride',2,'Name','skipConv2');
lgraph = addLayers(lgraph,skipConv2);
lgraph = connectLayers(lgraph,'MP3','skipConv2');
lgraph = connectLayers(lgraph,'skipConv2','add2/in2')
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.001, ...
'MaxEpochs',10, ...
'Shuffle','every-epoch', ...
'ValidationData',validationSet, ...
'ValidationFrequency',4000, ...
'ValidationPatience',1, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',1,...
'Verbose',false, ...
'Plots','training-progress', ...
'MiniBatchSize', 20, ...
'CheckpointPath', 't7g');
net = trainNetwork(trainingSet,lgraph,options);
[YPred,scores] = classify(net,testSet,'MiniBatchSize',20);
[S,I] = maxk(scores',5);
YValidation = testSet.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)
top5 = sum(sum(tbl.Label(I)' == YValidation))/numel(YValidation)
function data = customreader(filename)
load(filename,'frame');
end

Respuesta aceptada

Joss Knight
Joss Knight el 16 de Mzo. de 2022
The easiest way to validate after training for classification is to do exactly what you do in your example code to check the accuracy of your test set, but with your validation set. To compute the cross-entropy loss rather than accuracy you might need to implement the crossentropy function yourself. You could just pass your validation data in instead of training data and train for a few iterations to get some numbers.
The simplest way to get rid of this unusual fall in accuracy at the final iteration is to use moving average statistics for your batch normalization layers. Set the BatchNormalizationStatistics training option to 'moving'. (See the documentation.)
  4 comentarios
Joss Knight
Joss Knight el 21 de Mzo. de 2022
No, it was introduced in R2021a.
Kwasi
Kwasi el 21 de Mzo. de 2022
Thank you for responding and the answer.

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Productos


Versión

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by