Overview

What makes a ML model black-box? It’s the interactions!

The first step in understanding interactions is to measure their strength. This is exactly what Friedman and Popescu’s H-statistics [1] do:

Statistic Short description How to read its value?
Hj2 Overall interaction strength per feature Proportion of prediction variability explained by interactions on feature j.
Hjk2 Pairwise interaction strength Proportion of joint effect variability of features j and k coming from their pairwise interaction.
Hjkl2 Three-way interaction strength Proportion of joint effect variability of three features coming from their three-way interaction.

See section Background for details and definitions.

{hstats} offers these statistics comparably fast and for any model, even for multi-output models, or models with case weights. Additionally, we provide a global statistic H2 measuring the proportion of prediction variability unexplained by main effects [5], and an experimental feature importance measure. After having identified strong interactions, their shape can be investigated by stratified partial dependence or ICE plots.

The core functions hstats(), partial_dep(), ice(), perm_importance(), and average_loss() can directly be applied to DALEX explainers, meta learners (mlr3, tidymodels, caret) and most other models. In case you need more flexibility, a tailored prediction function can be specified. Both data.frame and matrix data structures are supported.

Limitations

  1. H-statistics are based on partial dependence estimates and are thus as good or bad as these. One of their problems is that the model is applied to unseen/impossible feature combinations. In extreme cases, H-statistics intended to be in the range between 0 and 1 can become larger than 1. Accumulated local effects (ALE) [8] mend above problem of partial dependence estimates. They, however, depend on the notion of “closeness”, which is highly non-trivial in higher dimension and for discrete features.
  2. Due to their computational complexity of O(n2), where n is the number of rows considered, H-statistics are usually evaluated on relatively small subsets of the training (or validation/test) data. Consequently, the estimates are typically not very robust. To get more robust results, increase the default n_max = 500 of hstats().

Landscape

{hstats} is not the first R package to explore interactions. Here is an incomplete selection:

  • {gbm}: Implementation of m-wise interaction statistics of [1] for {gbm} models using the weighted tree-traversal method of [2] to estimate partial dependence functions.
  • {iml}: Implementation of overall and pairwise H-statistics.
  • {EIX}: Interaction statistics extracted from the tree structure of XGBoost and LightGBM.
  • {randomForestExplainer}: Interaction statistics extracted from the tree structure of random forests.
  • {vivid}: Cool visualization of interaction patterns. Partly based on {flashlight}.
  • {flashlight}: Model-agnostic implementation of some statistics of [1]. Planned to switch to the much faster {hstats}.

Installation

# From CRAN
install.packages("hstats")

# From Github
devtools::install_github("mayer79/hstats")

Usage

To demonstrate the typical workflow, we use a beautiful house price dataset with about 14,000 transactions from Miami-Dade County available in the {shapviz} package, and analyzed in [3]. We are going to model logarithmic sales prices with XGBoost.

Fit model

library(hstats)
library(ggplot2)
library(xgboost)
library(shapviz)

colnames(miami) <- tolower(colnames(miami))
miami$log_ocean <- log(miami$ocean_dist)
x <- c("log_ocean", "tot_lvg_area", "lnd_sqfoot", "structure_quality", "age", "month_sold")

# Train/valid split
set.seed(1)
ix <- sample(nrow(miami), 0.8 * nrow(miami))

y_train <- log(miami$sale_prc[ix])
y_valid <- log(miami$sale_prc[-ix])
X_train <- data.matrix(miami[ix, x])
X_valid <- data.matrix(miami[-ix, x])

dtrain <- xgb.DMatrix(X_train, label = y_train)
dvalid <- xgb.DMatrix(X_valid, label = y_valid)

# Fit via early stopping
fit <- xgb.train(
  params = list(learning_rate = 0.15, objective = "reg:squarederror", max_depth = 5),
  data = dtrain,
  watchlist = list(valid = dvalid),
  early_stopping_rounds = 20,
  nrounds = 1000,
  callbacks = list(cb.print.evaluation(period = 100))
)

