Main Content

lbfgsupdate

Update parameters using limited-memory BFGS (L-BFGS)

Since R2023a

    Description

    Update the network learnable parameters in a custom training loop using the limited-memory BFGS (L-BFGS) algorithm.

    The L-BFGS algorithm [1] is a quasi-Newton method that approximates the Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm. Use the L-BFGS algorithm for small networks and data sets that you can process in a single batch.

    Note

    This function applies the L-BFGS optimization algorithm to update network parameters in custom training loops. To train a neural network using the trainnet function using the L-BFGS solver, use the trainingOptions function and set the solver to "lbfgs".

    example

    [netUpdated,solverStateUpdated] = lbfgsupdate(net,lossFcn,solverState) updates the learnable parameters of the network net using the L-BFGS algorithm with the specified loss function and solver state. Use this syntax in a training loop to iteratively update a network defined as a dlnetwork object.

    [parametersUpdated,solverStateUpdated] = lbfgsupdate(parameters,lossFcn,solverState) updates the learnable parameters in parameters using the L-BFGS algorithm with the specified loss function and solver state. Use this syntax in a training loop to iteratively update the learnable parameters of a network defined as a function.

    ___ = lbfgsupdate(___,Name=Value) specifies additional options using one or more name-value arguments.

    Examples

    collapse all

    Read the transmission casing data from the CSV file "transmissionCasingData.csv".

    filename = "transmissionCasingData.csv";
    tbl = readtable(filename,TextType="String");

    Convert the labels for prediction to categorical using the convertvars function.

    labelName = "GearToothCondition";
    tbl = convertvars(tbl,labelName,"categorical");

    To train a network using categorical features, convert the categorical predictors to categorical using the convertvars function by specifying a string array containing the names of all the categorical input variables.

    categoricalPredictorNames = ["SensorCondition" "ShaftCondition"];
    tbl = convertvars(tbl,categoricalPredictorNames,"categorical");

    Loop over the categorical input variables. For each variable, convert the categorical values to one-hot encoded vectors using the onehotencode function.

    for i = 1:numel(categoricalPredictorNames)
        name = categoricalPredictorNames(i);
        tbl.(name) = onehotencode(tbl.(name),2);
    end

    View the first few rows of the table.

    head(tbl)
        SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition    GearToothCondition
        ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ______________    __________________
    
        -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13             0    1             1    0          No Tooth Fault  
        -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12             0    1             1    0          No Tooth Fault  
          1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13             0    1             0    1          No Tooth Fault  
          1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13             0    1             0    1          No Tooth Fault  
          1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39             0    1             0    1          No Tooth Fault  
          1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39             0    1             0    1          No Tooth Fault  
          1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39             0    1             0    1          No Tooth Fault  
          1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39             0    1             0    1          No Tooth Fault  
    

    Extract the training data.

    predictorNames = ["SigMean" "SigMedian" "SigRMS" "SigVar" "SigPeak" "SigPeak2Peak" ...
        "SigSkewness" "SigKurtosis" "SigCrestFactor" "SigMAD" "SigRangeCumSum" ...
        "SigCorrDimension" "SigApproxEntropy" "SigLyapExponent" "PeakFreq" ...
        "HighFreqPower" "EnvPower" "PeakSpecKurtosis" "SensorCondition" "ShaftCondition"];
    XTrain = table2array(tbl(:,predictorNames));
    numInputFeatures = size(XTrain,2);

    Extract the targets and convert them to one-hot encoded vectors.

    TTrain = tbl.(labelName);
    TTrain = onehotencode(TTrain,2);
    numClasses = size(TTrain,2);

    Convert the predictors and targets to dlarray objects with format "BC" (batch, channel).

    XTrain = dlarray(XTrain,"BC");
    TTrain = dlarray(TTrain,"BC");

    Define the network architecture.

    numHiddenUnits = 32;
    
    layers = [
        featureInputLayer(numInputFeatures)
        fullyConnectedLayer(16)
        layerNormalizationLayer
        reluLayer
        fullyConnectedLayer(numClasses)
        softmaxLayer];
    
    net = dlnetwork(layers);

    Define the modelLoss function, listed in the Model Loss Function section of the example. This function takes as input a neural network, input data, and targets. The function returns the loss and the gradients of the loss with respect to the network learnable parameters.

    The lbfgsupdate function requires a loss function with the syntax [loss,gradients] = f(net). Create a variable that parameterizes the evaluated modelLoss function to take a single input argument.

    lossFcn = @(net) dlfeval(@modelLoss,net,XTrain,TTrain);

    Initialize an L-BFGS solver state object with a maximum history size of 3 and an initial inverse Hessian approximation factor of 1.1.

    solverState = lbfgsState( ...
        HistorySize=3, ...
        InitialInverseHessianFactor=1.1);

    Train the network a maximum of 200 iterations. Stop training early when the norm of the gradients or steps are smaller than 0.00001. Print the training loss every 10 iterations.

    maxIterations = 200;
    gradientTolerance = 1e-5;
    stepTolerance = 1e-5;
    
    iteration = 0;
    
    while iteration < maxIterations
        iteration = iteration + 1;
        [net, solverState] = lbfgsupdate(net,lossFcn,solverState);
    
        if iteration==1 || mod(iteration,10)==0
            fprintf("Iteration %d: Loss: %d\n",iteration,solverState.Loss);
        end
    
        if solverState.GradientsNorm < gradientTolerance || ...
                solverState.StepNorm < stepTolerance || ...
                solverState.LineSearchStatus == "failed"
            break
        end
    end
    Iteration 1: Loss: 9.343236e-01
    Iteration 10: Loss: 4.721475e-01
    Iteration 20: Loss: 4.678575e-01
    Iteration 30: Loss: 4.666964e-01
    Iteration 40: Loss: 4.665921e-01
    Iteration 50: Loss: 4.663871e-01
    Iteration 60: Loss: 4.662519e-01
    Iteration 70: Loss: 4.660451e-01
    Iteration 80: Loss: 4.645303e-01
    Iteration 90: Loss: 4.591753e-01
    Iteration 100: Loss: 4.562556e-01
    Iteration 110: Loss: 4.531167e-01
    Iteration 120: Loss: 4.489444e-01
    Iteration 130: Loss: 4.392228e-01
    Iteration 140: Loss: 4.347853e-01
    Iteration 150: Loss: 4.341757e-01
    Iteration 160: Loss: 4.325102e-01
    Iteration 170: Loss: 4.321948e-01
    Iteration 180: Loss: 4.318990e-01
    Iteration 190: Loss: 4.313784e-01
    Iteration 200: Loss: 4.311314e-01
    

    Model Loss Function

    The modelLoss function takes as input a neural network net, input data X, and targets T. The function returns the loss and the gradients of the loss with respect to the network learnable parameters.

    function [loss, gradients] = modelLoss(net, X, T)
    
    Y = forward(net,X);
    loss = crossentropy(Y,T);
    gradients = dlgradient(loss,net.Learnables);
    
    end

    Input Arguments

    collapse all

    Neural network, specified as a dlnetwork object.

    The function updates the Learnables property of the dlnetwork object. net.Learnables is a table with three variables:

    • Layer — Layer name, specified as a string scalar.

    • Parameter — Parameter name, specified as a string scalar.

    • Value — Parameter value, specified as a cell array containing a dlarray object.

    Learnable parameters, specified as a dlarray object, a numeric array, a cell array, a structure, or a table.

    If you specify parameters as a table, it must contain these variables:

    • Layer — Layer name, specified as a string scalar.

    • Parameter — Parameter name, specified as a string scalar.

    • Value — Parameter value, specified as a cell array containing a dlarray object.

    You can specify parameters as a container of learnable parameters for your network using a cell array, structure, or table, or using nested cell arrays or structures. The learnable parameters inside the cell array, structure, or table must be dlarray objects or numeric values with the data type double or single.

    If parameters is a numeric array, then lossFcn must not use the dlgradient function.

    Loss function, specified as a function handle or an AcceleratedFunction object with the syntax [loss,gradients] = f(net), where loss and gradients correspond to the loss and gradients of the loss with respect to the learnable parameters, respectively.

    To parametrize a model loss function that has a call to the dlgradient function, specify the loss function as @(net) dlfeval(@modelLoss,net,arg1,...,argN), where modelLoss is a function with the syntax [loss,gradients] = modelLoss(net,arg1,...,argN) that returns the loss and gradients of the loss with respect to the learnable parameters in net given arguments arg1,...,argN.

    If parameters is a numeric array, then the loss function must not use the dlgradient or dlfeval functions.

    If the loss function has more than two outputs, also specify the NumLossFunctionOutputs argument.

    Data Types: function_handle

    Solver state, specified as an lbfgsState object or [].

    Name-Value Arguments

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: lbfgsupdate(net,lossFcn,solverState,LineSearchMethod="strong-wolfe") updates the learnable parameters in net and searches for a learning rate that satisfies the strong Wolfe conditions.

    Method to find suitable learning rate, specified as one of these values:

    • "weak-wolfe" — Search for a learning rate that satisfies the weak Wolfe conditions. This method maintains a positive definite approximation of the inverse Hessian matrix.

    • "strong-wolfe" — Search for a learning rate that satisfies the strong Wolfe conditions. This method maintains a positive definite approximation of the inverse Hessian matrix.

    • "backtracking" — Search for a learning rate that satisfies sufficient decrease conditions. This method does not maintain a positive definite approximation of the inverse Hessian matrix.

    Maximum number of line search iterations to determine the learning rate, specified as a positive integer.

    Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    Number of loss function outputs, specified as an integer greater than or equal to two. Set this option when lossFcn has more than two output arguments.

    Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    Output Arguments

    collapse all

    Updated network, returned as a dlnetwork object.

    The function updates the Learnables property of the dlnetwork object.

    Updated learnable parameters, returned as an object with the same type as parameters.

    Updated solver state, returned as an lbfgsState state object.

    Algorithms

    collapse all

    Limited-Memory BFGS

    The L-BFGS algorithm [1] is a quasi-Newton method that approximates the Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm. Use the L-BFGS algorithm for small networks and data sets that you can process in a single batch.

    The algorithm updates learnable parameters W at iteration k+1 using the update step given by

    Wk+1=WkηkBk1J(Wk),

    where Wk denotes the weights at iteration k, ηk is the learning rate at iteration k, Bk is an approximation of the Hessian matrix at iteration k, and J(Wk) denotes the gradients of the loss with respect to the learnable parameters at iteration k.

    The L-BFGS algorithm computes the matrix-vector product Bk1J(Wk) directly. The algorithm does not require computing the inverse of Bk.

    To save memory, the L-BFGS algorithm does not store and invert the dense Hessian matrix B. Instead, the algorithm uses the approximation Bkm1λkI, where m is the history size, the inverse Hessian factor λk is a scalar, and I is the identity matrix. The algorithm then stores the scalar inverse Hessian factor only. The algorithm updates the inverse Hessian factor at each step.

    To compute the matrix-vector product Bk1J(Wk) directly, the L-BFGS algorithm uses this recursive algorithm:

    1. Set r=Bkm1J(Wk), where m is the history size.

    2. For i=m,,1:

      1. Let β=1skiykiykir, where ski and yki are the step and gradient differences for iteration ki, respectively.

      2. Set r=r+ski(akiβ), where a is derived from s, y, and the gradients of the loss with respect to the loss function. For more information, see [1].

    3. Return Bk1J(Wk)=r.

    References

    [1] Liu, Dong C., and Jorge Nocedal. "On the limited memory BFGS method for large scale optimization." Mathematical programming 45, no. 1 (August 1989): 503-528. https://doi.org/10.1007/BF01589116.

    Extended Capabilities

    Version History

    Introduced in R2023a