Main Content

Create Custom Deep Learning Training Plot

Since R2023b

This example shows how to create a custom training plot that updates at each iteration during training of deep learning neural networks using trainnet.

You can specify neural network training options using trainingOptions. You can create training plots using the "Plots" and "Metrics" name-value pair arguments. To create a custom training plot and further customize training beyond the options available in trainingOptions, specify output functions by using the "OutputFcn" name-value pair argument of trainingOptions. trainnet calls these functions once before the start of training, after each training iteration, and once after training has finished.

Each time the output functions are called, trainnet passes a structure containing information such as the current iteration number, loss, and accuracy.

The network trained in this example classifies the gear tooth condition of a transmission system into two categories, "Tooth Fault" and "No Tooth Fault", based on a mixture of numeric sensor readings, statistics, and categorical labels. For more information, see Train Neural Network with Tabular Data.

The custom output function defined in this example plots the natural logarithms of gradient norm, step norm, training loss, and validation loss during training and stops training early once the training loss is lower than the desired loss threshhold.

plot.png

Load and Preprocess Training Data

Read the transmission casing data from the CSV file "transmissionCasingData.csv".

filename = "transmissionCasingData.csv";
tbl = readtable(filename,TextType="String");

Convert the labels for prediction, and the categorical predictors to categorical using the convertvars function. In this data set, there are two categorical features, "SensorCondition" and "ShaftCondition".

labelName = "GearToothCondition";
categoricalPredictorNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,[labelName categoricalPredictorNames],"categorical");

To train a network using categorical features, you must convert the categorical features to numeric. You can do this using the onehotencode function.

for i = 1:numel(categoricalPredictorNames)
    name = categoricalPredictorNames(i);
    tbl.(name) = onehotencode(tbl.(name),2);
end

Set aside data for testing. Partition the data into a training set containing 80% of the data, a validation set containing 10% of the data, and a test set containing the remaining 10% of the data. To partition the data, use the trainingPartitions function, attached to this example as a supporting file. To access this file, open the example as a live script.

numObservations = size(tbl,1);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.80 0.1 0.1]);

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

Convert the data to a format that the trainnet function supports. Convert the predictors and targets to numeric and categorical arrays, respectively, using the table2array function.

predictorNames = ["SigMean" "SigMedian" "SigRMS" "SigVar" "SigPeak" "SigPeak2Peak" ...
    "SigSkewness" "SigKurtosis" "SigCrestFactor" "SigMAD" "SigRangeCumSum" ...
    "SigCorrDimension" "SigApproxEntropy" "SigLyapExponent" "PeakFreq" ...
    "HighFreqPower" "EnvPower" "PeakSpecKurtosis" "SensorCondition" "ShaftCondition"];

XTrain = table2array(tblTrain(:,predictorNames));
TTrain = tblTrain.(labelName);

XValidation = table2array(tblValidation(:,predictorNames));
TValidation = tblValidation.(labelName);

XTest = table2array(tblTest(:,predictorNames));
TTest = tblTest.(labelName);

Network Architecture

Define the neural network architecture.

  • For feature input, specify a feature input layer with the number of features. Normalize the input using Z-score normalization.

  • Specify a fully connected layer with a size of 16, followed by a layer normalization and ReLU layer.

  • For classification output, specify a fully connected layer with a size that matches the number of classes, followed by a softmax layer.

numFeatures = size(XTrain,2);
hiddenSize = 16;
classNames = categories(tbl{:,labelName});
numClasses = numel(classNames);

layers = [
    featureInputLayer(numFeatures,Normalization="zscore")
    fullyConnectedLayer(hiddenSize)
    layerNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

Initialize Custom Plot

The next step in training your neural network is to define trainingOptions. This includes your custom output function. To create an animated plot which is updated during training, set up a figure with animatedline objects first and then pass the line handles to your output function.

Create a 3-by-1 tiled chart layout. Define animatedline objects to plot gradientNorm in the top tile and stepNorm in the middle tile. Define animatedline objects to plot trainingLoss and validationLoss in the bottom tile. Save the animatedline handles to a struct called lines.

tiledlayout(3,1);

C = colororder;

nexttile
lines.gradientNormLine = animatedline(Color=C(1,:));
ylabel("log(gradientNorm)")
ylim padded

nexttile
lines.stepNormLine = animatedline(Color=C(1,:));
ylabel("log(stepNorm)")
ylim padded

nexttile
lines.trainingLossLine = animatedline(Color=C(1,:));
lines.validationLossLine = animatedline(Color=C(2,:));
xlabel("Iterations")
ylabel("log(loss)")
ylim padded

Define Training Options

Use the function updatePlotAndStopTraining defined at the bottom of this page to update the animatedline objects and to stop training early when the training loss is smaller than a desired loss threshold. Use the "OutputFcn" name-value pair argument of trainingOptions to pass this function to trainnet.

Specify the training options:

  • Train using the L-BFGS solver. This solver suits tasks with small networks and when the data fits in memory.

  • Train using the CPU. Because the network and data are small, the CPU is better suited.

  • Validate the network every 5 iterations using the validation data.

  • Suppress the verbose output.

  • Include the custom output function updatePlotAndStopTraining.

Define the loss threshold.

lossThreshold = 0.4;
options = trainingOptions("lbfgs", ...
    ExecutionEnvironment="cpu", ...
    ValidationData={XValidation,TValidation}, ...
    ValidationFrequency=5, ...
    Verbose=false, ...
    OutputFcn=@(info)updatePlotAndStopTraining(info,lines,lossThreshold));

Add the step tolerance to the top plot. Add the gradient tolerance to the middle plot. Add the loss threshold to the bottom plot

nexttile(1)
yline(log(options.GradientTolerance),"--","log(gradientTolerance)")

nexttile(2)
yline(log(options.StepTolerance),"--","log(stepTolerance)")

nexttile(3)
yline(log(lossThreshold),"--","log(lossThreshold)")
legend(["log(trainingLoss)","log(validationLoss)",""],"Location","eastoutside")

Train Neural Network

Train the network. To display the reason why training stops, use two output arguments. If training stops because the training loss is smaller than the loss threshold, then the field StopReason of the info output argument is "Stopped by OutputFcn".

[net,info] = trainnet(XTrain,TTrain,layers,"crossentropy",options);

disp(info.StopReason)
Stopped by OutputFcn

Test Network

Predict the labels of the test data using the trained network. Predict the classification scores using the trained network then convert the predictions to labels using the onehotdecode function.

scoresTest = predict(net,XTest);
YTest = onehotdecode(scoresTest,classNames,2);
accuracy = mean(YTest==TTest)
accuracy = 0.8636

Custom Output Function

Define the output function updatePlotAndStopTraining(info,lines,lossThreshold), which plots the logarithm of gradient norm, step norm, training loss, and validation loss. It also stops training when the training loss is smaller than the loss threshold. Training stops when the output function returns true.

function stop = updatePlotAndStopTraining(info,lines,lossThreshold)

iteration = info.Iteration;
gradientNorm = info.GradientNorm;
stepNorm = info.StepNorm;
trainingLoss = info.TrainingLoss;
validationLoss = info.ValidationLoss;

if ~isempty(trainingLoss)
    addpoints(lines.gradientNormLine,iteration,log(gradientNorm))
    addpoints(lines.stepNormLine,iteration,log(stepNorm))
    addpoints(lines.trainingLossLine,iteration,log(trainingLoss))
end

if ~isempty(validationLoss)
    addpoints(lines.validationLossLine,iteration,log(validationLoss))
end

stop = trainingLoss < lossThreshold;
end

See Also

| | | |

Related Topics