Main function of the package. Calculates exact permutation SHAP values with respect to a background dataset.

permshap(object, ...)

# S3 method for default
permshap(
  object,
  X,
  bg_X,
  pred_fun = stats::predict,
  feature_names = colnames(X),
  bg_w = NULL,
  parallel = FALSE,
  parallel_args = NULL,
  verbose = TRUE,
  ...
)

# S3 method for ranger
permshap(
  object,
  X,
  bg_X,
  pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
  feature_names = colnames(X),
  bg_w = NULL,
  parallel = FALSE,
  parallel_args = NULL,
  verbose = TRUE,
  ...
)

# S3 method for Learner
permshap(
  object,
  X,
  bg_X,
  pred_fun = NULL,
  feature_names = colnames(X),
  bg_w = NULL,
  parallel = FALSE,
  parallel_args = NULL,
  verbose = TRUE,
  ...
)

Arguments

object

Fitted model object.

...

Additional arguments passed to pred_fun(object, X, ...).

X

\((n \times p)\) matrix or data.frame with rows to be explained. The columns should only represent model features, not the response (but see feature_names on how to overrule this).

bg_X

Background data used to integrate out "switched off" features, often a subset of the training data (typically 50 to 500 rows) It should contain the same columns as X. In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value.

pred_fun

Prediction function of the form function(object, X, ...), providing \(K \ge 1\) numeric predictions per row. Its first argument represents the model object, its second argument a data structure like X. Additional (named) arguments are passed via .... The default, stats::predict(), will work in most cases.

feature_names

Optional vector of column names in X used to calculate SHAP values. By default, this equals colnames(X). Not supported if X is a matrix.

bg_w

Optional vector of case weights for each row of bg_X.

parallel

If TRUE, use parallel foreach::foreach() to loop over rows to be explained. Must register backend beforehand, e.g., via doFuture package, see README for an example. Parallelization automatically disables the progress bar.

parallel_args

Named list of arguments passed to foreach::foreach(). Ideally, this is NULL (default). Only relevant if parallel = TRUE. Example on Windows: if object is a GAM fitted with package mgcv, then one might need to set parallel_args = list(.packages = "mgcv").

verbose

Set to FALSE to suppress the progress bar.

Value

An object of class "permshap" with the following components:

  • S: \((n \times p)\) matrix with SHAP values or, if the model output has dimension \(K > 1\), a list of \(K\) such matrices.

  • X: Same as input argument X.

  • baseline: Vector of length K representing the average prediction on the background data.

Methods (by class)

  • permshap(default): Default permutation SHAP method.

  • permshap(ranger): Permutation SHAP method for "ranger" models, see Readme for an example.

  • permshap(Learner): Permutation SHAP method for "mlr3" models, see Readme for an example.

Examples

# MODEL ONE: Linear regression
fit <- lm(Sepal.Length ~ ., data = iris)

# Select rows to explain (only feature columns)
X_explain <- iris[1:2, -1]

# Select small background dataset (could use all rows here because iris is small)
set.seed(1)
bg_X <- iris[sample(nrow(iris), 100), ]

# Calculate SHAP values
s <- permshap(fit, X_explain, bg_X = bg_X)
#> Exact permutation SHAP values
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
s
#> SHAP values of first 2 observations:
#>      Sepal.Width Petal.Length Petal.Width   Species
#> [1,]  0.21571169    -1.981893   0.3157855 0.5825284
#> [2,] -0.03223278    -1.981893   0.3157855 0.5825284

# MODEL TWO: Multi-response linear regression
fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
s <- permshap(fit, iris[1:4, 3:5], bg_X = bg_X)
#> Exact permutation SHAP values
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================================| 100%
s
#> SHAP values of first 2 observations:
#> $Sepal.Length
#>      Petal.Length Petal.Width  Species
#> [1,]    -2.165211 0.006007392 1.234918
#> [2,]    -2.165211 0.006007392 1.234918
#> 
#> $Sepal.Width
#>      Petal.Length Petal.Width  Species
#> [1,]   -0.3696749  -0.6246925 1.315597
#> [2,]   -0.3696749  -0.6246925 1.315597
#> 

# Non-feature columns can be dropped via 'feature_names'
s <- permshap(
  fit,
  iris[1:4, ],
  bg_X = bg_X,
  feature_names = c("Petal.Length", "Petal.Width", "Species")
)
#> Exact permutation SHAP values
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================================| 100%
s
#> SHAP values of first 2 observations:
#> $Sepal.Length
#>      Petal.Length Petal.Width  Species
#> [1,]    -2.165211 0.006007392 1.234918
#> [2,]    -2.165211 0.006007392 1.234918
#> 
#> $Sepal.Width
#>      Petal.Length Petal.Width  Species
#> [1,]   -0.3696749  -0.6246925 1.315597
#> [2,]   -0.3696749  -0.6246925 1.315597
#>