Signal Segmentation by Sweeping Hyperparameters
This example shows how to use a preconfigured template in the Experiment Manager app to set up a signal segmentation experiment involving time-frequency features. The goal of the experiment is to train a network to divide electrocardiogram (ECG) signals into P-wave, QRS-complex, and T-wave beat morphologies to assess the overall health of the heart and look for abnormalities. The training and testing data consists of 105 two-channel ECG signals labeled by an automated expert system and stored in 210 MAT files [1], [2].
Experiment Manager (Deep Learning Toolbox) enables you to train networks with different hyperparameter combinations and compare results as you search for the specifications that best solve your problem. Signal Processing Toolbox™ provides preconfigured templates that you can use to quickly set up signal processing experiments. For more information about Experiment Manager templates, see Quickly Set Up Experiment Using Preconfigured Template (Deep Learning Toolbox).
Open Experiment
First, open the template. In the Experiment Manager toolstrip, click
New and select Project
. In the dialog box,
click Blank Project, scroll to the Signal Classification
Experiments section, click Signal Segmentation by Sweeping
Hyperparameters, and optionally specify a name for the project folder.
Built-in training experiments consist of a description, an initialization function, a table of hyperparameters, a setup function, a collection of metric functions to evaluate the results of the experiment, and a set of supporting files.
The Description field contains a textual description of the experiment. For this example, the description is:
Signal Segmentation by Sweeping Hyperparameters
The Hyperparameters section specifies the strategy and
hyperparameter values to use. This experiment follows the Exhaustive
Sweep
strategy, in which Experiment Manager trains the network
using every combination of hyperparameter values specified in the hyperparameter table. In
this case, the experiment has two hyperparameters:
featureType
specifies the feature to extract from each signal and use to train the network. The options are:"none"
— Use raw data to train the network."fsst"
— Extract the Fourier synchrosqueezed transform of each signal and use the transforms to train the network."cwt"
— Extract the continuous wavelet transform of each signal and use the transforms to train the network.
initialLearnRate
specifies the initial learning rate. If the learning rate is too low, then training can take a long time. If the learning rate is too high, then training might reach a suboptimal result or diverge. The options are1e-3
and1e-2
.
The Setup Function section specifies a function that
configures the training data, network architecture, and training options for the experiment.
To open the Setup Function in the MATLAB® Editor, click Edit. The input to the setup function is
params
, a structure with fields from the hyperparameter table. The
function returns four outputs that the experiment passes to trainnet
(Deep Learning Toolbox):
dsTrain
— A datastore that contains the training signals and their corresponding labelslayers
— A layer array that defines the neural network architecturelossFcn
— A cross-entropy loss function for segmentation tasks, returned as"crossentropy"
options
— AtrainingOptions
object that contains training algorithm details
In this example, the setup function has these sections:
Download and Load Training Data downloads the data files into the temporary directory for the system and creates a signal datastore that points to the data set.
Format Labels restructures the data labels so they are in the form of a signal mask, which is how the network expects them for a segmentation task. A signal mask is a series of categorical labels with the same length as the signal. Use this function if your labels are arranged differently, for example, as a region-of-interest (ROI) table.
Resize Data resizes the data to a specific input size, if required by your model or feature extraction methods.
Extract Features uses
params.featureType
to specify the feature extraction function. The continuous wavelet transform and the Fourier synchrosqueezed transform both have the same time resolution as the input signals.Split Data splits the data into training and validation sets.
Read Data into Memory loads all your data into memory to speed up the training process, if your system has enough resources.
Define Network Architecture creates a neural network using an array of layers that includes a bidirectional long short-term memory (BiLSTM) layer and a fully connected layer. For more information, see Example Deep Learning Networks Architectures (Deep Learning Toolbox).
Define Training Hyperparameters calls
trainingOptions
(Deep Learning Toolbox) to set the hyperparameters to use when training the network.Supporting Functions includes functionality for feature extraction, label formatting, and data resizing.
The Post-Training Custom Metrics section specifies optional functions that Experiment Manager evaluates each time it finishes training the network. This experiment does not include metric functions.
The Supporting Files section enables you to identify, add, or remove files that your experiment requires. This experiment does not use supporting files.
Run Experiment
When you run the experiment, Experiment Manager repeatedly trains the network defined by the setup function. Each trial uses a different combination of hyperparameter values. By default, Experiment Manager runs one trial at a time. If you have Parallel Computing Toolbox™, you can run multiple trials at the same time or offload your experiment as a batch job in a cluster.
To run one trial of the experiment at a time, on the Experiment Manager toolstrip, set Mode to
Sequential
and click Run.To run multiple trials at the same time, set Mode to
Simultaneous
and click Run. For more information, see Run Experiments in Parallel (Deep Learning Toolbox).To offload the experiment as a batch job, set Mode to
Batch Sequential
orBatch Simultaneous
, specify your cluster and pool size, and click Run. For more information, see Offload Experiments as Batch Jobs to a Cluster (Deep Learning Toolbox).
A table of results displays the training accuracy, training loss, validation accuracy, and validation loss values for each trial.
Evaluate Results
To find the best result for your experiment, sort the table of results by validation accuracy:
Point to the Validation Accuracy column.
Click the triangle ▼ icon.
Select Sort in Descending Order.
The trial with the highest validation accuracy appears at the top of the results table.
To record observations about the results of your experiment, add an annotation:
In the results table, right-click the Validation Accuracy cell of the best trial.
Select
Add Annotation
.In the Annotations pane, enter your observations in the text box.
For each experiment trial, Experiment Manager produces a training progress plot, a confusion matrix for training data, and a confusion matrix for validation data. To see one of those plots, select the trial and click the corresponding button in the Review Results gallery on the Experiment Manager toolstrip.
Close Experiment
In the Experiment Browser pane, right-click the experiment name and select Close Project
. Experiment Manager closes the experiment and the results contained in the project.
Setup Function
Download and Load Training Data
If you intend to place the data files in a folder different from the temporary
directory, replace tempdir
with your folder name. To limit the run
time, this experiment uses only 40% of the data. To use more data, increase the value of
dataRatio
.
function [dsTrain,layers,lossFcn,options] = Experiment_setup(params) dataURL = "https://www.mathworks.com/supportfiles/SPT/data/QTDatabaseECGData1.zip"; datasetFolder = fullfile(tempdir,"QTDataset"); if ~exist(datasetFolder,"dir") zipFile = websave(tempdir,dataURL); unzip(zipFile,tempdir) delete(zipFile) end ds = signalDatastore(datasetFolder,SignalVariableNames=["ecgSignal","signalRegionLabels"]); rng("default") dataRatio = 0.4; numSignals = numel(ds.Files); dsSubset = subset(ds,randperm(numSignals,ceil(dataRatio*numSignals))); numSignals = numel(dsSubset.Files);
Format Labels
If your labels are already in the form of a signal mask, comment out this line.
dsFormat = transform(dsSubset,@helperFormatLabel);
Resize Data
If your model or feature extraction methods require no specific signal length, comment out this line.
dsResize = transform(dsFormat,@helperResizeData);
Extract Features
In segmentation tasks, you must create features that have the same time resolution as the input signals. The Fourier synchrosqueezed transform and the continuous wavelet transform have this characteristic. This experiment explores both possibilities.
dsFeature = transform(dsResize,@(x)helperExtractFeatures(x,FeatureType=params.featureType)); feature = preview(dsFeature); featureSize = size(feature{1});
Split Data
The experiment uses the training set to train the model. Use 80% of the data for the training set.
The experiment uses the validation set to evaluate the performance of the trained model during hyperparameter tuning. Use the remaining 20% of the data for the validation set.
If you intend to evaluate the generalization performance of your finalized model, set aside some data for testing as well.
[trainIdx,validIdx] = dividerand(numSignals,0.8,0.2); dsTrain = subset(dsFeature,trainIdx); dsValid = subset(dsFeature,validIdx);
Read Data into Memory
If the data does not fit into memory, set fitMemory
to
false
.
fitMemory = true; if fitMemory dataTrain = readall(dsTrain); XTrain = signalDatastore(dataTrain(:,1)); TTrain = arrayDatastore(vertcat(dataTrain(:,2)),OutputType="same"); dsTrain = combine(XTrain,TTrain); dataValid = readall(dsValid); XValid = signalDatastore(dataValid(:,1)); TValid = arrayDatastore(vertcat(dataValid(:,2)),OutputType="same"); dsValid = combine(XValid,TValid); end
Define Network Architecture
Use bidirectional long short-term memory (BiLSTM) layers and a fully connected layer to ensure the output channel count matches the number of classes. For segmentation tasks, use cross-entropy loss.
numClasses = 4; layers = [ sequenceInputLayer(featureSize(1)) bilstmLayer(50,OutputMode="sequence") fullyConnectedLayer(numClasses) softmaxLayer ]; lossFcn = "crossentropy";
Define Training Hyperparameters
To set the batch size of the data for training, set
MiniBatchSize
to 20.To use the adaptive moment estimation (Adam) optimizer, specify the
"adam"
option.Because the training data has sequences with rows and columns corresponding to channels and time steps, respectively, specify the input data format as
"CTB"
(channel, time, batch).To specify the initial learning rate, use
params.initialLearnRate
.To use the parallel pool to read the transformed datastore, set
PreprocessingEnvironment
to"parallel"
.
miniBatchSize = 20; options = trainingOptions("adam", ... InputDataFormats="CTB", ... InitialLearnRate=params.initialLearnRate, ... MaxEpochs=15, ... LearnRateSchedule="piecewise", ... LearnRateDropFactor=0.5, ... LearnRateDropPeriod=10, ... MiniBatchSize=miniBatchSize, ... Shuffle="every-epoch", ... Plots="training-progress", ... Verbose=false, ... ValidationData=dsValid, ... ExecutionEnvironment="auto", ... PreprocessingEnvironment="serial"); end
Supporting Functions
helperExtractFeatures
function. Extract the features to use for training the network.
inputCell
is a two-element cell array that contains an ECG signal vector and a table of region labels.outputCell
is a two-element cell array that contains the ECG signal features and a categorical label vector mask of the same length as the signal.
Compute time-frequency maps by using fsst
or cwt
(Wavelet Toolbox).
function outputCell = helperExtractFeatures(inputCell,Nvargs) arguments inputCell Nvargs.FeatureType = "fsst" end Fs = 250; sigs = inputCell(:,1); features = cell(size(sigs)); switch Nvargs.FeatureType case "fsst" for idx = 1:length(sigs) [s,f] = fsst(sigs{idx},Fs,kaiser(128)); f_indices = (f > 0.5) & (f < 40); features{idx} = [real(s(f_indices,:)); imag(s(f_indices,:))]; end case "cwt" for idx = 1:length(sigs) [s,f] = cwt(sigs{idx},Fs); f_indices = (f > 0.5) & (f < 40); features{idx} = [real(s(f_indices,:)); imag(s(f_indices,:))]; end case "none" features = sigs; end outputCell = [features inputCell(:,2)]; end
helperFormatLabel
function. Convert region labels to a mask of labels of size equal to the size of the input ECG signal.
inputCell
is a two-element cell array that contains an ECG signal vector and a table of region labels.outputCell
is a two-element cell array that contains the ECG signal vector and a categorical label vector mask of the same length as the signal.
When extracting categorical masks, give priority to QRS regions in cases
of overlap. Set missing values to "n/a"
.
function outputCell = helperFormatLabel(inputCell) sig = inputCell{1}; roiTable = inputCell{2}; L = length(sig); M = signalMask(roiTable); priorityOrder = ["QRS";"P";"T"]; [~,priorityIdx] = ismember(priorityOrder,M.Categories); mask = catmask(M,L,OverlapAction="prioritizeByList",PriorityList=priorityIdx); mask(ismissing(mask)) = "n/a"; outputCell = {sig,mask}; end
helperResizeData
function. Break the input ECG signal and label mask into
targetLength
-sample segments. Truncate the signal so that the signal
has an integer number of segments.
inputCell
is a two-element cell array that contains an ECG signal and a label mask.outputCell
is a two-column cell array that contains all the 5000-sample signal segments and label masks that could be generated from the input data.
function outputCell = helperResizeData(inputCell) targetLength = 5000; sig = inputCell{1}; mask = inputCell{2}; numChunks = floor(size(sig,1)/targetLength); sig = sig(1:(numChunks*targetLength)); mask = mask(1:(numChunks*targetLength)); sigOut = reshape(sig,targetLength,numChunks)'; sigOut = num2cell(sigOut,2); lblOut = reshape(mask,targetLength,numChunks)'; lblOut = num2cell(lblOut,2); outputCell = [sigOut lblOut]; end
References
[1] Goldberger, Ary L., Luis A. N. Amaral, Leon Glass, Jeffery M. Hausdorff, Plamen Ch. Ivanov, Roger G. Mark, Joseph E. Mietus, George B. Moody, Chung-Kang Peng, and H. Eugene Stanley. "PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals." Circulation. Vol. 101, No. 23, 2000, pp. e215–e220. doi: 10.1161/01.CIR.101.23.e215.
[2] Laguna, Pablo, Roger G. Mark, Ary L. Goldberger, and George B. Moody. "A Database for Evaluation of Algorithms for Measurement of QT and Other Waveform Intervals in the ECG." Computers in Cardiology. Vol. 24, 1997, pp. 673–676.
See Also
Apps
- Experiment Manager (Deep Learning Toolbox) | Signal Labeler
Objects
Functions
cwt
(Wavelet Toolbox) |fsst
|trainnet
(Deep Learning Toolbox) |trainingOptions
(Deep Learning Toolbox)