Using self attention layer with 2D images
34 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Hi,
I am wondering how to use the selfattention layer in image calssaifcation using CNN without we need to flatten the data as explained in this example:
% load digit dataset
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.7, 'randomized');
% define network architecture
layers = [
imageInputLayer([28 28 1], 'Name', 'input')
convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv1')
batchNormalizationLayer('Name', 'bn1')
reluLayer('Name', 'relu1')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool1')
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2')
batchNormalizationLayer('Name', 'bn2')
reluLayer('Name', 'relu2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool2')
flattenLayer('Name', 'flatten')
selfAttentionLayer(8, 64, 'Name', 'self_attention')
fullyConnectedLayer(10, 'Name', 'fc')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'output')]
% set training options
options = trainingOptions('sgdm', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs', 5, ...
'Shuffle', 'every-epoch', ...
'ValidationData', imdsValidation, ...
'ValidationFrequency', 30, ...
'Verbose', false, ...
'Plots', 'training-progress')
% training the network
net = trainNetwork(imdsTrain, layers, options);
0 comentarios
Respuestas (1)
Neha
el 21 de Nov. de 2023
Hi Mahmoud,
I understand that you want to use self-attention layer in image classification. The self-attention layer, also known as the multi-head self-attention layer, is commonly employed in Transformer models like BERT and vision transformers (ViT). Its primary function is to understand the relationships between positions within the input data. This input data is usually sequential, representing either temporal sequences or 1D spatial information. Therefore it is necessary to use the "flattenLayer" to ensure that the input data to the "selfAttentionLayer" is one directional.
Hope this helps!
0 comentarios
Ver también
Categorías
Más información sobre Image Data Workflows en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!