Main Content

kfoldfun

Cross-validate function for classification

    Description

    vals = kfoldfun(CVMdl,fun) cross-validates the function fun by applying fun to the data stored in the cross-validated model CVMdl. You must pass fun as a function handle.

    example

    Examples

    collapse all

    Train a classification tree classifier, and then cross-validate it using a custom k-fold loss function.

    Load Fisher’s iris data set.

    load fisheriris

    Train a classification tree classifier.

    Mdl = fitctree(meas,species);

    Mdl is a ClassificationTree model.

    Cross-validate Mdl using the default 10-fold cross-validation. Compute the classification error (proportion of misclassified observations) for the validation-fold observations.

    rng(1); % For reproducibility
    CVMdl = crossval(Mdl);
    L = kfoldLoss(CVMdl,'LossFun','classiferror')
    L = 
    0.0467
    

    Examine the result when the cost of misclassifying a flower as versicolor is 10, and the cost of any other misclassification is 1. Create the custom function noversicolor (shown at the end of this example). This function attributes a cost of 10 for misclassifying a flower as versicolor, and a cost of 1 for any other misclassification.

    Compute the mean misclassification error with the noversicolor cost.

    mean(kfoldfun(CVMdl,@noversicolor))
    ans = 
    0.2267
    

    This code creates the function noversicolor.

    function averageCost = noversicolor(CMP,~,~,~,Xtest,Ytest,~)
    % noversicolor Example custom cross-validation function
    %    Attributes a cost of 10 for misclassifying versicolor irises, and 1 for
    %    the other irises.  This example function requires the fisheriris data
    %    set.
    Ypredict = predict(CMP,Xtest);
    misclassified = not(strcmp(Ypredict,Ytest)); % Different result
    classifiedAsVersicolor = strcmp(Ypredict,'versicolor'); % Index of bad decisions
    cost = sum(misclassified) + ...
        9*sum(misclassified & classifiedAsVersicolor); % Total differences
    averageCost = cost/numel(Ytest); % Average error
    end

    Input Arguments

    collapse all

    Cross-validated model, specified as a ClassificationPartitionedModel object, ClassificationPartitionedEnsemble object, or ClassificationPartitionedGAM object.

    Cross-validated function, specified as a function handle. fun has the syntax:

    testvals = fun(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
    • CMP is a compact model stored in one element of the CVMdl.Trained property.

    • Xtrain is the training matrix of predictor values.

    • Ytrain is the training array of response values.

    • Wtrain are the training weights for observations.

    • Xtest and Ytest are the test data, with associated weights Wtest.

    • The returned value testvals must have the same size across all folds.

    Data Types: function_handle

    Output Arguments

    collapse all

    Cross-validation results, returned as a numeric matrix. vals contains the arrays of testvals output, concatenated vertically over all folds. For example, if testvals from every fold is a numeric vector of length N, kfoldfun returns a KFold-by-N numeric matrix with one row per fold.

    Data Types: double

    Extended Capabilities

    Version History

    Introduced in R2011a

    expand all