Borrar filtros
Borrar filtros

Export LSTM to ONNX with proper input information

7 visualizaciones (últimos 30 días)
Brita Linnestad
Brita Linnestad el 7 de Jul. de 2022
Comentada: Sivylla Paraskevopoulou el 12 de Jul. de 2022
I have created a LSTM network and converted it to onnx using matlabs exportONNXNetwork. The onnx-network will be loaded in Java using OrtSession.
Layers: [5x1 nnet.cnn.layer.Layer]
layers =
Sequence Input Sequence input with 6 dimensions (numberOfFeatures)
LSTM LSTM with 50 hidden units
Fully connected 2 fully connected layers
Softmax softmax
Clasification Output crossentropyex
Sequence length is 24.
Using exportONNXNetwork(netlstm,filename), the only reported input is 'sequenceinput'.
How can i set up exportONNXNetwork so the onnx-model holds more/all input information needed when loading the model in Java?

Respuestas (1)

Sivylla Paraskevopoulou
Sivylla Paraskevopoulou el 7 de Jul. de 2022
I am not sure what you mean by "more/all input information". If you mean that you want a network that can be used for prediction, you must train the layer graph that you created and then export the trained network and not the layer graph.
  2 comentarios
Brita Linnestad
Brita Linnestad el 11 de Jul. de 2022
I have trained the layer graph, and then exported the trained network.
When loading the trained network in Java using OrtSession, I get an ortsession runtimeerror :
Non-zero status code returned while running LSTM node. Name:'lstm' Status Message: Input initial_h must have shape {1,24,50}. Actual:{1,1,50}
How can I, before I export my model from Matlab, set initial_h or other information needed for OrtSession to run properly?
Sivylla Paraskevopoulou
Sivylla Paraskevopoulou el 12 de Jul. de 2022
In MATLAB, if your input data is a vector sequence, the sequenceInputLayer expects the data in the format CSN, where C is the number of features or channels, S is the sequence length, and N is the number of observations. For an example on how to train a network with a vector sequence input, Train Network for Sequence Classification.
When you export the network to ONNX, the input tensor shape should be NSC. I am not sure what is happenning to the input when you convert from ONNX to ortSession.

Iniciar sesión para comentar.

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Productos


Versión

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by