Predecir y actualizar el estado de una red en Simulink
En este ejemplo se muestra cómo predecir respuestas de una red neuronal recurrente entrenada en Simulink® mediante el bloque Stateful Predict
. Este ejemplo usa una red de memoria de corto-largo plazo (LSTM) preentrenada.
Cargar una red preentrenada
Cargue JapaneseVowelsNet
, una red de LSTM preentrenada en el conjunto de datos de vocales japonesas, como se describe en [1] y [2]. Esta red se ha entrenado con las secuencias ordenadas por su longitud con un tamaño de minilote de 27.
load JapaneseVowelsNet
Visualice la arquitectura de red.
analyzeNetwork(net);
Cargar datos de prueba
Cargue los datos de prueba de las vocales japonesas. XTest
es un arreglo de celdas que contiene 370 secuencias de dimensión 12 con diferentes longitudes. TTest
es un vector categórico de las etiquetas "1","2",...,"9", que se corresponden con los nueve hablantes.
Cree un arreglo de horario simin
con filas con marcas de tiempo y copias repetidas de X
.
load JapaneseVowelsTestData X = XTest{94}; numTimeSteps = size(X,2); simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));
Modelo de Simulink para la predicción de respuestas
El modelo de Simulink para la predicción de respuestas contiene un bloque Stateful Predict
para predecir las puntuaciones y un bloque From Workspace
para cargar la secuencia de datos de entrada en las unidades de tiempo.
Para restablecer la red neuronal recurrente a su estado inicial durante la simulación, coloque el bloque Stateful Predict
dentro de un Resettable Subsystem
y use la señal de control Reset
como activador.
open_system('StatefulPredictExample');
Configurar modelo para la simulación
Ajuste los parámetros de configuración del modelo del bloque Stateful Predict
.
set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat'); set_param('StatefulPredictExample', 'SimulationMode', 'Normal');
Ejecutar la simulación
Para calcular las respuestas de la red JapaneseVowelsNet
, ejecute la simulación. Las puntuaciones de la predicción se guardan en el área de trabajo de MATLAB®.
out = sim('StatefulPredictExample');
Represente las puntuaciones de la predicción. La gráfica muestra cómo cambian las puntuaciones de la predicción entre unidades de tiempo.
scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps)); classNames = string(net.Layers(end).Classes); figure lines = plot(scores'); xlim([1 numTimeSteps]) legend("Class " + classNames,'Location','northwest') xlabel("Time Step") ylabel("Score") title("Prediction Scores Over Time Steps")
Destaque las puntuaciones de la predicción a lo largo de las unidades de tiempo de la clase correcta.
trueLabel = TTest(94); lines(trueLabel).LineWidth = 3;
Muestre la predicción final de unidades de tiempo en una gráfica de barras.
figure bar(scores(:,end)) title("Final Prediction Scores") xlabel("Class") ylabel("Score")
Referencias
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
Consulte también
Stateful Predict | Stateful Classify | Predict | Image Classifier