Esta página aún no se ha traducido para esta versión. Puede ver la versión más reciente de esta página en inglés.

fitrensemble

Fit ensemble of learners for regression

Sintaxis

Mdl = fitrensemble(Tbl,ResponseVarName)
Mdl = fitrensemble(Tbl,formula)
Mdl = fitrensemble(Tbl,Y)
Mdl = fitrensemble(X,Y)
Mdl = fitrensemble(___,Name,Value)

Descripción

ejemplo

Mdl = fitrensemble(Tbl,ResponseVarName) returns the trained regression ensemble model object (Mdl) that contains the results of boosting 100 regression trees using LSBoost and the predictor and response data in the table Tbl. ResponseVarName is the name of the response variable in Tbl.

ejemplo

Mdl = fitrensemble(Tbl,formula) applies formula to fit the model to the predictor and response data in the table Tbl. formula is an explanatory model of the response and a subset of predictor variables in Tbl used to fit Mdl. For example, 'Y~X1+X2+X3' fits the response variable Tbl.Y as a function of the predictor variables Tbl.X1, Tbl.X2, and Tbl.X3.

ejemplo

Mdl = fitrensemble(Tbl,Y) treats all variables in the table Tbl as predictor variables. Y is the vector of responses that is not in Tbl.

ejemplo

Mdl = fitrensemble(X,Y) uses the predictor data in the matrix X and response data in the vector Y.

ejemplo

Mdl = fitrensemble(___,Name,Value) uses additional options specified by one or more Name,Value pair arguments and any of the input arguments in the previous syntaxes. For example, you can specify the number of learning cycles, the ensemble-aggregation method, or to implement 10-fold cross-validation.

Ejemplos

contraer todo

Create a regression ensemble that predicts the fuel economy of a car given the number of cylinders, volume displaced by the cylinders, horsepower, and weight. Then, train another ensemble using fewer predictors. Compare the in-sample predictive accuracies of the ensembles.

Load the carsmall data set. Store the variables to be used in training in a table.

load carsmall
Tbl = table(Cylinders,Displacement,Horsepower,Weight,MPG);

Train a regression ensemble.

Mdl1 = fitrensemble(Tbl,'MPG');

Mdl1 is a RegressionEnsemble model. Some notable characteristics of Mdl1 are:

  • The ensemble-aggregation algorithm is 'LSBoost'.

  • Because the ensemble-aggregation method is a boosting algorithm, regression trees that allow a maximum of 10 splits compose the ensemble.

  • One hundred trees compose the ensemble.

Because MPG is a variable in the MATLAB® Workspace, you can obtain the same result by entering

Mdl1 = fitrensemble(Tbl,MPG);

Use the trained regression ensemble to predict the fuel economy for a four-cylinder car with a 200-cubic inch displacement, 150 horsepower, and weighing 3000 lbs.

pMPG = predict(Mdl1,[4 200 150 3000])
pMPG = 25.6467

Train a new ensemble using all predictors in Tbl except Displacement.

formula = 'MPG ~ Cylinders + Horsepower + Weight';
Mdl2 = fitrensemble(Tbl,formula);

Compare the resubstitution MSEs between Mdl1 and Mdl2.

mse1 = resubLoss(Mdl1)
mse1 = 0.3096
mse2 = resubLoss(Mdl2)
mse2 = 0.5861

The in-sample MSE for the ensemble that trains on all predictors is lower.

Estimate the generalization error of an ensemble of boosted regression trees.

Load the carsmall data set. Choose the number of cylinders, volume displaced by the cylinders, horsepower, and weight as predictors of fuel economy.

load carsmall
X = [Cylinders Displacement Horsepower Weight];

Cross-validate an ensemble of regression trees using 10-fold cross-validation. Using a decision tree template, specify that each tree should be a split once only.

rng(1); % For reproducibility
t = templateTree('MaxNumSplits',1);
Mdl = fitrensemble(X,MPG,'Learners',t,'CrossVal','on');

Mdl is a RegressionPartitionedEnsemble model.

Plot the cumulative, 10-fold cross-validated, mean-squared error (MSE). Display the estimated generalization error of the ensemble.

kflc = kfoldLoss(Mdl,'Mode','cumulative');
figure;
plot(kflc);
ylabel('10-fold cross-validated MSE');
xlabel('Learning cycle');

estGenError = kflc(end)
estGenError = 25.1238

kfoldLoss returns the generalization error by default. However, plotting the cumulative loss allows you to monitor how the loss changes as weak learners accumulate in the ensemble.

The ensemble achieves an MSE of around 23.5 after accumulating about 30 weak learners.

If you are satisfied with the generalization error of the ensemble, then, to create a predictive model, train the ensemble again using all of the settings except cross-validation. However, it is good practice to tune hyperparameters such as the maximum number of decision splits per tree and the number of learning cycles..

This example shows how to optimize hyperparameters automatically using fitrensemble. The example uses the carsmall data.

Load the data.

load carsmall

Find hyperparameters that minimize five-fold cross-validation loss by using automatic hyperparameter optimization.

