Main Content

Predecir y actualizar el estado de una red en Simulink

En este ejemplo se muestra cómo predecir respuestas de una red neuronal recurrente entrenada en Simulink® mediante el bloque Stateful Predict. Este ejemplo usa una red de memoria de corto-largo plazo (LSTM) preentrenada.

Cargar una red preentrenada

Cargue JapaneseVowelsNet, una red de LSTM preentrenada en el conjunto de datos de vocales japonesas, como se describe en [1] y [2]. Esta red se ha entrenado con las secuencias ordenadas por su longitud con un tamaño de minilote de 27.

load JapaneseVowelsNet

Visualice la arquitectura de red.

analyzeNetwork(net);

Cargar datos de prueba

Cargue los datos de prueba de las vocales japonesas. XTest es un arreglo de celdas que contiene 370 secuencias de dimensión 12 con diferentes longitudes. TTest es un vector categórico de las etiquetas "1","2",...,"9", que se corresponden con los nueve hablantes.

Cree un arreglo de horario simin con filas con marcas de tiempo y copias repetidas de X.

load JapaneseVowelsTestData
X = XTest{94};
numTimeSteps = size(X,2);
simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));

Modelo de Simulink para la predicción de respuestas

El modelo de Simulink para la predicción de respuestas contiene un bloque Stateful Predict para predecir las puntuaciones y un bloque From Workspace para cargar la secuencia de datos de entrada en las unidades de tiempo.

Para restablecer la red neuronal recurrente a su estado inicial durante la simulación, coloque el bloque Stateful Predict dentro de un Resettable Subsystem y use la señal de control Reset como activador.

open_system('StatefulPredictExample');

Configurar modelo para la simulación

Ajuste los parámetros de configuración del modelo del bloque Stateful Predict.

set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulPredictExample', 'SimulationMode', 'Normal');

Ejecutar la simulación

Para calcular las respuestas de la red JapaneseVowelsNet, ejecute la simulación. Las puntuaciones de la predicción se guardan en el área de trabajo de MATLAB®.

out = sim('StatefulPredictExample');

Represente las puntuaciones de la predicción. La gráfica muestra cómo cambian las puntuaciones de la predicción entre unidades de tiempo.

scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps));

classNames = string(net.Layers(end).Classes);
figure
lines = plot(scores');
xlim([1 numTimeSteps])
legend("Class " + classNames,'Location','northwest')
xlabel("Time Step")
ylabel("Score")
title("Prediction Scores Over Time Steps")

Destaque las puntuaciones de la predicción a lo largo de las unidades de tiempo de la clase correcta.

trueLabel = TTest(94);
lines(trueLabel).LineWidth = 3;

Muestre la predicción final de unidades de tiempo en una gráfica de barras.

figure
bar(scores(:,end))
title("Final Prediction Scores")
xlabel("Class")
ylabel("Score")

Referencias

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Consulte también

| | |

Temas relacionados