R/light_breakdown.R
light_breakdown.Rd
Calculates sequential additive variable contributions (approximate SHAP) to the prediction of a single observation, see Gosiewska and Biecek (see reference) and the details below.
light_breakdown(x, ...)
# Default S3 method
light_breakdown(x, ...)
# S3 method for class 'flashlight'
light_breakdown(
x,
new_obs,
data = x$data,
by = x$by,
v = NULL,
visit_strategy = c("importance", "permutation", "v"),
n_max = Inf,
n_perm = 20,
seed = NULL,
use_linkinv = FALSE,
description = TRUE,
digits = 2,
...
)
# S3 method for class 'multiflashlight'
light_breakdown(x, ...)
An object of class "flashlight" or "multiflashlight".
Further arguments passed to prettyNum()
to format numbers
in description text.
One single new observation to calculate variable attribution for.
Needs to be a data.frame
of same structure as data
.
An optional data.frame
.
An optional vector of column names used to filter data
for rows with equal values in "by" variables as new_obs
.
Vector of variable names to assess contribution for. Defaults to all except those specified by "y", "w" and "by".
In what sequence should variables be visited?
By "importance", by n_perm
"permutation" or as "v" (see Details).
Maximum number of rows in data
to consider in the reference data.
Set to lower value if data
is large.
Number of permutations of random visit sequences.
Only used if visit_strategy = "permutation"
.
An integer random seed used to shuffle rows if n_max
is smaller than the number of rows in data
.
Should retransformation function be applied? Default is FALSE
.
Should descriptions be added? Default is TRUE
.
Passed to prettyNum()
to format numbers in description text.
An object of class "light_breakdown" with the following elements:
data
A tibble with results.
by
Same as input by
.
The breakdown algorithm works as follows: First, the visit order
\((x_1, ..., x_m)\) of the variables v
is specified.
Then, in the query data
, the column \(x_1\) is set to the value of \(x_1\)
of the single observation new_obs
to be explained.
The change in the (weighted) average prediction on data
measures the
contribution of \(x_1\) on the prediction of new_obs
.
This procedure is iterated over all \(x_i\) until eventually, all rows
in data
are identical to new_obs
.
A complication with this approach is that the visit order is relevant,
at least for non-additive models. Ideally, the algorithm could be repeated
for all possible permutations of v
and its results averaged per variable.
This is basically what SHAP values do, see the reference below for an explanation.
Unfortunately, there is no efficient way to do this in a model agnostic way.
We offer two visit strategies to approximate SHAP:
"importance": Using the short-cut described in the reference below: The variables are sorted by the size of their contribution in the same way as the breakdown algorithm but without iteration, i.e., starting from the original query data for each variable \(x_i\).
"permutation": Averages contributions from a small number of random permutations
of v
.
Note that the minimum required elements in the (multi-)flashlight are a
"predict_function", "model", and "data". The latter can also directly be passed to
light_breakdown()
. Note that by default, no retransformation function is applied.
light_breakdown(default)
: Default method not implemented yet.
light_breakdown(flashlight)
: Variable attribution to single observation
for a flashlight.
light_breakdown(multiflashlight)
: Variable attribution to single observation
for a multiflashlight.
A. Gosiewska and P. Biecek (2019). IBREAKDOWN: Uncertainty of model explanations for non-additive predictive models. ArXiv.
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris)
fl_part <- flashlight(
model = fit_part, label = "part", data = iris, y = "Sepal.Length"
)
plot(light_breakdown(fl_part, new_obs = iris[1, ]))
# Second model
fit_full <- lm(Sepal.Length ~ ., data = iris)
fl_full <- flashlight(
model = fit_full, label = "full", data = iris, y = "Sepal.Length"
)
fls <- multiflashlight(list(fl_part, fl_full))
plot(light_breakdown(fls, new_obs = iris[1, ]))