Main Content

Define Nested Deep Learning Layer

If Deep Learning Toolbox™ does not provide the layer you require for your classification or regression problem, then you can define your own custom layer using this example as a guide. For a list of built-in layers, see List of Deep Learning Layers.

To create a custom layer that itself defines a layer graph, you can specify a dlnetwork object as a learnable parameter. This method is known as network composition. You can use network composition to:

  • Create a single custom layer that represents a block of learnable layers, for example, a residual block.

  • Create a network with control flow, for example, a network with a section that can dynamically change depending on the input data.

  • Create a network with loops, for example, a network with sections that feed the output back into itself.

For more information, see Deep Learning Network Composition.

This example shows how to create a custom layer representing a residual block. The custom layer residualBlockLayer contains a learnable block of layers consisting of convolution, group normalization, ReLU, and addition layers, and also includes a skip connection and an optional convolution layer and group normalization layer in the skip connection. The layer has a single input that is used twice, as the input to each branch. This diagram highlights the residual block structure.

Structure of residual block. It contains a convolution, a group normalization, a ReLU, a second convolution, a second group normalization, an addition, and a ReLU layer connected in series. There is a skip connection from the block input to the addition layer. There is also another convolution and group normalization layer connected in series that appears on the skip connection. The layers appearing on the skip connection are highlighted as optional.

To define a custom deep learning layer, you can use the template provided in this example, which takes you through the following steps:

  1. Name the layer – Give the layer a name so that you can use it in MATLAB®.

  2. Declare the layer properties – Specify the properties of the layer and which parameters are learned during training.

  3. Create a constructor function (optional) – Specify how to construct the layer and initialize its properties. If you do not specify a constructor function, then at creation, the software initializes the Name, Description, and Type properties with [] and sets the number of layer inputs and outputs to 1.

  4. Create forward functions – Specify how data passes forward through the layer (forward propagation) at prediction time and at training time.

  5. Create a backward function (optional) – Specify the derivatives of the loss with respect to the input data and the learnable parameters (backward propagation). If you do not specify a backward function, then the forward functions must support dlarray objects.

Layer with Learnable Parameters Template

Copy the layer with learnable parameters template into a new file in MATLAB. This template outlines the structure of a layer with learnable parameters and includes the functions that define the layer behavior.

classdef myLayer < nnet.layer.Layer % & nnet.layer.Formattable (Optional) 

    properties
        % (Optional) Layer properties.

        % Layer properties go here.
    end

    properties (Learnable)
        % (Optional) Layer learnable parameters.

        % Layer learnable parameters go here.
    end
    
    methods
        function layer = myLayer()
            % (Optional) Create a myLayer.
            % This function must have the same name as the class.

            % Layer constructor function goes here.
        end
        
        function [Z1, …, Zm] = predict(layer, X1, …, Xn)
            % Forward input data through the layer at prediction time and
            % output the result.
            %
            % Inputs:
            %         layer       - Layer to forward propagate through
            %         X1, ..., Xn - Input data
            % Outputs:
            %         Z1, ..., Zm - Outputs of layer forward function
            
            % Layer forward function for prediction goes here.
        end

        function [Z1, …, Zm, memory] = forward(layer, X1, …, Xn)
            % (Optional) Forward input data through the layer at training
            % time and output the result and a memory value.
            %
            % Inputs:
            %         layer       - Layer to forward propagate through
            %         X1, ..., Xn - Input data
            % Outputs:
            %         Z1, ..., Zm - Outputs of layer forward function
            %         memory      - Memory value for custom backward propagation

            % Layer forward function for training goes here.
        end

        function [dLdX1, …, dLdXn, dLdW1, …, dLdWk] = ...
                backward(layer, X1, …, Xn, Z1, …, Zm, dLdZ1, …, dLdZm, memory)
            % (Optional) Backward propagate the derivative of the loss  
            % function through the layer.
            %
            % Inputs:
            %         layer             - Layer to backward propagate through
            %         X1, ..., Xn       - Input data
            %         Z1, ..., Zm       - Outputs of layer forward function            
            %         dLdZ1, ..., dLdZm - Gradients propagated from the next layers
            %         memory            - Memory value from forward function
            % Outputs:
            %         dLdX1, ..., dLdXn - Derivatives of the loss with respect to the
            %                             inputs
            %         dLdW1, ..., dLdWk - Derivatives of the loss with respect to each
            %                             learnable parameter
            
            % Layer backward function goes here.
        end
    end
