how to set regression tree's crossval error function other than mse or mcr?

1 visualización (últimos 30 días)
when training a regression tree, I want to use another error function instead of mse or mcr. The training code generated by regression app is
% Perform cross-validation
partitionedModel = crossval(trainedModel.RegressionTree, 'KFold', 5);
% Compute validation predictions
validationPredictions = kfoldPredict(partitionedModel);
% Compute validation RMSE
validationRMSE = sqrt(kfoldLoss(partitionedModel, 'LossFun', 'mse'));
I can change LossFun in kfoldLoss, but it's for validation of the trained model. How to change the LossFun during training?

Respuestas (1)

Kausthub
Kausthub el 12 de Sept. de 2023
Hi Jie Li,
I understand that you would like to know how to change the ‘LossFun’ of a Regression Tree and use another error function instead of Mean Squared Error (MSE) or Misclassification Rate (MCR) when using cross-validation during the training phase.
Regression Trees do not have weights to train, nor it has the back propagation step to utilize the error and learn or update the weights. Regression Trees utilize ‘Node Splitting Rules’ to determine how to split the nodes instead of a ‘LossFun. MATLAB provides three ‘Node Splitting Rules which are, Standard CART, Curvature Test and Interaction Test whose details are provided here (https://www.mathworks.com/help/stats/fitrtree.html#butl1ll_head).
You can change the ‘Node Splitting Rule’ using the ‘PredictSelection’ parameter. The default rule is the Standard CART which utilizes the MSE whereas Curvature and Interaction Tests utilize the p-value of chi-square tests. You may refer the following documentation to know more about ‘PredictSelection’ parameter: https://www.mathworks.com/help/stats/fitrtree.html#bt6cr84-PredictorSelection
crossval function returns the partitioned model which are trained based on the ‘Node Splitting Rule’ and it does return an error value. Hence, it does not require a ‘LossFun’ parameter. You may refer the following documentation to know more about “crossval” function: https://www.mathworks.com/help/stats/regressiontree.crossval.html
An example of creating a cross-validated regression tree with ‘PredictSelection’ set to ‘Curvature Test’ rule is:
load carsmall
tree = fitrtree([Weight, Cylinders],MPG,...
'CategoricalPredictors',2,'MinParentSize',20,...
'PredictorNames',{'W','C'}, 'PredictorSelection','curvature');
partitionedModel = crossval(tree, 'KFold', 5);
kfoldLoss(partitionedModel, 'LossFun', 'mse')
You may refer the following articles for more information:
Hope this helps and clarifies your query regarding the steps to set a regression tree’s ‘crossval’ error function!

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by