improve LSTM with test data for traffic

3 visualizaciones (últimos 30 días)
arash rad
arash rad el 8 de Oct. de 2022
Hi everyone
I use a code from github and it's a LSTM algorithm . at first it run but it can't predict traffic flow data in each time step so i use this for loop and if to check if the prediction is not equal to test use the test data in the prediction instead of the predicted data .
for i = 1:Nt
if YPred ~= YTest
YPred(i) == YTest(i);
end
YPred = round(YPred);
thank you for helping me
the whole code is this :
clc;clear all;close all
warning off
flow_data = readtable('zafar_queue.xlsx');
Y = flow_data.nVehContrib;
data = Y';
%
%about 4 hours and 20 minutes for data training
% about 40 minutes for test
numTimeStepsTrain = floor(0.95*numel(data));
dataTrain = data(1:numTimeStepsTrain);
dataTest = data(numTimeStepsTrain+1:end);
% Normalize(Training Data Set)
mu = mean(dataTrain);
sig = std(dataTrain);
dataTrainStandardized = (dataTrain - mu) / sig;
XTrain = dataTrainStandardized(1:end-1);
YTrain = dataTrainStandardized(2:end);
%LSTM Net Architecture Def
numFeatures = 1;
numResponses = 1;
numHiddenUnits = 200;
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits)
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('adam', ...
'MaxEpochs',500, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',150, ...
'LearnRateDropFactor',0.25, ...
'Verbose',1, ...
'Plots','training-progress');
%
% Train LSTM Net
net = trainNetwork(XTrain,YTrain,layers,options);
% Normalize flow a value between 0 and 1 (Testing Data Set)
dataTestStandardized = (dataTest - mu) / sig;
XTest = dataTestStandardized(1:end);
net = predictAndUpdateState(net,XTrain);
[net,YPred] = predictAndUpdateState(net,YTrain(end));
%
% Predict as long as the test period
numTimeStepsTest = numel(XTest);
for i = 2:numTimeStepsTest
[net,YPred(:,i)] = predictAndUpdateState(net,YPred(:,i-1),'ExecutionEnvironment','cpu');
end
% RMSE calculation of test data set
YTest = dataTest(1:end);
YTest = (YTest - mu) / sig;
rmse = sqrt(mean((YPred-YTest).^2))
% Denormalize Data
YPred = sig*YPred + mu;
YTest = sig*YTest + mu;
% X Label : collect one minute period
x_data = seconds(flow_data.begin);
x_train = x_data(1:numTimeStepsTrain);
x_train = x_train';
x_pred = x_data(numTimeStepsTrain:numTimeStepsTrain+numTimeStepsTest);
YPred = round(YPred);
Nt = length(YTest);
for i = 1:Nt
if YPred ~= YTest
YPred(i) == YTest(i);
else
YPred()
end
YPred = round(YPred);
% Train + Predict Plot
figure
plot(x_train(1:end),dataTrain(1:end))
hold on
plot(x_pred,[data(numTimeStepsTest) YPred],'.-')
% hold off
xlabel("time")
ylabel("FLow")
title("Forecast")
legend(["Observed" "Forecast"])
% RMSE Plot : Test + Predict Plot
figure
subplot(2,1,1)
plot(YTest)
hold on
plot(YPred,'.-')
hold off
legend(["Observed" "Forecast"])
ylabel("Period of time")
title("Forecast")
subplot(2,1,2)
stem(YPred - YTest)
xlabel("period of time")
ylabel("Error")
title("RMSE = " + rmse)
% Train + Test + Predict Plot
figure
plot(x_data,Y)
hold on
plot(x_pred,[data(numTimeStepsTrain) YPred],'.-')
hold off
xlabel("one-min period")
ylabel("Traffic Flow")
title("Compare Data")
legend(["Raw" "Forecast"])

Respuestas (0)

Categorías

Más información sobre Sequence and Numeric Feature Data Workflows en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by