Main function of the package. Calculates exact permutation SHAP values with respect to a background dataset.
permshap(object, ...)
# S3 method for default
permshap(
object,
X,
bg_X,
pred_fun = stats::predict,
feature_names = colnames(X),
bg_w = NULL,
parallel = FALSE,
parallel_args = NULL,
verbose = TRUE,
...
)
# S3 method for ranger
permshap(
object,
X,
bg_X,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
feature_names = colnames(X),
bg_w = NULL,
parallel = FALSE,
parallel_args = NULL,
verbose = TRUE,
...
)
# S3 method for Learner
permshap(
object,
X,
bg_X,
pred_fun = NULL,
feature_names = colnames(X),
bg_w = NULL,
parallel = FALSE,
parallel_args = NULL,
verbose = TRUE,
...
)
Fitted model object.
Additional arguments passed to pred_fun(object, X, ...)
.
\((n \times p)\) matrix or data.frame
with rows to be explained.
The columns should only represent model features, not the response
(but see feature_names
on how to overrule this).
Background data used to integrate out "switched off" features,
often a subset of the training data (typically 50 to 500 rows)
It should contain the same columns as X
.
In cases with a natural "off" value (like MNIST digits),
this can also be a single row with all values set to the off value.
Prediction function of the form function(object, X, ...)
,
providing \(K \ge 1\) numeric predictions per row. Its first argument
represents the model object
, its second argument a data structure like X
.
Additional (named) arguments are passed via ...
.
The default, stats::predict()
, will work in most cases.
Optional vector of column names in X
used to calculate
SHAP values. By default, this equals colnames(X)
. Not supported if X
is a matrix.
Optional vector of case weights for each row of bg_X
.
If TRUE
, use parallel foreach::foreach()
to loop over rows
to be explained. Must register backend beforehand, e.g., via doFuture package,
see README for an example. Parallelization automatically disables the progress bar.
Named list of arguments passed to foreach::foreach()
.
Ideally, this is NULL
(default). Only relevant if parallel = TRUE
.
Example on Windows: if object
is a GAM fitted with package mgcv,
then one might need to set parallel_args = list(.packages = "mgcv")
.
Set to FALSE
to suppress the progress bar.
An object of class "permshap" with the following components:
S
: \((n \times p)\) matrix with SHAP values or, if the model output has
dimension \(K > 1\), a list of \(K\) such matrices.
X
: Same as input argument X
.
baseline
: Vector of length K representing the average prediction on the
background data.
permshap(default)
: Default permutation SHAP method.
permshap(ranger)
: Permutation SHAP method for "ranger" models, see Readme for an example.
permshap(Learner)
: Permutation SHAP method for "mlr3" models, see Readme for an example.
# MODEL ONE: Linear regression
fit <- lm(Sepal.Length ~ ., data = iris)
# Select rows to explain (only feature columns)
X_explain <- iris[1:2, -1]
# Select small background dataset (could use all rows here because iris is small)
set.seed(1)
bg_X <- iris[sample(nrow(iris), 100), ]
# Calculate SHAP values
s <- permshap(fit, X_explain, bg_X = bg_X)
#> Exact permutation SHAP values
#>
|
| | 0%
|
|=================================== | 50%
|
|======================================================================| 100%
s
#> SHAP values of first 2 observations:
#> Sepal.Width Petal.Length Petal.Width Species
#> [1,] 0.21571169 -1.981893 0.3157855 0.5825284
#> [2,] -0.03223278 -1.981893 0.3157855 0.5825284
# MODEL TWO: Multi-response linear regression
fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
s <- permshap(fit, iris[1:4, 3:5], bg_X = bg_X)
#> Exact permutation SHAP values
#>
|
| | 0%
|
|================== | 25%
|
|=================================== | 50%
|
|==================================================== | 75%
|
|======================================================================| 100%
s
#> SHAP values of first 2 observations:
#> $Sepal.Length
#> Petal.Length Petal.Width Species
#> [1,] -2.165211 0.006007392 1.234918
#> [2,] -2.165211 0.006007392 1.234918
#>
#> $Sepal.Width
#> Petal.Length Petal.Width Species
#> [1,] -0.3696749 -0.6246925 1.315597
#> [2,] -0.3696749 -0.6246925 1.315597
#>
# Non-feature columns can be dropped via 'feature_names'
s <- permshap(
fit,
iris[1:4, ],
bg_X = bg_X,
feature_names = c("Petal.Length", "Petal.Width", "Species")
)
#> Exact permutation SHAP values
#>
|
| | 0%
|
|================== | 25%
|
|=================================== | 50%
|
|==================================================== | 75%
|
|======================================================================| 100%
s
#> SHAP values of first 2 observations:
#> $Sepal.Length
#> Petal.Length Petal.Width Species
#> [1,] -2.165211 0.006007392 1.234918
#> [2,] -2.165211 0.006007392 1.234918
#>
#> $Sepal.Width
#> Petal.Length Petal.Width Species
#> [1,] -0.3696749 -0.6246925 1.315597
#> [2,] -0.3696749 -0.6246925 1.315597
#>