Borrar filtros
Borrar filtros

Neural network with multiple input

40 visualizaciones (últimos 30 días)
Raffaele Villa
Raffaele Villa el 25 de Sept. de 2023
Respondida: Avadhoot el 4 de Oct. de 2023
Hi,
i'm starting using Neural Network with Matlab, but i'm facing some problem in the first setup.
I desinging a simple neural network that take as input 2 3D images and, with a two step process, firstly analyses them separately and then combine the 2 branch of the network to obtain a single output [0-1].
Here a scheme of the network:
Here my code:
clear
clc
% Load training data
[file,folder]=uigetfile('*.csv');
[trainImages,trainLabels] = Caricamento([folder,file]);
% trainImages=squeeze(trainImages);
dim=size(trainImages);
trainImages1=trainImages(:,1);
trainImages2=trainImages(:,2);
dim=size(trainImages1{1,1});
% trainImages1 = reshape(trainImages1,dim(1), dim(2), dim(3), 1, dim(4));
% trainImages2 = reshape(trainImages2,dim(1), dim(2), dim(3), 1, dim(4));
% Define the network architecture for input 1
layers1 = [
image3dInputLayer([dim(1) dim(2) dim(3)],'Name','input1')
convolution3dLayer(5,20,'Name','conv1')
reluLayer('Name','relu1')
maxPooling3dLayer(2,'Stride',2,'Name','maxpool1')];
% Define the network architecture for input 2
layers2 = [
image3dInputLayer([dim(1) dim(2) dim(3)],'Name','input2')
convolution3dLayer(5,20,'Name','conv2')
reluLayer('Name','relu2')
maxPooling3dLayer(2,'Stride',2,'Name','maxpool2')];
% classes = categories(trainLabels);
% Combine the two input layers
middleLayers = [
depthConcatenationLayer(2,'Name','concat')
fullyConnectedLayer(2,'Name','fc')
softmaxLayer('Name','softmax')
classificationLayer('Name','output', 'Classes', 'auto')];
% Combine all layers
lgraph = layerGraph();
lgraph = addLayers(lgraph,layers1);
lgraph = addLayers(lgraph,layers2);
lgraph = addLayers(lgraph,middleLayers);
lgraph = connectLayers(lgraph,'maxpool1','concat/in1');
lgraph = connectLayers(lgraph,'maxpool2','concat/in2');
% Specify training options
options = trainingOptions('sgdm','MaxEpochs',15,...
'InitialLearnRate',0.0001);
% Train the network
convnet = trainNetwork([trainImages1,trainImages2],trainLabels,lgraph,options);
Actually Caricamento Function load a csv file with the loacation of every images and the corresponding labels, giving 3 cell vectors (trainImages1,trainImages2,trainLabels) with sizes
size(trainImages1)
size(trainImages2)
size(trainLabels)
ans =
112 1
ans =
112 1
ans =
112 1
every cell of the 2 first vectors are
size(trainImages1{1})
ans =
93 65 10
Running the code it gives me this error
Error using trainNetwork
Invalid training data. Predictors and responses must have the same number of observations.
Error in cnn (line 49)
convnet = trainNetwork([trainImages1,trainImages2],trainLabels,lgraph,options);
whats wrong with my data?
i had also tried to load the images as multidimensional matrix of sizes
size(trainImages1{1})
ans =
93 65 10 112
but i had the same error.
Thank you in advance for your suggestions.
Raffaele Villa

Respuestas (1)

Avadhoot
Avadhoot el 4 de Oct. de 2023
Hi Raffaele,
I understand that you are trying to pass a concatenation of cell arrays as the input. The error you encountered is because the input data is not in the proper format for multiple input neural networks.
To train a multiple-input network, you need to use a single data store that combines both sets of input data. You can achieve this by custom data handling using the arrayDataStorefunction in MATLAB.
Please refer to the following documentation for more information about the “arrayDataStore” function:
Here is a reference code for more clarification:
cds = arrayDatastore([trainImages1,trainImages2,trainLabels]);
This line would create an “arrayDatastorecontaining 112 rows and 3 columns. After executing this, you could pass the cds” variable directly to the “trainNetwork” function as below:
convnet = trainNetwork(cds,lgraph,options);
Please refer to the below documentation to know more about the “trainNetwork” function:
I hope it helps,
Regards,
Avadhoot.

Categorías

Más información sobre Get Started with Statistics and Machine Learning Toolbox en Help Center y File Exchange.

Productos


Versión

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by