Main Content

Clasificación de secuencias mediante deep learning

Este ejemplo muestra cómo clasificar datos secuenciales mediante una red de memoria de corto-largo plazo (LSTM).

Para entrenar una red neuronal profunda con la que clasificar datos secuenciales, se puede utilizar una red de LSTM. Una red de LSTM permite introducir datos secuenciales en una red y hacer predicciones basadas en las unidades de tiempo individuales de los datos secuenciales.

Este ejemplo usa el conjunto de datos de vocales japonesas como se describe en [1] y [2]. En este ejemplo se entrena una red de LSTM para reconocer al hablante dados los datos de series de tiempo que representan dos vocales japonesas pronunciadas consecutivamente. Los datos de entrenamiento contienen datos de series de tiempo para nueve hablantes. Cada secuencia cuenta con 12 características y diferentes longitudes. El conjunto de datos contiene 270 observaciones de entrenamiento y 370 observaciones de prueba.

Cargar datos secuenciales

Cargue los datos de entrenamiento de las vocales japonesas. XTrain es un arreglo de celdas que contiene 270 secuencias de dimensión 12 con diferentes longitudes. Y es un vector categórico de las etiquetas "1","2",...,"9", que se corresponden con los nueve hablantes. Las entradas de XTrain son matrices con 12 filas (una fila por característica) y un número variable de columnas (una columna por unidad de tiempo).

[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)
ans=5×1 cell array
    {12x20 double}
    {12x26 double}
    {12x22 double}
    {12x20 double}
    {12x21 double}

Visualice la primera serie de tiempo en una gráfica. Cada línea se corresponde con una característica.

figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
numFeatures = size(XTrain{1},1);
legend("Feature " + string(1:numFeatures),Location="northeastoutside")

Figure contains an axes object. The axes object with title Training Observation 1 contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

Preparar datos para el relleno

Durante el entrenamiento, de forma predeterminada, el software divide los datos de entrenamiento en minilotes y rellena las secuencias de manera que tengan la misma longitud. Si se rellenan demasiado, se puede producir un efecto negativo en el rendimiento de la red.

Para evitar que el proceso de entrenamiento añada demasiado relleno, puede ordenar los datos de entrenamiento por longitud de secuencia y elegir un tamaño de minilote para que las secuencias de un minilote tengan una longitud similar. La siguiente figura muestra el efecto de rellenar secuencias antes y después de ordenar los datos.

Obtenga las longitudes de las secuencias para cada observación.

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

Ordene los datos por longitud de secuencia.

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

Visualice las longitudes de las secuencias ordenadas en una gráfica de barras.

figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

Figure contains an axes object. The axes object with title Sorted Data contains an object of type bar.

Elija un tamaño de minilote de 27 para dividir los datos de entrenamiento de manera uniforme y reducir la cantidad de relleno en los minilotes. La siguiente figura ilustra el relleno añadido a las secuencias.

miniBatchSize = 27;

Definir la arquitectura de red de LSTM

Defina la arquitectura de la red de LSTM. Especifique el tamaño de la entrada para tener secuencias de tamaño 12 (la dimensión de los datos de entrada). Especifique una capa de LSTM bidireccional con 100 unidades ocultas y obtenga como salida el último elemento de la secuencia. Por último, especifique nueve clases incluyendo una capa totalmente conectada de tamaño 9, seguida de una capa softmax y una capa de clasificación.

Si tiene acceso a secuencias completas en el momento de la predicción, podrá usar una capa de LSTM bidireccional en su red. Una capa de LSTM bidireccional aprende de la secuencia completa en cada unidad de tiempo. Si no tiene acceso a la secuencia completa en el momento de la predicción, por ejemplo, si está pronosticando valores o prediciendo una unidad de tiempo cada vez, utilice una capa de LSTM en su lugar.

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  5x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 12 dimensions
     2   ''   BiLSTM                  BiLSTM with 100 hidden units
     3   ''   Fully Connected         9 fully connected layer
     4   ''   Softmax                 softmax
     5   ''   Classification Output   crossentropyex

Ahora, especifique las opciones de entrenamiento. Especifique el solver en "adam", el umbral del gradiente en 1 y el número máximo de épocas en 50. Para rellenar los datos y que tengan la misma longitud que las secuencias más largas, especifique la longitud de la secuencia en "longest". Para asegurarse de que los datos permanecen ordenados por longitud de la secuencia, especifique que no se cambie el orden nunca.

Dado que los minilotes son pequeños y tienen secuencias cortas, el entrenamiento es más adecuado para la CPU. Establezca la opción ExecutionEnvironment en "cpu". Para realizar un entrenamiento en una GPU, si está disponible, establezca la opción ExecutionEnvironment en "auto" (este es el valor por defecto).

options = trainingOptions("adam", ...
    ExecutionEnvironment="cpu", ...
    GradientThreshold=1, ...
    MaxEpochs=50, ...
    MiniBatchSize=miniBatchSize, ...
    SequenceLength="longest", ...
    Shuffle="never", ...
    Verbose=0, ...
    Plots="training-progress");

Entrenar la red de LSTM

Entrene la red de LSTM con las opciones de entrenamiento especificadas usando trainNetwork.

net = trainNetwork(XTrain,YTrain,layers,options);

{"String":"Figure Training Progress (24-Jul-2022 21:29:52) contains 2 axes objects and another object of type uigridlayout. Axes object 1 contains 9 objects of type patch, text, line. Axes object 2 contains 9 objects of type patch, text, line.","Tex":[],"LaTex":[]}

Probar la red de LSTM

Cargue el conjunto de prueba y clasifique las secuencias por hablantes.

Cargue los datos de prueba de las vocales japonesas. XTest es un arreglo de celdas que contiene 370 secuencias de dimensión 12 con diferentes longitudes. YTest es un vector categórico de las etiquetas "1","2",...,"9", que se corresponden con los nueve hablantes.

[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)
ans=3×1 cell array
    {12x19 double}
    {12x17 double}
    {12x19 double}

La red de LSTM net se ha entrenado utilizando minilotes de secuencias de longitud similar. Asegúrese de que los datos de prueba se organizan de la misma forma. Ordene los datos de prueba por longitud de secuencia.

numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end

[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

Clasifique los datos de prueba. Para reducir la cantidad de relleno introducida por el proceso de clasificación, especifique el mismo tamaño de minilote utilizado para el entrenamiento. Para aplicar el mismo relleno que en los datos de entrenamiento, especifique la longitud de secuencia en "longest".

YPred = classify(net,XTest, ...
    MiniBatchSize=miniBatchSize, ...
    SequenceLength="longest");

Calcule la precisión de clasificación de las predicciones.

acc = sum(YPred == YTest)./numel(YTest)
acc = 0.9730

Referencias

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Consulte también

| | | |

Temas relacionados