Main Content

Train Neural ODE Network

This example shows how to train an augmented neural ordinary differential equation (ODE) network.

A neural ODE [1] is a deep learning operation that returns the solution of an ODE. In particular, given an input, a neural ODE operation outputs the numerical solution of the ODE y=f(t,y,θ) for the time horizon (t0,t1) and the initial condition y(t0)=y0, where t and y denote the ODE function inputs and θ is a set of learnable parameters. Typically, the initial condition y0 is either the network input or, as in the case of this example, the output of another deep learning operation.

An augmented neural ODE [2] operation improves upon a standard neural ODE by augmenting the input data with extra channels and then discarding the augmentation after the neural ODE operation. Empirically, augmented neural ODEs are more stable, generalize better, and have a lower computational cost than neural ODEs.

This example trains a simple convolutional neural network with an augmented neural ODE operation.


The ODE function is itself a neural network. In this example, the model uses a network with a convolution and a tanh layer:


The example shows how to train a neural network to classify images of digits using an augmented neural ODE operation.

Load Training Data

Load the training images and labels using the digitTrain4DArrayData function.

load DigitsDataTrain

View the number of classes of the training data.

TTrain = labelsTrain;
classNames = categories(TTrain);
numClasses = numel(classNames)
numClasses = 10

View some images from the training data.

numObservations = size(XTrain,4);
idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));

Define Neural Network Architecture

Define the following network, which classifies images.

  • A convolution-ReLU block with 8 3-by-3 filters with a stride of 2

  • An augmentation layer that concatenates an array of zeros to the input such that the output has twice as many channels as the input

  • A neural ODE operation with ODE function containing a convolution-tanh block with 16 3-by-3 filters

  • A discard augmentation layer that trims trailing elements in the channel dimension so that the output has half as many channels as the input

  • For classification output, a fully connect operation of size 10 (the number of classes) and a softmax operation


A neural ODE layer outputs the solution of a specified ODE function. For this example, specify a neural network containing a convolution and tanh layer the ODE function.


The neural ODE network must have matching input and output sizes. To calculate the input size of the neural network in the ODE layer, note that:

  • The input data for the image classification network are arrays of 28-by-28-by-1 images.

  • The images flow through a convolution layer with 8 filters that downsamples by a factor of 2.

  • The output of the convolution layer flows through an augmentation layer that doubles the number of channel dimensions.

This means that the inputs to the neural ODE layer are 14-by-14-by-16 arrays, where the spatial dimensions have size 14 and the channel dimension has size 16. Because the convolution layer downsamples the 28-by-28 images by a factor of two, the spatial sizes are 14. Because the convolution layer outputs 8 channels (the number of filters of the convolution layer) and that the augmentation layer doubles the number of channels, the channel size is 16.

Create the neural network to use for the neural ODE layer. Because the network does not have an input layer, do not initialize the network.

numFilters = 8;

layersODE = [

netODE = dlnetwork(layersODE,Initialize=false);

Create the image classification network. For the augmentaion and discard augmentation layers, use function layers with the channelAugmentation and discardChannelAugmentation functions listed in the Channel Augmentation Function and Discard Channel Augmentation Function sections of the example, respectively. To access these functions, open the example as a live script.

inputSize = size(XTrain,1:3);
filterSize = 3;
tspan = [0 0.1];

layers = [

Specify Training Options

Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.

  • Train using the Adam solver.

  • Train with a learning rate of 0.01.

  • Shuffle the data every epoch.

  • Monitor the training progress in a plot and display the accuracy.

  • Disable the verbose output.

options = trainingOptions("adam", ...
    InitialLearnRate=0.01, ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...

Train the neural network using the trainnet function. For classification, use cross-entropy loss. By default, the trainnet function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainnet function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

Test Model

Test the classification accuracy of the model by comparing the predictions on a held-out test set with the true labels.

Load the test data.

load DigitsDataTest
TTest = labelsTest;

After training, making predictions on new data does not require the labels. Create a minibatchqueue object containing only the predictors of the test data:

  • Set the number of outputs of the mini-batch queue to 1.

  • Preprocess the predictors using the preprocessPredictors function, listed in the Mini-Batch Predictors Preprocessing Function section of the example.

  • For the single output of the datastore, specify the mini-batch format "SSCB" (spatial, spatial, channel, batch).

dsTest = arrayDatastore(XTest,IterationDimension=4);

mbqTest = minibatchqueue(dsTest,1, ...
    MiniBatchFormat="SSCB", ...

Loop over the mini-batches and classify the sequences using modelPredictions function, listed in the Model Predictions Function section of the example.

YTest = modelPredictions(net,mbqTest,classNames);

Visualize the predictions in a confusion matrix.


Calculate the classification accuracy.

accuracy = mean(TTest==YTest)
accuracy = 0.9262

Channel Augmention Function

The channelAugmentation function augments pads the channel dimension of the input data X such that the output has twice as many channels.

function Z = channelAugmentation(X)

idxC = finddim(X,"C");
szC = size(X,idxC);
Z = paddata(X,2*szC,Dimension=idxC);


Discard Channel Augmention Function

The discardChannelAugmentation function augments trims the channel dimension of the input data X such that the output has half as many channels.

function Z = discardChannelAugmentation(X)

idxC = finddim(X,"C");
szC = size(X,idxC);
Z = trimdata(X,floor(szC/2),Dimension=idxC);


Model Predictions Function

The modelPredictions function takes as input the neural network, a mini-batch queue of input data mbq, and the class names, and computes the model predictions by iterating over all data. The function uses the onehotdecode function to find the predicted classes with the highest score.

function predictions = modelPredictions(net,mbq,classNames)

predictions = [];

while hasdata(mbq)
    X = next(mbq);
    Y = predict(net,X);
    Y = onehotdecode(Y,classNames,1)';
    predictions = [predictions; Y];


Predictors Preprocessing Function

The preprocessPredictors function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenating the data into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image to use as a singleton channel dimension.

function X = preprocessPredictors(dataX)

X = cat(4,dataX{:});



  1. Chen, Ricky T. Q., Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. “Neural Ordinary Differential Equations.” Preprint, submitted June 19, 2018.

  2. Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. “Augmented Neural ODEs.” Preprint, submitted October 26, 2019.

See Also

| | |

Related Topics