Main Content

Perform Instance Segmentation Using SOLOv2

This example shows how to segment object instances of randomly rotated machine parts in a bin using a deep learning SOLOv2 network.

Instance segmentation is a computer vision technique in which you detect and localize objects while simultaneously generating a segmentation map for each of the detected instances. For more information about instance segmentation with SOLOv2, see Get Started with SOLOv2 for Instance Segmentation.

This example first shows how to perform instance segmentation using a pretrained SOLOv2 network that can detect a single class. Then, you can optionally configure and train a SOLOv2 network using transfer learning, and evaluate prediction results.

Download Pretrained SOLOv2 Network

By default, this example downloads a pretrained version of the SOLOv2 instance segmentation network using the downloadTrainedNetwork helper function. The helper function is attached to this example as a supporting file. You can use the pretrained network to run the entire example without waiting for training to complete.

trainedSOLOv2_url = "https://ssd.mathworks.com/supportfiles/vision/data/trainedSOLOv2BinDataset.zip";
downloadTrainedNetwork(trainedSOLOv2_url,pwd);
Downloading pretrained network.
This can take several minutes to download...
Done.
load("trainedSOLOv2.mat");

Download Bin Picking Dataset

This example uses the bin picking data set. The data set contains 150 images of 3-D pipe connectors, generated with Simulink® software. The data consists of images of machine parts lying at random orientations inside a bin, viewed from different angles and under different lighting conditions. The data set contains instance mask information for every object in every image, and combines all types of machine parts into a single class.

Specify dataDir as the location of the data set. Download the data set using the downloadBinObjectData helper function. This function is attached to the example as a supporting file.

dataDir = fullfile(tempdir,"BinDataset");
dataset_url = "https://ssd.mathworks.com/supportfiles/vision/data/binDataset.zip";
downloadBinObjectData(dataset_url,dataDir);
Downloading Bin Object Dataset.
This can take several minutes to download...
Done.

Perform Instance Segmentation

Read a sample image from the data set.

sampleImage = imread("testBinDataImage.png");

Predict the mask, labels, and confidence scores for each object instance using the segmentObjects function.

[masks,labels,scores] = segmentObjects(net,sampleImage,Threshold=0.4);

Display the instance masks over the image using the insertObjectMask function. Specify a colormap using the lines function, so that each object instance appears in a different color. Use the getBoxFromMask helper function to generate bounding boxes corresponding to each segmented object instance and overlay them on the image with probability scores as labels.

maskColors = lines(numel(labels));
overlayedMasks = insertObjectMask(sampleImage,masks,MaskColor=maskColors);
imshow(overlayedMasks)
boxes = getBoxFromMask(masks);
showShape("rectangle",boxes,Label="Scores: "+num2str(scores),LabelOpacity=0.4);

Prepare Data for Training

Create a file datastore that reads the annotation data from MAT files. Use the matReaderBinData function, attached to the example as a supporting file, to parse the MAT files and return the corresponding training data as a 1-by-4 cell array containing image data, bounding boxes, object masks, and labels.

annsDir = fullfile(dataDir,"synthetic_parts_dataset","annotations");
ds = fileDatastore(annsDir,FileExtensions=".mat",ReadFcn=@(x)matReaderBinData(x,dataDir));

Partition Data

To improve the reproducibility of this example, set the global random state to the default state.

rng("default");

Split the data into training, validation, and test sets. Because the total number of images is relatively small, allocate a relatively large percentage (70%) of the data for training. Allocate 15% for validation and the rest for testing.

numImages = length(ds.Files);
numTrain = floor(0.7*numImages);
numVal = floor(0.15*numImages);

shuffledIndices = randperm(numImages);
trainDS = subset(ds,shuffledIndices(1:numTrain));
valDS   = subset(ds,shuffledIndices(numTrain+1:numTrain+numVal));
testDS  = subset(ds,shuffledIndices(numTrain+numVal+1:end));

Visualize Training Data

Preview the ground truth data for training by reading a sample image from the training subset of the file datastore.

gsSample = preview(trainDS);
gsImg = gsSample{1};
boxes = gsSample{2};
labels = gsSample{3};
masks  = gsSample{4};

Visualize the ground truth data by using the insertObjectMasks function to overlay the instance masks and corresponding bounding boxes and labels on the sample image.

overlayedMasks = insertObjectMask(gsImg,masks,Opacity=0.5);
imshow(overlayedMasks)
showShape("rectangle",boxes,Label=string(labels),Color="green");

Define SOLOv2 Network Architecture

Create the SOLOv2 instance segmentation model by using the solov2 object. Specify the name of the pretrained SOLOv2 instance segmentation network trained on COCO data set. Specify the class name, the estimated anchor boxes, and the network input size. Specify an input size to which all images must be resized using the optional InputSize name-value argument.

