predict
Predict labels using classification tree
Description
uses
additional options specified by one or more label
= predict(Mdl
,X
,Name,Value
)Name,Value
pair
arguments. For example, you can specify to prune Mdl
to
a particular level before predicting labels.
[
uses any of the input argument
in the previous syntaxes and additionally returns:label
,score
,node
,cnum
]
= predict(___)
A matrix of classification scores (
score
) indicating the likelihood that a label comes from a particular class. For classification trees, scores are posterior probabilities. For each observation inX
, the predicted class label corresponds to the minimum expected misclassification cost among all classes.A vector of predicted node numbers for the classification (
node
).A vector of predicted class number for the classification (
cnum
).
Input Arguments
Mdl
— Trained classification tree
ClassificationTree
model object | CompactClassificationTree
model object
Trained classification tree, specified as a ClassificationTree
or CompactClassificationTree
model
object. That is, Mdl
is a trained classification
model returned by fitctree
or compact
.
X
— Predictor data to be classified
numeric matrix | table
Predictor data to be classified, specified as a numeric matrix or table.
Each row of X
corresponds to one observation,
and each column corresponds to one variable.
For a numeric matrix:
The variables making up the columns of
X
must have the same order as the predictor variables that trainedMdl
.If you trained
Mdl
using a table (for example,Tbl
), thenX
can be a numeric matrix ifTbl
contains all numeric predictor variables. To treat numeric predictors inTbl
as categorical during training, identify categorical predictors using theCategoricalPredictors
name-value pair argument offitctree
. IfTbl
contains heterogeneous predictor variables (for example, numeric and categorical data types) andX
is a numeric matrix, thenpredict
throws an error.
For a table:
predict
does not support multicolumn variables or cell arrays other than cell arrays of character vectors.If you trained
Mdl
using a table (for example,Tbl
), then all predictor variables inX
must have the same variable names and data types as those that trainedMdl
(stored inMdl.PredictorNames
). However, the column order ofX
does not need to correspond to the column order ofTbl
.Tbl
andX
can contain additional variables (response variables, observation weights, etc.), butpredict
ignores them.If you trained
Mdl
using a numeric matrix, then the predictor names inMdl.PredictorNames
and corresponding predictor variable names inX
must be the same. To specify predictor names during training, see thePredictorNames
name-value pair argument offitctree
. All predictor variables inX
must be numeric vectors.X
can contain additional variables (response variables, observation weights, etc.), butpredict
ignores them.
Data Types: table
| double
| single
Name-Value Arguments
Specify optional pairs of arguments as
Name1=Value1,...,NameN=ValueN
, where Name
is
the argument name and Value
is the corresponding value.
Name-value arguments must appear after other arguments, but the order of the
pairs does not matter.
Before R2021a, use commas to separate each name and value, and enclose
Name
in quotes.
Subtrees
— Pruning level
0 (default) | vector of nonnegative integers | 'all'
Pruning level, specified as the comma-separated pair consisting
of 'Subtrees'
and a vector of nonnegative integers
in ascending order or 'all'
.
If you specify a vector, then all elements must be at least 0
and
at most max(Mdl.PruneList)
. 0
indicates
the full, unpruned tree and max(Mdl.PruneList)
indicates
the completely pruned tree (i.e., just the root node).
If you specify 'all'
, then predict
operates
on all subtrees (i.e., the entire pruning sequence). This specification
is equivalent to using 0:max(Mdl.PruneList)
.
predict
prunes Mdl
to
each level indicated in Subtrees
, and then estimates
the corresponding output arguments. The size of Subtrees
determines
the size of some output arguments.
To invoke Subtrees
, the properties PruneList
and PruneAlpha
of Mdl
must
be nonempty. In other words, grow Mdl
by setting 'Prune','on'
,
or by pruning Mdl
using prune
.
Example: 'Subtrees','all'
Data Types: single
| double
| char
| string
Output Arguments
label
— Predicted class labels
vector | array
Predicted
class labels, returned as a vector or array. Each entry of label
corresponds
to the class with minimal expected cost for the corresponding row
of X
.
Suppose Subtrees
is a numeric vector containing T
elements (for 'all'
, see Subtrees
),
and X
has N
rows.
If the response data type is
char
and:T
= 1, thenlabel
is a character matrix containingN
rows. Each row contains the predicted label produced by subtreeSubtrees
.T
> 1, thenlabel
is anN
-by-T
cell array.
Otherwise,
label
is anN
-by-T
array having the same data type as the response. (The software treats string arrays as cell arrays of character vectors.)
In the latter two cases, column j
of label
contains the vector of predicted labels produced
by subtree Subtrees(
.j
)
score
— Posterior probabilities
numeric matrix
Posterior probabilities, returned as a numeric matrix of size N
-by-K
,
where N
is the number of observations (rows) in X
,
and K
is the number of classes (in Mdl.ClassNames
). score(i,j)
is
the posterior probability that row i
of X
is
of class j
.
If Subtrees
has T
elements,
and X
has N
rows, then score
is
an N
-by-K
-by-T
array,
and node
and cnum
are N
-by-T
matrices.
cnum
— Class numbers
numeric vector
Class numbers corresponding to the predicted labels
,
returned as a numeric vector. Each entry of cnum
corresponds
to a predicted class number for the corresponding row of X
.
Examples
Predict Labels Using a Classification Tree
Examine predictions for a few rows in a data set left out of training.
Load Fisher's iris data set.
load fisheriris
Partition the data into training (50%) and validation (50%) sets.
n = size(meas,1); rng(1) % For reproducibility idxTrn = false(n,1); idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices idxVal = idxTrn == false; % Validation set logical indices
Grow a classification tree using the training set.
Mdl = fitctree(meas(idxTrn,:),species(idxTrn));
Predict labels for the validation data. Count the number of misclassified observations.
label = predict(Mdl,meas(idxVal,:));
label(randsample(numel(label),5)) % Display several predicted labels
ans = 5x1 cell
{'setosa' }
{'setosa' }
{'setosa' }
{'virginica' }
{'versicolor'}
numMisclass = sum(~strcmp(label,species(idxVal)))
numMisclass = 3
The software misclassifies three out-of-sample observations.
Estimate Class Posterior Probabilities Using a Classification Tree
Load Fisher's iris data set.
load fisheriris
Partition the data into training (50%) and validation (50%) sets.
n = size(meas,1); rng(1) % For reproducibility idxTrn = false(n,1); idxTrn(randsample(n,round(0.5*n))) = true; % Training set logical indices idxVal = idxTrn == false; % Validation set logical indices
Grow a classification tree using the training set, and then view it.
Mdl = fitctree(meas(idxTrn,:),species(idxTrn)); view(Mdl,'Mode','graph')
The resulting tree has four levels.
Estimate posterior probabilities for the test set using subtrees pruned to levels 1 and 3.
[~,Posterior] = predict(Mdl,meas(idxVal,:),'SubTrees',[1 3]);
Mdl.ClassNames
ans = 3x1 cell
{'setosa' }
{'versicolor'}
{'virginica' }
Posterior(randsample(size(Posterior,1),5),:,:),... % Display several posterior probabilities
ans = ans(:,:,1) = 1.0000 0 0 1.0000 0 0 1.0000 0 0 0 0 1.0000 0 0.8571 0.1429 ans(:,:,2) = 0.3733 0.3200 0.3067 0.3733 0.3200 0.3067 0.3733 0.3200 0.3067 0.3733 0.3200 0.3067 0.3733 0.3200 0.3067
The elements of Posterior
are class posterior probabilities:
Rows correspond to observations in the validation set.
Columns correspond to the classes as listed in
Mdl.ClassNames
.Pages correspond to the subtrees.
The subtree pruned to level 1 is more sure of its predictions than the subtree pruned to level 3 (i.e., the root node).
More About
Predicted Class Label
predict
classifies by minimizing the expected
misclassification cost:
where:
is the predicted classification.
K is the number of classes.
is the posterior probability of class j for observation x.
is the cost of classifying an observation as y when its true class is j.
Score (tree)
For trees, the score of a classification of a leaf node is the posterior probability of the classification at that node. The posterior probability of the classification at a node is the number of training sequences that lead to that node with the classification, divided by the number of training sequences that lead to that node.
For an example, see Posterior Probability Definition for Classification Tree.
True Misclassification Cost
The true misclassification cost is the cost of classifying an observation into an incorrect class.
You can set the true misclassification cost per class by using the 'Cost'
name-value argument when you create the classifier. Cost(i,j)
is the cost
of classifying an observation into class j
when its true class is
i
. By default, Cost(i,j)=1
if
i~=j
, and Cost(i,j)=0
if i=j
.
In other words, the cost is 0
for correct classification and
1
for incorrect classification.
Expected Cost
The expected misclassification cost per observation is an averaged cost of classifying the observation into each class.
Suppose you have Nobs
observations that you want to classify with a trained
classifier, and you have K
classes. You place the observations
into a matrix X
with one observation per row.
The expected cost matrix CE
has size
Nobs
-by-K
. Each row of
CE
contains the expected (average) cost of classifying
the observation into each of the K
classes.
CE(n,k)
is
where:
K is the number of classes.
is the posterior probability of class i for observation X(n).
is the true misclassification cost of classifying an observation as k when its true class is i.
Predictive Measure of Association
The predictive measure of association is a value that indicates the similarity between decision rules that split observations. Among all possible decision splits that are compared to the optimal split (found by growing the tree), the best surrogate decision split yields the maximum predictive measure of association. The second-best surrogate split has the second-largest predictive measure of association.
Suppose xj and xk are predictor variables j and k, respectively, and j ≠ k. At node t, the predictive measure of association between the optimal split xj < u and a surrogate split xk < v is
PL is the proportion of observations in node t, such that xj < u. The subscript L stands for the left child of node t.
PR is the proportion of observations in node t, such that xj ≥ u. The subscript R stands for the right child of node t.
is the proportion of observations at node t, such that xj < u and xk < v.
is the proportion of observations at node t, such that xj ≥ u and xk ≥ v.
Observations with missing values for xj or xk do not contribute to the proportion calculations.
λjk is a value in (–∞,1]. If λjk > 0, then xk < v is a worthwhile surrogate split for xj < u.
Algorithms
predict
generates predictions by following
the branches of Mdl
until it reaches a leaf node
or a missing value. If predict
reaches a leaf node,
it returns the classification of that node.
If predict
reaches a node with a missing value
for a predictor, its behavior depends on the setting of the Surrogate
name-value
pair when fitctree
constructs Mdl
.
Surrogate
='off'
(default) —predict
returns the label with the largest number of training samples that reach the node.Surrogate
='on'
—predict
uses the best surrogate split at the node. If all surrogate split variables with positive predictive measure of association are missing,predict
returns the label with the largest number of training samples that reach the node. For a definition, see Predictive Measure of Association.
Alternative Functionality
Simulink Block
To integrate the prediction of a classification tree model into Simulink®, you can use the ClassificationTree
Predict block in the Statistics and Machine Learning Toolbox™ library or a MATLAB® Function block with the predict
function. For
examples, see Predict Class Labels Using ClassificationTree Predict Block and Predict Class Labels Using MATLAB Function Block.
When deciding which approach to use, consider the following:
If you use the Statistics and Machine Learning Toolbox library block, you can use the Fixed-Point Tool (Fixed-Point Designer) to convert a floating-point model to fixed point.
Support for variable-size arrays must be enabled for a MATLAB Function block with the
predict
function.If you use a MATLAB Function block, you can use MATLAB functions for preprocessing or post-processing before or after predictions in the same MATLAB Function block.
Extended Capabilities
Tall Arrays
Calculate with arrays that have more rows than fit in memory.
This function fully supports tall arrays. You can use models trained on either in-memory or tall data with this function.
For more information, see Tall Arrays.
C/C++ Code Generation
Generate C and C++ code using MATLAB® Coder™.
Usage notes and limitations:
You can generate C/C++ code for both
predict
andupdate
by using a coder configurer. Or, generate code only forpredict
by usingsaveLearnerForCoder
,loadLearnerForCoder
, andcodegen
.Code generation for
predict
andupdate
— Create a coder configurer by usinglearnerCoderConfigurer
and then generate code by usinggenerateCode
. Then you can update model parameters in the generated code without having to regenerate the code.Code generation for
predict
— Save a trained model by usingsaveLearnerForCoder
. Define an entry-point function that loads the saved model by usingloadLearnerForCoder
and calls thepredict
function. Then usecodegen
(MATLAB Coder) to generate code for the entry-point function.
To generate single-precision C/C++ code for
predict
, specify the name-value argument"DataType","single"
when you call theloadLearnerForCoder
function.You can also generate fixed-point C/C++ code for
predict
. Fixed-point code generation requires an additional step that defines the fixed-point data types of the variables required for prediction. Create a fixed-point data type structure by using the data type function generated bygenerateLearnerDataTypeFcn
, and use the structure as an input argument ofloadLearnerForCoder
in an entry-point function. Generating fixed-point C/C++ code requires MATLAB Coder™ and Fixed-Point Designer™.This table contains notes about the arguments of
predict
. Arguments not included in this table are fully supported.Argument Notes and Limitations Mdl
For the usage notes and limitations of the model object, see Code Generation of the
CompactClassificationTree
object.X
For general code generation,
X
must be a single-precision or double-precision matrix or a table containing numeric variables, categorical variables, or both.In the coder configurer workflow,
X
must be a single-precision or double-precision matrix.For fixed-point code generation,
X
must be a fixed-point matrix.The number of rows, or observations, in
X
can be a variable size, but the number of columns inX
must be fixed.If you want to specify
X
as a table, then your model must be trained using a table, and your entry-point function for prediction must do the following:Accept data as arrays.
Create a table from the data input arguments and specify the variable names in the table.
Pass the table to
predict
.
For an example of this table workflow, see Generate Code to Classify Data in Table. For more information on using tables in code generation, see Code Generation for Tables (MATLAB Coder) and Table Limitations for Code Generation (MATLAB Coder).
label
If the response data type is char
andcodegen
cannot determine that the value ofSubtrees
is a scalar, thenlabel
is a cell array of character vectors.'Subtrees'
Names in name-value arguments must be compile-time constants. For example, to allow user-defined pruning levels in the generated code, include
{coder.Constant('Subtrees'),coder.typeof(0,[1,n],[0,1])}
in the-args
value ofcodegen
(MATLAB Coder), wheren
ismax(Mdl.PruneList)
.The
'Subtrees'
name-value pair argument is not supported in the coder configurer workflow.For fixed-point code generation, the
'Subtrees'
value must becoder.Constant('all')
or have an integer data type.
For more information, see Introduction to Code Generation.
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
Usage notes and limitations:
The
predict
function does not support decision tree models trained with surrogate splits.
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2011a
See Also
fitctree
| compact
| prune
| loss
| edge
| margin
| CompactClassificationTree
| ClassificationTree
Abrir ejemplo
Tiene una versión modificada de este ejemplo. ¿Desea abrir este ejemplo con sus modificaciones?
Comando de MATLAB
Ha hecho clic en un enlace que corresponde a este comando de MATLAB:
Ejecute el comando introduciéndolo en la ventana de comandos de MATLAB. Los navegadores web no admiten comandos de MATLAB.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)