LSTM not outputting sequence
Mostrar comentarios más antiguos
I am attempting to do sequence-to-sequence classification.
I haveNtime series of
observations each, and each observation collectsp features.
I build a
cell array XTrain. I set XTrain{i} to be the i-th
time series in my database.
I have two classes. I build a
cell array YTrain, where YTrain{i} is a
categorical vector telling me which class is at which time.
Now I build the following network:
inputSize = [p, 1, 1];
filterSize = [2 1];
numFilters = 20;
numHiddenUnits = 128;
numClasses = 2;
layers = [ ...
sequenceInputLayer(inputSize,'Name','input')
sequenceFoldingLayer('Name','fold')
convolution2dLayer(filterSize,numFilters,'Name','conv1')
reluLayer('Name','relu1')
convolution2dLayer(filterSize,numFilters,'Name','conv2')
reluLayer('Name','relu2')
flattenLayer('Name','flatten')
sequenceUnfoldingLayer('Name','unfold')
lstmLayer(numHiddenUnits,'OutputMode','sequence','Name','lstm')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')
classificationLayer('Name','classification')];
lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph,'fold/miniBatchSize','unfold/miniBatchSize');
maxEpochs = 1;
miniBatchSize = 2;
options = trainingOptions('adam', ...
'ExecutionEnvironment','cpu', ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'GradientThreshold',1, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,lgraph,options);
However, if I then run:
YScores = predict(net,XTrain,'MiniBatchSize',1);
the output is a
cell array whose i-th entry is a
vector of class probabilities.
This is INCORRECT. It should be a
vector of class probabilities.
4 comentarios
John Malik
el 18 de Dic. de 2019
Ridwan Alam
el 25 de Dic. de 2019
Editada: Ridwan Alam
el 25 de Dic. de 2019
Hey John, did you get a solution? Please share. Thanks!
Mohammad Sami
el 26 de Dic. de 2019
Can you try putting the sequenceUnfoldingLayer before the flatten layer.
John Malik
el 26 de Dic. de 2019
Editada: John Malik
el 26 de Dic. de 2019
Respuestas (0)
Categorías
Más información sobre Deep Learning Toolbox en Centro de ayuda y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!