How to freeze specific neuron weights in a feedforward network during training process?

6 visualizaciones (últimos 30 días)
I am trying to train a simple feedforward network in which I need to freeze the weights and biases of certain neurons in a particular layer during training so that those weights do not change after each epoch. I am aware that it is possible to freeze an entire layer, but I am not sure how to proceed with weight specific freezing of few neurons.

Respuesta aceptada

Matt J
Matt J el 5 de Mzo. de 2025
Editada: Matt J el 11 de Mzo. de 2025
I don't think you can do it without the Deep Learning Toolbox. If you have the Deep Learning Toolbox, you can define a custom layer with mask parameters to indicate which weights and biases are to be frozen,
classdef MaskedFullyConnectedLayer < nnet.layer.Layer
%Defines a fully connected layer with masks to indicate which weights and
%biases are to be learned and which are to be "frozen". The gradient calculation
%with respect to frozen variables will always be masked to zero.
properties (Learnable)
Weights
Bias
end
properties
MaskWeights
MaskBias
InputSize
OutputSize
end
methods
function layer = MaskedFullyConnectedLayer (inputSize, outputSize, maskWeights, maskBias, name)
% Constructor for masked fully connected layer
layer.Name = name;
layer.Description = "Masked Fully Connected Layer with Frozen Weights/Bias";
layer.InputSize = inputSize;
layer.OutputSize = outputSize;
% Initialize weights and biases
layer.Weights = randn(outputSize, inputSize) * 0.01;
layer.Bias = zeros(outputSize, 1);
% Store masks
layer.MaskWeights = ~maskWeights;
layer.MaskBias = ~maskBias;
end
function Y = predict(layer, X)
Y = layer.Weights*X + layer.Bias;
end
function [dLdX, dLdW, dLdB] = backward(layer,X,~,dLdY,~)
xsiz=size(X);
X=permute(X,[1,3,2]);
dLdY=permute(dLdY,[1,3,2]);
% Compute gradients
dLdW = batchmtimes(dLdY,'none',X,'transpose'); % Gradient w.r.t. Weights
dLdB = batchmtimes(dLdY,'n',1,'n'); % Gradient w.r.t. Bias
% Apply masks to prevent updates to frozen parameters
dLdW = dLdW .* layer.MaskWeights;
dLdB = dLdB .* layer.MaskBias;
% Compute gradient w.r.t. input
dLdX = reshape( pagemtimes(layer.Weights,'transpose', dLdY,'none') ,xsiz);
end
end
end
function out = batchmtimes(X,transpX, Y,transpY)
%Assumes X Y already permuted with permute(__,[1,3,2])
out=sum(pagemtimes(X,transpX, Y,transpY),3);
end
An example of usage would be:
% Define input and output sizes
inputSize = 10;
outputSize = 5;
% Define frozen variables with logical masks (1 = frozen, 0 = trainable)
frozenWeights = false(outputSize,inputSize);
frozenWeights(1:3, :) = 1; % Freeze first 3 input rows
frozenBias = false(outputSize,1);
frozenkBias(1:2) = 1; % Freeze first 2 bias elements
% Create the custom layer
maskedLayer = MaskedFullyConnectedLayer(inputSize, outputSize, frozenWeights, frozenBias, "MaskedFC");
maskedLayer.Weights(frozenWeights)=___ ; %assign desired values to the frozen weights
maskedLayer.Bias(frozenBias)=___ ; %assign desired values to the frozen biases
% Add to a network (example with dummy layers)
layers = [
featureInputLayer(inputSize)
maskedLayer
reluLayer
fullyConnectedLayer(2)
softmaxLayer
classificationLayer
];
analyzeNetwork(layers);
% Create dummy data for training
XTrain = randn(100, inputSize);
YTrain = categorical(randi([1, 2], 100, 1));
% Train network
options = trainingOptions('adam', 'MaxEpochs', 5, 'Verbose', true);
net = trainNetwork(XTrain, YTrain, layers, options);
  5 comentarios
Matt J
Matt J el 11 de Mzo. de 2025
I edited my post with a corrected version. Please try again.

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by