Borrar filtros
Borrar filtros

1D CNN for sequence-to-label classification, model input errors

12 visualizaciones (últimos 30 días)
I created a 1D CNN to classify numerical sequences into 5 classes. I cannot get the correct input formatting to train the model
data = num2cell(data, 2); % sequence input format
label = categorical(label); % response input format
for i=1:length(data)
data{i} = data{i}';
end
[idxTrain,idxTest] = trainingPartitions(numel(data), [0.9 0.1]);
dataTrain = data(idxTrain);
labelTrain = label(idxTrain);
dataTest = data(idxTest);
labelTest = label(idxTest);
options = trainingOptions("adam",...
'InitialLearnRate',1e-3,...
'LearnRateDropFactor',0.1,...
'LearnRateDropPeriod',20,...
'MaxEpochs',60,...
'MiniBatchSize',36,...
'LearnRateSchedule','piecewise', ...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress');
[net,model_performance] = trainNetwork(dataTrain, labelTrain, CNN, options);
the input format are as follows:
sequence input (cell array)
label input (categorical scalar)- classes include ["N","L","R","A","V"]

Respuesta aceptada

Debraj Maji
Debraj Maji el 28 de Nov. de 2023
Editada: Debraj Maji el 28 de Nov. de 2023
I understand that you are trying to train a 1D CNN Network.
The response variable in the 'trainNet' function should be a cell array of categorical row vectors(every cell is a categorical row vector) and not a simple categorical row vector as the input consists of multiple observations. In your case 'labelTrain' should not be a 900*1 categorical row vector but a 1x900 cell where each individual cell is a categorical row vector.
One of the possible ways of modifying the above code to accomodate the change would be to use a cell array where each element is a categorical row vector. I have attached the correct code below for your reference:
clear;
load test\CNN.mat;
data = num2cell(data,2); % sequence input format
label = categorical(label); % response input format
for i=1:length(data)
data{i} = data{i}';
end
label1 = {};
for i=1:length(label)
label1{i} = categorical(label(i))';
end
[idxTrain,idxTest] = trainingPartitions(numel(data), [0.9 0.1]);
dataTrain = data(idxTrain);
labelTrain = label1(idxTrain);
dataTest = data(idxTest);
labelTest = label1(idxTest);
options = trainingOptions("adam",...
'InitialLearnRate',1e-3,...
'LearnRateDropFactor',0.1,...
'LearnRateDropPeriod',20,...
'MaxEpochs',60,...
'MiniBatchSize',36,...
'LearnRateSchedule','piecewise', ...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress');
[net,model_performance] = trainNetwork(dataTrain, labelTrain, CNN, options);
After running the above code with the attached data the Accuracy and loss curves obtained are attached below:
I hope this resolves your query.
With regards,
Debraj.

Más respuestas (0)

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Etiquetas

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by