Borrar filtros
Borrar filtros

multiple input to a pre-trained model

3 visualizaciones (últimos 30 días)
Rayan Matlob
Rayan Matlob el 5 de Jul. de 2022
Editada: Rayan Matlob el 6 de Jul. de 2022
I have three classes folders (Good, Moderate and Severe)
Each class folder of them has (5 subolders) which are (Original images, Red, Blue, Green, HUE, Value),
where (Red, Blue, Green, HUE,Value) are subolders contain images after applying filters on the (Original images folder).
I am using a pre-trained model (resnet50 or any other model you suggest), all images in all the folders are numbered in the same sequence (each subfolder contains images from 1 to 200).
How to train the model by taking each single image from subfolder(Original images),and to apply it in parralel with the images from the other subfoldere (Red, Blue, Green, HUe,Value) to the input of the model.
Note: for the validation, i need to use only the (original_images folder) and the model should fetch the other images from the other subfolders
Next is the matlab code, thanks:
  1 comentario
Rayan Matlob
Rayan Matlob el 6 de Jul. de 2022
Editada: Rayan Matlob el 6 de Jul. de 2022
imds = imageDatastore('C:\Users\Rayan\Desktop\9_8_balance_data\R_9_1_GSM_3', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.77,'randomized');
numTrainImages = numel(imdsTrain.Labels);
net = resnet50;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
edit(fullfile(matlabroot,'examples','nnet','main','findLayersToReplace.m'))
[learnableLayer,classLayer] = findLayersToReplace(lgraph);
numClasses = numel(categories(imdsTrain.Labels));
if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
newLearnableLayer = convolution2dLayer(1,numClasses, ...
'Name','new_conv', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
end
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
layers = lgraph.Layers;
connections = lgraph.Connections;
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain)
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
miniBatchSize=10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',60, ...
'InitialLearnRate',0.00065, ...
'Shuffle','every-epoch', ...
'ValidationFrequency',valFrequency, ...
'ValidationData',augimdsValidation, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(augimdsTrain,lgraph,options);

Iniciar sesión para comentar.

Respuestas (0)

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Productos


Versión

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by