networkToTrain = solov2("resnet50-coco","Object",InputSize=[736 1280 3]);

Specify Training Options

Specify network training options using the trainingOptions (Deep Learning Toolbox) function. Train the instance segmentation network using the SGDM solver for five epochs. Specify the learning rate dropping factor of 0.99 every epoch. To ensure the convergence of gradients in the initial iterations, set the GradientThreshold name-value argument to 35. Specify the ValidationData name-value argument as the validation data, valDS.

options = trainingOptions("sgdm", ...
    InitialLearnRate=0.0005, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropPeriod=1, ...
    LearnRateDropFactor=0.99, ...
    Momentum=0.9, ...
    MaxEpochs=5, ...
    MiniBatchSize=4, ...
    ExecutionEnvironment="auto", ...
    VerboseFrequency=5, ...
    Plots="training-progress", ...
    ResetInputNormalization=false, ...
    ValidationData=valDS, ...
    ValidationFrequency=25, ...
    GradientThreshold=35, ...
    OutputNetwork="best-validation-loss");

Train SOLOv2 Network

To train the network, set the doTraining variable to true. Train the network by using the trainSOLOV2 function. To reuse the extracted features from the pretrained backbone network and optimize the detection heads for the data set, freeze the feature extraction subnetwork by specifying the FreezeSubNetwork name-value argument.

Train on one or more GPUs, if they are available. Using a GPU requires a Parallel Computing Toolbox™ license and a CUDA®-enabled NVIDIA® GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox). Training takes about 15 minutes on an NVIDIA Titan RTX™ with 24 GB of memory.

doTraining = false;
if doTraining       
    net = trainSOLOV2(trainDS,networkToTrain,options,FreezeSubNetwork="backbone");
    modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
    save(fullfile(tempdir,"trainedSOLOv2"+modelDateTime+".mat"), ...
        "net");
else
    load("trainedSOLOv2.mat");
end

Evaluate Trained SOLOv2 Network

Evaluate the trained SOLOv2 network by measuring the average precision. Precision quantifies the ability of the network to classify objects correctly.

Detect the instance masks for all test images.

resultsDS = segmentObjects(net,testDS,Threshold=0.1);
Running SoloV2 network
--------------------------
* Processed 23 images.

Calculate the average precision (AP) and mean average precision (mAP) metrics by using the evaluateInstanceSegmentation function. In this example, AP and mAP are identical because the objects are in only one class.

metrics = evaluateInstanceSegmentation(resultsDS,testDS,0.5);
display(metrics.DatasetMetrics)
  1×3 table

    NumObjects      mAP          AP    
    __________    _______    __________

       184        0.98784    {[0.9878]}

Display the metrics for every test image to identify which images are not performing as expected.

display(metrics.ImageMetrics)
  23×3 table

          NumObjects      mAP          AP    
          __________    _______    __________

    1         8               1    {[     1]}
    2         8               1    {[     1]}
    3         8               1    {[     1]}
    4         8               1    {[     1]}
    5         8               1    {[     1]}
    6         8               1    {[     1]}
    7         8               1    {[     1]}
    8         8               1    {[     1]}
    9         8               1    {[     1]}
    10        8               1    {[     1]}
    11        8               1    {[     1]}
    12        8               1    {[     1]}
    13        8               1    {[     1]}
    14        8               1    {[     1]}
    15        8               1    {[     1]}
    16        8         0.85938    {[0.8594]}
    17        8         0.85938    {[0.8594]}
    18        8               1    {[     1]}
    19        8               1    {[     1]}
    20        8               1    {[     1]}
    21        8               1    {[     1]}
    22        8               1    {[     1]}
    23        8               1    {[     1]}

A precision/recall (PR) curve highlights how precise an instance segmentation model is at varying levels of recall. The ideal precision is 1 at all recall levels. Extract the precision, recall, and average precision metrics from the evaluateInstanceSegmentation function output.

precision = metrics.ClassMetrics.Precision;
recall = metrics.ClassMetrics.Recall;
averagePrecision = metrics.ClassMetrics.AP;

Plot the PR curve for the test data.

figure
plot(recall{1},precision{1})
title(sprintf("Average Precision for Single Class Instance Segmentation: " + "%.2f",averagePrecision{1}))
xlabel("Recall")
ylabel("Precision")
grid on

Supporting Function

The getBoxFromMask function converts instance masks to bounding boxes.

function boxes = getBoxFromMask(masks)

for idx  = 1:size(masks,3)
    mask = masks(:,:,idx);
    [ptsR, ptsC] = find(mask);

    minR = min(ptsR);
    maxR = max(ptsR);
    minC = min(ptsC);
    maxC = max(ptsC);
    
    boxes(idx,:) = [minC minR maxC-minC maxR-minR];

end
end

See Also

| | | |

Related Topics