Import PyTorch LSTM Model into Matlab
9 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Hey Guys,
I am currently trying to use my Pytorch LSTM in Matlab (Trained with Pytorch Lightning) but I have no idea how to use the importNetworkFromPyTorch function with an LSTM. The Structure of the model is the following:
LSTM -> Linear -> Sigmoid
The LSTM properties (https://docs.pytorch.org/docs/stable/generated/torch.nn.LSTM.html) are (num_inputs=3, nhid=5, nlayers=5) which causes the Linear layer to be (in=5, out=1).
The Training Data has the shape [BS, 600, 3] with BS being batch_size, 600 being the time series and 3 being the individual input at one timestep. The shape of the hidden state is [5, BS, 5].
So my problem is that I do not understand what input sizes I have to put into the importNetworkFromPyTorch function.
I expect it so be something like this:
net = importNetworkFromPyTorch("example/path/model.pt",PyTorchInputSizes={[NaN,3], [2, 5, NaN, 5]})
I exported the traced model by:
traced_model = torch.jit.trace(model.model.forward, (input, hidden_input))
torch.jit.save(traced_model, "model.pt")
The shape of input is [3] and of hidden_input is ([5, 1, 5], [5, 1, 5]) (one for hidden state and one for context)
Can you please tell me how to use this importNetworkFromPyTorch function.
0 comentarios
Respuestas (1)
Gayathri
el 15 de Mayo de 2025
Can you please confirm on which MATLAB function you are using? And are you facing any errors when running the "importNetworkFromPyTorch" command in MATLAB?
I can see in the MATLAB documentation that importing LSTM layers is only supported from MATLAB R2025a. Please upgrade to MATLAB R2025a to import the LSTM model.
Hope this helps!
0 comentarios
Ver también
Categorías
Más información sobre Image Data Workflows en Help Center y File Exchange.
Productos
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!