Visualizar activaciones de redes de LSTM
Este ejemplo muestra cómo investigar y visualizar las características aprendidas por las redes de LSTM extrayendo las activaciones.
Cargue una red preentrenada. JapaneseVowelsNet
es una red de LSTM preentrenada en el conjunto de datos de vocales japonesas, como se describe en [1] y [2]. Se ha entrenado con las secuencias ordenadas por su longitud con un tamaño de minilote de 27.
load JapaneseVowelsNet
Visualice la arquitectura de red.
net.Layers
ans = 5x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax 5 'classoutput' Classification Output crossentropyex with '1' and 8 other classes
Cargue los datos de prueba.
[XTest,YTest] = japaneseVowelsTestData;
Visualice la primera serie de tiempo en una gráfica. Cada línea se corresponde con una característica.
X = XTest{1}; figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature " + string(1:numFeatures),'Location','northeastoutside')
Por cada unidad de tiempo de las secuencias, obtenga las activaciones producidas por la capa de LSTM (capa 2) para una unidad de tiempo en concreto y actualice el estado de la red.
sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits; for i = 1:sequenceLength features(:,i) = activations(net,X(:,i),idxLayer); [net, YPred(i)] = classifyAndUpdateState(net,X(:,i)); end
Visualice las primeras diez unidades ocultas mediante un mapa de calor.
figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")
El mapa de calor muestra la intensidad con la que se activa cada unidad oculta y destaca cómo cambian las activaciones a lo largo del tiempo.
Referencias
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
Consulte también
trainNetwork
| trainingOptions
| lstmLayer
| bilstmLayer
| sequenceInputLayer
| activations