knitr::opts_chunk$set(warning = FALSE, message = FALSE)
library(dplyr)
library(flashlight)
library(plotly)
library(ranger)
library(lme4)
library(moderndive)
library(splitTools)
library(MetricsWeighted)
set.seed(4933)
There are different R packages devoted to model agnostic interpretability, DALEX and iml being among the best known. A couple of years ago, I programmed the alternative flashlight for different reasons, such as:
Since almost all plots in flashlight are constructed with ggplot
, it is super easy to turn them into interactive plotly objects: just add a simple ggplotly()
to the end of the call.
We will use a sweet dataset with more than 20’000 houses to model house prices by a set of derived features such as the logarithmic living area. The location will be represented by the postal code.
We first load the data and prepare some of the columns for modeling. Furthermore, we specify the set of features and the response.
data("house_prices")
prep <- house_prices %>%
mutate(
log_price = log(price),
log_sqft_living = log(sqft_living),
log_sqft_lot = log(sqft_lot),
log_sqft_basement = log1p(sqft_basement),
year = as.numeric(format(date, '%Y')),
age = year - yr_built
)
x <- c(
"year", "age", "log_sqft_living", "log_sqft_lot",
"bedrooms", "bathrooms", "log_sqft_basement",
"condition", "waterfront", "zipcode"
)
y <- "log_price"
head(prep[c(y, x)])
## # A tibble: 6 x 11
## log_price year age log_sqft_living log_sqft_lot bedrooms bathrooms
## <dbl> <dbl> <dbl> <dbl> <dbl> <int> <dbl>
## 1 12.3 2014 59 7.07 8.64 3 1
## 2 13.2 2014 63 7.85 8.89 3 2.25
## 3 12.1 2015 82 6.65 9.21 2 1
## 4 13.3 2014 49 7.58 8.52 4 3
## 5 13.1 2015 28 7.43 9.00 3 2
## 6 14.0 2014 13 8.60 11.5 4 4.5
## # ... with 4 more variables: log_sqft_basement <dbl>, condition <fct>,
## # waterfront <lgl>, zipcode <fct>
Then, we split the dataset into 80% training and 20% test rows, stratified on the (binned) response log_price
.
idx <- partition(prep[[y]], c(train = 0.8, test = 0.2), type = "stratified")
train <- prep[idx$train, ]
test <- prep[idx$test, ]
We fit two models:
# Mixed-effects model
fit_lmer <- lmer(
update(reformulate(x, "log_price"), . ~ . - zipcode + (1 | zipcode)),
data = train
)
# Random forest
fit_rf <- ranger(
reformulate(x, "log_price"),
always.split.variables = "zipcode",
data = train
)
cat("R-squared OOB:", fit_rf$r.squared)
## R-squared OOB: 0.8463311
Now, we are ready to inspect our two models regarding performance, variable importance, and effects.
First, we pack all model dependent information into flashlights (the explainer objects) and combine them to a multiflashlight. As evaluation dataset, we pass the test data. This ensures that interpretability tools using the response (e.g., performance measures and permutation importance) are not being biased by overfitting.
fl_lmer <- flashlight(model = fit_lmer, label = "LMER")
fl_rf <- flashlight(
model = fit_rf,
label = "RF",
predict_function = function(mod, X) predict(mod, X)$predictions
)
fls <- multiflashlight(
list(fl_lmer, fl_rf),
y = "log_price",
data = test,
metrics = list(RMSE = rmse, `R-squared` = r_squared)
)
Let’s evaluate model RMSE and R-squared on the hold-out dataset. Here, the mixed-effects model performs a tiny little bit better than the random forest:
(light_performance(fls) %>%
plot(fill = "darkred") +
labs(title = "Model performance", x = element_blank())) %>%
ggplotly()
Next, we inspect the variable strength based on permutation importance. It shows by how much the RMSE is being increased when shuffling a variable before prediction. The results are quite similar between the two models.
(light_importance(fls, v = x) %>%
plot(fill = "darkred") +
labs(title = "Permutation importance", y = "Drop in RMSE")) %>%
ggplotly()
To get an impression of the effect of the living area, we select 200 observations and profile their predictions with increasing (log) living area, keeping everything else fixed (Ceteris Paribus). These ICE (individual conditional expectation) plots are vertically centered in order to highlight potential interaction effects. If all curves coincide, there are no interaction effects and we can say that the effect of the feature is modelled in an additive way (no surprise for the additive linear mixed-effects model).
(light_ice(fls, v = "log_sqft_living", n_max = 200, center = "middle") %>%
plot(alpha = 0.05, color = "darkred") +
labs(title = "Centered ICE plot", y = "log_price (shifted)")) %>%
ggplotly()
Averaging many uncentered ICE curves provides the famous partial dependence plot, introduced in Friedman’s seminal paper on gradient boosting machines (2001).
(light_profile(fls, v = "log_sqft_living", n_bins = 21) %>%
plot(rotate_x = FALSE) +
labs(title = "Partial dependence plot", y = y) +
scale_colour_viridis_d(begin = 0.2, end = 0.8)) %>%
ggplotly()
The last figure extends the partial dependence plot with three additional curves, all evaluated on the hold-out dataset:
(light_effects(fls, v = "log_sqft_living", n_bins = 21) %>%
plot(use = "all") +
labs(title = "Different effect estimates", y = y) +
scale_colour_viridis_d(begin = 0.2, end = 0.8)) %>%
ggplotly()
Combining flashlight
with plotly
works well and provides crystal clear plots. They look quite cool if shipped in an HTML report (like this one…).