About Me Blog
A tutorial on tidy cross-validation with R Analyzing NetHack data, part 1: What kills the players Analyzing NetHack data, part 2: What players kill the most Building a shiny app to explore historical newspapers: a step-by-step guide Classification of historical newspapers content: a tutorial combining R, bash and Vowpal Wabbit, part 1 Classification of historical newspapers content: a tutorial combining R, bash and Vowpal Wabbit, part 2 Curly-Curly, the successor of Bang-Bang Dealing with heteroskedasticity; regression with robust standard errors using R Easy time-series prediction with R: a tutorial with air traffic data from Lux Airport Exporting editable plots from R to Powerpoint: making ggplot2 purrr with officer Fast food, causality and R packages, part 1 Fast food, causality and R packages, part 2 For posterity: install {xml2} on GNU/Linux distros Forecasting my weight with R From webscraping data to releasing it as an R package to share with the world: a full tutorial with data from NetHack Get text from pdfs or images using OCR: a tutorial with {tesseract} and {magick} Getting data from pdfs using the pdftools package Getting the data from the Luxembourguish elections out of Excel Going from a human readable Excel file to a machine-readable csv with {tidyxl} Historical newspaper scraping with {tesseract} and R How Luxembourguish residents spend their time: a small {flexdashboard} demo using the Time use survey data Imputing missing values in parallel using {furrr} Intermittent demand, Croston and Die Hard Looking into 19th century ads from a Luxembourguish newspaper with R Making sense of the METS and ALTO XML standards Manipulate dates easily with {lubridate} Manipulating strings with the {stringr} package Maps with pie charts on top of each administrative division: an example with Luxembourg's elections data Missing data imputation and instrumental variables regression: the tidy approach Modern R with the tidyverse is available on Leanpub Objects types and some useful R functions for beginners Pivoting data frames just got easier thanks to `pivot_wide()` and `pivot_long()` R or Python? Why not both? Using Anaconda Python within R with {reticulate} Searching for the optimal hyper-parameters of an ARIMA model in parallel: the tidy gridsearch approach Some fun with {gganimate} Split-apply-combine for Maximum Likelihood Estimation of a linear model Statistical matching, or when one single data source is not enough The best way to visit Luxembourguish castles is doing data science + combinatorial optimization The never-ending editor war (?) The year of the GNU+Linux desktop is upon us: using user ratings of Steam Play compatibility to play around with regex and the tidyverse Using Data Science to read 10 years of Luxembourguish newspapers from the 19th century Using a genetic algorithm for the hyperparameter optimization of a SARIMA model Using cosine similarity to find matching documents: a tutorial using Seneca's letters to his friend Lucilius Using linear models with binary dependent variables, a simulation study Using the tidyverse for more than data manipulation: estimating pi with Monte Carlo methods What hyper-parameters are, and what to do with them; an illustration with ridge regression {disk.frame} is epic {pmice}, an experimental package for missing data imputation in parallel using {mice} and {furrr} Building formulae Functional peace of mind Get basic summary statistics for all the variables in a data frame Getting {sparklyr}, {h2o}, {rsparkling} to work together and some fun with bash Importing 30GB of data into R with sparklyr Introducing brotools It's lists all the way down It's lists all the way down, part 2: We need to go deeper Keep trying that api call with purrr::possibly() Lesser known dplyr 0.7* tricks Lesser known dplyr tricks Lesser known purrr tricks Make ggplot2 purrr Mapping a list of functions to a list of datasets with a list of columns as arguments Predicting job search by training a random forest on an unbalanced dataset Teaching the tidyverse to beginners Why I find tidyeval useful tidyr::spread() and dplyr::rename_at() in action Easy peasy STATA-like marginal effects with R Functional programming and unit testing for data munging with R available on Leanpub How to use jailbreakr My free book has a cover! Work on lists of datasets instead of individual datasets by using functional programming Method of Simulated Moments with R New website! Nonlinear Gmm with R - Example with a logistic regression Simulated Maximum Likelihood with R Bootstrapping standard errors for difference-in-differences estimation with R Careful with tryCatch Data frame columns as arguments to dplyr functions Export R output to a file I've started writing a 'book': Functional programming and unit testing for data munging with R Introduction to programming econometrics with R Merge a list of datasets together Object Oriented Programming with R: An example with a Cournot duopoly R, R with Atlas, R with OpenBLAS and Revolution R Open: which is fastest? Read a lot of datasets at once with R Unit testing with R Update to Introduction to programming econometrics with R Using R as a Computer Algebra System with Ryacas

Explainbility of {tidymodels} models with {iml}

