Main Content

Signal Classification Using Transfer Learning

This example shows how to use a preconfigured template in the Experiment Manager app to set up a signal classification experiment that uses transfer learning. The goal of the experiment is to train a classifier to determine if an electrocardiogram (ECG) signal exhibits arrhythmia, congestive heart failure, or normal sinus rhythm. The training and testing data consists of 96 ECG recordings from the MIT-BIH Arrhythmia Database, 30 recordings from the BIDMC Congestive Heart Failure Database, and 36 recordings from the MIT-BIH Normal Sinus Rhythm Database [1], [2], [3].

Experiment Manager (Deep Learning Toolbox) enables you to train networks with different hyperparameter combinations and compare results as you search for the specifications that best solve your problem. Signal Processing Toolbox™ provides preconfigured templates that you can use to quickly set up signal processing experiments. For more information about Experiment Manager templates, see Quickly Set Up Experiment Using Preconfigured Template (Deep Learning Toolbox).

Open Experiment

First, open the template. In the Experiment Manager toolstrip, click New and select Project. In the dialog box, click Blank Project, scroll to the Signal Classification Experiments section, click Signal Classification Using Transfer Learning, and optionally specify a name for the project folder.

Select experiment in Experiment Manager.

Built-in training experiments consist of a description, an initialization function, a table of hyperparameters, a setup function, a collection of metric functions to evaluate the results of the experiment, and a set of supporting files.

The Description field contains a textual description of the experiment. For this example, the description is:

Signal Classification Using Transfer Learning
The Initialization Function section specifies optional code to initialize the experiment before running any trials. In this experiment, the initialization function loads the data, extracts signal scalograms to use as features, and distributes the scalograms into training and validation data sets To open the Initialization Function in the MATLAB® Editor, click Edit. The function returns a structure called output, which contains two fields:

  • dsTrain contains the training data.

  • dsValid contains the validation data.

In this example, the initialization function has these sections:

  • Download and Load Training Data downloads the data files into the temporary directory for the system and creates a signal datastore that points to the data set.

  • Resize Data resizes the data to a specific input size, if required by your model or feature extraction methods.

  • Extract Features uses params.featureType to specify the feature extraction function. The continuous wavelet transform (CWT) and the Fourier synchrosqueezed transform (FSST) both have the same time resolution as the input signals.

  • Split Data splits the data into training and validation sets.

  • Read Data into Memory loads all your data into memory to speed up the training process, if your system has enough resources.

  • Supporting Functions includes functionality for feature extraction and data resizing.

The Hyperparameters section specifies the strategy and hyperparameter values to use. This experiment follows the Exhaustive Sweep strategy, in which Experiment Manager trains the network using every combination of hyperparameter values specified in the hyperparameter table. In this case, the experiment has three hyperparameters:

  • solver specifies the solver to use for training the neural network. For more information, see trainingOptions (Deep Learning Toolbox). The options are two stochastic solvers:

    • "adam" — Adaptive moment estimation (Adam)

    • "sgdm" — Stochastic gradient descent with momentum

  • initialLearnRate specifies the initial learning rate. If the learning rate is too low, then training can take a long time. If the learning rate is too high, then training might reach a suboptimal result or diverge. The options are 1e-4 and 1e-3.

  • pretrainedNetwork is an optional parameter that specifies the name of the pretrained network that performs the classification. To access this parameter, click Add and select Add From Suggested List. The options are:

    • "squeezenet" — SqueezeNet neural network [4].

    • "googlenet" — GoogLeNet neural network [5]. This option requires the Deep Learning Toolbox™ Model for GoogLeNet Network support package.

The Setup Function section specifies a function that configures the training data, network architecture, and training options for the experiment. To open the Setup Function in the MATLAB Editor, click Edit. The input to the setup function is params, a structure with fields from the hyperparameter table that also includes the data output by the initialization function. The function returns four outputs that the experiment passes to trainnet (Deep Learning Toolbox):

  • dsTrain — A datastore that contains the training signals and their corresponding labels

  • net — A dlnetwork object that specifies the pretrained neural network

  • lossFcn — A cross-entropy loss function for classification tasks, returned as "crossentropy"

  • options — A trainingOptions object that contains training algorithm details