For reproducibility, set the random seed and use the 'expected-improvement-plus' acquisition function.

rng(1)
Mdl = fitrensemble([Horsepower,Weight],MPG,'OptimizeHyperparameters','auto',...
    'HyperparameterOptimizationOptions',struct('AcquisitionFunctionName',...
    'expected-improvement-plus'))
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |       Method | NumLearningC-|    LearnRate |  MinLeafSize |
|      | result |             | runtime     | (observed)  | (estim.)    |              | ycles        |              |              |
|===================================================================================================================================|
|    1 | Best   |      3.0003 |       1.039 |      3.0003 |      3.0003 |      LSBoost |           33 |      0.41003 |           14 |
|    2 | Accept |      3.1411 |      0.5451 |      3.0003 |      3.0109 |          Bag |           20 |            - |           25 |
|    3 | Accept |      4.8671 |     0.29049 |      3.0003 |      3.0049 |      LSBoost |           12 |     0.068144 |            2 |
|    4 | Accept |      4.1814 |      5.9125 |      3.0003 |      3.0826 |          Bag |          303 |            - |           41 |
|    5 | Best   |      2.9281 |     0.66414 |      2.9281 |      2.9342 |      LSBoost |           28 |      0.25811 |           14 |
|    6 | Accept |      6.3689 |      0.2674 |      2.9281 |      2.9284 |      LSBoost |           10 |     0.003762 |           14 |
|    7 | Accept |       2.931 |     0.25223 |      2.9281 |      2.9286 |      LSBoost |           10 |      0.28649 |           20 |
|    8 | Accept |      4.1823 |      3.1803 |      2.9281 |      2.9283 |      LSBoost |          152 |      0.26553 |           49 |
|    9 | Best   |      2.9093 |     0.24291 |      2.9093 |      2.9094 |          Bag |           10 |            - |           10 |
|   10 | Accept |      3.0763 |     0.26593 |      2.9093 |      2.9096 |      LSBoost |           10 |      0.35342 |            8 |
|   11 | Accept |      2.9689 |      1.1113 |      2.9093 |      2.9328 |      LSBoost |           52 |      0.99586 |           24 |
|   12 | Accept |      4.1823 |     0.31728 |      2.9093 |      2.9279 |      LSBoost |           12 |      0.97917 |           50 |
|   13 | Best   |      2.8562 |      10.015 |      2.8562 |      2.8368 |          Bag |          496 |            - |           15 |
|   14 | Accept |      3.0078 |     0.25195 |      2.8562 |      2.8459 |          Bag |           10 |            - |            2 |
|   15 | Accept |      2.9582 |      10.375 |      2.8562 |      2.8483 |          Bag |          499 |            - |            4 |
|   16 | Accept |      3.2591 |      0.6166 |      2.8562 |      2.8559 |      LSBoost |           25 |      0.99852 |           13 |
|   17 | Accept |      3.0629 |      1.0847 |      2.8562 |      2.8559 |          Bag |           47 |            - |            1 |
|   18 | Accept |      3.6183 |     0.80539 |      2.8562 |      2.8559 |      LSBoost |           33 |      0.99923 |            3 |
|   19 | Accept |      3.6616 |      5.0798 |      2.8562 |       2.856 |      LSBoost |          235 |      0.99077 |            1 |
|   20 | Accept |      6.0321 |      4.7379 |      2.8562 |      2.8558 |      LSBoost |          208 |    0.0010121 |            1 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |       Method | NumLearningC-|    LearnRate |  MinLeafSize |
|      | result |             | runtime     | (observed)  | (estim.)    |              | ycles        |              |              |
|===================================================================================================================================|
|   21 | Accept |      2.9087 |     0.26443 |      2.8562 |       2.856 |          Bag |           10 |            - |            6 |
|   22 | Accept |      3.3691 |       8.014 |      2.8562 |       2.856 |      LSBoost |          371 |      0.13218 |            9 |
|   23 | Accept |      2.9046 |     0.26803 |      2.8562 |      2.8705 |          Bag |           11 |            - |           17 |
|   24 | Accept |      2.9699 |     0.65019 |      2.8562 |      2.8699 |      LSBoost |           28 |      0.52539 |           21 |
|   25 | Best   |      2.8554 |      9.7454 |      2.8554 |      2.8558 |          Bag |          481 |            - |           14 |
|   26 | Accept |      2.8708 |      10.081 |      2.8554 |      2.8577 |          Bag |          494 |            - |            8 |
|   27 | Best   |      2.8421 |      9.8712 |      2.8421 |      2.8508 |          Bag |          487 |            - |           13 |
|   28 | Accept |      2.8425 |      9.9592 |      2.8421 |      2.8464 |          Bag |          490 |            - |           13 |
|   29 | Accept |      2.9523 |      10.619 |      2.8421 |      2.8463 |      LSBoost |          500 |      0.13668 |           19 |
|   30 | Accept |      2.8525 |      9.5464 |      2.8421 |      2.8468 |          Bag |          475 |            - |           12 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 164.3211 seconds.
Total objective function evaluation time: 116.075

