LSTM Layer input size.

8 visualizaciones (últimos 30 días)
Alex
Alex el 17 de Sept. de 2023
Comentada: Alex el 18 de Sept. de 2023
Hi all, quick question.
I am learning regression LSTMs and in the following code I've got a TrainX variable which corresponds to [batchSize, sequenceLength, inputSize], which is [5950 x 14 x5].
clc; clear; close all;
% LSTM - Test 1 (AAPL)
% Read the CSV file into a table
data = readtable('AAPL.csv');
% Display the first few rows of the table
head(data)
% Extract data
Date = data.Date;
OpenP = data.Open;
HighP = data.High;
LowP = data.Low;
CloseP = data.Close;
CloseAdjP = data.AdjClose;
Volume = data.Volume;
%Scaling min-max
scaledOpenP = (OpenP - min(OpenP)) / (max(OpenP) - min(OpenP));
scaledHighP = (HighP - min(HighP)) / (max(HighP) - min(HighP));
scaledCloseP = (CloseP - min(CloseP)) / (max(CloseP) - min(CloseP));
scaledLowP = (LowP - min(LowP)) / (max(LowP) - min(LowP));
scaledCloseAdjP = (CloseAdjP - min(CloseAdjP)) / (max(CloseAdjP) - min(CloseAdjP));
scaledVolume = (Volume - min(Volume)) / (max(Volume) - min(Volume));
TrainX = [];
TrainY = [];
n_future = 1;
n_past = 14;
len = length(scaledCloseP);
for i = 1:len-n_past-n_future+1
% Extract sequences for TrainX
TrainX(end+1, :, 1) = scaledOpenP(i:i+n_past-1);
TrainX(end, :, 2) = scaledHighP(i:i+n_past-1);
%TrainX(end, :, 3) = scaledCloseP(i:i+n_past-1);
TrainX(end, :, 3) = scaledLowP(i:i+n_past-1);
TrainX(end, :, 4) = scaledCloseAdjP(i:i+n_past-1);
TrainX(end, :, 5) = scaledVolume(i:i+n_past-1);
% Next day close price for TrainY
TrainY(end+1, 1) = scaledCloseP(i+n_past+n_future-1);
end
%TrainX = permute(TrainX, [2, 1, 3]);
% Define the number of features
numFeatures = [14 5];
numHiddenUnits1 = 500; % First LSTM layer
numHiddenUnits2 = 200; % Second LSTM layer
dropoutRate = 0.2; % Dropout rate
numResponses = 1;
layers = [ ...
sequenceInputLayer(numFeatures, Normalization="zscore")
lstmLayer(numHiddenUnits1)
fullyConnectedLayer(numResponses)
regressionLayer];
analyzeNetwork(layers);
options = trainingOptions('adam', ...
'MaxEpochs', 10, ...
'MiniBatchSize', 16, ...
'Verbose', 1, ...
'Plots', 'training-progress');
% Train the network
net = trainNetwork(TrainX, TrainY, layers, options);
I have used standart input for LSTM (At least as I did it in Python).
I think the issues is numFeatures = [14 5]; How should I specify it?
When I run the code matlab gives out the following error
Error using trainNetwork
Invalid network.
Error in LSTM_1 (line 86)
net = trainNetwork(TrainX, TrainY, layers, options);
Caused by:
Layer 2: LSTM layers must have scalar input size, but input size (14×5) was received. Try using a
flatten layer before the LSTM layer.

Respuesta aceptada

Ben
Ben el 18 de Sept. de 2023
For sequenceInputLayer you don't need to specify the sequence length as a feature. So you would just need numFeatures = 5.
For batches of sequence data in trainNetwork you need each observation in the batch to be a cell, this applies to the input and output sequences - this is to allow for cases where each sequence might have a different length. Additionally each cell should contain an array with size "NumFeatures x SequenceLength".
Here's one way you could do that with data like yours
batchSize = 5950;
sequenceLength = 14;
numFeatures = 5;
% generate random data for example
TrainX = randn(batchSize,sequenceLength,numFeatures);
% permute batch x sequence x features -> features x sequence x batch
TrainX = permute(TrainX,[3,2,1]);
% convert to cell
TrainX = num2cell(TrainX,[1,2]);
% flatten into column vector of cell-s
TrainX = TrainX(:);
To demonstrate, here's how you would use trainNetwork to train an LSTM that attempts to just memorise the input:
layers = [
sequenceInputLayer(numFeatures)
lstmLayer(numFeatures)
regressionLayer];
opts = trainingOptions("sgdm");
trainNetwork(TrainX,TrainX,layers,opts)
  1 comentario
Alex
Alex el 18 de Sept. de 2023
That worked.
Thank you very much! :)

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Sequence and Numeric Feature Data Workflows en Help Center y File Exchange.

Productos


Versión

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by