# Mean squared error: 0.0515
average_loss(fit, X = X_valid, y = y_valid)

Interaction statistics

Let’s calculate different H-statistics via hstats():

# 4 seconds on simple laptop - a random forest will take 2 minutes
set.seed(782)
system.time(
  s <- hstats(fit, X = X_train)  #, approx = TRUE: twice as fast
)
s
# H^2 (normalized)
# [1] 0.10

plot(s)  # Or summary(s) for numeric output

# Save for later
# saveRDS(s, file = "h_statistics.rds")

Interpretation

  • H2: About 10% of prediction variability is unexplained by the sum of all main effects. The interaction effects seem to be important.
  • Hj2: The strongest overall interactions are associated with “log_ocean” (logarithmic distance to the ocean): About 6% of prediction variability can be attributed to its interactions.
  • Hjk2: About 8.5% of the joint effect variability of “log_ocean” and “age” comes from their pairwise interaction.

Remarks

  1. Pairwise statistics Hjk2 are calculated only for the features with strong overall interactions Hj2.
  2. H-statistics need to repeatedly calculate predictions on up to n2 rows. That is why {hstats} samples 500 rows by default. To get more robust results, increase this value at the price of slower run time.
  3. Pairwise statistics Hjk2 measures interaction strength relative to the combined effect of the two features. This does not necessarily show which interactions are strongest in absolute numbers. To do so, we can study unnormalized statistics:
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))

Since distance to the ocean and age have high values in overall interaction strength, it is not surprising that a strong relative pairwise interaction is translated into a strong absolute one.

Note: {hstats} can crunch three-way interaction statistics Hjkl2 as well. To calculate them for m features with strongest overall interactions, set threeway_m = m.

Describe interaction

Let’s study different plots to understand how the strong interaction between distance to the ocean and age looks like. We will check the following three visualizations.

  1. Stratified PDP
  2. Two-dimensional PDP (once as heatmap, once by representing the second variable on the color scale)
  3. Centered ICE plot with colors

They all reveal a substantial interaction between the two variables in the sense that the age effect gets weaker the closer to the ocean. Note that numeric BY features are automatically binned into quartile groups.

plot(partial_dep(fit, v = "age", X = X_train, BY = "log_ocean"), show_points = FALSE)

pd <- partial_dep(fit, v = c("age", "log_ocean"), X = X_train, grid_size = 1000)
plot(pd)
plot(pd, d2_geom = "line", show_points = FALSE)

ic <- ice(fit, v = "age", X = X_train, BY = "log_ocean")
plot(ic, center = TRUE)

Variable importance

In the spirit of [1], and related to [4], we can extract from the “hstats” objects a partial dependence based variable importance measure. It measures not only the main effect strength (see [4]), but also all its interaction effects. It is rather experimental, so use it with care (details in the section “Background”):

plot(pd_importance(s))

# Compared with four times repeated permutation importance regarding MSE
set.seed(10)
plot(perm_importance(fit, X = X_valid, y = y_valid))

Permutation importance returns the same order in this case:

DALEX

The main functions work smoothly on DALEX explainers:

library(hstats)
library(DALEX)
library(ranger)

set.seed(1)

fit <- ranger(Sepal.Length ~ ., data = iris)
ex <- DALEX::explain(fit, data = iris[, -1], y = iris[, 1])

s <- hstats(ex)
s  # 0.054
plot(s)

# Strongest relative interaction
plot(ice(ex, v = "Sepal.Width", BY = "Petal.Width"), center = TRUE)
plot(partial_dep(ex, v = "Sepal.Width", BY = "Petal.Width"), show_points = FALSE)
plot(partial_dep(ex, v = c("Sepal.Width", "Petal.Width"), grid_size = 200))

perm_importance(ex)

# Permutation importance
# Petal.Length  Petal.Width  Sepal.Width      Species 
#   0.59836442   0.11625137   0.07966910   0.03982554 

