Main Content

Visualizar activaciones de redes de LSTM

Este ejemplo muestra cómo investigar y visualizar las características aprendidas por las redes de LSTM extrayendo las activaciones.

Cargue una red preentrenada. JapaneseVowelsNet es una red de LSTM preentrenada en el conjunto de datos de vocales japonesas, como se describe en [1] y [2]. Se ha entrenado con las secuencias ordenadas por su longitud con un tamaño de minilote de 27.

load JapaneseVowelsNet

Visualice la arquitectura de red.

net.Layers
ans = 
  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Cargue los datos de prueba.

[XTest,YTest] = japaneseVowelsTestData;

Visualice la primera serie de tiempo en una gráfica. Cada línea se corresponde con una característica.

X = XTest{1};

figure
plot(XTest{1}')
xlabel("Time Step")
title("Test Observation 1")
numFeatures = size(XTest{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Figure contains an axes object. The axes object with title Test Observation 1 contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

Por cada unidad de tiempo de las secuencias, obtenga las activaciones producidas por la capa de LSTM (capa 2) para una unidad de tiempo en concreto y actualice el estado de la red.

sequenceLength = size(X,2);
idxLayer = 2;
outputSize = net.Layers(idxLayer).NumHiddenUnits;

for i = 1:sequenceLength
    features(:,i) = activations(net,X(:,i),idxLayer);
    [net, YPred(i)] = classifyAndUpdateState(net,X(:,i));
end

Visualice las primeras diez unidades ocultas mediante un mapa de calor.

figure
heatmap(features(1:10,:));
xlabel("Time Step")
ylabel("Hidden Unit")
title("LSTM Activations")

Figure contains an object of type heatmap. The chart of type heatmap has title LSTM Activations.

El mapa de calor muestra la intensidad con la que se activa cada unidad oculta y destaca cómo cambian las activaciones a lo largo del tiempo.

Referencias

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Consulte también

| | | | |

Temas relacionados