Classify Text Data Using Deep Learning
This example shows how to classify text data using a deep learning long short-term memory (LSTM) network.
Text data is naturally sequential. A piece of text is a sequence of words, which might have dependencies between them. To learn and use long-term dependencies to classify sequence data, use an LSTM neural network. An LSTM network is a type of recurrent neural network (RNN) that can learn long-term dependencies between time steps of sequence data.
To input text to an LSTM network, first convert the text data into numeric sequences. You can achieve this using a word encoding which maps documents to sequences of numeric indices. For better results, also include a word embedding layer in the network. Word embeddings map words in a vocabulary to numeric vectors rather than scalar indices. These embeddings capture semantic details of the words, so that words with similar meanings have similar vectors. They also model relationships between words through vector arithmetic. For example, the relationship "Rome is to Italy as Paris is to France" is described by the equation Italy – Rome + Paris = France.
There are four steps in training and using the LSTM network in this example:
Import and preprocess the data.
Convert the words to numeric sequences using a word encoding.
Create and train an LSTM network with a word embedding layer.
Classify new text data using the trained LSTM network.
Import Data
Import the factory reports data. This data contains labeled textual descriptions of factory events. To import the text data as strings, specify the text type to be 'string'
.
filename = "factoryReports.csv"; data = readtable(filename,'TextType','string'); head(data)
Description Category Urgency Resolution Cost _____________________________________________________________________ ____________________ ________ ____________________ _____ "Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45 "Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35 "There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200 "Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352 "Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55 "Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371 "A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441 "Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
The goal of this example is to classify events by the label in the Category
column. To divide the data into classes, convert these labels to categorical.
data.Category = categorical(data.Category); classNames = categories(data.Category);
View the distribution of the classes in the data using a histogram.
figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")
The next step is to partition it into sets for training and validation. Partition the data into a training partition and a held-out partition for validation and testing. Specify the holdout percentage to be 20%.
cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);
Extract the text data and labels from the partitioned tables.
textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; YTrain = dataTrain.Category; YValidation = dataValidation.Category;
To check that you have imported the data correctly, visualize the training text data using a word cloud.
figure
wordcloud(textDataTrain);
title("Training Data")
Preprocess Text Data
Create a function that tokenizes and preprocesses the text data. The function preprocessText
, listed at the end of the example, performs these steps:
Tokenize the text using
tokenizedDocument
.Convert the text to lowercase using
lower
.Erase the punctuation using
erasePunctuation
.
Preprocess the training data and the validation data using the preprocessText
function.
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
View the first few preprocessed training documents.
documentsTrain(1:5)
ans = 5×1 tokenizedDocument: 9 tokens: items are occasionally getting stuck in the scanner spools 10 tokens: loud rattling and banging sounds are coming from assembler pistons 10 tokens: there are cuts to the power when starting the plant 5 tokens: fried capacitors in the assembler 4 tokens: mixer tripped the fuses
Convert Document to Sequences
To input the documents into an LSTM network, use a word encoding to convert the documents into sequences of numeric indices.
To create a word encoding, use the wordEncoding
function.
enc = wordEncoding(documentsTrain);
The next conversion step is to pad and truncate documents so they are all the same length. The trainingOptions
function provides options to pad and truncate input sequences automatically. However, these options are not well suited for sequences of word vectors. Instead, pad and truncate the sequences manually. If you left-pad and truncate the sequences of word vectors, then the training might improve.
To pad and truncate the documents, first choose a target length, and then truncate documents that are longer than it and left-pad documents that are shorter than it. For best results, the target length should be short without discarding large amounts of data. To find a suitable target length, view a histogram of the training document lengths.
documentLengths = doclength(documentsTrain); figure histogram(documentLengths) title("Document Lengths") xlabel("Length") ylabel("Number of Documents")
Most of the training documents have fewer than 10 tokens. Use this as your target length for truncation and padding.
Convert the documents to sequences of numeric indices using doc2sequence
. To truncate or left-pad the sequences to have length 10, set the 'Length'
option to 10.
sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
XTrain(1:5)
ans=5×1 cell array
{[ 0 1 2 3 4 5 6 7 8 9]}
{[10 11 12 13 14 2 15 16 17 18]}
{[ 19 2 20 21 7 22 23 24 7 25]}
{[ 0 0 0 0 0 26 27 6 7 17]}
{[ 0 0 0 0 0 0 28 29 7 30]}
Convert the validation documents to sequences using the same options.
XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);
Create and Train LSTM Network
Define the LSTM network architecture. To input sequence data into the network, include a sequence input layer and set the input size to 1. Next, include a word embedding layer of dimension 50 and the same number of words as the word encoding. Next, include an LSTM layer and set the number of hidden units to 80. To use the LSTM layer for a sequence-to-label classification problem, set the output mode to 'last'
. Finally, add a fully connected layer with the same size as the number of classes, and a softmax layer.
inputSize = 1; embeddingDimension = 50; numHiddenUnits = 80; numWords = enc.NumWords; numClasses = numel(categories(YTrain)); layers = [ ... sequenceInputLayer(inputSize) wordEmbeddingLayer(embeddingDimension,numWords) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer]
layers = 5×1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' Word Embedding Layer Word embedding layer with 50 dimensions and 423 unique words 3 '' LSTM LSTM with 80 hidden units 4 '' Fully Connected 4 fully connected layer 5 '' Softmax softmax
Specify Training Options
Specify the training options:
Train using the Adam solver.
Specify the input data format 'CTB'.
Specify a mini-batch size of 16.
Shuffle the data every epoch.
Monitor the training progress by setting the
'Plots'
option to'training-progress'
.Specify the validation data using the
'ValidationData'
option.Suppress verbose output by setting the
'Verbose'
option tofalse
.
By default, trainnet
uses a GPU if one is available. Otherwise, it uses the CPU. To specify the execution environment manually, use the 'ExecutionEnvironment'
name-value pair argument of trainingOptions
. Training on a CPU can take significantly longer than training on a GPU. Training with a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
options = trainingOptions('adam', ... 'InputDataFormats','CTB', ... 'MiniBatchSize',16, ... 'GradientThreshold',2, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'Plots','training-progress', ... 'Metrics','accuracy', ... 'Verbose',false);
Train the LSTM network using the trainnet
function.
net = trainnet(XTrain,YTrain,layers,"crossentropy",options);
Predict Using New Data
Classify the event type of three new reports. Create a string array containing the new reports.
reportsNew = [ ... "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
Preprocess the text data using the preprocessing steps as the training documents.
documentsNew = preprocessText(reportsNew);
Convert the text data to sequences using doc2sequence
with the same options as when creating the training sequences.
XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);
Classify the new sequences using the trained LSTM network.
scores = minibatchpredict(net,XNew,InputDataFormats="CTB");
labelsNew = scores2label(scores,classNames)
labelsNew = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
Preprocessing Function
The function preprocessText
performs these steps:
Tokenize the text using
tokenizedDocument
.Convert the text to lowercase using
lower
.Erase the punctuation using
erasePunctuation
.
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
See Also
fastTextWordEmbedding
(Text Analytics Toolbox) | wordEmbeddingLayer
(Text Analytics Toolbox) | tokenizedDocument
(Text Analytics Toolbox) | lstmLayer
| trainnet
| trainingOptions
| dlnetwork
| doc2sequence
(Text Analytics Toolbox) | sequenceInputLayer
| wordcloud
(Text Analytics Toolbox)
Related Topics
- Generate Text Using Deep Learning
- Word-By-Word Text Generation Using Deep Learning (Text Analytics Toolbox)
- Create Simple Text Model for Classification (Text Analytics Toolbox)
- Analyze Text Data Using Topic Models (Text Analytics Toolbox)
- Analyze Text Data Using Multiword Phrases (Text Analytics Toolbox)
- Train a Sentiment Classifier (Text Analytics Toolbox)
- Sequence Classification Using Deep Learning
- Deep Learning in MATLAB