In this example, the setup function has these sections:

The Post-Training Custom Metrics section specifies optional functions that Experiment Manager evaluates each time it finishes training the network. This experiment does not include metric functions.

The Supporting Files section enables you to identify, add, or remove files required by your experiment. This experiment does not use supporting files.

Run Experiment

When you run the experiment, Experiment Manager repeatedly trains the network defined by the setup function. Each trial uses a different combination of hyperparameter values. By default, Experiment Manager runs one trial at a time. If you have Parallel Computing Toolbox™, you can run multiple trials at the same time or offload your experiment as a batch job in a cluster.

  • To run one trial of the experiment at a time, on the Experiment Manager toolstrip, set Mode to Sequential and click Run.

  • To run multiple trials at the same time, set Mode to Simultaneous and click Run. For more information, see Run Experiments in Parallel (Deep Learning Toolbox).

  • To offload the experiment as a batch job, set Mode to Batch Sequential or Batch Simultaneous, specify your cluster and pool size, and click Run. For more information, see Offload Experiments as Batch Jobs to a Cluster (Deep Learning Toolbox).

A table of results displays the training accuracy, training loss, validation accuracy, and validation loss values for each trial.

Evaluate Results

To find the best result for your experiment, sort the table of results by validation accuracy:

  1. Point to the Validation Accuracy column.

  2. Click the triangle ▼ icon.

  3. Select Sort in Descending Order.

The trial with the highest validation accuracy appears at the top of the results table.

To record observations about the results of your experiment, add an annotation:

  1. In the results table, right-click the Validation Accuracy cell of the best trial.

  2. Select Add Annotation.

  3. In the Annotations pane, enter your observations in the text box.

For each experiment trial, Experiment Manager produces a training progress plot, a confusion matrix for training data, and a confusion matrix for validation data. To see one of those plots, select the trial and click the corresponding button in the Review Results gallery on the Experiment Manager toolstrip.

Close Experiment

In the Experiment Browser pane, right-click the experiment name and select Close Project. Experiment Manager closes the experiment and the results contained in the project.

Initialization Function

Download and Load Training Data

If you intend to place the data files in a folder different from the temporary directory, replace tempdir with your folder name. To limit the run time, this experiment uses only 40% of the data. To use more data, increase the value of dataRatio.

function output = ExperimentInitialization()

dataURL = "https://raw.githubusercontent.com/mathworks/physionet_ECG_data/main/ECGData.zip";
datasetFolder = fullfile(tempdir,"PhysionetECGData");
if ~exist(datasetFolder,"dir")
     mkdir(datasetFolder)
     zipFile = websave(tempdir,dataURL);
     unzip(zipFile,datasetFolder)
     delete(zipFile)
end

ECGData = importdata(fullfile(datasetFolder,"ECGData.mat"));
data = ECGData.Data;
numSignals = size(data,1);
data = mat2cell(data,ones(numSignals,1),size(data,2));
labels = ECGData.Labels;
labels = categorical(labels);
classNames = categories(labels);
fs = 128;
ds = combine(signalDatastore(data,SampleRate=fs),arrayDatastore(labels));

rng("default")
dataRatio = 0.4;
idx = splitlabels(labels,dataRatio);
ds = subset(ds,idx{1});
labels = labels(idx{1});

Resize Data

If your model or feature extraction methods require no specific signal length, comment out this line.

ds= transform(ds,@helperResizeData);

Extract Features

Use the scalograms of the signals as features. To extract the scalograms, use the continuous wavelet transform implemented in cwt (Wavelet Toolbox). Extract data over the frequency range from 0.5 Hz to 50 Hz.

dsFeature = transform(ds, ...
    @(x)helperExtractFeatures(x, ...
    fs, ...
    FrequencyRange=[0.5,50]));

