Main Content

Make Predictions Using dlnetwork Object

This example shows how to make predictions using a dlnetwork object by looping over mini-batches.

For large data sets, or when predicting on hardware with limited memory, make predictions by looping over mini-batches of the data using the minibatchpredict function.

Load dlnetwork Object

Load a trained dlnetwork object and the corresponding class names. The neural network has one input and two outputs. It takes images of handwritten digits as input, and predicts the digit label and angle of rotation.

load dlnetDigits

Load Data for Prediction

Load the digits test data for prediction.

load DigitsDataTest

View the class names.

classNames
classNames = 10x1 cell
    {'0'}
    {'1'}
    {'2'}
    {'3'}
    {'4'}
    {'5'}
    {'6'}
    {'7'}
    {'8'}
    {'9'}

View some of the images and the corresponding labels and angles of rotation.

numObservations = size(XTest,4);
numPlots = 9;
idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = labelsTest(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i)))
end

Make Predictions

Make predictions using the minibatchpredict function and convert the classification scores to labels using the scores2label function. By default, the minibatchpredict function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU. To specify the execution environment, use the ExecutionEnvironment option.

[scoresTest,Y2Test] = minibatchpredict(net,XTest);
Y1Test = scores2label(scoresTest,classNames);

Visualize some of the predictions.

idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = Y1Test(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i)))
end

See Also

| | | |

Related Topics