Strongest relative interaction shown as ICE plot.

Multivariate responses

{hstats} works also with multivariate output, see examples for probabilistic classification with

  • ranger,
  • LightGBM, and
  • XGBoost.

Common preparation

library(hstats)

ix <- c(1:40, 51:90, 101:140)
train <- iris[ix, ]
valid <- iris[-ix, ]

X_train <- data.matrix(train[, -5])
X_valid <- data.matrix(valid[, -5])
y_train <- train[[5]]
y_valid <- valid[[5]]

ranger

library(ranger)

set.seed(1)

fit <- ranger(Species ~ ., data = train, probability = TRUE)
average_loss(fit, X = valid, y = "Species", loss = "mlogloss")  # 0.02

perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")

(s <- hstats(fit, X = iris[, -5]))
plot(s, normalize = FALSE, squared = FALSE)

ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width") |> 
  plot(center = TRUE)

LightGBM

Note: Versions < 4.0.0 require passing reshape = TRUE to the prediction function.

library(lightgbm)

set.seed(1)

params <- list(objective = "multiclass", num_class = 3, learning_rate = 0.2)
dtrain <- lgb.Dataset(X_train, label = as.integer(y_train) - 1)
dvalid <- lgb.Dataset(X_valid, label = as.integer(y_valid) - 1)

fit <- lgb.train(
  params = params, 
  data = dtrain,
  valids = list(valid = dvalid),
  early_stopping_rounds = 20,
  nrounds = 1000
)

# mlogloss: 9.331699e-05
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss")

perm_importance(fit, X = X_valid, y = y_valid, loss = "mlogloss", m_rep = 100)
# Permutation importance regarding mlogloss
# Petal.Length  Petal.Width  Sepal.Width Sepal.Length 
#  2.624241332  1.011168660  0.082477177  0.009757393

partial_dep(fit, v = "Petal.Length", X = X_train) |> 
  plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train) |> 
  plot(alpha = 0.05)

# Interaction statistics, including three-way stats
(H <- hstats(fit, X = X_train, threeway_m = 4))  
# 0.3010446 0.4167927 0.1623982

plot(H, ncol = 1)

XGBoost

Mind the reshape = TRUE sent to the prediction function.

library(xgboost)

set.seed(1)

params <- list(objective = "multi:softprob", num_class = 3, learning_rate = 0.2)
dtrain <- xgb.DMatrix(X_train, label = as.integer(y_train) - 1)
dvalid <- xgb.DMatrix(X_valid, label = as.integer(y_valid) - 1)

fit <- xgb.train(
  params = params, 
  data = dtrain,
  watchlist = list(valid = dvalid),
  early_stopping_rounds = 20,
  nrounds = 1000
)

# We need to pass reshape = TRUE to get a beautiful matrix
predict(fit, head(X_train, 2), reshape = TRUE)
#           [,1]        [,2]         [,3]
# [1,] 0.9974016 0.002130089 0.0004682819
# [2,] 0.9971375 0.002129525 0.0007328897

# mlogloss: 0.006689544
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE)

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> 
  plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |> 
  plot(alpha = 0.05)

perm_importance(
  fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100
)
# Permutation importance regarding mlogloss
# Petal.Length  Petal.Width Sepal.Length  Sepal.Width 
#  1.731532873  0.276671377  0.009158659  0.005717263 

# Interaction statistics including three-way stats
(H <- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4))  
# 0.02714399 0.16067364 0.11606973

plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)

Non-probabilistic classification

When predictions are factor levels, {hstats} uses internal one-hot-encoding. Usually, probabilistic classification makes more sense though.

library(ranger)

set.seed(1)

fit <- ranger(Species ~ ., data = train)
average_loss(fit, X = valid, y = "Species", loss = "classification_error")  # 0

perm_importance(fit, X = valid, y = "Species", loss = "classification_error")
# Petal.Width Petal.Length Sepal.Length  Sepal.Width 
#       0.350        0.275        0.000        0.000 
       
