Main Content

Wavelet Time Scattering with GPU Acceleration — Music Genre Classification

This example shows how to accelerate the computation of wavelet scattering features using gpuArray (Parallel Computing Toolbox) and a parallelized "depth-first" version of wavelet time scattering. You must have a CUDA-enabled NVIDIA GPU with compute capability 3.0 or higher. See GPU Support by Release (Parallel Computing Toolbox) for details.

This example reproduces the CPU version found in Music Genre Classification Using Wavelet Time Scattering. The example uses wavelet scattering features with a support vector machine to classify the genre of a musical excerpt. The example uses the audio datastore to manage reading of the audio files from disk as well as creating training and test sets.

GTZAN Dataset

The data set used in this example is the GTZAN Genre Collection [7][8]. The data is provided as a zipped tar archive which is approximately 1.2 GB. The uncompressed data set requires about 3 GB of disk space. Extracting the compressed tar file from the link provided in the references creates a folder with ten subfolders. Each subfolder is named for the genre of music samples it contains. The genres are: blues, classical, country, disco, hiphop, jazz, metal, pop, reggae, and rock. There are 100 examples of each genre and each audio file consists of about 30 seconds of data sampled at 22050 Hz. In the original paper, the authors used a number of time-domain and frequency-domain features including mel-frequency cepstral (MFC) coefficients extracted from each music example and a Gaussian mixture model (GMM) classification to achieve an accuracy of 61 percent [7]. Subsequently, deep learning networks have been applied to this data. In most cases, these deep learning approaches consist of convolutional neural networks (CNN) with the MFC coefficients or spectrograms as the input to the deep CNN. These approaches have resulted in performance of around 84% [4]. An LSTM approach with spectrogram time slices resulted in 79% accuracy and time-domain and frequency-domain features coupled with an ensemble learning approach (AdaBoost) resulted in 82% accuracy on a test set [2][3]. Recently, a sparse representation machine learning approach achieved approximately 89% accuracy [6].

Wavelet Scattering Framework

The wavelet scattering framework in this example uses two filter banks with 8 wavelets per octave in the first filter bank and 1 wavelet per octave in the second filter bank. The invariant scale is set to 0.5 seconds, which corresponds to slightly more than 11000 samples for the sample rate of 22050 Hz.

Create the wavelet time scattering network.

