Borrar filtros
Borrar filtros

Is it possible to quantize a projected NN?

18 visualizaciones (últimos 30 días)
Silvia
Silvia el 17 de Mayo de 2024
Comentada: Silvia el 23 de Mayo de 2024
I have trained a LSTM Model to filter some signals (approximately following this example: https://it.mathworks.com/help/signal/ug/denoise-eeg-signals-using-differentiable-signal-processing-layers.html). Then, since the aim is to implement the NN in an hardware, I reduced the number of learnables using the Projection method of compression. This gives me the possibility to reduce the learnables from 4.3k to 1.5k. Then, I improved the accuracy of the model fine-tuning the projected LSTM. Finally, I would like to quantize the model on 8 bits. When I use the quantize function in MATLAB I have the following error: "The class deep.internal.quantization.config.ConfigType has no Constant property or Static method named 'InputProjector'."
Is it possible to quantize a projected NN?
Fairly new in this area, be kind :)
  2 comentarios
Venu
Venu el 19 de Mayo de 2024
The error you encountered suggests that the direct application of MATLAB's built-in quantization function to a projected NN might not be straightforward due to compatibility issues with custom or non-standard layer types introduced during the projection process. Can you share more details regarding the projection method that you implemented?
Silvia
Silvia el 21 de Mayo de 2024
Thank you @Venu!
I share here the code about the Training part of the LSTM model, the LSTM model with projected layer, the fine-tuned projected and, finally, the compression part:
%% Training LSTM Neural Network to Denoise Signals
numFeatures = 1; % One single sequence as input; one single sequence as output
numHiddenUnits = 32; % The amount of information that the layer remembers between time steps (the hidden state)
% If the number of hidden units is to too large,
% then the layer can overfit to the training data
% Good practice: use power of 2
layers = [
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits, OutputMode="sequence")
%dropoutLayer(0.5) % to reduce overfitting (good practice: from 0.2 to
%0.5)
fullyConnectedLayer(numFeatures)
];
maxEpochs = 10;
miniBatchSize = 32;
options = trainingOptions("adam", ...
Metrics="rmse", ...
MaxEpochs=maxEpochs, ...
MiniBatchSize=miniBatchSize, ...
InitialLearnRate=0.001, ...
GradientThreshold=1, ...
Plots="training-progress", ...
Shuffle="every-epoch", ...
Verbose=false, ...
ValidationData=ds_Validate, ...
ValidationFrequency=floor(trainsetSize/miniBatchSize), ...
OutputNetwork="best-validation-loss");
%trainingFlag = "Train networks";
trainingFlag = "Download networks";
if trainingFlag == "Train networks"
rawNet9 = trainnet(ds_Train,layers,"mse",options);
save rawNet9.mat rawNet9
else
% Download pre-trained networks
load rawNet9.mat rawNet9
end
network = rawNet9;
analyzeNetwork(network)
totalLearnables = 4352 + 33;
%% LSTM NN Compression via Projection
% - specify same number of hidden units
% - specify an output projector equal to 25% of num. hidden units
% - specify input equal to 75% of input size
% - ensure that the output and input sizes are positive
% taking the max
outputProjectorSize = max(1, floor(0.25*numHiddenUnits));
inputProjectorSize = max(1, floor(0.75*numFeatures));
layersProjected = [ ...
sequenceInputLayer(numFeatures)
lstmProjectedLayer(numHiddenUnits,outputProjectorSize,inputProjectorSize)
fullyConnectedLayer(numFeatures)];
%trainingFlag = "Train networks";
trainingFlag = "Download networks";
if trainingFlag == "Train networks"
rawNet9Projected = trainnet(ds_Train,layersProjected,"mse",options);
save rawNet9Projected.mat rawNet9Projected
else
% Download pre-trained networks
load rawNet9Projected.mat rawNet9Projected
end
analyzeNetwork(rawNet9Projected)
totalLearnablesProjected = 1537 + 33;
figure
bar([totalLearnables totalLearnablesProjected])
xticklabels(["Unprojected","Projected"])
xlabel("Network")
ylabel("Number of Learnables")
title("Number of Learnables")
%% Fine-Tune Compressed Network
optionsFineTuning = trainingOptions("adam", ...
Metrics="rmse", ...
MaxEpochs=30, ...
MiniBatchSize=miniBatchSize, ...
InitialLearnRate=0.0005, ...
GradientThreshold=1, ...
Plots="training-progress", ...
Shuffle="every-epoch", ...
Verbose=false, ...
ValidationData=ds_Validate, ...
ValidationFrequency=floor(trainsetSize/miniBatchSize), ...
OutputNetwork="best-validation-loss");
%trainingFlag = "Train networks";
trainingFlag = "Download networks";
if trainingFlag == "Train networks"
rawNet9FineTuned = trainnet(ds_Train,projectedNetwork,"mse",optionsFineTuning);
save rawNet9FineTuned.mat rawNet9FineTuned
else
% Download pre-trained networks
load rawNet9FineTuned.mat rawNet9FineTuned
end
%% Quantize Projected Network
% Create a quantization object and specify the network that has to be used
% as target
quantObjNetwork = dlquantizer(rawNet9FineTuned, 'ExecutionEnvironment','MATLAB');
% Use the calibrate function to exercise the network with the calibration
% data and collect range statistics for the weights, biases, and
% activations at each layer
calResults = calibrate(quantObjNetwork, arrdsCal);
quantizedNetwork = quantize(quantObjNetwork);

Iniciar sesión para comentar.

Respuesta aceptada

Katja Mogalle
Katja Mogalle el 23 de Mayo de 2024
Unfortunately, it is currently not possible to quantize an LSTM or a projected LSTM layer. I'll mention your request to the development team.
At the moment, the only compression technique available for LSTM is projection, as you have already discovered. If you are trying to reduce the size of the network, you could also try a GRU layer instead of an LSTM layer. GRU layers have fewer parameters than LSTM layers and you can also apply projection.
Hope that helps.
  1 comentario
Silvia
Silvia el 23 de Mayo de 2024
Thank you very much @Katja Mogalle for the request and for the suggestion.
I am not familiar with GRU layers but I am very curious to try this approach!

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Quantization, Projection, and Pruning en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by