TL;DR, you can jump straight into the visuals and
application with cheem::run_app()
, but we suggest you read
the introduction to get situated with the context first.
Non-linear models regularly result in more accurate prediction than their linear counterparts. However, the number and complexity of their terms make them more opaque to the interpretability. The our ability to understand how features (variables or predictors) influence predictions is important to a wide range of audiences. Attempts to bring interpretability to such complex models is an important aspect of eXplainable Artificial Intelligence (XAI).
Local explanations are one such tool used in XAI. They attempt to approximate the feature importance in the vicinity of one instance (observation). That is to say that they give an approximation of linear terms at the position of one in-sample or out-of-sample observation.
If the analyst can explore how models lead to bad predictions it can suggest insight into issues of the data or suggest models that may be more robust to misclassified or extreme residuals. An analyst may want to explore the support feature contributions where the explanations makes sense or may be completely unreliable. We purpose this sort of analysis as conducted with interactive graphics in the analysis and R package titled cheem.
This framework is broadly applicable for any model and compatible local explanation. We will illustrate with xgboost::xgboost() model (xgb) and the tree SHAP local explanation with shapviz::shapviz(). The model attempts to predict housing sales price from 11 predictors for 338 sale events from one neighborhood in the 2018 Ames data.
The first things we need are the prediction and a local explanation (or other embedded space). Here we create a xgb model, create predictions, and find the SHAP values of each observation.
## Download if not installed
if(!require(cheem)) install.packages("cheem", dependencies = TRUE)
if(!require(treeshap)) install.packages("treeshap", dependencies = TRUE)
if(!require(shapviz)) install.packages("shapviz", dependencies = TRUE)
## Load onto session
library(cheem)
library(xgboost)
library(shapviz)
## Setup
X <- amesHousing2018_NorthAmes[, 1:9]
Y <- amesHousing2018_NorthAmes$SalePrice
clas <- amesHousing2018_NorthAmes$SubclassMS
## Model and predict
ames_train <- data.matrix(X) %>% xgb.DMatrix(label = Y)
ames_xgb_fit <- xgboost(data = ames_train, max.depth = 3, nrounds = 25)
ames_xgb_pred <- predict(ames_xgb_fit, newdata = ames_train)
ames_xgb_pred %>% head()
## SHAP values
shp <- shapviz(ames_xgb_fit, X_pred = ames_train, X = X)
## Keep just the [n, p] local explanations
ames_xgb_shap <- shp$S
ames_xgb_shap %>% head()
Note that the choice of the model, prediction, and local explanation
(or other embedding) is choice of the analyst and not facilitated by
cheem. Now let’s prepare for the visualization of these
spaces with a cheem::cheem_ls()
call before we start our
analysis.
We have extracted tree SHAP, an feature importance measure in the vicinity of each observation. We need to identify an instance of interest to explore; we do so with the linked brushing available in the global view. Then we will vary contributions from different features to test the support an explanation in a radial tour
To get more complete view lets look at approximations of the data space, attribution space, and model fits side-by-side with linked brushing with the help of plotly and crosstalk. We have identified an observation with a large Mahalanobis distance (in data space) and the closest neighbor in attribution space.
prim <- 1
comp <- 17
global_view(ames_chm, primary_obs = prim, comparison_obs = comp,
height_px = 240, width_px = 720,
as_ggplot = TRUE, color = "log_maha.data")
From this global view we want to identify a primary instance (PI) and optionally a comparison instance (CI) to explore. Misclassified or observations with high residuals are good targets for further exploration. One point sticks out in this case. Instance 243 (shown as *) is a Gentoo (purple) penguin, while the model predict it to be a Chinstrap penguin. Penguin 169 (shown as x) is reasonably close by and correctly predicted as Gentoo. In practice we used linked brushing and misclassification information to guide our search.
There is a lot to unpack here. The normalized distribution of all
feature attribution from all instances are shown as parallel coordinates
lines. The above selected PI and CI are shown here as a dashed and
dotted line respectively. The first thing we notice is that the
attribution of the PI is close to it’s (incorrect) prediction of
Chinstrap (orange) in terms of bill length (bl
) and flipper
length (fl
). In terms of bill depth and body mass
(bd
and bm
) it is more like its observed
species Gentoo (purple). We select flipper length as the feature to
manipulate.
## Normalized attribution basis of the PI
bas <- sug_basis(ames_xgb_shap, rownum = prim)
## Default feature to manipulate:
#### the feature with largest separation between PI and CI attribution
mv <- sug_manip_var(
ames_xgb_shap, primary_obs = prim, comparison_obs = comp)
## Make the radial tour
ggt <- radial_cheem_tour(
ames_chm, basis = bas, manip_var = mv,
primary_obs = prim, comparison_obs = comp, angle = .15)
## Animate it
animate_gganimate(ggt, fps = 6)
#height = 2, width = 4.5, units = "in", res = 150
## Or as a plotly html widget
#animate_plotly(ggt, fps = 6)
Starting from the attribution projection, this instance already looks more like its observed Gentoo than predicted Chinstrap. However, by frame 8, the basis has a full contribution of flipper length and does look more like the predicted Chinstrap. Looking at the parallel coordinate lines on the basis visual we can see that flipper length has a large gap between PI and CI, lets check the original variables to digest.
library(ggplot2)
prim <- 1
ggplot(penguins_na.rm, aes(x = bill_length_mm,
y = flipper_length_mm,
colour = species,
shape = species)) +
geom_point() +
## Highlight PI, *
geom_point(data = penguins_na.rm[prim, ],
shape = 8, size = 5, alpha = 0.8) +
## Theme, scaling, color, and labels
theme_bw() +
theme(aspect.ratio = 1) +
scale_color_brewer(palette = "Dark2") +
labs(y = "Flipper length [mm]", x = "Bill length [mm]",
color = "Observed species", shape = "Observed species")
This profile, with two features that are most distinguished between the PI and CI. This instance is nested in the in between the Chinstrap penguins. That makes this instance particularly hard for a random forest model to classify as decision tree can only make partition on one value (horizontal and vertical lines here).
We provide an interactive shiny application.
Interactive features are made possible with plotly,
crosstalk, and DT. We have
preprocessed simulated and modern datasets for you to explore this
analysis with. Alternatively, bring your own data by saving the return
of cheem_ls()
as an rds file. Follow along with the example
in ?cheem_ls
.
Interpretability of black-box models is important to maintain. Local explanation extend this interpretability by approximating the feature importance in the vicinity of one instance. We purpose post-hoc analysis of these local explanations. First we explore them in a global, full instance context. Then we explore the support of the local explanation to see where it seems plausible or unreliable.
cheem is agnostic to model or local explanation, but requires a model and local explanation. Above we illustrated using a random forest to predict penguin species. Below demonstrates using other attribution spaces from different models.
shapviz is being actively maintained and is hosted on CRAN. It is compatible with H2O, lgb, and xgb models.
https://github.com/ModelOriented/shapviz
if(!require(shapviz)) install.packages("shapviz")
if(!require(xgboost)) install.packages("xgboost")
library(shapviz)
library(xgboost)
set.seed(3653)
## Setup
X <- spinifex::penguins_na.rm[, 1:4]
Y <- spinifex::penguins_na.rm$species
clas <- spinifex::penguins_na.rm$species
## Model and predict
peng_train <- data.matrix(X) %>%
xgb.DMatrix(label = Y)
peng_xgb_fit <- xgboost(data = peng_train, max.depth = 3, nrounds = 25)
peng_xgb_pred <- predict(peng_xgb_fit, newdata = peng_train)
## SHAP
peng_xgb_shap <- shapviz(peng_xgb_fit, X_pred = peng_train, X = X)
## Keep just the [n, p] local explanations
peng_xgb_shap <- peng_xgb_shap$S
treeshap is only available on CRAN. It is compatible with many tree-based models including gbm, lbm, rf, ranger, and xgb models.
https://github.com/ModelOriented/treeshap
if(!require(treeshap)) install.packages("treeshap")
if(!require(randomForest)) install.packages("randomForest")
library(treeshap)
library(randomForest)
## Setup
X <- spinifex::wine[, -1:2]
Y <- spinifex::wine$Alcohol
clas <- spinifex::wine$Type
## Fit randomForest::randomForest
wine_rf_fit <- randomForest::randomForest(
X, Y, ntree = 125,
mtry = ifelse(is_discrete(Y), sqrt(ncol(X)), ncol(X) / 3),
nodesize = max(ifelse(is_discrete(Y), 1, 5), nrow(X) / 500))
wine_rf_pred <- predict(wine_rf_fit)
## treeshap::treeshap()
wine_rf_tshap <- wine_rf_fit %>%
treeshap::randomForest.unify(X) %>%
treeshap::treeshap(X, interactions = FALSE, verbose = FALSE)
## Keep just the [n, p] local explanations
wine_rf_tshap <- wine_rf_tshap$shaps
DALEX is a popular and versatile XAI package available on CRAN. It is compatible with many models, but it uses the original, slower variant of SHAP local explanation. Expect long run times for sizable data or complex models.
https://ema.drwhy.ai/shapley.html#SHAPRcode
if(!require(DALEX)) install.packages("DALEX")
library(DALEX)
## Setup
X <- dragons[, c(1:4, 6)]
Y <- dragons$life_length
clas <- dragons$colour
## Model and predict
drag_lm_fit <- lm(data = data.frame(Y, X), Y ~ .)
drag_lm_pred <- predict(drag_lm_fit)
## SHAP via DALEX, versatile but slow
drag_lm_exp <- explain(drag_lm_fit, data = X, y = Y,
label = "Dragons, LM, SHAP")
## DALEX::predict_parts_shap is flexible, but slow and one row at a time
drag_lm_shap <- matrix(NA, nrow(X), ncol(X))
sapply(1:nrow(X), function(i){
pps <- predict_parts_shap(drag_lm_exp, new_observation = X[i, ])
## Keep just the [n, p] local explanations
drag_lm_shap[i, ] <<- tapply(
pps$contribution, pps$variable, mean, na.rm = TRUE) %>% as.vector()
})
drag_lm_shap <- as.data.frame(drag_lm_shap)