loss
Description
computes the classification loss (specified by L
= loss(thresholder
,Tbl
)thresholder.LossFun
) by
using the fairnessThresholder
object thresholder
and
the table data Tbl
.
computes the classification loss (specified by L
= loss(thresholder
,X
,attribute
,Y
)thresholder.LossFun
) by
using the fairnessThresholder
object thresholder
, the
matrix data X
, the sensitive attribute specified by
attribute
, and the true class labels Y
.
Examples
Adjust Score Threshold for Fairness
Train a tree ensemble for binary classification, and compute the disparate impact for each group in the sensitive attribute. To reduce the disparate impact value of the nonreference group, adjust the score threshold for classifying observations.
Load the data census1994
, which contains the data set adultdata
and the test data set adulttest
. The data sets consist of demographic information from the US Census Bureau that can be used to predict whether an individual makes over $50,000 per year. Preview the first few rows of adultdata
.
load census1994
head(adultdata)
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary ___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______ 39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K 50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K 38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K 53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K 28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K 37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K 49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K 52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Each row contains the demographic information for one adult. The information includes sensitive attributes, such as age
, marital_status
, relationship
, race
, and sex
. The third column flnwgt
contains observation weights, and the last column salary
shows whether a person has a salary less than or equal to $50,000 per year (<=50K
) or greater than $50,000 per year (>50K
).
Remove observations with missing values.
adultdata = rmmissing(adultdata); adulttest = rmmissing(adulttest);
Partition adultdata
into training and validation sets. Use 60% of the observations for the training set trainingData
and 40% of the observations for the validation set validationData
.
rng("default") % For reproducibility c = cvpartition(adultdata.salary,"Holdout",0.4); trainingIdx = training(c); validationIdx = test(c); trainingData = adultdata(trainingIdx,:); validationData = adultdata(validationIdx,:);
Train a boosted ensemble of trees using the training data set trainingData
. Specify the response variable, predictor variables, and observation weights by using the variable names in the adultdata
table. Use random undersampling boosting as the ensemble aggregation method.
predictors = ["capital_gain","capital_loss","education", ... "education_num","hours_per_week","occupation","workClass"]; Mdl = fitcensemble(trainingData,"salary", ... PredictorNames=predictors, ... Weights="fnlwgt",Method="RUSBoost");
Predict salary values for the observations in the test data set adulttest
, and calculate the classification error.
labels = predict(Mdl,adulttest); L = loss(Mdl,adulttest)
L = 0.2080
The model accurately predicts the salary categorization for approximately 80% of the test set observations.
Compute fairness metrics with respect to the sensitive attribute sex
by using the test set model predictions. In particular, find the disparate impact for each group in sex
. Use the report
and plot
object functions of fairnessMetrics
to display the results.
evaluator = fairnessMetrics(adulttest,"salary", ... SensitiveAttributeNames="sex",Predictions=labels, ... ModelNames="Ensemble",Weights="fnlwgt"); evaluator.PositiveClass
ans = categorical
>50K
evaluator.ReferenceGroup
ans = 'Male'
report(evaluator,BiasMetrics="DisparateImpact")
ans=2×4 table
ModelNames SensitiveAttributeNames Groups DisparateImpact
__________ _______________________ ______ _______________
Ensemble sex Female 0.73792
Ensemble sex Male 1
plot(evaluator,"DisparateImpact")
For the nonreference group (Female
), the disparate impact value is the proportion of predictions in the group with a positive class value (>50K
) divided by the proportion of predictions in the reference group (Male
) with a positive class value. Ideally, disparate impact values are close to 1.
To try to improve the nonreference group disparate impact value, you can adjust model predictions by using the fairnessThresholder
function. The function uses validation data to search for an optimal score threshold that maximizes accuracy while satisfying fairness bounds. For observations in the critical region below the optimal threshold, the function changes the labels so that the fairness constraints hold for the reference and nonreference groups. By default, the function tries to find a score threshold so that the disparate impact value for the nonreference group is in the range [0.8,1.25].
fairnessMdl = fairnessThresholder(Mdl,validationData,"sex","salary")
fairnessMdl = fairnessThresholder with properties: Learner: [1x1 classreg.learning.classif.CompactClassificationEnsemble] SensitiveAttribute: 'sex' ReferenceGroups: Male ResponseName: 'salary' PositiveClass: >50K ScoreThreshold: 1.6749 BiasMetric: 'DisparateImpact' BiasMetricValue: 0.9702 BiasMetricRange: [0.8000 1.2500] ValidationLoss: 0.2017
fairnessMdl
is a fairnessThresholder
model object. Note that the predict
function of the ensemble model Mdl
returns scores that are not posterior probabilities. Scores are in the range instead, and the maximum score for each observation is greater than 0. For observations whose maximum scores are less than the new score threshold (fairnessMdl.ScoreThreshold
), the predict
function of the fairnessMdl
object adjusts the prediction. If the observation is in the nonreference group, the function predicts the observation into the positive class. If the observation is in the reference group, the function predicts the observation into the negative class. These adjustments do not always result in a change in the predicted label.
Adjust the test set predictions by using the new score threshold, and calculate the classification error.
fairnessLabels = predict(fairnessMdl,adulttest); fairnessLoss = loss(fairnessMdl,adulttest)
fairnessLoss = 0.2064
The new classification error is similar to the original classification error.
Compare the disparate impact values across the two sets of test predictions: the original predictions computed using Mdl
and the adjusted predictions computed using fairnessMdl
.
newEvaluator = fairnessMetrics(adulttest,"salary", ... SensitiveAttributeNames="sex",Predictions=[labels,fairnessLabels], ... ModelNames=["Original","Adjusted"],Weights="fnlwgt"); newEvaluator.PositiveClass
ans = categorical
>50K
newEvaluator.ReferenceGroup
ans = 'Male'
report(newEvaluator,BiasMetrics="DisparateImpact")
ans=2×5 table
Metrics SensitiveAttributeNames Groups Original Adjusted
_______________ _______________________ ______ ________ ________
DisparateImpact sex Female 0.73792 1.0048
DisparateImpact sex Male 1 1
plot(newEvaluator,"di")
The disparate impact value for the nonreference group (Female
) is closer to 1 when you use the adjusted predictions.
Input Arguments
thresholder
— Fairness classification model
fairnessThresholder
object
Fairness classification model, specified as a fairnessThresholder
object. The ScoreThreshold
property of the object must be nonempty.
Tbl
— Data set
table
Data set, specified as a table. Each row of Tbl
corresponds to
one observation, and each column corresponds to one variable. If you use a table when
creating the fairnessThresholder
object, then you must use a table when
using the loss
function. The table must include all required
predictor variables, the sensitive attribute, and the response variable. The table can
include additional variables. Multicolumn variables and cell arrays other than cell
arrays of character vectors are not allowed.
Data Types: table
X
— Predictor data
numeric matrix
Predictor data, specified as a numeric matrix. Each row of X
corresponds to one observation, and each column corresponds to one predictor variable.
If you use a matrix when creating the fairnessThresholder
object, then
you must use a matrix when using the loss
function.
X
, attribute
, and Y
must have the same number of rows.
Data Types: single
| double
attribute
— Sensitive attribute
numeric column vector | logical column vector | character array | string array | cell array of character vectors | categorical column vector
Sensitive attribute, specified as a numeric column vector, logical column vector, character array, string array, cell array of character vectors, or categorical column vector.
Data Types: single
| double
| logical
| char
| string
| cell
| categorical
Y
— Class labels
numeric column vector | logical column vector | character array | string array | cell array of character vectors | categorical column vector
Class labels, specified as a numeric column vector, logical column vector, character array, string array, cell array of character vectors, or categorical column vector.
If
Y
is a character array, then each row of the array must correspond to a class label.The data type of
Y
must be the same as the data type of the response variable used to createthresholder
.If
thresholder.Learner
is a classification model object, then the distinct classes inY
must be a subset of the classes inthresholder.Learner.ClassNames
.
Data Types: single
| double
| logical
| char
| string
| cell
| categorical
Output Arguments
L
— Classification loss
numeric scalar
Classification loss, returned as a numeric scalar. The
loss
function computes the classification loss specified
by the LossFun
value
used when creating thresholder
. The function uses the data set
predictions, adjusted using the thresholder.ScoreThreshold
value. For
more information, see Reject Option-Based Classification.
Version History
Introduced in R2023a
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: United States.
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)