Main Content

Train Voice Activity Detection in Noise Model Using Deep Learning

This example shows how to detect regions of speech in a low signal-to-noise environment using deep learning. You train a bidirectional long short-term memory (BiLSTM) network from scratch to perform voice activity detection (VAD) and compare that network to a pretrained deep learning-based VAD. To explore the model trained from scratch in this example, see Voice Activity Detection in Noise Using Deep Learning. To use an off-the-shelf deep learning-based VAD, see detectspeechnn.


Voice activity detection is an essential component of many audio systems, such as automatic speech recognition, speaker recognition, and audio conferencing. Voice activity detection can be especially challenging in low signal-to-noise (SNR) situations, where speech is obstructed by noise.

For reproducibility, set the random seed to default.

rng default

In high SNR scenarios, traditional speech detection algorithms perform adequately. Read in an audio file that consists of words spoken with pauses between and listen to it.

fs = 16e3;
[speech,fileFs] = audioread("MaleVolumeUp-16-mono-6secs.ogg");

Use the detectSpeech function to locate regions of speech. The detectSpeech function correctly identifies all regions of speech.


Load two noise signals and resample to the audio sample rate.

[noise200,fileFs200] = audioread("WashingMachine-16-8-mono-200secs.mp3");
[noise1000,fileFs1000] = audioread("WashingMachine-16-8-mono-1000secs.mp3");
noise200 = resample(noise200,fs,fileFs200);
noise1000 = resample(noise1000,fs,fileFs1000);

Use the supporting function mixSNR to corrupt the clean speech signal with washing machine noise at a desired SNR level in dB. Listen to the corrupted audio.

SNR = -10;
noisySpeech = mixSNR(speech,noise200,SNR);


Call detectSpeech on the noisy speech signal. The function fails to detect the speech regions given the very low SNR. The remainder of the example walks through training and evaluating deep learning-based VAD networks that can perform well under low SNR.


Download and Prepare Data

Download and extract the Google Speech Commands Dataset [1].

downloadFolder = matlab.internal.examples.downloadSupportFile("audio","");
dataFolder = tempdir;
dataset = fullfile(dataFolder,"google_speech");

Create audioDatastore objects to point to the training and validation data sets.

adsTrain = audioDatastore(fullfile(dataset,"train"),IncludeSubfolders=true);
adsValidation = audioDatastore(fullfile(dataset,"validation"),IncludeSubfolders=true);

Construct Train and Validation Signals

The Google dataset consists of isolated words. Use the supporting function, constructSignal, to contruct train and validation signals that consist of isolated words and regions of silence. The constructSignal function also returns ground truth binary masks indicating the regions of speech in the train and validation signals.

[audioTrain,TTrainPerSample] = constructSignal(adsTrain,fs,1000);
[audioValidation,TValidationPerSample] = constructSignal(adsValidation,fs,200);

Listen to the first 10 seconds of the constructed signal. Use signalMask and plotsigroi to visualize the signal and ground truth binary mask.

duration = 10;


mask = signalMask(TTrainPerSample,SampleRate=fs);
title("Clean Signal ("+duration+" seconds)")

Add Noise to Train and Validation Signals

Use the supporting function mixSNR to corrupt the train and validation signals with noise.

audioTrain = mixSNR(audioTrain,noise1000,SNR);
audioValidation = mixSNR(audioValidation,noise200,SNR);

Listen to the first 10 seconds of the train signal and visualize the signal and mask.


title("Training Signal ("+duration+" seconds)")

Input Pipeline

Define an audioFeatureExtractor to extract the following spectral features: spectralCentroid, spectralCrest, spectralEntropy, spectralFlux, spectralKurtosis, spectralRolloffPoint, spectralSkewness, spectralSlope, and the periodicity feature harmonicRatio. Extract features using a 256-point Hann window with 50% overlap.