(s <- hstats(fit, X = train[, -5]))
# H^2 (normalized)
#     setosa versicolor  virginica 
# 0.09885417 0.46151300 0.23238800 

plot(s, normalize = FALSE, squared = FALSE)

partial_dep(fit, v = "Petal.Length", X = train, BY = "Petal.Width") |> 
  plot(show_points = FALSE)

Meta-learning packages

Here, we provide examples for {tidymodels}, {caret}, and {mlr3}.

tidymodels

library(hstats)
library(tidymodels)

set.seed(1)

iris_recipe <- iris |> 
  recipe(Sepal.Length ~ .)

reg <- linear_reg() |>
  set_engine("lm")
  
iris_wf <- workflow() |>
  add_recipe(iris_recipe) |>
  add_model(reg)

fit <- iris_wf |>
  fit(iris)
  
s <- hstats(fit, X = iris[, -1])
s # 0 -> no interactions
plot(partial_dep(fit, v = "Petal.Width", X = iris))

imp <- perm_importance(fit, X = iris, y = "Sepal.Length")
imp
# Permutation importance
# Petal.Length      Species  Petal.Width  Sepal.Width 
#   4.39197781   0.35038891   0.11966090   0.09604322 

plot(imp)

caret

library(hstats)
library(caret)

set.seed(1)

fit <- train(
  Sepal.Length ~ ., 
  data = iris, 
  method = "lm", 
  tuneGrid = data.frame(intercept = TRUE),
  trControl = trainControl(method = "none")
)

h2(hstats(fit, X = iris[, -1]))  # 0

plot(ice(fit, v = "Petal.Width", X = iris), center = TRUE)
plot(perm_importance(fit, X = iris, y = "Sepal.Length"))

mlr3

library(hstats)
library(mlr3)
library(mlr3learners)

set.seed(1)

# Probabilistic classification
task_iris <- TaskClassif$new(id = "class", backend = iris, target = "Species")
fit_rf <- lrn("classif.ranger", predict_type = "prob")
fit_rf$train(task_iris)
s <- hstats(fit_rf, X = iris[, -5])
plot(s)

# Permutation importance (probabilistic using multi-logloss)
p <- perm_importance(
  fit_rf, X = iris, y = "Species", loss = "mlogloss", predict_type = "prob"
)
plot(p)

# Non-probabilistic using classification error
perm_importance(fit_rf, X = iris, y = "Species", loss = "classification_error") |> 
  plot()

Background

In [1], Friedman and Popescu introduced different statistics to measure interaction strength based on partial dependence functions. Closely following their notation, we will summarize the main ideas.

Partial dependence

Let F : Rp → R denote the prediction function that maps the p-dimensional feature vector x = (x1,…,xp) to its prediction. Furthermore, let Fs(xs) = Ex\s(F(xs,x\s)) be the partial dependence function of F on the feature subset xs, where s ⊆ {1, …, p}, as introduced in [2]. Here, the expectation runs over the joint marginal distribution of features x\s not in xs.

Given data, Fs(xs) can be estimated by the empirical partial dependence function

$$ \hat F_s(\boldsymbol x_s) = \frac{1}{n} \sum_{i = 1}^n F(\boldsymbol x_s, \boldsymbol x_{i \setminus s}), $$

where xi \ s, i = 1, …, n, are the observed values of x\s.

A partial dependence plot (PDP) plots the values of s(xs) over a grid of evaluation points xs. Its disaggregated version is called individual conditional expectation (ICE), see [7].

Overall interaction strength

If there are no interactions involving xj, we can decompose the prediction function F into the sum of the partial dependence Fj on xj and the partial dependence F\j on all other features x\j, i.e.,

F(x) = Fj(xj) + F\j(x\j).

Correspondingly, Friedman and Popescu’s statistic of overall interaction strength of xj is given by

