How to use trainlm with L2 regulation

15 visualizaciones (últimos 30 días)
Hongyun Wang
Hongyun Wang el 31 de Ag. de 2023
Comentada: Shivansh el 11 de Sept. de 2023
As a simple test problem, I am training a neural network with 1 hidden layer for function fitting. The training method "trainlm" works well when I set
net.performParam.regularization=0.
To prevent over-fitting (and other purposes), I like to introduce L2 regulation. However, when I set
net.performParam.regularization=1e-6 (or any other positive number),
the training stopped at iteration 3 with "Maximum Mu reached".
Can we use trainlm with L2 regulation at all?
  3 comentarios
Hongyun Wang
Hongyun Wang el 10 de Sept. de 2023
Hi Shivansh,
Thank you for your reply! I tried the example below. Again it was terminated with 'Maximum MU reached'. If I comment out 'return' in the code below, the continuation of the code first trains it without the regulation and then trains it with the regulation. We see that the performance goal of 1e-3 is well achieved in the presence of the regulation. My question is how to use trainlm with the regulation to get to the specified performance goal. The regulation is supposed to help make the optimization process more well defined and more robust (the Hessian more positive definite).
Hongyun
% Generate training data
numSamples = 10000; % Number of training samples
X = -5+10*rand(1,numSamples); % input data
Y = sin(X); % output targets
% Create a feedforward neural network
hiddenSize = 256; % Number of hidden units
net = feedforwardnet(hiddenSize);
% Set the regularization parameter
net.trainFcn='trainlm';
net.divideMode='none'; % Use all data for training
% Set up the training parameters
net.trainParam.epochs = 1000; % Maximum number of epochs
net.trainParam.goal=1e-3;
net.trainParam.min_grad=1e-6;
net.trainParam.showCommandLine = true; % Display training progress in command window
net.trainParam.showWindow = false; % Do not show training GUI
% Train the network with the regulation
net.performParam.regularization = 1e-6;
net = train(net, X, Y);
% Evaluate the trained network on training data
Y_pred = net(X);
mse = mean((Y_pred - Y).^2);
% Display the mean squared error
disp(['Mean Squared Error: ', num2str(mse)]);
%
return
%%
% First train the network without the regulation
net.performParam.regularization = 0;
net = train(net, X, Y);
% Then demonstrate that the performance goal is well achieved in the presence of regulation.
net.performParam.regularization = 1e-6;
net = train(net, X, Y);
% Evaluate the trained network on training data
Y_pred = net(X);
mse = mean((Y_pred - Y).^2);
% Display the mean squared error
disp(['Mean Squared Error: ', num2str(mse)]);
%
Shivansh
Shivansh el 11 de Sept. de 2023
Hi Hongyun,
The above code can be executed with regularization but the parameters should be in sync with each other. Some possible workarounds can be:
  • Decreasing the complexity of the model. (It works when hidden size = 32).
  • Decreasing the strength of regularization. (Working for net.performParam.regularization = 1e-7;)
The above actions will execute the code but may not lead to optimal results.
You are able to get the results when training first without regularization and then with regularization because the first training sets the initial weights closer to optimal setting and the second training makes the solution better. Another way to resolve the above problem with similar parameters can be to execute the regularization training first. When you train the network with regularization first, the algorithm reaches the maximum mu value and terminates prematurely, as you mentioned. However, when you subsequently train the network without regularization, it starts from the weights obtained from the previous training and continues the optimization process. Since the network is already initialized with weights that are close to the optimal solution, the training without regularization is able to further improve the performance and achieve the desired goal.

Iniciar sesión para comentar.

Respuestas (1)

Ashu
Ashu el 6 de Sept. de 2023
Hey Wang,
I understand that you training a network with "trainlm" and the training stops with "Maximum Mu reached". The problem with regularization is that it is difficult to determine the optimum value for the performance ratio parameter. If you make this parameter too large, you might get overfitting. If the ratio is too small, the network does not adequately fit the training data.
The following suggestions might help you in training the network better.
1. To resolve this issue you can experiment with the training parameters of "trainlm", like increasing the value of "net.trainParam.mu_max", "net.trainParam.mu_dec".
For the list of parameters, please refer the following documentation of "trainlm"
2. Use Automated Regularisation (trainbr) : The weights and biases of the network are assumed to be random variables with specified distributions. The regularization parameters are related to the unknown variances associated with these distributions. You can then estimate these parameters using statistical techniques.
Please refer to the following page to learn more about this.
I hope this was helpful.
  2 comentarios
Hongyun Wang
Hongyun Wang el 8 de Sept. de 2023
I reduced net.trainParam.mu_dec from 0.1 to 0.01 and reduced net.trainParam.mu_inc from 10 to 2. It did work to some extent. What puzzles me is that the L2-regulation term should make the optimization problem more well-posed and more robust. It may prevent the NN to fit the data properly. But It should not make the iteration diverge.
Hongyun Wang
Hongyun Wang el 8 de Sept. de 2023
Is it possible to follow/monitor the SSE part and the SSW part of the performance during the training? If not, how can we calculate the SSE part and the SSW part after the training from the trained net? Is there a function call for this task? Or do we need to code it using the coefficients in the trained net?

Iniciar sesión para comentar.

Categorías

Más información sobre Sequence and Numeric Feature Data Workflows en Help Center y File Exchange.

Productos


Versión

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by