Using transformer neural network for classification task

19 visualizaciones (últimos 30 días)
numChannels = inputSize;
maxPosition = 256;
numHeads = 4;
numKeyChannels = numHeads*32;
layers = [
sequenceInputLayer(numChannels,Name="input")
positionEmbeddingLayer(numChannels, maxPosition, Name="pos-emb");
additionLayer(2, Name="add")
selfAttentionLayer(numHeads,numKeyChannels,'AttentionMask','causal')
selfAttentionLayer(numHeads,numKeyChannels)
indexing1dLayer("last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph, "input", "add/in2");
maxEpochs = 100;
miniBatchSize = 32;
learningRate = 0.001;
solver = 'adam';
shuffle = 'every-epoch';
gradientThreshold = 10;
executionEnvironment = "auto"; % chooses local GPU if available, otherwise CPU
options = trainingOptions(solver, ...
'Plots','training-progress', ...
'MaxEpochs', maxEpochs, ...
'MiniBatchSize', miniBatchSize, ...
'Shuffle', shuffle, ...
'InitialLearnRate', learningRate, ...
'GradientThreshold', gradientThreshold, ...
'ExecutionEnvironment', executionEnvironment);
The input size is 12, so there are 12 features.
numClasses is 4, so I am classifying it into 4 class.
But it gives the following error when I try to run it
"
Error in test123_20240727 (line 195)
net=trainNetwork(XTrain, YTrain, layers, options);
Caused by:
Layer 'add': Unconnected input. Each layer input must be connected to the output of another layer.
"
line 195 is "net=trainNetwork(XTrain, YTrain, layers, options);"
Can anyone help me with this?
  7 comentarios
Umar
Umar el 29 de Jul. de 2024
Hi @ haohaoxuexi1,
If you are still having issues with modifying your code, please let us know. We will be happy to help you out.
haohaoxuexi1
haohaoxuexi1 el 29 de Jul. de 2024
@Umar Hi Umar, I am good at the moment. Will let u know if I have further question.

Iniciar sesión para comentar.

Respuesta aceptada

Joss Knight
Joss Knight el 29 de Jul. de 2024

You've passed layers instead of lgraph to trainNetwork.

  2 comentarios
Umar
Umar el 29 de Jul. de 2024
@Joss Knight, Thanks for jumping in. Please advice how to use lgraph to trainNetwork by providing code snippet. Again, thanks for your cooperation.
Joss Knight
Joss Knight el 13 de Ag. de 2024
net=trainNetwork(XTrain, YTrain, lgraph, options);
instead of
net=trainNetwork(XTrain, YTrain, layers, options);

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Productos


Versión

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by