Best observed feasible point:
    Method    NumLearningCycles    LearnRate    MinLeafSize
    ______    _________________    _________    ___________

     Bag             487              NaN           13     

Observed objective function value = 2.8421
Estimated objective function value = 2.8468
Function evaluation time = 9.8712

Best estimated feasible point (according to models):
    Method    NumLearningCycles    LearnRate    MinLeafSize
    ______    _________________    _________    ___________

     Bag             490              NaN           13     

Estimated objective function value = 2.8468
Estimated function evaluation time = 9.8819
Mdl = 
  classreg.learning.regr.RegressionBaggedEnsemble
                         ResponseName: 'Y'
                CategoricalPredictors: []
                    ResponseTransform: 'none'
                      NumObservations: 94
    HyperparameterOptimizationResults: [1×1 BayesianOptimization]
                           NumTrained: 490
                               Method: 'Bag'
                         LearnerNames: {'Tree'}
                 ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                              FitInfo: []
                   FitInfoDescription: 'None'
                       Regularization: []
                            FResample: 1
                              Replace: 1
                     UseObsForLearner: [94×490 logical]


  Properties, Methods

The optimization searched over the methods for regression (Bag and LSBoost), over NumLearningCycles, over the LearnRate for LSBoost, and over the tree learner MinLeafSize. The output is the ensemble regression with the minimum estimated cross-validation loss.

One way to create an ensemble of boosted regression trees that has satisfactory predictive performance is to tune the decision tree-complexity level using cross-validation. While searching for an optimal complexity level, tune the learning rate to minimize the number of learning cycles.

Load the carsmall data set. Choose the number of cylinders, volume displaced by the cylinders, horsepower, and weight as predictors of fuel economy.

load carsmall
Tbl = table(Cylinders,Displacement,Horsepower,Weight,MPG);

To search for the optimal tree-complexity level:

  1. Cross-validate a set of ensembles. Exponentially increase the tree-complexity level for subsequent ensembles from decision stump (one split) to at most n - 1 splits. n is the sample size. Also, vary the learning rate for each ensemble between 0.1 to 1.

  2. Estimate the cross-validated mean-squared error (MSE) for each ensemble.

  3. For tree-complexity level , , compare the cumulative, cross-validated MSE of the ensembles by plotting them against number of learning cycles. Plot separate curves for each learning rate on the same figure.

  4. Choose the curve that achieves the minimal MSE, and note the corresponding learning cycle and learning rate.

Cross-validate a deep regression tree and a stump. Because the data contain missing values, use surrogate splits. These regression trees serve as benchmarks.

rng(1); % For reproducibility
MdlDeep = fitrtree(Tbl,'MPG','CrossVal','on','MergeLeaves','off',...
    'MinParentSize',1,'Surrogate','on');
MdlStump = fitrtree(Tbl,'MPG','MaxNumSplits',1,'CrossVal','on',...
    'Surrogate','on');

Cross-validate an ensemble of 150 boosted regression trees using 5-fold cross-validation. Using a tree template:

  • Vary the maximum number of splits using the values in the sequence . m is such that is no greater than n - 1.

  • Turn on surrogate splits.

For each variant, adjust the learning rate using each value in the set {0.1, 0.25, 0.5, 1};

n = size(Tbl,1);
m = floor(log2(n - 1));
learnRate = [0.1 0.25 0.5 1];
numLR = numel(learnRate);
maxNumSplits = 2.^(0:m);
numMNS = numel(maxNumSplits);
numTrees = 150;
Mdl = cell(numMNS,numLR);

for k = 1:numLR;
    for j = 1:numMNS;
        t = templateTree('MaxNumSplits',maxNumSplits(j),'Surrogate','on');
        Mdl{j,k} = fitrensemble(Tbl,'MPG','NumLearningCycles',numTrees,...
            'Learners',t,'KFold',5,'LearnRate',learnRate(k));
    end;
end;

Estimate the cumulative, cross-validated MSE of each ensemble.

kflAll = @(x)kfoldLoss(x,'Mode','cumulative');
errorCell = cellfun(kflAll,Mdl,'Uniform',false);
error = reshape(cell2mat(errorCell),[numTrees numel(maxNumSplits) numel(learnRate)]);
errorDeep = kfoldLoss(MdlDeep);
errorStump = kfoldLoss(MdlStump);

Plot how the cross-validated MSE behaves as the number of trees in the ensemble increases. Plot the curves with respect to learning rate on the same plot, and plot separate plots for varying tree-complexity levels. Choose a subset of tree complexity levels to plot.