afe = audioFeatureExtractor(SampleRate=fs, ...
    Window=hann(256,"Periodic"), ...
    OverlapLength=128, ...
    spectralCentroid=true, ...
    spectralCrest=true, ...
    spectralEntropy=true, ...
    spectralFlux=true, ...
    spectralKurtosis=true, ...
    spectralRolloffPoint=true, ...
    spectralSkewness=true, ...
    spectralSlope=true, ...

featuresTrain = extract(afe,audioTrain);

Display the dimensions of the features matrix. The first dimension corresponds to the number of windows the signal was broken into (it depends on the signal length, window length, and overlap length). The second dimension is the number of features used in this example.

[numWindows,numFeatures] = size(featuresTrain)
numWindows = 124999
numFeatures = 9

In classification applications, it is a good practice to normalize all features to have zero mean and unity standard deviation.

Compute the mean and standard deviation for each coefficient, and use them to normalize the data.

M = mean(featuresTrain,1);
S = std(featuresTrain,[],1);
featuresTrain = (featuresTrain - M) ./ S;

Extract features from the validation signal using the same process.

XValidation = extract(afe,audioValidation);
XValidation = (XValidation - mean(XValidation,1)) ./ std(XValidation,[],1);

Each feature corresponds to 256 samples of data (the window length), sampled every 128 samples (the hop length). For each window, set the expected voice/no voice value to the mode of the baseline mask values corresponding to those 256 samples. Convert the voice/no voice mask to categorical.

windowLength = numel(afe.Window);
overlapLength = afe.OverlapLength;

TTrain = mode(buffer(TTrainPerSample,windowLength,overlapLength,"nodelay"),1);

TTrain = categorical(TTrain);

Do the same for the validation mask.

TValidation = mode(buffer(TValidationPerSample,windowLength,overlapLength,"nodelay"),1);

TValidation = categorical(TValidation);

Use the supporting function featureBuffer to split the training features and the mask into sequences with a duration approximately 8 seconds and a 75% overlap between consecutive sequences.

sequenceDuration = 8;
analysisHopLength = numel(afe.Window) - afe.OverlapLength;
sequenceLength = round(sequenceDuration*fs/analysisHopLength);

overlapPercent = 0.75;

XTrain = featureBuffer(featuresTrain',sequenceLength,overlapPercent);
TTrain = featureBuffer(TTrain,sequenceLength,overlapPercent);

Network Architecture

LSTM networks can learn long-term dependencies between time steps of sequence data. This example uses the bidirectional LSTM layer bilstmLayer (Deep Learning Toolbox) to look at the sequence in both forward and backward directions.

layers = [ ...

Training Options

To define parameters for training, use trainingOptions (Deep Learning Toolbox). Use the Adam optimizer with a mini-batch size of 64 and a piecewise learn rate schedule.

maxEpochs = 20;
miniBatchSize = 64;
options = trainingOptions("adam", ...
    MaxEpochs=maxEpochs, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    Verbose=false, ...
    ValidationFrequency=floor(numel(XTrain)/miniBatchSize), ...
    ValidationData={XValidation.',TValidation}, ...
    Plots="training-progress", ...
    LearnRateSchedule="piecewise", ...
    Metrics = "Accuracy",...
    LearnRateDropFactor=0.1, ...
    LearnRateDropPeriod=5, ...
    InputDataFormats = "CTB");

Train Network

To train the network, use trainnet.

speechDetectNet = trainnet(XTrain,TTrain,layers,"crossentropy" ,options);

Evaluate Trained Network

Estimate voice activity in the validation signal using the trained network. Convert the estimated VAD mask from categorical to double, then replicate the window-based decisions to sample-based decisions.

YValidation = predict(speechDetectNet,XValidation);
YValidation = scores2label(YValidation,unique(TValidation));
YValidation = double(YValidation)-1;
wL = numel(afe.Window);
hL = wL - afe.OverlapLength;
YValidationPerSample = [repelem(YValidation(1),floor(wL/2 + hL/2),1);
    repelem(YValidation(end),ceil(wL/2 + hL/2),1)];

Calculate and plot the validation confusion matrix from the vectors of actual and estimated labels. Save the results for later analysis.

cc = confusionchart(TValidationPerSample,YValidationPerSample, ...
    title="speechDetect - Validation Confusion Chart", ...

speechDetectResults = cc.NormalizedValues;

Evaluate Pretrained VAD Network

The vadnet network is a pretrained network for voice activity detection. You can use it with the vadnetPreprocess and vadnetPostprocess functions for applications such as transfer learning, or you can use detectspeechnn, which encapsulates vadnetPreprocess, vadnet, and vadnetPostprocess for inference-only applications. The vadnet network performs well under every-day adverse conditions, however it fails in the cases of extreme SNR, such as the -10 dB SNR used in this example. Also, vadnet was trained to detect regions of continuous speech (meaning several words in a row), not isolated words. In short, the pretrained vadnet fails for the validation signal in this example.

Load in the pretrained vadnet model.

net = audioPretrainedNetwork("vadnet");

Extract features from the validation signal using the same input pipeline used to train the network.

XValidation = vadnetPreprocess(audioValidation,fs);

Predict the VAD mask.

y = predict(net,gpuArray(XValidation));

vadnet is a regression network and requires additional post-processing to determine decision boundaries. Use vadnetPostprocess to determine the boundaries of voice activity regions.

boundaries = vadnetPostprocess(audioValidation,16e3,y);

The vadnetPostprocess function returns the decisions as time boundaries. To convert the boundaries to a binary mask that corresponds to the original signal samples, use sigroi2binmask.

YValidationPerSample = double(sigroi2binmask(boundaries,size(audioValidation,1)));

To create a confusion chart to analyze the error, use confusionchart (Deep Learning Toolbox).

confusionchart(TValidationPerSample,YValidationPerSample, ...
    title="vadnet - Validation Confusion Chart", ...

Transfer Learning

Apply transfer learning to the pretrained vadnet to make use of both the pretrained weights and the network architecture.

Extract features from the audio.

featuresTrain = vadnetPreprocess(audioTrain,fs);

Buffer the ground truth mask so that decisions correspond to the analysis windows used in vadnetPreprocess.

windowLength = 400;
overlapLength = 240;
TTrainPerSamplePadded = [zeros(floor(windowLength/2),1);TTrainPerSample;zeros(ceil(windowLength/2),1)];
TTrain = mode(buffer(TTrainPerSamplePadded,windowLength,overlapLength,"nodelay"),1);

Buffer the validation mask.

TValidationPerSamplePadded = [zeros(floor(windowLength/2),1);TValidationPerSample;zeros(ceil(windowLength/2),1)];
TValidation = mode(buffer(TValidationPerSamplePadded,windowLength,overlapLength,"nodelay"),1);

Split the long training signal into overlapped sequences for training. Do the same for the ground-truth mask.

sequenceDuration = 8;
analysisHopLength = windowLength - overlapLength;
sequenceLength = round(sequenceDuration*fs/analysisHopLength);

overlapPercent = 0.75;

XTrain = featureBuffer(featuresTrain,sequenceLength,overlapPercent);
TTrain = featureBuffer(TTrain,sequenceLength,overlapPercent);

To define parameters for training, use trainingOptions (Deep Learning Toolbox).

miniBatchSize = 12;
maxEpochs = 9;
options = trainingOptions("adam", ...
    InitialLearnRate=0.01, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropPeriod=3, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    ValidationFrequency=floor(numel(XTrain)/miniBatchSize), ...
    ValidationData={XValidation,TValidation}, ...
    Verbose=false, ...
    Plots="training-progress", ...
    MaxEpochs=maxEpochs, ...
    OutputNetwork="best-validation-loss" ...

To train the network, use trainnet.

noisyvadnet = trainnet(XTrain,TTrain,net,"mse",options);

Estimate voice activity in the validation signal using the trained network. Postprocess the predictions using vadnetPostprocess, then convert the boundaries in time to a sample-based mask.

y = predict(noisyvadnet,gpuArray(XValidation));
boundaries = vadnetPostprocess(audioValidation,fs,y);
YValidationPerSample = double(sigroi2binmask(boundaries,size(audioValidation,1)));

Calculate and plot the validation confusion matrix from the vectors of actual and estimated labels. Save the results for later analysis.

cc = confusionchart(TValidationPerSample,YValidationPerSample, ...
    title="noisyvadnet - Validation Confusion Chart", ...

noisyvadnetResults = cc.NormalizedValues;

Compare Networks

There are several considerations when choosing a network, such as size, inference speed, error, and streaming capabilities.


The speechDetectNet trained from scratch in this example is well-suited for streaming inference because its BiLSTM layers retain state between calls. See Voice Activity Detection in Noise Using Deep Learning for an example of using speechDetect for streaming voice activity detection.

The vadnet architecture consists of convolutional, recurrent, and fully-connected layers, and is not well-suited for low-latency streaming. See the vadnet documentation for an example of streaming VAD detection using vadnet.

Network Size

Compare the network sizes.

networks = ["speechDetect","noisyvadnet"];
b = bar(reordercats(categorical(networks),networks),[whos("speechDetectNet").bytes/1024,whos("noisyvadnet").bytes/1024]);
title("Network Size")
ylabel("Size (KB)")
grid on
b.FaceColor = "flat";
b.CData(2,:) = [0.8500 0.3250 0.0980];

Network Inference Speed

Compare the network inference speeds. The simple speechDetect architecture has faster inference speed on both the CPU and the GPU for short durations (approximately 8 second chunks or less). For longer durations, speechDetect is faster than noisyvadnet on the GPU and slower on the CPU.

durationsToTest = [1,5,10,20,40];
environment = ["CPU","GPU"];

speechDetectSpeed = zeros(numel(durationsToTest),numel(environment));
noisyvadnetSpeed = zeros(numel(durationsToTest),numel(environment));
for jj = 1:numel(environment)
    for ii = 1:numel(durationsToTest)
        idx = 1:durationsToTest(ii)*fs;
        speechDetectFeatures = extract(afe,audioValidation(idx))';
        vadnetFeatures = vadnetPreprocess(audioValidation(idx),fs);

        switch environment(jj)
            case "CPU"
                speechDetectSpeed(ii,1) = timeit(@()predict(speechDetectNet,speechDetectFeatures.'),1);
                noisyvadnetSpeed(ii,1) = 0;%timeit(@()predict(noisyvadnet,vadnetFeatures),1);
            case "GPU"
                speechDetectSpeed(ii,2) = gputimeit(@()predict(speechDetectNet,gpuArray(speechDetectFeatures.')),1);
                noisyvadnetSpeed(ii,2) = gputimeit(@()predict(noisyvadnet,gpuArray(vadnetFeatures)),1);

for ii = 1:numel(environment)
    plot(durationsToTest,speechDetectSpeed(:,ii),"b-", ...
        durationsToTest,noisyvadnetSpeed(:,ii),"r-", ...
        durationsToTest,speechDetectSpeed(:,ii),"bo", ...
    grid on
    xlabel("Audio Duration (s)")
    ylabel("Computation Duration (s)")
    title("Inference Speed ("+environment(ii)+")")

Network Error

Use the previously calculated confusion charts to display common statistics for error analysis. Accuracy, recall, precision, and f1 score are all derived from the confusion matrices previously plotted.

Accuracy is defined as the ratio of correctly predicted observations to the total observations. It is the most intuitive metric but can be misleading for imbalanced data sets. For example, if speech is only present in 5% of the audio, then classifying all audio as non-speech would result in 95 % accuracy.


Recall, also called sensitivity, is the ratio of correctly predicted positive observations to all observations that belong to the positive class. Recall answers the question: Of all speech regions, how many were correctly classified? A low recall indicates that regions of speech were misclassified as regions of nonspeech.


Precision is the ratio of correctly predicted positive observations to the total predicted positive observations. Precision answers the question: Of all the observations the network classified as speech, how many were actually speech? A low precision indicates that regions of nonspeech were misclassified as regions of speech.


F1 score is the harmonic mean of the precision and recall: it accounts for both false positives and false negatives.


The true measure of a network depends on your application. In real-world situations, a cost function is usually optimized which weights the costs of false positives and false negatives.

TP = speechDetectResults(2,2);
TN = speechDetectResults(1,1);
FP = speechDetectResults(1,2);
FN = speechDetectResults(2,1);
speechDetectAccuracy = (TP+TN)/(TP+TN+FP+FN);
speechDetectRecall = TP/(TP+FN);
speechDetectPrecision = TP/(TP+FP);
speechDetectF1Score = 2*(speechDetectRecall*speechDetectPrecision)/(speechDetectRecall+speechDetectPrecision);

TP = noisyvadnetResults(2,2);
TN = noisyvadnetResults(1,1);
FP = noisyvadnetResults(1,2);
FN = noisyvadnetResults(2,1);
noisyvadnetAccuracy = (TP+TN)/(TP+TN+FP+FN);
noisyvadnetRecall = TP/(TP+FN);
noisyvadnetPrecision = TP/(TP+FP);
noisyvadnetF1Score = 2*(noisyvadnetRecall*noisyvadnetPrecision)/(noisyvadnetRecall+noisyvadnetPrecision);

bar(categorical(["Accuracy","Recall","Precision","F1 Score"]), ...
    [speechDetectAccuracy,noisyvadnetAccuracy; ...
    speechDetectRecall,noisyvadnetRecall; ...
    speechDetectPrecision,noisyvadnetPrecision; ...
title("Error Analysis")
grid on

Supporting Functions

Convert Feature Vectors to Sequences

function sequences = featureBuffer(features,featureVectorsPerSequence,overlapPercent)
% y = featureBuffer(x,sequenceLength,overlapPercent) buffers a sequence of
% feature vectors, x, into sequences of length sequenceLength overlapped by
% overlapPercent. The sequences output are returns in a cell array for
% consumption by trainnet.

featureVectorOverlap = round(overlapPercent*featureVectorsPerSequence);
hopLength = featureVectorsPerSequence - featureVectorOverlap;

N = floor((size(features,2) - featureVectorsPerSequence)/hopLength) + 1;
sequences = cell(N,1);

idx = 1;
for jj = 1:N
    sequences{jj} = features(:,idx:idx + featureVectorsPerSequence - 1);
    idx = idx + hopLength;



function [noisySignal,requestedNoise] = mixSNR(signal,noise,ratio)
% [noisySignal,requestedNoise] = mixSNR(signal,noise,ratio) returns a noisy
% version of the signal, noisySignal. The noisy signal has been mixed with
% noise at the specified ratio in dB.

numSamples = size(signal,1);

% Convert noise to mono
noise = mean(noise,2);

% Trim or expand noise to match signal size
if size(noise,1)>=numSamples
    % Choose a random starting index such that you still have numSamples
    % after indexing the noise.
    start = randi(size(noise,1) - numSamples + 1);
    noise = noise(start:start+numSamples-1);
    numReps = ceil(numSamples/size(noise,1));
    temp = repmat(noise,numReps,1);
    start = randi(size(temp,1) - numSamples - 1);
    noise = temp(start:start+numSamples-1);

signalNorm = norm(signal);
noiseNorm = norm(noise);

goalNoiseNorm = signalNorm/(10^(ratio/20));
factor = goalNoiseNorm/noiseNorm;

requestedNoise = noise.*factor;
noisySignal = signal + requestedNoise;

noisySignal = noisySignal./max(abs(noisySignal));

Construct Signal

function [audio,mask] = constructSignal(ds,fs,duration)
% [audio,mask] = constructSignal(ds,fs,duration) constructs an audio signal
% of the specified duration by concatenating samples from the
% audioDatastore ds with random duration of silence between.

win = hamming(50e-3*fs,"periodic");

% Create a 1000-second training signal by combining multiple speech files
% from the training data set. Use detectSpeech to remove unwanted portions
% of each file. Insert a random period of silence between speech segments.
% Preallocate the training signal.
N = duration*fs;
audio = zeros(N,1);

% Preallocate the voice activity training mask. Values of 1 in the mask
% correspond to samples located in areas with voice activity. Values of 0
% correspond to areas with no voice activity.
mask = zeros(N,1);

% Specify a maximum silence segment duration of 2 seconds.
maxSilenceSegment = 2;

% Construct the training signal by calling read on the datastore in a loop.
numSamples = 1;
while numSamples < N
    data = read(ds);
    data = data ./ max(abs(data)); % Scale amplitude

    % Determine regions of speech
    idx = detectSpeech(data,fs,Window=win);

    % If a region of speech is detected
    if ~isempty(idx)

        % Extend the indices by five frames
        idx(1,1) = max(1,idx(1,1) - 5*numel(win));
        idx(1,2) = min(length(data),idx(1,2) + 5*numel(win));

        % Isolate the speech
        data = data(idx(1,1):idx(1,2));

        % Write speech segment to training signal
        audio(numSamples:numSamples+numel(data)-1) = data;

        % Set VAD baseline
        mask(numSamples:numSamples+numel(data)-1) = true;

        % Random silence period
        numSilenceSamples = randi(maxSilenceSegment*fs,1,1);
        numSamples = numSamples + numel(data) + numSilenceSamples;
audio = audio(1:N);
mask = mask(1:N);


[1] Warden P. "Speech Commands: A public dataset for single-word speech recognition", 2017. Available from Copyright Google 2017. The Speech Commands Dataset is licensed under the Creative Commons Attribution 4.0 license