-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtutorial_medical_expenditure.Rmd
749 lines (537 loc) · 19 KB
/
tutorial_medical_expenditure.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
---
title: "Assessing and Addressing Fairness in a Predictive Model of Medical Expenditure"
date: "July 18 2023"
output:
html_document:
df_print: paged
theme: journal
editor_options:
markdown:
wrap: 72
---
This notebook is based on the python tutorial from IBM's
[AIF360](https://nbviewer.org/github/IBM/AIF360/blob/master/examples/tutorial_medical_expenditure.ipynb)
# Dataset
The dataset used in this tutoral was adapted from the ['Medical
Expenditure Panel Survey'](https://meps.ahrq.gov/mepsweb/).
- The specific data used is the 2015 Full Year Consolidated Data File
(h181.csv) as well as the 2016 Full Year Consolidated Data File
(h192.csv).
- The 2015 file contains data from rounds 3,4,5 of panel 19 (2014) and
rounds 1,2,3 of panel 20 (2015).
- The 2016 file contains data from rounds 3,4,5 of panel 20 (2015) and
rounds 1,2,3 of panel 21 (2016).
- The panel 19 dataset is split into 3 parts: a train, a validation,
and a test/deployment part.
- The panel 21 data is used for testing data/concept drift after model
training.
- For each dataset, the sensitive attribute is 'RACE' consisting of
'Whites' defined by the features RACEV2X = 1 (White) and HISPANX = 2
(non Hispanic) and 'Non-Whites' that included everyone else.
- Our modelling goal is to predict 'high' utilization. To measure
utilization, the total number of trips requiring some sort of
medical care was recorded. Respondents were considered to have high
utilization if they had 10 or more visits (\~17%).
# Tidymodels
The [`tidymodels`](https://www.tidymodels.org/packages/) collection of R
packages simplifies and standardizes the process of modeling and machine
learning in R by providing a consistent framework that integrates with
the tidyverse. It emphasizes tidy data principles and offers a wide
range of modeling techniques, tuning, and evaluation functionalities.
The idea is to specify `recipes` that simplify the pre and post processing steps
typially used in predicitive modeling (or model fitting in general).
# Downloading and Preprocessing data
**Documentation:**
<https://github.com/Trusted-AI/AIF360/blob/master/aif360/data/raw/meps/README.md>
Make sure to read and abide by the data use agreements in the above
link.
Download and run the data preparation script
```{bash, echo = TRUE, eval = FALSE}
curl https://raw.githubusercontent.com/Trusted-AI/AIF360/master/aif360/data/raw/meps/generate_data.R --output generate_data.R
Rscript generate_data.R
```
Additional setup for aif360
```{bash, echo = TRUE, eval = FALSE}
conda create -n aif360 python=3 pip
conda activate aif360
pip install aif360
```
## 1. Load required libraries and Panel 19 data
```{r message=FALSE}
library(ggplot2)
library(patchwork)
library(skimr)
library(knitr)
library(rmarkdown)
library(pander)
#optional
library(reticulate)
use_condaenv("aif360")
source("helpers.R")
meps_19 <- process_meps_csv("h181.csv", 19)
#optional
readr::write_csv(meps_19, file="h181-processed.csv")
```
## 2. Create training/validation/testing model splits
As in the aif360 tutorial we will split MEPS 19 into train (50%),
validate (30%), and test (20%)
However, because `high_utilization` is unbalanced, we will use it as a
stratification factor to ensure equal representation among the splits
```{r}
meps_19 %>%
count(high_utilization) %>%
mutate(prop = n/sum(n)) %>%
kable()
```
Additionally, since we first want direct access to the training dataset
for reweighting, we will split 50/50 and further split the *test*
dataset into validation/test.
*NOTE:* This is not the **typical** workflow due to the apparent need to
define the case weights outside of the pre-processing recipe:
<https://www.tidyverse.org/blog/2022/05/case-weights/>
```{r}
set.seed(123)
meps_19_split <- initial_split(meps_19, strata = high_utilization, prop=.5)
```
We can now carry out fairness re-weighing applied to training set only
```{r}
meps_19_train <- training(meps_19_split)
#R-version from helpers.R
meps_19_train <- reweight(
meps_19_train,
"high_utilization",
"RACE")
#What do these weights look like?
meps_19_train %>%
group_by(
high_utilization,
RACE
) %>%
summarize(
n(),
#only the first as the weights are define per combination of outcome/feature
instance_weights[1]
) %>%
kable()
```
Alternatively can use the python version
```{python}
from aif360.sklearn.preprocessing import Reweighing
import pandas as pd
reweighing_obj = Reweighing(prot_attr='RACE')
inp_df = r.meps_19_train.set_index('RACE', drop=False)
python_weights = reweighing_obj.fit_transform(X=inp_df, y=r.meps_19_train.high_utilization)[0]
python_weights.reset_index(drop=True, inplace=True)
python_weights.groupby(['high_utilization', 'RACE']).agg({'instance_weights':'first'})
```
With the training dataset in hand, we will now form testing/validation
splits.
```{r}
meps_19_inital_test <- testing(meps_19_split)
#As we only need the weights for the training data, set to NA for clarity
meps_19_inital_test$instance_weights <- NA_real_
set.seed(4353)
meps_19_tv_split <- initial_split(meps_19_inital_test, strata = high_utilization, prop=.6)
#To keep things more organized
meps_19_splits <- list(
training = meps_19_train,
validation = training(meps_19_tv_split),
testing = testing(meps_19_tv_split)
)
rm(meps_19_train, meps_19_tv_split)
#Summarize the data splits
bind_rows(lapply(meps_19_splits, function(x){
x %>%
count(high_utilization) %>%
mutate(
prop = n/sum(n),
total = sum(n)) %>%
pivot_wider(id_cols = total, names_from = high_utilization, values_from = prop)
}), .id="split") %>%
kable()
```
## 3. Learning a Logistic Regression (elastic net) classifier
We will fit two versions of the model, one without and one with the
weights
*NOTE:* We are using the same variables as the AIF360 python tutorial
but some additional feature selection and engineering would probably be
useful.
For instance:
1. How best to deal with feature encoding, including presence of -1
which indicates 'Inapplicable'
2. Over-specified categories such as the presence of SEX=1 and SEX=2,
either of which can be derived from each other.
3. In addition to terms of use, the variables are defined
[here](https://meps.ahrq.gov/data_stats/download_data/pufs/h181/h181doc.shtml#Data)
With that in mind we will train an initial unweighted version
### 3.1 Original (unweighted) classifier
```{r}
#specify model and fitting backend
lr_mod <-
logistic_reg(penalty = tune(), mixture = tune()) %>%
set_engine("glmnet")
#get pre-specified recipe from helpers.R
lr_recipe <- meps_base_recipe(meps_19_splits$training)
#add recipe to a workflow
lr_workflow <-
workflow() %>%
add_model(lr_mod) %>%
add_recipe(lr_recipe)
#To use tune_grid, need to form training/validation sets, note this is not the standard workflow
#which is `validation_split` of the training dataset.
#Here we will form the sets manually
train_val_sets <-
make_splits(meps_19_splits$training, assessment=meps_19_splits$validation) %>%
lst %>%
manual_rset(ids="Validation")
set.seed(9534)
#Performs the actual tuning with respect to the validation set, we will measure balanced accuracy here
#for comparison with the aif360 tutorial.
#note that control_grid is needed to save probabilities for downstream assessement
#in addition to balanced accuracy, we also collect area under the ROC curve as that
#also forces `tune_grid` to save probabilities
lr_res <-
lr_workflow %>%
tune_grid(train_val_sets,
grid = 10,
control=control_grid(save_pred = TRUE),
metrics = metric_set(bal_accuracy, roc_auc))
#Simply choose the parameters that give the best accuracy
lr_best <-
lr_res %>%
select_best(metric = "bal_accuracy")
lr_best %>% kable()
```
#### 3.1.1 Assess alternate probability cutoffs
By 'default' the probability cutoff is usually .5. Especially as our
data is unbalanced, we can also examine other possible cutoffs with
respect to balanced accuracy and our measure of fairness: **disparate
impact**.
```{r}
lr_stats <-
lr_res %>%
collect_predictions(parameters = lr_best)
#get ahold of the validation set annotated with predicted probabilities
#also ensure the protected variable (RACE) is encoded as a factor with
#non-protected level, protected level.
meps_validation_results <-
bind_cols(
assessment(train_val_sets$splits[[1]]),
select(lr_stats, .pred_high, .pred_low, .pred_class)
) %>%
mutate(
RACE = factor(RACE, levels=c("White", "Non-White"))
)
validation_probs_thresh <-
alternate_cutoff_stats(
meps_validation_results,
outcome="high_utilization",
protected_feature="RACE"
)
#Choose the best one
best_acc = validation_probs_thresh %>%
arrange(-balanced_accuracy) %>%
head(n=1)
```
Examine confusion matrices for default prediction cutoffs
```{r}
validation_preds_by_race <- group_split(meps_validation_results, RACE)
#since list names are not supplied by `group_split`
names(validation_preds_by_race) <- sapply(validation_preds_by_race, function(x) x$RACE[1])
lapply(validation_preds_by_race, function(x){
conf_mat(x, .pred_class, high_utilization, dnn=c("Utilization", "Predicted Utilization"))
})
```
```{r}
table(validation_preds_by_race$White$high_utilization)
```
Better accuracy is achieved if we choose a lower cutoff
```{r}
bind_rows(
filter(validation_probs_thresh, p_cutoff == .5),
best_acc
) %>% kable()
```
```{r}
#Add to our summary table
best_params = bind_cols(
version="Original",
best_acc,
lr_best
)
best_params %>% kable()
```
Note that although the balanced accuracy is **OK**, the disparate impact
measure shows room for improvement as we are aiming for a value closer
to 1.0. To get closer to reaching this goal we can turn to the
relatively simple approach utilizing different weights for each outcome
and protected feature category. We saw this in Section 2.
### 3.2. Increasing classifier fairness by reweighting
For case-weights to work with `tidymodels`, they need to be (re)defined
by passing them to the `importance_weights()` function. The defined
recipe will automatically detect them. See this
[reference](https://recipes.tidymodels.org/reference/case_weights.html)
for more info.
```{r}
meps_19_splits_rw <- lapply(meps_19_splits, function(x){
mutate(x, instance_weights = importance_weights(instance_weights))
})
lr_recipe_rw <- meps_base_recipe(meps_19_splits_rw$training, addtl_vars="instance_weights")
#Verify that this worked
filter(
lr_recipe_rw$var_info,
role != "predictor"
) %>% kable()
```
In addition, the case weights also need to be added to the workflow
using the `add_case_weights` function.
```{r}
lr_workflow_rw <-
workflow() %>%
add_model(lr_mod) %>%
add_recipe(lr_recipe_rw) %>%
add_case_weights(instance_weights)
#Unforunately we can't re-use existing training/validation sets as they need to have the defined weight column
train_val_sets_rw <-
make_splits(meps_19_splits_rw$training, assessment=meps_19_splits_rw$validation) %>%
lst %>%
manual_rset(ids="Reweighted Validation")
set.seed(9534)
#As before, control_grid is needed to save probabilities for downstream assessement
lr_res_rw <-
lr_workflow_rw %>%
tune_grid(train_val_sets_rw,
grid = 10,
control=control_grid(save_pred = TRUE),
metrics = metric_set(bal_accuracy, roc_auc))
#Choosing the best model parameters based on balanced accuracy
lr_best_rw <-
lr_res_rw %>%
select_best(metric = "bal_accuracy")
lr_best_rw %>% kable()
```
#### 3.2.1 Assess alternate probability cutoffs
As before we will assess alternative probability cutoffs with respect to
balanced accuracy and fairness as measured by disparate impact.
```{r}
lr_stats_rw <-
lr_res_rw %>%
collect_predictions(parameters = lr_best_rw)
#get ahold of the validation set
meps_validation_results_rw <-
bind_cols(
assessment(train_val_sets_rw$splits[[1]]),
select(
lr_stats_rw,
.pred_high,
.pred_low,
.pred_class)
) %>%
mutate(
RACE = factor(RACE, levels=c("White", "Non-White")
)
)
validation_probs_thresh_rw <-
alternate_cutoff_stats(
meps_validation_results_rw,
outcome="high_utilization",
protected_feature="RACE"
)
best_acc_rw = validation_probs_thresh_rw %>%
arrange(-balanced_accuracy) %>%
head(n=1)
best_params_rw = bind_cols(
version="Reweighted",
best_acc_rw,
lr_best_rw
)
comb_params <- bind_rows(
best_params,
best_params_rw
)
comb_params %>% kable()
```
Interestingly, reweighting results in a fairly dramatic increase in
fairness with a .01 decrease in balanced accuracy at the corresponding
optimal probability cutoffs.
## 4. Visualize relationship between balanced accuracy and disparate impact
Since we have both a standard classifier and one utilizing fairness
weights we can compare their results visually
```{r}
comb_orig_rw_thresh <- bind_rows(
mutate(validation_probs_thresh, version="Original"),
mutate(validation_probs_thresh_rw, version="Reweighted")
)
```
```{r, echo=FALSE, eval=FALSE}
save(comb_orig_rw_thresh, file="tmp_validation_results.RData")
```
```{r, fig.width=9.5, fig.height=7}
#the 1-p_cutoff refers to the issue that we want to calibrate for the high/less frequent class, but the calibrations were
#done for the base level or low class
p0 <- ggplot(data=comb_orig_rw_thresh, mapping=aes(x=p_cutoff, y=balanced_accuracy)) +
geom_line(linewidth=2) +
geom_vline(data=comb_params, mapping=aes(xintercept = p_cutoff), linetype="dashed") +
facet_wrap(~version, ncol=2) +
theme_bw() +
xlab("") +
ylab("Balanced Accuracy")
decons_di <- pivot_longer(data=comb_orig_rw_thresh, cols=`White`:disparate_impact)
p1 <- ggplot(data=filter(decons_di, name != "disparate_impact"), mapping=aes(x=p_cutoff, y=value, group=name, color=name)) +
geom_line(linewidth=2) +
geom_vline(data=comb_params, mapping=aes(xintercept = p_cutoff), linetype="dashed") +
scale_color_discrete(name="Group") +
facet_wrap(~version, ncol=2) +
theme_bw() +
xlab("") +
ylab("Prop. of Pred. High Utilization")
p2 <- ggplot(data=filter(decons_di, name == "disparate_impact"), mapping=aes(x=p_cutoff, y=value)) +
geom_line(linewidth=2) +
geom_hline(yintercept=1, linetype="dashed") +
geom_vline(data=comb_params, mapping=aes(xintercept = p_cutoff), linetype="dashed") +
facet_wrap(~version, ncol=2) +
theme_bw() +
xlab("Classification Threshold (High Utilization)") +
ylab("DI")
p0 / p1 / p2
```
## 5. Application to test set
Assuming we are happy with the performance of this model in the
validation set, we can fit to the entire training set and evaluate the
final performance on the test dataset.
Note that because we formed the splits outside of `tidymodels` this is a
slightly more complex procedure than simply calling `last_fit`.
```{r}
#Here we need to specify the actual optimal parameters we want to use for the model
lr_mod_rw <- lr_mod %>%
update(parameters=
select(
filter(
comb_params,
version == "Reweighted"
),
penalty,
mixture
))
#Add this new model to the workflow, replacing the old version with placeholder variables
#used for tuning
lr_workflow_rw <-
lr_workflow_rw %>%
update_model(lr_mod_rw)
#Fit the model
lr_rw_trained <-
lr_workflow_rw %>%
fit(data = meps_19_splits_rw$training)
#Determine predictions for the test data for final assessment
lr_rw_test_aug <-
augment(
lr_rw_trained,
meps_19_splits_rw$testing
) %>%
mutate(
RACE = factor(RACE, levels=c("White", "Non-White"))
)
#Finally, using the optimal probability cutoffs we will compute the corresponding
#balanced accuracy and
lr_testing_results_rw <- cutoff_stats(
lr_rw_test_aug,
cutoff=filter(comb_params,
version == "Reweighted")$p_cutoff,
outcome="high_utilization",
protected_feature="RACE"
)
test_params_rw = cbind(version="Test",lr_testing_results_rw, lr_best_rw)
comb_params <- bind_rows(comb_params, test_params_rw)
comb_params %>% kable()
```
Application to the testing dataset shows dips in both balanced accuracy
and disparate impact, so it is both less accurate and less fair.
## 5. Assessing data/concept drift
Given that the data was trained from panel 19, we can also evaluate its
performance on other subsets of the data such as panel 21. Decreases in
perfomance and/or fairness indicates we need to retrain the model.
Read in and prepare panel 21. In this case it is derived from another
file, so we should take care to ensure that it is comparable.
```{r}
meps_21 <- process_meps_csv("h192.csv", 21)
#also add in placeholder `instance_weights` column
meps_21 <- mutate(
meps_21,
instance_weights=importance_weights(NA_real_)
)
```
Since we are using a pre-processing recipe it should fail if columns are
missing, additionally we could implement type checks as part of the
recipe. Finally, below we can examine how similar panel 21 is to the
portion of panel 19 used for training by comparing the computed values.
```{r, layout="l-body-outset"}
#We use the `prep` and `bake` functions to apply the specific pre-processing recipe and
#retrieve the input data in the same form as provided to the model
meps_19_prepped <-
prep(
lr_recipe_rw,
training = meps_19_splits_rw$training
)
meps_19_baked <- bake(meps_19_prepped, new_data=NULL)
#Note that the factor `high_utilization` is not particulary useful in this plot
#this just simplifies plotting
meps_19_skim <- skim(
meps_19_baked,
-instance_weights,
-high_utilization
) %>% partition()
#apply the same transformations to meps_21 and create a summary table
meps_21_baked <- bake(meps_19_prepped, new_data = meps_21)
meps_21_skim <- skim(
meps_21_baked,
-instance_weights,
-high_utilization
) %>% partition()
#Note, the use of a `full_join` here will show any differences in columns
numeric_combined_skim <- full_join(
as_tibble(meps_19_skim$numeric),
as_tibble(meps_21_skim$numeric),
by=c("skim_variable"),
suffix = c(".19", ".21"),
)
select(
numeric_combined_skim,
skim_variable,
mean.19, mean.21,
sd.19, sd.21,
p0.19, p0.21,
p25.19, p25.21,
p50.19, p50.21,
p75.19, p75.21,
p100.19, p100.21,
hist.19, hist.21
) %>% paged_table()
```
Note that most of these variables are categorical variables encoded as
0/1 so the `mean.*` columns specify the proportion.
```{r}
panel21_tbl_aug <-
augment(
lr_rw_trained,
meps_21
) %>%
mutate(
RACE = factor(RACE, levels=c("White", "Non-White"))
)
panel21_tbl_stats <- cutoff_stats(
panel21_tbl_aug,
cutoff=filter(comb_params,
version == "Reweighted")$p_cutoff,
outcome="high_utilization",
protected_feature="RACE"
)
panel21_params_rw = cbind(version="Panel 21",panel21_tbl_stats, lr_best_rw)
comb_params <- bind_rows(comb_params, panel21_params_rw)
comb_params %>% kable()
```
The model looks fairly close to the performance of the Panel 19 test
data so we can stop here. A this point the model can be retrained if
needed (See AIF360 Python tutorial).
```{r, R.options = list(width = 270)}
pander(sessionInfo())
```