Split Data

  • The experiment uses the training set to train the model. Use 80% of the data for the training set.

  • The experiment uses the validation set to evaluate the performance of the trained model during hyperparameter tuning. Use the remaining 20% of the data for the validation set.

  • If you intend to evaluate the generalization performance of your finalized model, set aside some data for testing as well.

trainValIdx = splitlabels(labels,[0.8,0.2]);
trainIdx = trainValIdx{1};
validIdx = trainValIdx{2};
dsTrain = subset(dsFeature,trainIdx);
dsValid = subset(dsFeature,validIdx);

Read Data into Memory

If the data does not fit into memory, set fitMemory to false.

fitMemory = true;
if fitMemory
    dataTrain = readall(dsTrain);
    XTrain = signalDatastore(dataTrain(:,1));
    TTrain = arrayDatastore(vertcat(dataTrain(:,2)),OutputType="same");
    dsTrain = combine(XTrain,TTrain);
    
    dataValid = readall(dsValid);
    XValid = signalDatastore(dataValid(:,1));
    TValid = arrayDatastore(vertcat(dataValid(:,2)),OutputType="same");
    dsValid = combine(XValid,TValid);
end

Supporting Functions

helperExtractFeatures function.  Extract features used for training.

  • inputCell is a two-element cell array that contains an ECG signal vector and a categorical label.

  • outputCell is a two-element cell array that contains the ECG signal features and a categorical label.

function outputCell = helperExtractFeatures(inputCell,fs,Nvargs)
    arguments
        inputCell
        fs = 1
        Nvargs.FrequencyRange = [0,fs/2]
    end
    sigs = inputCell(:,1);
    features = cell(size(sigs));
    for idx = 1:length(sigs)
        [s,f] = cwt(sigs{idx},fs);
        findices = f > Nvargs.FrequencyRange(1) & f < Nvargs.FrequencyRange(2);
        features{idx} = abs(s(findices,:));
    end
    outputCell = [features inputCell(:,2)];
end

helperResizeData function.  Pad or truncate an input ECG signal into targetLength-sample segments.

  • inputCell is a two-element cell array that contains an ECG signal and a label.

  • outputCell is a two-column cell array that contains same-length labeled signal segments.

Pad or truncate the data to the target length and normalize it.

function outputCell = helperResizeData(inputCell)
    targetLength = 60000;
    sig = padsequences(inputCell(:,1),2, ...
        Length=targetLength, ...
        PaddingValue="symmetric");
    sig = normalize(sig);
    outputCell = [{sig} inputCell(2)];
end

Setup Function

Load Training Data

Use the params.InitializationFunctionOutput parameter to access the experiment data.

function [dsTrain,net,lossFcn,options] = Experiment_setup(params)

initData = params.InitializationFunctionOutput;
dsTrain = initData.dsTrain;
dsValid = initData.dsValid;

Load Pretrained Network

Load SqueezeNet unless specified otherwise.

numClasses = numel(initData.classNames);
if ~isfield(params, "pretrainedNetwork")
    params.pretrainedNetwork = "squeezenet";
end
net = imagePretrainedNetwork(params.pretrainedNetwork,NumClasses=numClasses);

Edit Network for Transfer Learning

In transfer learning experiments, most of the network layers are typically frozen to the pretrained weights, with only the last few layers made learnable. This section adjusts the learning rate of the last NumberOfLearnableLayers layers to LearnRateFactor.

net = helperModifyPretrainedNetwork(net,NumberofLearnableLayers=1,LearnRateFactor=10);

Format Network Input

Use the ind2rgb function to convert the indexed images generated by cwt to three-channel RGB images of the kind used by the pretrained networks. Resize the scalograms to the specific dimensions required by the networks.

expectedInputSize = net.Layers(1).InputSize;
dsTrain = helperFormatInput(dsTrain,expectedInputSize);
dsValid = helperFormatInput(dsValid,expectedInputSize);