end

Name Layer

First, give the layer a name. In the first line of the class file, replace the existing name myLayer with residualBlockLayer.

classdef residualBlockLayer < nnet.layer.Layer
    ...
end

Next, rename the myLayer constructor function (the first function in the methods section) so that it has the same name as the layer.

    methods
        function layer = residualBlockLayer()           
            ...
        end

        ...
     end

Save Layer

Save the layer class file in a new file named residualBlockLayer.m. The file name must match the layer name. To use the layer, you must save the file in the current folder or in a folder on the MATLAB path.

Declare Properties and Learnable Parameters

Declare the layer properties in the properties section and declare learnable parameters by listing them in the properties (Learnable) section.

By default, custom intermediate layers have these properties.

PropertyDescription
NameLayer name, specified as a character vector or a string scalar. To include a layer in a layer graph, you must specify a nonempty, unique layer name. If you train a series network with the layer and Name is set to '', then the software automatically assigns a name to the layer at training time.
Description

One-line description of the layer, specified as a character vector or a string scalar. This description appears when the layer is displayed in a Layer array. If you do not specify a layer description, then the software displays the layer class name.

TypeType of the layer, specified as a character vector or a string scalar. The value of Type appears when the layer is displayed in a Layer array. If you do not specify a layer type, then the software displays the layer class name.
NumInputsNumber of inputs of the layer, specified as a positive integer. If you do not specify this value, then the software automatically sets NumInputs to the number of names in InputNames. The default value is 1.
InputNamesInput names of the layer, specified as a cell array of character vectors. If you do not specify this value and NumInputs is greater than 1, then the software automatically sets InputNames to {'in1',...,'inN'}, where N is equal to NumInputs. The default value is {'in'}.
NumOutputsNumber of outputs of the layer, specified as a positive integer. If you do not specify this value, then the software automatically sets NumOutputs to the number of names in OutputNames. The default value is 1.
OutputNamesOutput names of the layer, specified as a cell array of character vectors. If you do not specify this value and NumOutputs is greater than 1, then the software automatically sets OutputNames to {'out1',...,'outM'}, where M is equal to NumOutputs. The default value is {'out'}.

If the layer has no other properties, then you can omit the properties section.

Tip

If you are creating a layer with multiple inputs, then you must set either the NumInputs or InputNames properties in the layer constructor. If you are creating a layer with multiple outputs, then you must set either the NumOutputs or OutputNames properties in the layer constructor. For an example, see Define Custom Deep Learning Layer with Multiple Inputs.

The residual block layer does not require any additional properties, so you can remove the properties section.

This custom layer has only one learnable parameter, the residual block itself specified as a dlnetwork object. Declare this learnable parameter in the properties (Learnable) section and call the parameter Network.

    properties (Learnable)
        % Layer learnable parameters
    
        % Residual block.
        Network
    end

Create Constructor Function

Create the function that constructs the layer and initializes the layer properties. Specify any variables required to create the layer as inputs to the constructor function.

The residual block layer constructor function requires four input arguments:

  • Number of convolutional filters

  • Stride (optional, with default stride 1)

  • Flag to include convolution in skip connection (optional, with default flag false)

  • Layer name (optional, with default name '')