mnsPlot = [1 round(numel(maxNumSplits)/2) numel(maxNumSplits)];
figure;
for k = 1:3;
    subplot(2,2,k);
    plot(squeeze(error(:,mnsPlot(k),:)),'LineWidth',2);
    axis tight;
    hold on;
    h = gca;
    plot(h.XLim,[errorDeep errorDeep],'-.b','LineWidth',2);
    plot(h.XLim,[errorStump errorStump],'-.r','LineWidth',2);
    plot(h.XLim,min(min(error(:,mnsPlot(k),:))).*[1 1],'--k');
    h.YLim = [10 50];
    xlabel 'Number of trees';
    ylabel 'Cross-validated MSE';
    title(sprintf('MaxNumSplits = %0.3g', maxNumSplits(mnsPlot(k))));
    hold off;
end;
hL = legend([cellstr(num2str(learnRate','Learning Rate = %0.2f'));...
        'Deep Tree';'Stump';'Min. MSE']);
hL.Position(1) = 0.6;

Each curve contains a minimum cross-validated MSE occurring at the optimal number of trees in the ensemble.

Identify the maximum number of splits, number of trees, and learning rate that yields the lowest MSE overall.

[minErr,minErrIdxLin] = min(error(:));
[idxNumTrees,idxMNS,idxLR] = ind2sub(size(error),minErrIdxLin);

fprintf('\nMin. MSE = %0.5f',minErr)
fprintf('\nOptimal Parameter Values:\nNum. Trees = %d',idxNumTrees);
fprintf('\nMaxNumSplits = %d\nLearning Rate = %0.2f\n',...
    maxNumSplits(idxMNS),learnRate(idxLR))
Min. MSE = 17.01148
Optimal Parameter Values:
Num. Trees = 38
MaxNumSplits = 4
Learning Rate = 0.10

Create a predictive ensemble based on the optimal hyperparameters and the entire training set.

tFinal = templateTree('MaxNumSplits',maxNumSplits(idxMNS),'Surrogate','on');
MdlFinal = fitrensemble(Tbl,'MPG','NumLearningCycles',idxNumTrees,...
    'Learners',tFinal,'LearnRate',learnRate(idxLR))
MdlFinal = 

  classreg.learning.regr.RegressionEnsemble
           PredictorNames: {1×4 cell}
             ResponseName: 'MPG'
    CategoricalPredictors: []
        ResponseTransform: 'none'
          NumObservations: 94
               NumTrained: 38
                   Method: 'LSBoost'
             LearnerNames: {'Tree'}
     ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'
                  FitInfo: [38×1 double]
       FitInfoDescription: {2×1 cell}
           Regularization: []


MdlFinal is a RegressionEnsemble. To predict the fuel economy of a car given its number of cylinders, volume displaced by the cylinders, horsepower, and weight, pass the predictor data and MdlFinal to predict.

Argumentos de entrada

contraer todo

Sample data used to train the model, specified as a table. Each row of Tbl corresponds to one observation, and each column corresponds to one predictor variable. Tbl can contain one additional column for the response variable. Multi-column variables and cell arrays other than cell arrays of character vectors are not allowed.

  • If Tbl contains the response variable and you want to use all remaining variables as predictors, then specify the response variable using ResponseVarName.

  • If Tbl contains the response variable, and you want to use a subset of the remaining variables only as predictors, then specify a formula using formula.

  • If Tbl does not contain the response variable, then specify the response data using Y. The length of response variable and the number of rows of Tbl must be equal.

Nota

To save memory and execution time, supply X and Y instead of Tbl.

Tipos de datos: table

Response variable name, specified as the name of the response variable in Tbl.

You must specify ResponseVarName as a character vector or string scalar. For example, if Tbl.Y is the response variable, then specify ResponseVarName as 'Y'. Otherwise, fitrensemble treats all columns of Tbl as predictor variables.

Tipos de datos: char | string

Explanatory model of the response and a subset of the predictor variables, specified as a character vector or string scalar in the form of 'Y~X1+X2+X3'. In this form, Y represents the response variable, and X1, X2, and X3 represent the predictor variables. The variables must be variable names in Tbl (Tbl.Properties.VariableNames).

To specify a subset of variables in Tbl as predictors for training the model, use a formula. If you specify a formula, then the software does not use any variables in Tbl that do not appear in formula.

Tipos de datos: char | string

Predictor data, specified as numeric matrix.

Each row corresponds to one observation, and each column corresponds to one predictor variable.

The length of Y and the number of rows of X must be equal.

To specify the names of the predictors in the order of their appearance in X, use the PredictorNames name-value pair argument.

Tipos de datos: single | double

Response, specified as a numeric vector. Each element in Y is the response to the observation in the corresponding row of X or Tbl. The length of Y and the number of rows of X or Tbl must be equal.

Tipos de datos: single | double

Argumentos de par nombre-valor

Specify optional comma-separated pairs of Name,Value arguments. Name is the argument name and Value is the corresponding value. Name must appear inside quotes. You can specify several name and value pair arguments in any order as Name1,Value1,...,NameN,ValueN.

Ejemplo: 'NumLearningCycles',500,'Method','Bag','Learners',templateTree(),'CrossVal','on' cross-validates an ensemble of 500 bagged regression trees using 10-fold cross-validation.

Nota

You cannot use any cross-validation name-value pair argument along with the 'OptimizeHyperparameters' name-value pair argument. You can modify the cross-validation for 'OptimizeHyperparameters' only by using the 'HyperparameterOptimizationOptions' name-value pair argument.

General Ensemble Options

contraer todo

Ensemble-aggregation method, specified as the comma separated pair consisting of 'Method' and 'Bag' or 'LSBoost'.

ValueDescription
'Bag'Bootstrap aggregating or bagging (for example, random forest)
'LSBoost'Least-squares boosting or LSBoost

Ejemplo: 'Method','Bag'

Number of ensemble learning cycles, specified as a positive integer. At every learning cycle, the software trains one weak learner for every template object in Learners. Consequently, the software trains NumLearningCycles*numel(Learners) learners.

The software composes the ensemble using all trained learners and stores them in Mdl.Trained.

For more details, see Tips.

Ejemplo: 'NumLearningCycles',500

Tipos de datos: single | double

Weak learners to use in the ensemble, specified as 'tree', a regression tree template object, or a cell vector of regression tree template objects.

  • 'tree' specifies to train an ensemble of regression trees using default regression tree options.

    • For bagging, fitrensemble grows deep trees by default. For a complete list of defaults, see templateTree.

    • For LSBoost, fitrensemble grows shallow trees by default. That is, 'Learners' is templateTree('MaxNumSplits',10).

  • A regression tree template object specifies how to grow all trees in the ensemble. You can create a regression tree template using templateTree.

  • A cell vector of m regression tree templates specifies to grow m regression trees per learning cycle (see NumLearningCycles). For example, for an ensemble composed of two types of regression trees, supply {t1 t2}, where t1 and t2 are regression tree templates returned by templateTree.

fitrensemble sets the maximum number of splits for regression tree weak learners to 10 when both of these conditions occur:

  • You do not set a value for the 'MaxNumSplits' property of the weak learners.

  • You use a boosting algorithm as the ensemble-aggregation method for fitrensemble.

For example, if you specify 'Learners',templateTree() and 'Method','LSBoost', then fitrensemble sets the maximum number of splits of the tree weak learners to 10.

Otherwise, fitrensemble defers to the regression tree template object to choose any required default values for the 'Learners' name-value pair argument.

For details on the number of learners to train, see NumLearningCycles and Tips.

Ejemplo: 'Learners',templateTree('MaxNumSplits',5)

Printout frequency, specified as the comma-separated pair consisting of 'NPrint' and a positive integer or 'off'.

To track the number of weak learners or folds that fitrensemble trained so far, specify a positive integer. That is, if you specify the positive integer m:

  • Without also specifying any cross-validation option (for example, CrossVal), then fitrensemble displays a message to the command line every time it completes training m weak learners.

  • And a cross-validation option, then fitrensemble displays a message to the command line every time it finishes training m folds.

If you specify 'off', then fitrensemble does not display a message when it completes training weak learners.

Sugerencia

When training an ensemble of many weak learners on a large data set, specify a positive integer for NPrint.

Ejemplo: 'NPrint',5

Tipos de datos: single | double | char | string

Categorical predictors list, specified as the comma-separated pair consisting of 'CategoricalPredictors' and one of the values in this table.

ValueDescription
Vector of positive integersAn entry in the vector is the index value corresponding to the column of the predictor data (X or Tbl) that contains a categorical variable.
Logical vectorA true entry means that the corresponding column of predictor data (X or Tbl) is a categorical variable.
Character matrixEach row of the matrix is the name of a predictor variable. The names must match the entries in PredictorNames. Pad the names with extra blanks so each row of the character matrix has the same length.
String array or cell array of character vectorsEach element in the array is the name of a predictor variable. The names must match the entries in PredictorNames.
'all'All predictors are categorical.

By default, if the predictor data is in a table (Tbl), fitrensemble assumes that a variable is categorical if it contains logical values, categorical values, a string array, or a cell array of character vectors. If the predictor data is a matrix (X), fitrensemble assumes all predictors are continuous. To identify any categorical predictors when the data is a matrix, use the 'CategoricalPredictors' name-value pair argument.

Ejemplo: 'CategoricalPredictors','all'

Tipos de datos: single | double | logical | char | string | cell

Predictor variable names, specified as the comma-separated pair consisting of 'PredictorNames' and a string array of unique names or cell array of unique character vectors. The functionality of 'PredictorNames' depends on the way you supply the training data.

  • If you supply X and Y, then you can use 'PredictorNames' to give the predictor variables in X names.

    • The order of the names in PredictorNames must correspond to the column order of X. That is, PredictorNames{1} is the name of X(:,1), PredictorNames{2} is the name of X(:,2), and so on. Also, size(X,2) and numel(PredictorNames) must be equal.

    • By default, PredictorNames is {'x1','x2',...}.

  • If you supply Tbl, then you can use 'PredictorNames' to choose which predictor variables to use in training. That is, fitrensemble uses only the predictor variables in PredictorNames and the response variable in training.

    • PredictorNames must be a subset of Tbl.Properties.VariableNames and cannot include the name of the response variable.

    • By default, PredictorNames contains the names of all predictor variables.

    • It is a good practice to specify the predictors for training using either 'PredictorNames' or formula only.

Ejemplo: 'PredictorNames',{'SepalLength','SepalWidth','PetalLength','PetalWidth'}

Tipos de datos: string | cell

Response variable name, specified as the comma-separated pair consisting of 'ResponseName' and a character vector or string scalar.

  • If you supply Y, then you can use 'ResponseName' to specify a name for the response variable.

  • If you supply ResponseVarName or formula, then you cannot use 'ResponseName'.

Ejemplo: 'ResponseName','response'

Tipos de datos: char | string

Response transformation, specified as the comma-separated pair consisting of 'ResponseTransform' and either 'none' or a function handle. The default is 'none', which means @(y)y, or no transformation. For a MATLAB® function or a function you define, use its function handle. The function handle must accept a vector (the original response values) and return a vector of the same size (the transformed response values).

Ejemplo: Suppose you create a function handle that applies an exponential transformation to an input vector by using myfunction = @(y)exp(y). Then, you can specify the response transformation as 'ResponseTransform',myfunction.

Tipos de datos: char | string | function_handle

Cross-Validation Options

contraer todo

Cross-validation flag, specified as the comma-separated pair consisting of 'Crossval' and 'on' or 'off'.

If you specify 'on', then the software implements 10-fold cross-validation.

To override this cross-validation setting, use one of these name-value pair arguments: CVPartition, Holdout, KFold, or Leaveout. To create a cross-validated model, you can use one cross-validation name-value pair argument at a time only.

Alternatively, cross-validate later by passing Mdl to crossval or crossval.

Ejemplo: 'Crossval','on'

Cross-validation partition, specified as the comma-separated pair consisting of 'CVPartition' and a cvpartition partition object as created by cvpartition. The partition object specifies the type of cross-validation and the indexing for the training and validation sets.

To create a cross-validated model, you can use one of these four name-value pair arguments only: 'CVPartition', 'Holdout', 'KFold', or 'Leaveout'.

Ejemplo: Suppose you create a random partition for 5-fold cross-validation on 500 observations by using cvp = cvpartition(500,'KFold',5). Then, you can specify the cross-validated model by using 'CVPartition',cvp.

Fraction of the data used for holdout validation, specified as the comma-separated pair consisting of 'Holdout' and a scalar value in the range (0,1). If you specify 'Holdout',p, then the software completes these steps:

  1. Randomly select and reserve p*100% of the data as validation data, and train the model using the rest of the data.

  2. Store the compact, trained model in the Trained property of the cross-validated model.

To create a cross-validated model, you can use one of these four name-value pair arguments only: CVPartition, Holdout, KFold, or Leaveout.

Ejemplo: 'Holdout',0.1

Tipos de datos: double | single

Number of folds to use in a cross-validated model, specified as the comma-separated pair consisting of 'KFold' and a positive integer value greater than 1. If you specify 'KFold',k, then the software completes these steps.

  1. Randomly partition the data into k sets.

  2. For each set, reserve the set as validation data, and train the model using the other k – 1 sets.

  3. Store the k compact, trained models in the cells of a k-by-1 cell vector in the Trained property of the cross-validated model.

To create a cross-validated model, you can use one of these four name-value pair arguments only: CVPartition, Holdout, KFold, or Leaveout.

Ejemplo: 'KFold',5

Tipos de datos: single | double

Leave-one-out cross-validation flag, specified as the comma-separated pair consisting of 'Leaveout' and 'on' or 'off'. If you specify 'Leaveout','on', then, for each of the n observations (where n is the number of observations excluding missing observations, specified in the NumObservations property of the model), the software completes these steps:

  1. Reserve the observation as validation data, and train the model using the other n – 1 observations.

  2. Store the n compact, trained models in the cells of an n-by-1 cell vector in the Trained property of the cross-validated model.

To create a cross-validated model, you can use one of these four name-value pair arguments only: CVPartition, Holdout, KFold, or Leaveout.

Ejemplo: 'Leaveout','on'

Other Regression Options

contraer todo

Observation weights, specified as the comma-separated pair consisting of 'Weights' and a numeric vector of positive values or name of a variable in Tbl. The software weighs the observations in each row of X or Tbl with the corresponding value in Weights. The size of Weights must equal the number of rows of X or Tbl.

If you specify the input data as a table Tbl, then Weights can be the name of a variable in Tbl that contains a numeric vector. In this case, you must specify Weights as a character vector or string scalar. For example, if the weights vector W is stored as Tbl.W, then specify it as 'W'. Otherwise, the software treats all columns of Tbl, including W, as predictors or the response when training the model.

The software normalizes Weights to sum up to the value of the prior probability in the respective class.

By default, Weights is ones(n,1), where n is the number of observations in X or Tbl.

Tipos de datos: double | single | char | string

Sampling Options

contraer todo

Fraction of the training set to resample for every weak learner, specified as the comma-separated pair consisting of 'FResample' and a positive scalar in (0,1].

To use 'FResample', specify 'bag' for Method or set Resample to 'on'.

Ejemplo: 'FResample',0.75

Tipos de datos: single | double

Flag indicating sampling with replacement, specified as the comma-separated pair consisting of 'Replace' and 'off' or 'on'.

  • For 'on', the software samples the training observations with replacement.

  • For 'off', the software samples the training observations without replacement. If you set Resample to 'on', then the software samples training observations assuming uniform weights. If you also specify a boosting method, then the software boosts by reweighting observations.

Unless you set Method to 'bag' or set Resample to 'on', Replace has no effect.

Ejemplo: 'Replace','off'

Flag indicating to resample, specified as the comma-separated pair consisting of 'Resample' and 'off' or 'on'.

  • If Method is any boosting method, then:

    • 'Resample','on' specifies to sample training observations using updated weights as the multinomial sampling probabilities.

    • 'Resample','off' specifies to reweight observations at every learning iteration. This setting is the default.

  • If Method is 'bag', then 'Resample' must be 'on'. The software resamples a fraction of the training observations (see FResample) with or without replacement (see Replace).

LSBoost Method Options

contraer todo

Learning rate for shrinkage, specified as the comma-separated pair consisting of a numeric scalar in the interval (0,1].

To train an ensemble using shrinkage, set LearnRate to a value less than 1, for example, 0.1 is a popular choice. Training an ensemble using shrinkage requires more learning iterations, but often achieves better accuracy.

Ejemplo: 'LearnRate',0.1

Tipos de datos: single | double

Hyperparameter Optimization

contraer todo

Parameters to optimize, specified as the comma-separated pair consisting of 'OptimizeHyperparameters' and one of the following:

  • 'none' — Do not optimize.

  • 'auto' — Use {'Method','NumLearningCycles','LearnRate'} along with the default parameters for the specified Learners:

    • Learners = 'tree' (default) — {'MinLeafSize'}

    Nota

    For hyperparameter optimization, Learners must be a single argument, not a string array or cell array.

  • 'all' — Optimize all eligible parameters.

  • String array or cell array of eligible parameter names

  • Vector of optimizableVariable objects, typically the output of hyperparameters

The optimization attempts to minimize the cross-validation loss (error) for fitrensemble by varying the parameters. To control the cross-validation type and other aspects of the optimization, use the HyperparameterOptimizationOptions name-value pair.

Nota

'OptimizeHyperparameters' values override any values you set using other name-value pair arguments. For example, setting 'OptimizeHyperparameters' to 'auto' causes the 'auto' values to apply.

The eligible parameters for fitrensemble are:

  • Method — Eligible methods are 'Bag' or 'LSBoost'.

  • NumLearningCyclesfitrensemble searches among positive integers, by default log-scaled with range [10,500].

  • LearnRatefitrensemble searches among positive reals, by default log-scaled with range [1e-3,1].

  • MinLeafSizefitrensemble searches among integers log-scaled in the range [1,max(2,floor(NumObservations/2))].

  • MaxNumSplitsfitrensemble searches among integers log-scaled in the range [1,max(2,NumObservations-1)].

  • NumVariablesToSamplefitrensemble searches among integers in the range [1,max(2,NumPredictors)].

Set nondefault parameters by passing a vector of optimizableVariable objects that have nondefault values. For example,

load carsmall
params = hyperparameters('fitrensemble',[Horsepower,Weight],MPG,'Tree');
params(4).Range = [1,20];

Pass params as the value of OptimizeHyperparameters.

By default, iterative display appears at the command line, and plots appear according to the number of hyperparameters in the optimization. For the optimization and plots, the objective function is log(1 + cross-validation loss) for regression and the misclassification rate for classification. To control the iterative display, set the Verbose field of the 'HyperparameterOptimizationOptions' name-value pair argument. To control the plots, set the ShowPlots field of the 'HyperparameterOptimizationOptions' name-value pair argument.

For an example, see Optimize Regression Ensemble.

Ejemplo: 'auto'

Options for optimization, specified as the comma-separated pair consisting of 'HyperparameterOptimizationOptions' and a structure. This argument modifies the effect of the OptimizeHyperparameters name-value pair argument. All fields in the structure are optional.

Field NameValuesDefault
Optimizer
  • 'bayesopt' — Use Bayesian optimization. Internally, this setting calls bayesopt.

  • 'gridsearch' — Use grid search with NumGridDivisions values per dimension.

  • 'randomsearch' — Search at random among MaxObjectiveEvaluations points.

'gridsearch' searches in a random order, using uniform sampling without replacement from the grid. After optimization, you can get a table in grid order by using the command sortrows(Mdl.HyperparameterOptimizationResults).

'bayesopt'
AcquisitionFunctionName

  • 'expected-improvement-per-second-plus'

  • 'expected-improvement'

  • 'expected-improvement-plus'

  • 'expected-improvement-per-second'

  • 'lower-confidence-bound'

  • 'probability-of-improvement'

For details, see the bayesopt AcquisitionFunctionName name-value pair argument, or Acquisition Function Types.

'expected-improvement-per-second-plus'
MaxObjectiveEvaluationsMaximum number of objective function evaluations.30 for 'bayesopt' or 'randomsearch', and the entire grid for 'gridsearch'
MaxTime

Time limit, specified as a positive real. The time limit is in seconds, as measured by tic and toc. Run time can exceed MaxTime because MaxTime does not interrupt function evaluations.

Inf
NumGridDivisionsFor 'gridsearch', the number of values in each dimension. The value can be a vector of positive integers giving the number of values for each dimension, or a scalar that applies to all dimensions. This field is ignored for categorical variables.10
ShowPlotsLogical value indicating whether to show plots. If true, this field plots the best objective function value against the iteration number. If there are one or two optimization parameters, and if Optimizer is 'bayesopt', then ShowPlots also plots a model of the objective function against the parameters.true
SaveIntermediateResultsLogical value indicating whether to save results when Optimizer is 'bayesopt'. If true, this field overwrites a workspace variable named 'BayesoptResults' at each iteration. The variable is a BayesianOptimization object.false
Verbose

Display to the command line.

  • 0 — No iterative display

  • 1 — Iterative display

  • 2 — Iterative display with extra information

For details, see the bayesopt Verbose name-value pair argument.

1
UseParallelLogical value indicating whether to run Bayesian optimization in parallel, which requires Parallel Computing Toolbox™ . For details, see Parallel Bayesian Optimization.false
Repartition

Logical value indicating whether to repartition the cross-validation at every iteration. If false, the optimizer uses a single partition for the optimization.

true usually gives the most robust results because this setting takes partitioning noise into account. However, for good results, true requires at least twice as many function evaluations.

false
Use no more than one of the following three field names.
CVPartitionA cvpartition object, as created by cvpartition.'Kfold',5 if you do not specify any cross-validation field
HoldoutA scalar in the range (0,1) representing the holdout fraction.
KfoldAn integer greater than 1.

Ejemplo: 'HyperparameterOptimizationOptions',struct('MaxObjectiveEvaluations',60)

Tipos de datos: struct

Output Arguments

contraer todo

Trained ensemble model, returned as one of the model objects in this table.

Model ObjectSpecify Any Cross-Validation Options?Method SettingResample Setting
RegressionBaggedEnsembleNo'Bag''on'
RegressionEnsembleNo'LSBoost''off'
RegressionPartitionedEnsembleYes'LSBoost' or 'Bag''off' or 'on'

The name-value pair arguments that control cross-validation are CrossVal, Holdout, KFold, Leaveout, and CVPartition.

To reference properties of Mdl, use dot notation. For example, to access or display the cell vector of weak learner model objects for an ensemble that has not been cross-validated, enter Mdl.Trained at the command line.

Sugerencias

  • NumLearningCycles can vary from a few dozen to a few thousand. Usually, an ensemble with good predictive power requires from a few hundred to a few thousand weak learners. However, you do not have to train an ensemble for that many cycles at once. You can start by growing a few dozen learners, inspect the ensemble performance and then, if necessary, train more weak learners using resume.

  • Ensemble performance depends on the ensemble setting and the setting of the weak learners. That is, if you specify weak learners with default parameters, then the ensemble can perform poorly. Therefore, like ensemble settings, it is good practice to adjust the parameters of the weak learners using templates, and to choose values that minimize generalization error.

  • If you specify to resample using Resample, then it is good practice to resample to entire data set. That is, use the default setting of 1 for FResample.

  • After training a model, you can generate C/C++ code that predicts responses for new data. Generating C/C++ code requires MATLAB Coder™ . For details, see Introduction to Code Generation.

Algoritmos

  • For details of ensemble-aggregation algorithms, see Ensemble Algorithms.

  • If you specify 'Method','LSBoost', then the software grows shallow decision trees by default. You can adjust tree depth by specifying the MaxNumSplits, MinLeafSize, and MinParentSize name-value pair arguments using templateTree.

  • For dual-core systems and above, fitrensemble parallelizes training using Intel® Threading Building Blocks (TBB). For details on Intel TBB, see https://software.intel.com/en-us/intel-tbb.

Referencias

[1] Breiman, L. “Bagging Predictors.” Machine Learning. Vol. 26, pp. 123–140, 1996.

[2] Breiman, L. “Random Forests.” Machine Learning. Vol. 45, pp. 5–32, 2001.

[3] Freund, Y. and R. E. Schapire. “A Decision-Theoretic Generalization of On-Line Learning and an Application to Boosting.” J. of Computer and System Sciences, Vol. 55, pp. 119–139, 1997.

[4] Friedman, J. “Greedy function approximation: A gradient boosting machine.” Annals of Statistics, Vol. 29, No. 5, pp. 1189–1232, 2001.

[5] Hastie, T., R. Tibshirani, and J. Friedman. The Elements of Statistical Learning section edition, Springer, New York, 2008.

Capacidades ampliadas

Introducido en R2016b