In my previous blog post, I have shown how you could use {tidymodels} to train several machine learning models. Now, let’s take a look at getting some explanations out of them, using the {iml} package. Originally I did not intend to create a separate blog post, but I have encountered… an issue, or bug, when using both {iml} and {tidymodels} and I felt that it was important that I write about it. Maybe it’s just me that’s missing something, and you, kind reader, might be able to give me an answer. But let’s first reload the models from last time (the same packages as on the previous blog post are loaded):

trained_models_list
## [[1]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 5]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 5]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 5]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 5]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 5]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 5]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 5]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 5]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 5]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 5]> <tibble [1 × 1]>
## 
## [[2]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 5]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 5]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 5]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 5]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 5]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 5]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 5]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 5]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 5]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 5]> <tibble [1 × 1]>
## 
## [[3]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 5]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 5]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 5]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 5]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 5]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 5]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 5]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 5]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 5]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 5]> <tibble [1 × 1]>
## 
## [[4]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 7]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 7]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 7]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 7]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 7]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 7]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 7]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 7]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 7]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 7]> <tibble [1 × 1]>
## 
## [[5]]
## #  10-fold cross-validation 
## # A tibble: 10 x 4
##    splits               id     .metrics          .notes          
##  * <list>               <chr>  <list>            <list>          
##  1 <split [23.6K/2.6K]> Fold01 <tibble [20 × 5]> <tibble [1 × 1]>
##  2 <split [23.6K/2.6K]> Fold02 <tibble [20 × 5]> <tibble [1 × 1]>
##  3 <split [23.6K/2.6K]> Fold03 <tibble [20 × 5]> <tibble [1 × 1]>
##  4 <split [23.6K/2.6K]> Fold04 <tibble [20 × 5]> <tibble [1 × 1]>
##  5 <split [23.6K/2.6K]> Fold05 <tibble [20 × 5]> <tibble [1 × 1]>
##  6 <split [23.6K/2.6K]> Fold06 <tibble [20 × 5]> <tibble [1 × 1]>
##  7 <split [23.6K/2.6K]> Fold07 <tibble [20 × 5]> <tibble [1 × 1]>
##  8 <split [23.6K/2.6K]> Fold08 <tibble [20 × 5]> <tibble [1 × 1]>
##  9 <split [23.6K/2.6K]> Fold09 <tibble [20 × 5]> <tibble [1 × 1]>
## 10 <split [23.6K/2.6K]> Fold10 <tibble [20 × 5]> <tibble [1 × 1]>

Let’s see which of the models performed best (in cross-validation):

trained_models_list %>%
  map(show_best, metric = "accuracy", n = 1)
## [[1]]
## # A tibble: 1 x 7
##    penalty mixture .metric  .estimator  mean     n std_err
##      <dbl>   <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
## 1 6.57e-10  0.0655 accuracy binary     0.916    10 0.00179
## 
## [[2]]
## # A tibble: 1 x 7
##    mtry trees .metric  .estimator  mean     n std_err
##   <int> <int> <chr>    <chr>      <dbl> <int>   <dbl>
## 1    13  1991 accuracy binary     0.929    10 0.00172
## 
## [[3]]
## # A tibble: 1 x 7
##   num_terms prune_method .metric  .estimator  mean     n std_err
##       <int> <chr>        <chr>    <chr>      <dbl> <int>   <dbl>
## 1         5 backward     accuracy binary     0.904    10 0.00186
## 
## [[4]]
## # A tibble: 1 x 9
##    mtry trees tree_depth learn_rate .metric  .estimator  mean     n std_err
##   <int> <int>      <int>      <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
## 1    12  1245         12     0.0770 accuracy binary     0.929    10 0.00175
## 
## [[5]]
## # A tibble: 1 x 7
##   hidden_units    penalty .metric  .estimator  mean     n std_err
##          <int>      <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
## 1           10 0.00000307 accuracy binary     0.917    10 0.00209

Seems like the second model, the random forest performed the best (highest mean accuracy with lowest standard error). So let’s retrain the model on the whole training set and see how it fares on the testing set:

rf_specs <- trained_models_list[[2]]

Let’s save the best model specification in a variable:

best_rf_spec <- show_best(rf_specs, "accuracy", 1)

Let’s now retrain this model, using a workflow:

best_rf_model <- rand_forest(mode = "classification", mtry = best_rf_spec$mtry,
                           trees = best_rf_spec$trees) %>%
  set_engine("ranger")

preprocess <- recipe(job_search ~ ., data = pra) %>%
  step_dummy(all_predictors())

pra_wflow_best <- workflow() %>%
  add_recipe(preprocess) %>%
  add_model(best_rf_model)

best_model_fitted <- fit(pra_wflow_best, data = pra_train)
## Warning: The following variables are not factor vectors and will be ignored:
## `hours`

and let’s take a look at the confusion matrix:

predictions <- predict(best_model_fitted, new_data = pra_test) %>%
  bind_cols(pra_test)

predictions %>%
  mutate(job_search = as.factor(job_search)) %>%  
  accuracy(job_search, .pred_class)
## # A tibble: 1 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.924
predictions %>%
  mutate(job_search = as.factor(job_search)) %>%  
  conf_mat(job_search, .pred_class) 
##           Truth
## Prediction    N    S
##          N 2539  156
##          S   64  149

We see that predicting class S (“Si”, meaning, “yes” in Spanish) is tricky. One would probably need to use techniques such as SMOTE to deal with this (see this blog post for more info). Anyways, this is not today’s topic.

Let’s say that we are satisfied with the model and want some explanations out of it. I have already blogged about it in the past, so if you want more details, you can read this blog post.

Now, what is important, is that I have defined a complete workflow to deal with the data preprocessing and then the training of the model. So I’ll be using this workflow as well to get my explainability. What I mean with this is the following: to get explanations, we need a model, and a way to get predictions out of it. As I have shown before, my fitted workflow is able to give me predictions. So I should have every needed ingredient; {iml}, the package that I am using for explainability provides several functions that work all the same; you first define an object that takes as an input the fitted model, the design matrix, the target variable and the prediction function. Let’s start with defining the design matrix and the target variable:

library("iml")

features <- pra_test %>%
  select(-job_search)

target <- pra_test %>%
  mutate(job_search = as.factor(job_search)) %>%  
  select(job_search)

Now, let’s define the predict function:

predict_wrapper <- function(model, newdata){
  workflows:::predict.workflow(object = model, new_data = newdata)
}

Because a workflow is a bit special, I need to define this wrapper function that wraps the workflows:::predict.workflow() function. Normally, users should not have to deal with this function; as you can see, to access it I had to use the very special ::: function. ::: permits users to access private functions (not sure if this is the right term; what I mean is that private functions are used internally by the package and should not be available to users. AFAIK, this is how these functions are called in Python). I tried simply using the predict() function, which works interactively but I was getting issues with it when I was providing it to the constructor below:

predictor <- Predictor$new(
                         model = best_model_fitted,
                         data = features, 
                         y = target,
                         predict.fun = predict_wrapper
                       )

This creates a Predictor object from which I am now able to get explanations. For example, for feature importance, I would write the following:

feature_importance <- FeatureImp$new(predictor, loss = "ce")

plot(feature_importance)

And this is where I noticed that something was wrong; the variables we are looking at are categorical variables. So why am I not seeing the categories? Why is the most important variable the contract type, without the category of the contract type that is the most important? Remember that I created dummy variables using a recipe. So I was expecting something like type_of_contract_type_1, type_of_contract_type_2, etc… as variables.

This made me want to try to fit the model “the old way”, without using workflows. So for this I need to use the prep(), juice() and bake() functions, which are included in the {recipes} package. I won’t go into much detail, but the idea is that prep() is used to train the recipe, and compute whatever is needed to preprocess the data (such as means and standard deviations for normalization). For this, you should use the training data only. juice() returns the preprocessed training set, and bake() is then used to preprocessed a new data set, for instance the test set, using the same estimated parameters that were obtained with prep().

Using workflows avoids having to do these steps manually, but what I am hoping is that doing this manually will solve my issue. So let’s try:

# without workflows
trained_recipe <- prep(preprocess, training = pra_train)
## Warning: The following variables are not factor vectors and will be ignored:
## `hours`
pra_train_prep <- juice(trained_recipe)


best_model_fit <- fit(best_rf_model, job_search ~ ., data = pra_train_prep)


pra_test_bake_features <- bake(trained_recipe, pra_test) %>%
  select(-job_search)


predict_wrapper2 <- function(model, newdata){
  predict(object = model, new_data = newdata)
}

predictor2 <- Predictor$new(
                          model = best_model_fit,
                          data = pra_test_bake_features, 
                          y = target,
                          predict.fun = predict_wrapper2
                        )

feature_importance2 <- FeatureImp$new(predictor2, loss = "ce")

plot(feature_importance2)

Eureka! As you can see, the issue is now solved; we now have all the variables that were used for training the model, also in our explanations. I don’t know exactly what’s going on; is this a bug? Is it because the {workflows} package makes this process too streamlined that it somehow rebuilds the features and then returns the results? I have no idea. In any case, it would seem that for the time being, doing the training and explanations without the {workflows} package is the way to go if you require explanations as well.

Hope you enjoyed! If you found this blog post useful, you might want to follow me on twitter for blog post updates and watch my youtube channel. If you want to support my blog and channel, you could buy me an espresso or paypal.me, or buy my ebook on Leanpub.

Buy me an EspressoBuy me an Espresso