In the constructor function residualBlockLayer, specify the required input argument numFilters and the optional arguments as name-value pairs with the name NameValueArgs. Add a comment to the top of the function that explains the syntax of the function.

        function layer = residualBlockLayer(numFilters,NameValueArgs)
            % layer = residualBlockLayer(numFilters) creates a residual
            % block layer with the specified number of filters.
            %
            % layer = residualBlockLayer(numFilters,Name,Value) specifies
            % additional options using one or more name-value pair
            % arguments:
            % 
            %     'Stride'                 - Stride of convolution operation 
            %                                (default 1)
            %
            %     'IncludeSkipConvolution' - Flag to include convolution in
            %                                skip connection
            %                                (default false)
            %
            %     'Name'                   - Layer name
            %                                (default '')

            ...
        end

Parse Input Arguments

Parse the input arguments using an arguments block. List the arguments in the same order as the function syntax and specify the default values. Then, extract the values from the NameValueArgs input.

            % Parse input arguments.
            arguments
                numFilters                
                NameValueArgs.Stride = 1
                NameValueArgs.IncludeSkipConvolution = false
                NameValueArgs.Name = ''
            end
            
            stride = NameValueArgs.Stride;
            includeSkipConvolution = NameValueArgs.IncludeSkipConvolution;
            name = NameValueArgs.Name;

Initialize Layer Properties

In the constructor function, initialize the layer properties, including the dlnetwork object. Replace the comment % Layer constructor function goes here with code that initializes the layer properties.

Set the Name property to the input argument name.

            % Set layer name.
            layer.Name = name;

Give the layer a one-line description by setting the Description property of the layer. Set the description to describe the layer and any optional properties.

            % Set layer description.
            description = "Residual block with " + numFilters + " filters, stride " + stride;
            if includeSkipConvolution
                description = description + ", and skip convolution";
            end
            layer.Description = description;

Specify the type of the layer by setting the Type property. The value of Type appears when the layer is displayed in a Layer array.

            % Set layer type.
            layer.Type = "Residual Block";

Define the residual block. You can create the residual block layers as an uninitialized nested dlnetwork object without an input layer and allow the software to automatically initialize the learnable and state parameters at training time. For more information, see Automatically Initialize Learnable dlnetwork Objects for Training.

First, create a layer array containing the main layers of the block and convert it to a layer graph.

            % Define nested layer graph.
            layers = [
                convolution2dLayer(3,numFilters,'Padding','same','Stride',stride,'Name','conv1')
                groupNormalizationLayer('all-channels','Name','gn1')
                reluLayer('Name','relu1')
                convolution2dLayer(3,numFilters,'Padding','same','Name','conv2')
                groupNormalizationLayer('channel-wise','Name','gn2')
                
                additionLayer(2,'Name','add')
                reluLayer('Name','relu2')];
            
            lgraph = layerGraph(layers);

Next, add the skip connection. If the includeSkipConvolution flag is true, then also include a convolution layer and group normalization layer in the skip connection.

            % Add skip connection.
            if includeSkipConvolution
                layers = [
                    convolution2dLayer(1,numFilters,'Stride',stride,'Name','convSkip')
                    groupNormalizationLayer('all-channels','Name','gnSkip')];
                
                lgraph = addLayers(lgraph,layers);
                lgraph = connectLayers(lgraph,'gnSkip','add/in2'); 
            end

Since there is no input layer, this network has two unconnected inputs. If the network does not have the skip connection, the input to the 'conv2' layer and one of the inputs to the 'add' layer are unconnected. If the network does have the skip connection, then the unconnected inputs are the inputs to the 'conv1' and 'convSkip' layers.

Finally, convert the layer graph to a dlnetwork object and set the layer Network property. Create an uninitialized dlnetwork object. The weights and learnable parameters in the dlnetwork object are automatically initialized when the complete network is assembled for training.

            % Convert to dlnetwork.
            dlnet = dlnetwork(lgraph,'Initialize',false);
            
            % Set Network property.
            layer.Network = dlnet;

