Borrar filtros
Borrar filtros

trainMaskRCNN 함수를 이용한 훈련에서 조기 종료와 관련된 문제가 있습니다. (validation patience 관련)

3 visualizaciones (últimos 30 días)
Junhyeon
Junhyeon el 21 de Nov. de 2023
Movida: Angelo Yeo el 26 de Dic. de 2023
% Custom training experiment code
function output = ver_1120_training1(params, monitor)
load('C:\Users\UserPc\OneDrive\Master\Main project\2. AST\2. Extracting the color of PDA beads/Label data/ver_1116/polygonLabel.mat','polygonTable','validTable','detector');
load('C:\Users\UserPc\OneDrive\Master\Main project\2. AST\2. Extracting the color of PDA beads/Label data/ver_1120/polygonLabel.mat');
trainDs = transform(arrayDatastore(polygonTable,'OutputType','same'),@trainingDataTransformFcn);
validDs = transform(arrayDatastore(test,'OutputType','same'),@validationDataTransformFcn);
options = trainingOptions(params.solver,...
"MaxEpochs",100,"MiniBatchSize",2,"ExecutionEnvironment","multi-gpu",...
"Shuffle","every-epoch","InitialLearnRate",params.initLearnRate,"LearnRateSchedule","piecewise",...
"LearnRateDropPeriod",params.dropPeriod,"LearnRateDropFactor",params.dropFactor,...
"ValidationData",validDs,"ValidationFrequency",50,"ValidationPatience",20,...
"OutputNetwork","best-validation-loss",...
"BatchNormalizationStatistics","moving","ResetInputNormalization",false);
[trainedDetector, info] = trainMaskRCNN(trainDs,detector,options,"ExperimentMonitor",monitor);
output.trainedDetector = trainedDetector;
output.info = info;
end
안녕하세요.
MATLAB의 실험관리자 앱을 사용해서 Mask R-CNN의 훈련과 하이퍼-파라미터 튜닝을 진행하고 있습니다.
감지하고자 하는 객체의 class는 1개 종류 뿐입니다.
코드는 정상작동 되지만 한가지 문제가 있습니다.
저는 처음에 Validation patience를 5로 설정하였는데, 훈련이 5번의 검증(250번째 반복)만에 종료되었습니다.
Validation patience를 20으로 설정하면, 훈련이 대략 20~21번의 검증(약 1000번째 반복)만에 종료됩니다.
훈련 플롯을 살펴보면, Validation Loss는 점차 감소하는 추세입니다.
따라서 의도한대로 Validation loss가 20번 연속으로 감소하지 않으면 훈련이 종료되기 보다는
Validation patience로 설정한 횟수만큼 검증이 이어진 후에 훈련이 종료되는 것 같습니다.
조금 의아한 부분은 Mask Loss와 RMSE의 그래프가 초기에 0부터 시작하는 것처럼 보인다는 것입니다.
그렇다면 최소값이 0으로 설정되어 이러한 현상이 나타나는 것일까요?
그리고 오로지 첫번째 항목인 (Total) Loss에 대해서만 Early stop 여부가 결정되게 만들 수는 없을까요?
답변을 부탁드립니다.
감사합니다.

Respuesta aceptada

Junhyeon
Junhyeon el 21 de Nov. de 2023
Movida: Angelo Yeo el 26 de Dic. de 2023
답변이 달리지는 않았지만 거의 해결한 것 같습니다.
혹 저와 같은 문제를 겪는 분들을 위해 코멘트를 남깁니다.
trainingOptions 함수에서 validation patience 옵션을 설정하지 않고,
OutputFcn을 옵션과 커스텀 조기 종료 함수를 선언하니 정상적으로 훈련이 지속됩니다.
아래 저의 코드를 남깁니다.
function output = ver_1120_training1(params, monitor)
load('C:\Users\UserPc\OneDrive\Master\Main project\2. AST\2. Extracting the color of PDA beads/Label data/ver_1116/polygonLabel.mat','polygonTable','detector');
load('C:\Users\UserPc\OneDrive\Master\Main project\2. AST\2. Extracting the color of PDA beads/Label data/ver_1120/polygonLabel.mat');
trainDs = transform(arrayDatastore(polygonTable,'OutputType','same'),@trainingDataTransformFcn);
validDs = transform(arrayDatastore(test,'OutputType','same'),@validationDataTransformFcn);
options = trainingOptions(params.solver,"Momentum",params.momentum,...
"MaxEpochs",100,"MiniBatchSize",2,"ExecutionEnvironment","multi-gpu",...
"Shuffle","every-epoch","InitialLearnRate",params.initLearnRate,"LearnRateSchedule","piecewise",...
"LearnRateDropPeriod",params.dropPeriod,"LearnRateDropFactor",params.dropFactor,...
"ValidationData",validDs,"ValidationFrequency",50,...
"OutputNetwork","best-validation-loss",...
"BatchNormalizationStatistics","moving","ResetInputNormalization",false,...
"OutputFcn",@(info)stopFcn(info,10,params.initLearnRate));
[trainedDetector, info] = trainMaskRCNN(trainDs,detector,options,"ExperimentMonitor",monitor);
output.trainedDetector = trainedDetector;
output.info = info;
end
% validation patience를 대체하는 커스텀 조기 종료 함수
function stop = stopFcn(info, N, C)
stop = false;
persistent bestValLoss
persistent count
start = C;
% info.State 변수가 없으므로,
% 훈련 시작 시 count를 초기화 하는 조건에 initial learning rate를 이용한다.
% initial learning rate value가 유지되는 동안에는 훈련이 유지된다.
if start == info.LearnRate
bestValLoss = inf;
count = 0;
% 검증 손실값이 최저값을 갱신하면 count를 초기화한다.
elseif info.ValidationLoss < bestValLoss
bestValLoss = info.ValidationLoss;
count = 0;
% 검증 손실값이 같거나 증가하면 count에 1을 더한다.
else
count = count + 1;
end
% count가 지정된 임계값 (= validation patience) 이상이 되면 훈련을 중지한다.
if count >= N
stop = true;
end
end

Más respuestas (0)

Productos


Versión

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!