Why Are Hidden State and Cell State Vectors Zero After Training an LSTM Model with trainNetwork Functionality?
10 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Shubham Baisthakur
el 10 de Oct. de 2023
Comentada: Shubham Baisthakur
el 20 de Oct. de 2023
I am training an LSTM model using the trainNetwork functionality and follwing is the architecture of my model:
layers = [ ...
sequenceInputLayer(size(X_train{1},1))
layerNormalizationLayer
lstmLayer(x.num_hidden_units,'OutputMode','sequence')
fullyConnectedLayer(x.num_layers_ffnn)
dropoutLayer(0.1)
fullyConnectedLayer(1)
regressionLayer];
And I am training this using the following command:
options = trainingOptions('adam', ...
'MaxEpochs', 75, ...
'MiniBatchSize', x.batch_size, ...
'SequenceLength', 'longest', ...
'Shuffle', 'once', ...
'L2Regularization',0.01,...
'ValidationData',{X_val,Y_val}, ...
'ValidationFrequency',10,...
'Verbose',false,...
'ExecutionEnvironment','multi-gpu');
% Train the LSTM network
net = trainNetwork(X_train, Y_train, layers, options);
After training the model, the Hidden state and Cell state values for the LSTM layer is a vector of zeros. Why is this happening? I expect these vectors to have non-zero values to ensure the long term dependency between input and output parameters is captured.
0 comentarios
Respuesta aceptada
Neha
el 20 de Oct. de 2023
Hi Shubham,
The LSTM (Long Short-Term Memory) layer in a neural network is designed to remember values over arbitrary time intervals which indeed helps in maintaining and learning long-term dependencies. However, after training, the hidden and cell states of the LSTM layer are reset to zero. This is standard behavior for LSTMs, and it doesn't mean that the LSTM layer has not learned anything or that it's not working properly.
If you want to maintain the state of LSTM for some reason (like in case of time series prediction where you want the model to remember the state from the previous sequence), you can refer to the explanation for Open Loop Forecasting and Closed Loop Forecasting in the following documentation link:
Here "predictAndUpdateState" function has been used which updates the network state at every timestep.
Hope this helps!
Más respuestas (0)
Ver también
Categorías
Más información sobre Custom Training Loops en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!