View the completed constructor function.

        function layer = residualBlockLayer(numFilters,NameValueArgs)
            % layer = residualBlockLayer(numFilters) creates a residual
            % block layer with the specified number of filters.
            %
            % layer = residualBlockLayer(numFilters,Name,Value) specifies
            % additional options using one or more name-value pair
            % arguments:
            % 
            %     'Stride'                 - Stride of convolution operation 
            %                                (default 1)
            %
            %     'IncludeSkipConvolution' - Flag to include convolution in
            %                                skip connection
            %                                (default false)
            %
            %     'Name'                   - Layer name
            %                                (default '')
    
            % Parse input arguments.
            arguments
                numFilters
                NameValueArgs.Stride = 1
                NameValueArgs.IncludeSkipConvolution = false
                NameValueArgs.Name = ''
            end
    
            stride = NameValueArgs.Stride;
            includeSkipConvolution = NameValueArgs.IncludeSkipConvolution;
            name = NameValueArgs.Name;
    
            % Set layer name.
            layer.Name = name;
    
            % Set layer description.
            description = "Residual block with " + numFilters + " filters, stride " + stride;
            if includeSkipConvolution
                description = description + ", and skip convolution";
            end
            layer.Description = description;
            
            % Set layer type.
            layer.Type = "Residual Block";
    
            % Define nested layer graph.
            layers = [
                convolution2dLayer(3,numFilters,'Padding','same','Stride',stride,'Name','conv1')
                groupNormalizationLayer('all-channels','Name','gn1')
                reluLayer('Name','relu1')
                convolution2dLayer(3,numFilters,'Padding','same','Name','conv2')
                groupNormalizationLayer('channel-wise','Name','gn2')
    
                additionLayer(2,'Name','add')
                reluLayer('Name','relu2')];
    
            lgraph = layerGraph(layers);
    
            % Add skip connection.
            if includeSkipConvolution
                layers = [
                    convolution2dLayer(1,numFilters,'Stride',stride,'Name','convSkip')
                    groupNormalizationLayer('all-channels','Name','gnSkip')];
     
                lgraph = addLayers(lgraph,layers);
                lgraph = connectLayers(lgraph,'gnSkip','add/in2');  
            end 
    
            % Convert to dlnetwork.
            dlnet = dlnetwork(lgraph,'Initialize',false);
    
            % Set Network property.
            layer.Network = dlnet;
        end

With this constructor function, the command residualBlockLayer(64,'Stride',2,'IncludeSkipConvolution',true,'Name','res5') creates a residual block layer with 64 filters, a stride of 2, a convolution in the skip connection, and with the name 'res5'. The required sizes of weights and parameters are determined when the completed network is assembled for training.

Create Forward Functions

Create the layer forward functions to use at prediction time and training time.

Create a function named predict that propagates the data forward through the layer at prediction time and outputs the result.

The syntax for predict is [Z1,…,Zm] = predict(layer,X1,…,Xn), where X1,…,Xn are the n layer inputs and Z1,…,Zm are the m layer outputs. The values n and m must correspond to the NumInputs and NumOutputs properties of the layer.

Tip

If the number of inputs to predict can vary, then use varargin instead of X1,…,Xn. In this case, varargin is a cell array of the inputs, where varargin{i} corresponds to Xi. If the number of outputs can vary, then use varargout instead of Z1,…,Zm. In this case, varargout is a cell array of the outputs, where varargout{j} corresponds to Zj.

Tip

If the custom layer has a dlnetwork object for a learnable parameter, then in the predict function of the custom layer, use the predict function for the dlnetwork. Using the dlnetwork object predict function ensures that the software uses the correct layer operations for prediction.

Because the residual block has only one input and one output, the syntax for predict for the custom layer is Z = predict(layer,X).

By default, the layer uses predict as the forward function at training time. To use a different forward function at training time, or retain a value required for a custom backward function, you must also create a function named forward.

The dimensions of the inputs depend on the type of data and the output of the connected layers.