$$ H_{j}^2 = \frac{\frac{1}{n} \sum_{i = 1}^n\big[F(\boldsymbol x_i) - \hat F_j(x_{ij}) - \hat F_{\setminus j}(\boldsymbol x_{i\setminus j})\big]^2}{\frac{1}{n} \sum_{i = 1}^n\big[F(\boldsymbol x_i)\big]^2}. $$

Remarks

  1. Partial dependence functions (and F) are all centered to mean 0.
  2. Partial dependence functions (and F) are evaluated over the data distribution. This is different to partial dependence plots, where one uses a fixed grid.
  3. Weighted versions follow by replacing all arithmetic means by corresponding weighted means.
  4. Multivariate predictions can be treated in a component-wise manner.
  5. Due to (typically undesired) extrapolation effects of partial dependence functions, depending on the model, values above 1 may occur.
  6. Hj2 = 0 means there are no interactions associated with xj. The higher the value, the more prediction variability comes from interactions with xj.
  7. Since the denominator is the same for all features, the values of the test statistics can be compared across features.

Pairwise interaction strength

Again following [1], if there are no interaction effects between features xj and xk, their two-dimensional partial dependence function Fjk can be written as the sum of the univariate partial dependencies, i.e.,

Fjk(xj,xk) = Fj(xj) + Fk(xk).

Correspondingly, Friedman and Popescu’s statistic of pairwise interaction strength between xj and xk is defined as

$$ H_{jk}^2 = \frac{A_{jk}}{\frac{1}{n} \sum_{i = 1}^n\big[\hat F_{jk}(x_{ij}, x_{ik})\big]^2} $$

where

$$ A_{jk} = \frac{1}{n} \sum_{i = 1}^n\big[\hat F_{jk}(x_{ij}, x_{ik}) - \hat F_j(x_{ij}) - \hat F_k(x_{ik})\big]^2. $$

Remarks

  1. Remarks 1 to 4 of Hj2 also apply here.
  2. Hjk2 = 0 means there are no interaction effects between xj and xk. The larger the value, the more of the joint effect of the two features comes from the interaction.
  3. Since the denominator differs between variable pairs, unlike Hj, this test statistic is difficult to compare between variable pairs. If both main effects are very weak, a negligible interaction can get a high Hjk2.

Modification

To be better able to compare pairwise interaction strength across variable pairs, and to overcome the problem mentioned in the last remark, we suggest as alternative the unnormalized test statistic on the scale of the predictions, i.e., $\sqrt{A_{jk}}$.

Furthermore, instead of focusing on pairwise calculations for the most important features, we can select features with strongest overall interactions.

Three-way interactions

[1] also describes a test statistic to measure three-way interactions: in case there are no three-way interactions between features xj, xk and xl, their three-dimensional partial dependence function Fjkl can be decomposed into lower order terms:

Fjkl(xj,xk,xl) = Bjkl − Cjkl

with

Bjkl = Fjk(xj,xk) + Fjl(xj,xl) + Fkl(xk,xl)

and

Cjkl = Fj(xj) + Fk(xk) + Fl(xl).

The squared and scaled difference between the two sides of the equation leads to the statistic

$$ H_{jkl}^2 = \frac{\frac{1}{n} \sum_{i = 1}^n \big[\hat F_{jkl}(x_{ij}, x_{ik}, x_{il}) - B^i_{jkl} + C^i_{jkl}\big]^2}{\frac{1}{n} \sum_{i = 1}^n \hat F_{jkl}(x_{ij}, x_{ik}, x_{il})^2}, $$

where

Bjkli = jk(xij,xik) + jl(xij,xil) + kl(xik,xil)

and

Cjkli = j(xij) + k(xik) + l(xil).

Similar remarks as for Hjk2 apply.

Total interaction strength of all variables together

If the model is additive in all features (no interactions), then

$$ F(\boldsymbol x) = \sum_{j}^{p} F_j(x_j), $$

i.e., the (centered) predictions can be written as the sum of the (centered) main effects.

