How does selfAttentionLayer work,implementing validation with brief code?

13 visualizaciones (últimos 30 días)
How does selfAttentionLayer work in detail every step of the way, can you simply reproduce its working process based on the paper formula? Thus verifying the selfAttentionLayer it's correctness and consisency.
official description:
A self-attention layer computes single-head or multihead self-attention of its input.
The layer:
  1. Computes the queries, keys, and values from the input
  2. Computes the scaled dot-product attention across heads using the queries, keys, and values
  3. Merges the results from the heads
  4. Performs a linear transformation on the merged result

Respuesta aceptada

cui,xingxing
cui,xingxing el 11 de En. de 2024
Editada: cui,xingxing el 27 de Abr. de 2024
Here I have provided myself a simple code workflow with only 2 dimensions, "CT", to illustrate how each step works.
Note that each variable followed by a comment has a dimension representation.
%% 验证selfAttentionLayer操作计算与自己手算一致性!
XTrain = dlarray(rand(10,20));% CT
numClasses = 4;
numHeads = 6;
queryDims = 48; % N1=48
layers = [inputLayer(size(XTrain),"CT");
selfAttentionLayer(numHeads,queryDims,NumValueChannels=12,OutputSize=15,Name="sa");
layerNormalizationLayer;
fullyConnectedLayer(numClasses);
softmaxLayer];
net = dlnetwork(layers);
% analyzeNetwork(net)
XTrain = dlarray(XTrain,"CT");
[act1,act2] = predict(net,XTrain,Outputs=["input","sa"]);
act1 = extractdata(act1);% CT
act2 = extractdata(act2);% CT
% layer params
layerSA = net.Layers(2);
QWeights = layerSA.QueryWeights; % N1*C
KWeights = layerSA.KeyWeights;% N1*C
VWeights = layerSA.ValueWeights;% N2*C
outputW = layerSA.OutputWeights;% N3*N2
Qbias = layerSA.QueryBias; % N1*1
Kbias = layerSA.KeyBias;% N1*1
Vbias = layerSA.ValueBias; % N2*1
outputB = layerSA.OutputBias;% N3*1
% step1
q = QWeights*act1+Qbias; % N1*T
k = KWeights*act1+Kbias;% N1*T
v = VWeights*act1+Vbias;% N2*T
% step2,multiple heads
numChannelsQPerHeads = size(q,1)/numHeads;% 1*1
numChannelsVPerHeads = size(v,1)/numHeads;% 1*1
attentionM = cell(1,numHeads);
for i = 1:numHeads
idxQRange = numChannelsQPerHeads*(i-1)+1:numChannelsQPerHeads*i;
idxVRange = numChannelsVPerHeads*(i-1)+1:numChannelsVPerHeads*i;
qi = q(idxQRange,:);% diQ*T
ki = k(idxQRange,:);% diQ*T
vi = v(idxVRange,:);% diV*T
% attention
dk = size(qi,1);% 1*1
attentionScores = mysoftmax(ki'*qi./sqrt(dk));% T*T, note matlab interal code use k'*q,not q'*k
attentionM{i} = vi*attentionScores; % diV*T
end
%step3,merge attentionM
attention = cat(1,attentionM{:}); % N2*T,N2 = diV*numHeads
%step4,output linear projection
act_ = outputW*attention+outputB;% N3*T
act2(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
act_(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
I have reproduced its working process in the simplest possible way,hope it help others.
function out = mysoftmax(X,dim)
arguments
X
dim = 1;
end
% X = X-max(X,[],dim); %防止X过大导致取exp的值为Inf
X = exp(X);
out = X./sum(X,dim);
end
-------------------------Off-topic interlude, 2024-------------------------------
I am currently looking for a job in the field of CV algorithm development, based in Shenzhen, Guangdong, China,or a remote support position. I would be very grateful if anyone is willing to offer me a job or make a recommendation. My preliminary resume can be found at: https://cuixing158.github.io/about/ . Thank you!
Email: cuixingxing150@gmail.com

Más respuestas (0)

Categorías

Más información sobre Parallel and Cloud en Help Center y File Exchange.

Productos


Versión

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by