Layer InputInput SizeObservation Dimension
2-D imagesh-by-w-by-c-by-N, where h, w, and c correspond to the height, width, and number of channels of the images, respectively, and N is the number of observations.4
3-D imagesh-by-w-by-d-by-c-by-N, where h, w, d, and c correspond to the height, width, depth, and number of channels of the 3-D images, respectively, and N is the number of observations.5
Vector sequencesc-by-N-by-S, where c is the number of features of the sequences, N is the number of observations, and S is the sequence length.2
2-D image sequencesh-by-w-by-c-by-N-by-S, where h, w, and c correspond to the height, width, and number of channels of the images, respectively, N is the number of observations, and S is the sequence length.4
3-D image sequencesh-by-w-by-d-by-c-by-N-by-S, where h, w, d, and c correspond to the height, width, depth, and number of channels of the 3-D images, respectively, N is the number of observations, and S is the sequence length.5

For layers that output sequences, the layers can output sequences of any length or output data with no time dimension. Note that when training a network that outputs sequences using the trainNetwork function, the lengths of the input and output sequences must match.

For the residual block layer, a forward pass of the layer is simply a forward pass of the dlnetwork object. To pass the input data to the dlnetwork object, you must first convert it to a formatted dlarray object.

Implement this operation in the custom layer function predict. To perform a forward pass of the dlnetwork for prediction, use the predict function for dlnetwork objects. In this case, the input to the residual block layer is used as the input to both of the unconnected inputs to the dlnetwork object, so the syntax for predict for the dlnetwork object is Z = predict(dlnet,X,X).

Because the layers in the dlnetwork object do not behave differently during training and that the residual block layer does not require memory or a different forward function for training, you can remove the forward function from the class file.

Create the predict function and add a comment to the top of the function that explains the syntaxes of the function.

        function Z = predict(layer, X)
            % Forward input data through the layer at prediction time and
            % output the result.
            %
            % Inputs:
            %         layer - Layer to forward propagate through
            %         X     - Input data
            % Outputs:
            %         Z - Output of layer forward function
                       
            % Convert input data to formatted dlarray.
            X = dlarray(X,'SSCB');
            
            % Predict using network.
            dlnet = layer.Network;
            Z = predict(dlnet,X,X);
            
            % Strip dimension labels.
            Z = stripdims(Z);
        end

Because the predict function only uses functions that support dlarray objects, defining the backward function is optional. For a list of functions that support dlarray objects, see List of Functions with dlarray Support.

Completed Layer

View the completed layer class file.

