Main Content

resetState

Reset model internal states

Since R2024a

    Description

    resetState(d) clears the memory of past data patterns and prepares the deepSignalAnomalyDetectorLSTMForecaster model d so you can use detect on a new signal.

    Note

    resetState does not reset the trained network.

    example

    Examples

    collapse all

    Load the file sineWaveAnomalyData.mat, which contains two sets of synthetic three-channel sinusoidal signals.

    • sineWaveNormal contains the 10 sinusoids used to train the convolutional anomaly detector. Each signal has a series of small-amplitude impact-like imperfections but otherwise has stable amplitude and frequency.

    • sineWaveAbnormal contains three signals of similar length and amplitude to the training data. One of the signals has an abrupt, finite-time change in frequency. Another signal has a finite-duration amplitude change in one of its channels. A third has random spikes in each channel.

    Plot three normal signals and the three signals with anomalies.

    load sineWaveAnomalyData
     
    tiledlayout(3,2,TileSpacing="compact",Padding="compact")
    rnd = randperm(length(sineWaveNormal));
    for kj = 1:length(sineWaveAbnormal)
        nexttile
        plot(sineWaveNormal{rnd(kj)})
        title("Normal Signal")
        nexttile
        plot(sineWaveAbnormal{kj})
        title("Signal with Anomalies")
    end

    Figure contains 6 axes objects. Axes object 1 with title Normal Signal contains 3 objects of type line. Axes object 2 with title Signal with Anomalies contains 3 objects of type line. Axes object 3 with title Normal Signal contains 3 objects of type line. Axes object 4 with title Signal with Anomalies contains 3 objects of type line. Axes object 5 with title Normal Signal contains 3 objects of type line. Axes object 6 with title Signal with Anomalies contains 3 objects of type line.

    Create a long short-term memory (LSTM) forecaster object to detect the anomalies in the abnormal signals. Specify a window length of 10 samples.

    D = deepSignalAnomalyDetector(3,"lstmforecaster",windowLength=10);

    Train the forecaster using the anomaly-free sinusoids. Use the training options for the adaptive moment estimation (Adam) optimizer and specify a maximum number of 100 epochs. For more information, see trainingOptions (Deep Learning Toolbox).

    opts = trainingOptions("adam",MaxEpochs=100,ExecutionEnvironment="cpu");
    trainDetector(D,sineWaveNormal,opts)
        Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss
        _________    _____    ___________    _________    ____________
                1        1       00:00:00        0.001          0.6369
               50       50       00:00:03        0.001         0.19706
              100      100       00:00:05        0.001        0.064225
    Training stopped: Max epochs completed
    Computing threshold...
    Threshold computation completed.
    

    Use the trained detector to find the anomalies in the first signal. Reset the state of the detector. Stream the data one sample at a time and have the detector keep its state after each reading. Compute the reconstruction loss for each one-sample frame. Categorize signal regions where the loss exceeds a specified threshold as anomalous.

    resetState(D)
    
    sg = sineWaveAbnormal{1};
    anoms = NaN(size(sg));
    losss = zeros(size(sg));
    
    for kj = 1:length(sg)
        frame = sg(kj,:);
        [lb,lo] = detect(D,frame, ...
            KeepState=true,ExecutionEnvironment="cpu");
        anoms(kj) = lb;
        losss(kj) = lo;
    end

    Plot the anomalous signal, the reconstruction loss, and the categorical array that declares each sample of the signal as being anomalous or not.

    figure
    tiledlayout("vertical")
    nexttile
    plot(sg)
    nexttile
    plot(losss)
    nexttile
    stem(anoms,".")

    Figure contains 3 axes objects. Axes object 1 contains 3 objects of type line. Axes object 2 contains 3 objects of type line. Axes object 3 contains 3 objects of type stem.

    Reset the state of the detector. Find the anomalies in the third signal. Plot the anomalous signal, the reconstruction loss, and the categorical array that declares each sample of the signal as being anomalous or not.

    resetState(D)
    
    sg = sineWaveAbnormal{3};
    anoms = NaN(size(sg));
    losss = zeros(size(sg));
    
    for kj = 1:length(sg)
        frame = sg(kj,:);
        [lb,lo] = detect(D,frame, ...
            KeepState=true,ExecutionEnvironment="cpu");
        anoms(kj) = lb;
        losss(kj) = lo;
    end
    
    figure
    tiledlayout("vertical")
    nexttile
    plot(sg)
    nexttile
    plot(losss)
    nexttile
    stem(anoms,".")

    Figure contains 3 axes objects. Axes object 1 contains 3 objects of type line. Axes object 2 contains 3 objects of type line. Axes object 3 contains 3 objects of type stem.

    Input Arguments

    collapse all

    Anomaly detector, specified as a deepSignalAnomalyDetectorLSTMForecaster object. Use the deepSignalAnomalyDetector function to create d.

    Version History

    Introduced in R2024a