sf = helpergpuscat1('SignalLength',2^19,'SamplingFrequency',22050,...

Audio Datastore

Use audioDatastore to manage the GTZAN music genre collection. Each subfolder of the collection is named for the genre it represents. Set the 'IncludeSubFolders' property to true to instruct the audio datastore to use subfolders and set the 'LabelSource' property to 'foldernames' to create data labels based on the subfolder names. Set pathToData to a character vector containing the top-level data folder on your machine. The top-level data folder on your machine should contain ten subfolders each one named for one of the ten genres and must only contain audio files corresponding to those genres. This example assumes the top-level directory is inside your MATLAB™ tempdir directory and is called 'genres'. Extracting the compressed tar file should provide a folder containing the genre subfolders.

pathToData = fullfile(tempdir,'genres');
location = pathToData;
ads = audioDatastore(location,'IncludeSubFolders',true,...

Run the following to obtain a count of the musical genres in the data set.

ans=10×2 table
      Label      Count
    _________    _____

    blues         100 
    classical     100 
    country       100 
    disco         100 
    hiphop        100 
    jazz          100 
    metal         100 
    pop           100 
    reggae        100 
    rock          100 

As previously stated, there are 10 genres with 100 files each.

Training and Test Sets

Create training and test sets to develop and test our classifier. Allocate 80% of the data for training and hold out the remaining 20% for testing. The shuffle function of the audio datastore randomly shuffles the data. Do this prior to splitting the data by label to randomize the data. In this example, we set the random number generator seed for reproducibility. Use the audio datastore splitEachLabel function to perform the 80-20 split. splitEachLabel ensures that all classes are equally represented.

ads = shuffle(ads);
[adsTrain,adsTest] = splitEachLabel(ads,0.8);
ans=10×2 table
      Label      Count
    _________    _____

    blues         80  
    classical     80  
    country       80  
    disco         80  
    hiphop        80  
    jazz          80  
    metal         80  
    pop           80  
    reggae        80  
    rock          80  

ans=10×2 table
      Label      Count
    _________    _____

    blues         20  
    classical     20  
    country       20  
    disco         20  
    hiphop        20  
    jazz          20  
    metal         20  
    pop           20  
    reggae        20  
    rock          20  

You see that there are 800 records in the training data as expected and 200 records in the test data. Additionally, there are 80 examples of each genre in the training set and 20 examples of each genre in the test set.

To obtain the scattering features, define a helper function, helperbatchscatfeatures, that obtains the natural logarithm of the scattering features for 2^19 samples of each audio file and subsamples the number of scattering windows by 8. The source code for helperscatfeatures is listed in the appendix. Wavelet scattering features are computed using a batch size of 64 signals. The use of gpuArray (Parallel Computing Toolbox) with a CUDA-enabled NVIDIA GPU provides a significant acceleration for this batch computation. With this scattering framework, batch size, and GPU (NVIDIA Titan XP), the GPU implementation reduces the time needed to compute the scattering features by approximately a factor of 7.

N = 2^19;
batchsize = 64;
scTrain = [];
while hasdata(adsTrain)
    sc = helperbatchscatfeatures(adsTrain,sf,N,batchsize);
    scTrain = cat(3,scTrain,sc);

Record the number of time windows in the scattering transform for label creation.

numTimeWindows = size(sc,2);

Repeat the feature extraction process for the test data.

scTest = [];
while hasdata(adsTest)
   sc = helperbatchscatfeatures(adsTest,sf,N,batchsize);
   scTest = cat(3,scTest,sc); 

Determine the number of paths in the scattering network and reshape the training and test features into 2-D matrices.

[~,npaths] = sf.dfspaths();
Npaths = sum(npaths);
TrainFeatures = permute(scTrain,[2 3 1]);
TrainFeatures = reshape(TrainFeatures,[],Npaths,1);
TestFeatures = permute(scTest,[2 3 1]);
TestFeatures = reshape(TestFeatures,[],Npaths,1);

Each row of TrainFeatures and TestFeatures is one scattering time window across the Npaths (341) paths in the scattering transform of each audio signal. For each music sample, there are numTimeWindows such time windows. Accordingly, the feature matrix for the training data is 25600-by-341. The number of rows is equal to the number of training examples (800) multiplied by the number of scattering windows per example (32). Similarly, the scattering feature matrix for the test data is 6400-by-341. There are 200 test examples and 32 windows per example. Create a genre label for each of the 32 windows in the wavelet scattering feature matrix for the training data.

trainLabels = adsTrain.Labels;
numTrainSignals = numel(trainLabels);
trainLabels = repmat(trainLabels,1,numTimeWindows);
trainLabels = reshape(trainLabels',numTrainSignals*numTimeWindows,1);

Repeat the process for the test data.

testLabels = adsTest.Labels;
numTestSignals = numel(testLabels);
testLabels = repmat(testLabels,1,numTimeWindows);
testLabels = reshape(testLabels',numTestSignals*numTimeWindows,1);

In this example, use a multi-class support vector machine (SVM) classifier with a cubic polynomial kernel. Fit the SVM to the training data.

template = templateSVM(...
    'KernelFunction', 'polynomial', ...
    'PolynomialOrder', 3, ...
    'KernelScale', 'auto', ...
    'BoxConstraint', 1, ...
    'Standardize', true);
Classes = {'blues','classical','country','disco','hiphop','jazz',...
trainingOptions = struct('UseParallel',true);
classificationSVM = fitcecoc(...
    TrainFeatures, ...
    trainLabels, ...
    'Learners', template, ...
    'Coding', 'onevsone','ClassNames',categorical(Classes));

Test Set Prediction

Use the SVM model fit to the scattering transforms of the training data to predict music genres for the test data. Recall there are 32 time windows for each signal in the scattering transform. Use a simple majority vote to predict the genre. The helper function helperMajorityVote obtains the mode of the genre labels over all 32 scattering windows. If there is no unique mode, helperMajorityVote returns a classification error indicated by 'NoUniqueMode'. This results in an extra column in the confusion matrix. The source code for helperMajorityVote is listed in the appendix.

predLabels = predict(classificationSVM,TestFeatures);
[TestVotes,TestCounts] = helperMajorityVote(predLabels,adsTest.Labels,categorical(Classes));
testAccuracy = sum(eq(TestVotes,adsTest.Labels))/numTestSignals*100
testAccuracy = 87.5000

The test accuracy, testAccuracy, is approximately 87.5 percent. This accuracy is comparable with the state of the art of the GTZAN dataset.

Display the confusion matrix to inspect the genre-by-genre accuracy rates. Recall there are 20 examples in each class.


The diagonal of the confusion matrix plot shows that the classification accuracies for the individual genres is quite good in general. Extract these genre accuracies and plot separately.

cm = confusionmat(TestVotes,adsTest.Labels);
cm(:,end) = [];
genreAccuracy = diag(cm)./20*100;
title('GPU Example: Percent Correct by Genre (Test Set)');


This example demonstrated the use of wavelet time scattering and the audio datastore in music genre classification. In this example, wavelet time scattering achieved an classification accuracy comparable to state of the art performance for the GTZAN dataset. As opposed to other approaches requiring the extraction of a number of time-domain and frequency-domain features, wavelet scattering only required the specification of a single parameter, the scale of the time invariant. The audio datastore enabled us to efficiently manage the transfer of a large dataset from disk into MATLAB and permitted us to randomize the data and accurately retain genre membership of the randomized data through the classification workflow.


  1. Anden, J. and Mallat, S. 2014. Deep scattering spectrum. IEEE Transactions on Signal Processing, Vol. 62, 16, pp. 4114-4128.

  2. Bergstra, J., Casagrande, N., Erhan, D., Eck, D., and Kegl, B. Aggregate features and AdaBoost for music classification. Machine Learning, Vol. 65, Issue 2-3, pp. 473-484.

  3. Irvin, J., Chartock, E., and Hollander, N. 2016. Recurrent neural networks with attention for genre classification.

  4. Li, T., Chan, A.B., and Chun, A. 2010. Automatic musical pattern feature extraction using convolutional neural network. International Conference Data Mining and Applications.

  5. Mallat. S. 2012. Group invariant scattering. Communications on Pure and Applied Mathematics, Vol. 65, 10, pp. 1331-1398.

  6. Panagakis, Y., Kotropoulos, C.L., and Arce, G.R. 2014. Music genre classification via joint sparse low-rank representation of audio features. IEEE Transactions on Audio, Speech, and Language Processing, 22, 12, pp. 1905-1917.

  7. Tzanetakis, G. and Cook, P. 2002. Music genre classification of audio signals. IEEE Transactions on Speech and Audio Processing, Vol. 10, No. 5, pp. 293-302.

  8. GTZAN Genre Collection.

Appendix — Supporting Functions

helperMajorityVote — This function returns the mode of the class labels predicted over a number of feature vectors. In wavelet time scattering, we obtain a class label for each time window. If no unique mode is found a label of 'NoUniqueMode' is returned to denote a classification error.

function [ClassVotes,ClassCounts] = helperMajorityVote(predLabels,origLabels,classes)
% This function is in support of wavelet scattering examples only. It may
% change or be removed in a future release.

% Make categorical arrays if the labels are not already categorical
predLabels = categorical(predLabels);
origLabels = categorical(origLabels);
% Expects both predLabels and origLabels to be categorical vectors
Npred = numel(predLabels);
Norig = numel(origLabels);
Nwin = Npred/Norig;
predLabels = reshape(predLabels,Nwin,Norig);
ClassCounts = countcats(predLabels);
[mxcount,idx] = max(ClassCounts);
ClassVotes = classes(idx);
% Check for any ties in the maximum values and ensure they are marked as
% error if the mode occurs more than once
modecnt = modecount(ClassCounts,mxcount);
ClassVotes(modecnt>1) = categorical({'NoUniqueMode'});
ClassVotes = ClassVotes(:);

    function modecnt = modecount(ClassCounts,mxcount)
        modecnt = Inf(size(ClassCounts,2),1);
        for nc = 1:size(ClassCounts,2)
            modecnt(nc) = histc(ClassCounts(:,nc),mxcount(nc));

helperbatchscatfeatures — This function returns the wavelet time scattering feature matrix for a given input signal. In this case, we use the natural logarithm of the wavelet scattering coefficients. The scattering feature matrix is computed on 2^19 samples of a signal. The scattering features are subsampled by a factor of 8.

function sc = helperbatchscatfeatures(ds,sf,N,batchsize)
% This function is only intended to support examples in the Wavelet
% Toolbox. It may be changed or removed in a future release.

% Read batch of data from audio datastore
batch = helperReadBatch(ds,N,batchsize);
% Obtain scattering features
S = sf.featureMatrix(batch,'transform','log');
% Subsample the features
sc = S(:,1:8:end,:);

helperReadBatch — This function reads batches of a specified size from a datastore. Each column of the output is a separate signal from the datastore. The output may have fewer columns than the batchsize if the datastore does not have enough records.

function batchout = helperReadBatch(ds,N,batchsize)
% This function is only in support of Wavelet Toolbox examples. It may
% change or be removed in a future release.
% batchout = readBatch(ds,batchsize) where ds is the Datastore and
%   ds is the Datastore
%   batchsize is the batchsize

kk = 1;

while(hasdata(ds)) && kk <= batchsize
    tmpRead = read(ds);
    batchout(:,kk) = tmpRead(1:N); %#ok<AGROW>
    kk = kk+1;


See Also

Related Topics