Define Training Hyperparameters

  • To set the batch size of the data for training, set MiniBatchSize to 20.

  • To specify the optimizer, use params.solver.

  • To specify the initial learning rate, use params.initialLearnRate.

  • To use the parallel pool to read the transformed datastore, set PreprocessingEnvironment to "parallel".

  • For classification tasks, use cross-entropy loss.

miniBatchSize = 20;
options = trainingOptions(params.solver, ...
    InitialLearnRate=params.initialLearnRate, ...
    MaxEpochs=30, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false, ...
    ValidationData=dsValid, ...
    ExecutionEnvironment="auto", ...
    PreprocessingEnvironment="serial");

lossFcn = "crossentropy";
end

Supporting Functions

helperFormatInput function.  Extract features used for training.

  • inputDs is a combined datastore that contains ECG scalogram features and categorical labels.

  • outputDs is a combined datastore that contains ECG scalograms formatted as resized images and also categorical labels.

function outputDs = helperFormatInput(inputDs,ExpectedFeatureSize)
featuresDs = inputDs.UnderlyingDatastores{1};
featuresDsNew = transform(featuresDs, ...
    @(x){imresize(ind2rgb(round(rescale(x,0,255)),jet(128)),ExpectedFeatureSize(1:2))});
if isprop(featuresDs,"File")
    outputDs = combine(featuresDsNew,inputDs.UnderlyingDatastores{2});
else
    outputDs = combine(signalDatastore(readall(featuresDsNew)),inputDs.UnderlyingDatastores{2});
end
end

helperModifyPretrainedNetwork function.  Adapt the pretrained network for transfer learning.

  • net is a pretrained network

  • NumberofLearnableLayers specifies the number of layers for which the parameters are unfrozen and the learning rate set to a nonzero value

  • LearnRateFactor specifies the learning rate factor for the learnable parameters.

function net = helperModifyPretrainedNetwork(net,Nvargs)
arguments
    net
    Nvargs.NumberofLearnableLayers = 1
    Nvargs.LearnRateFactor = 10
end

for ii = 1:Nvargs.NumberofLearnableLayers
    learnableIndex = find(net.Learnables.Layer==net.Learnables.Layer(end+1-ii))';
    for idx = learnableIndex
        net = setLearnRateFactor(net, ...
            net.Learnables.Layer(idx), ...
            net.Learnables.Parameter(idx), ...
            Nvargs.LearnRateFactor);
    end
end
end

References

[1] Goldberger, Ary L., Luis A. N. Amaral, Leon Glass, Jeffery M. Hausdorff, Plamen Ch. Ivanov, Roger G. Mark, Joseph E. Mietus, George B. Moody, Chung-Kang Peng, and H. Eugene Stanley. "PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals." Circulation. Vol. 101, No. 23, 2000, pp. e215–e220. doi: 10.1161/01.CIR.101.23.e215.

[2] Moody, G. B., and R. G. Mark. "The impact of the MIT-BIH Arrhythmia Database." IEEE Engineering in Medicine and Biology Magazine. Vol. 20. Number 3, May-June 2001, pp. 45–50. (PMID: 11446209).

[3] Baim, D. S., W. S. Colucci, E. S. Monrad, H. S. Smith, R. F. Wright, A. Lanoue, D. F. Gauthier, B. J. Ransil, W. Grossman, and E. Braunwald. "Survival of patients with severe congestive heart failure treated with oral milrinone." Journal of the American College of Cardiology. Vol. 7, Number 3, 1986, pp. 661–670.

[4] Iandola, Forrest N., Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, and Kurt Keutzer. “SqueezeNet: AlexNet-Level Accuracy with 50x Fewer Parameters and <0.5MB Model Size.” arXiv, November 4, 2016. https://doi.org/10.48550/arXiv.1602.07360.

[5] Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. “Going Deeper with Convolutions.” arXiv, September 16, 2014. https://doi.org/10.48550/arXiv.1409.4842.

See Also

Apps

Objects

Functions

Topics