To measure the relative amount of variability unexplained by all main effects, we can therefore study the test statistic of total interaction strength

$$ H^2 = \frac{\frac{1}{n} \sum_{i = 1}^n \left[F(\boldsymbol x_i) - \sum_{j = 1}^p\hat F_j(x_{ij})\right]^2}{\frac{1}{n} \sum_{i = 1}^n\left[F(\boldsymbol x_i)\right]^2}. $$

A value of 0 means there are no interaction effects at all. Due to (typically undesired) extrapolation effects of partial dependence functions, depending on the model, values above 1 may occur.

In [5], 1 − H2 is called additivity index. A similar measure using accumulated local effects is discussed in [6].

Workflow

Calculation of all Hj2 requires O(n2p) predictions, while calculating of all pairwise Hjk requires O(n2p2 predictions. Therefore, we suggest to reduce the workflow in two important ways:

  1. Evaluate the statistics only on a subset of the data, e.g., on n′ = 500 observations.
  2. Calculate Hj2 for all features. Then, select a small number $m = O(\sqrt{p})$ of features with highest Hj2 and do pairwise calculations only on this subset.

This leads to a total number of O(n2p) predictions. If also three-way interactions are to be studied, m should be of the order p1/3.

Variable importance (experimental)

[4] proposed the standard deviation of the partial dependence function as a measure of variable importance (for continuous predictors).

Since the partial dependence function suppresses interaction effects, we propose a different measure in the spirit of the interaction statistics above: If xj has no effects, the (centered) prediction function F equals the (centered) partial dependence F\j on all other features x\j, i.e.,

F(x) = F\j(x\j).

Therefore, the following measure of variable importance follows:

$$ PDI_{j} = \frac{\frac{1}{n} \sum_{i = 1}^n\big[F(\boldsymbol x_i) - \hat F_{\setminus j}(\boldsymbol x_{i\setminus j})\big]^2}{\frac{1}{n} \sum_{i = 1}^n\big[F(\boldsymbol x_i)\big]^2}. $$

It differs from Hj2 only by not subtracting the main effect of the j-th feature in the numerator. It can be read as the proportion of prediction variability unexplained by all other features. As such, it measures variable importance of the j-th feature, including its interaction effects.

References

  1. Friedman, Jerome H., and Bogdan E. Popescu. Predictive Learning via Rule Ensembles. The Annals of Applied Statistics 2, no. 3 (2008): 916-54.
  2. Friedman, Jerome H. Greedy Function Approximation: A Gradient Boosting Machine. Annals of Statistics 29, no. 5 (2001): 1189-1232.
  3. Mayer, Michael, Steven C. Bourassa, Martin Hoesli, and Donato Scognamiglio. Machine Learning Applications to Land and Structure Valuation.” Journal of Risk and Financial Management 15, no. 5 (2022): 193.
  4. Greenwell, Brandon M., Bradley C. Boehmke, and Andrew J. McCarthy. A Simple and Effective Model-Based Variable Importance Measure. Arxiv (2018).
  5. Żółkowski, Artur, Mateusz Krzyziński, and Paweł Fijałkowski. Methods for extraction of interactions from predictive models. Undergraduate thesis. Faculty of Mathematics and Information Science, Warsaw University of Technology (2023).
  6. Molnar, Christoph, Giuseppe Casalicchio, and Bernd Bischl”. Quantifying Model Complexity via Functional Decomposition for Better Post-hoc Interpretability, in Machine Learning and Knowledge Discovery in Databases, Springer International Publishing (2020): 193-204.
  7. Goldstein, Alex, Adam Kapelner, Justin Bleich, and Emil Pitkin. Peeking inside the black box: Visualizing statistical learning with plots of individual conditional expectation. Journal of Computational and Graphical Statistics, 24, no. 1 (2015): 44-65.
  8. Apley, Daniel W., and Jingyu Zhu, Visualizing the Effects of Predictor Variables in Black Box Supervised Learning Models, Arxiv (2016).