classdef residualBlockLayer < nnet.layer.Layer
    % Example custom residual block layer.


    properties (Learnable)
        % Layer learnable parameters
    
        % Residual block.
        Network
    end
    
    methods
        function layer = residualBlockLayer(numFilters,NameValueArgs)
            % layer = residualBlockLayer(numFilters) creates a residual
            % block layer with the specified number of filters.
            %
            % layer = residualBlockLayer(numFilters,Name,Value) specifies
            % additional options using one or more name-value pair
            % arguments:
            % 
            %     'Stride'                 - Stride of convolution operation 
            %                                (default 1)
            %
            %     'IncludeSkipConvolution' - Flag to include convolution in
            %                                skip connection
            %                                (default false)
            %
            %     'Name'                   - Layer name
            %                                (default '')
    
            % Parse input arguments.
            arguments
                numFilters
                NameValueArgs.Stride = 1
                NameValueArgs.IncludeSkipConvolution = false
                NameValueArgs.Name = ''
            end
    
            stride = NameValueArgs.Stride;
            includeSkipConvolution = NameValueArgs.IncludeSkipConvolution;
            name = NameValueArgs.Name;
    
            % Set layer name.
            layer.Name = name;
    
            % Set layer description.
            description = "Residual block with " + numFilters + " filters, stride " + stride;
            if includeSkipConvolution
                description = description + ", and skip convolution";
            end
            layer.Description = description;
            
            % Set layer type.
            layer.Type = "Residual Block";
    
            % Define nested layer graph.
            layers = [
                convolution2dLayer(3,numFilters,'Padding','same','Stride',stride,'Name','conv1')
                groupNormalizationLayer('all-channels','Name','gn1')
                reluLayer('Name','relu1')
                convolution2dLayer(3,numFilters,'Padding','same','Name','conv2')
                groupNormalizationLayer('channel-wise','Name','gn2')
    
                additionLayer(2,'Name','add')
                reluLayer('Name','relu2')];
    
            lgraph = layerGraph(layers);
    
            % Add skip connection.
            if includeSkipConvolution
                layers = [
                    convolution2dLayer(1,numFilters,'Stride',stride,'Name','convSkip')
                    groupNormalizationLayer('all-channels','Name','gnSkip')];
     
                lgraph = addLayers(lgraph,layers);
                lgraph = connectLayers(lgraph,'gnSkip','add/in2');  
            end 
    
            % Convert to dlnetwork.
            dlnet = dlnetwork(lgraph,'Initialize',false);
    
            % Set Network property.
            layer.Network = dlnet;
        end
        
        function Z = predict(layer, X)
            % Forward input data through the layer at prediction time and
            % output the result.
            %
            % Inputs:
            %         layer - Layer to forward propagate through
            %         X     - Input data
            % Outputs:
            %         Z - Output of layer forward function
                       
            % Convert input data to formatted dlarray.
            X = dlarray(X,'SSCB');
            
            % Predict using network.
            dlnet = layer.Network;
            Z = predict(dlnet,X,X);
            
            % Strip dimension labels.
            Z = stripdims(Z);
        end
    end
end

GPU Compatibility

If the layer forward functions fully support dlarray objects, then the layer is GPU compatible. Otherwise, to be GPU compatible, the layer functions must support inputs and return outputs of type gpuArray (Parallel Computing Toolbox).

Many MATLAB built-in functions support gpuArray (Parallel Computing Toolbox) and dlarray input arguments. For a list of functions that support dlarray objects, see List of Functions with dlarray Support. For a list of functions that execute on a GPU, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox). To use a GPU for deep learning, you must also have a supported GPU device. For information on supported devices, see GPU Support by Release (Parallel Computing Toolbox). For more information on working with GPUs in MATLAB, see GPU Computing in MATLAB (Parallel Computing Toolbox).

In this example, the MATLAB functions used in predict all support dlarray objects, so the layer is GPU compatible.

Check Validity of Layer Using checkLayer

Check the layer validity of the custom layer residualBlockLayer using the checkLayer function.

Create an instance of a residual block layer. To access this layer, open this example as a live script.

numFilters = 64;

layer = residualBlockLayer(numFilters)
layer = 
  residualBlockLayer with properties:

       Name: ''

   Learnable Parameters
    Network: [1x1 dlnetwork]

  Show all properties

Check the layer validity using the checkLayer function. The layer expects 4-D array inputs, where the first three dimensions correspond to the height, width, and number of channels of the previous layer output, and the fourth dimension corresponds to the observations. Specify a typical input size and set the 'ObservationDimension' option to 4.

validInputSize = [56 56 64];
checkLayer(layer,validInputSize,'ObservationDimension',4)
Skipping GPU tests. No compatible GPU device found.
 
Skipping code generation compatibility tests. To check validity of the layer for code generation, specify the 'CheckCodegenCompatibility' and 'ObservationDimension' options.
 
Running nnet.checklayer.TestLayerWithoutBackward
.......... ...
Done nnet.checklayer.TestLayerWithoutBackward
__________

Test Summary:
	 13 Passed, 0 Failed, 0 Incomplete, 9 Skipped.
	 Time elapsed: 5.2518 seconds.

The function does not detect any issues with the layer.

See Also

| | | | |

Related Topics