Comparison between ALEPlot and ale packages
Chitu Okoli
January 10, 2024
Source:vignettes/articles/ale-ALEPlot.Rmd
ale-ALEPlot.Rmd
The {ALEPlot}
package is the reference implementation of the brilliant idea of accumulated local effects
(ALE) by Daniel Apley and Jingyu Zhu. However, it has not been
updated since 2018. The {ale}
package
attempts to rewrite and extend the original base work.
In developing the ale package, we must ensure that we
correctly implement the original ALE algorithm while extending it.
Indeed, some permanent unit tests call {ALEPlot}
to make
sure that ale provides identical results for identical
inputs. We thought that presenting some of these comparisons as a
vignette might be helpful. We focus here on the examples
from the {ALEPlot}
package so that results are directly
comparable.
Other than its extensions for ALE-based
statistics, here are some of the main points in which
ale differs from {ALEPlot}
where it provides
otherwise similar functionality:
- It uses
ggplot2
instead of base R graphics. We considerggplot2
to be a more modern and versatile graphics system. - It saves plots as ggplot objects to a “plots” element of the return value; it does not automatically print the plot to the screen. As we show, this lets the user manipulate the plots more flexibly.
- In the plot, the Y outcome variable is displayed by default on its
full absolute scale, centred on the median or mean, not on a scale
relative to zero. (This option can be controlled with the
relative_y
argument to [plot()], as we demonstrate.) We believe that such plots more easily interpretable. - In addition, there are numerous design choices to simply the function interface based on tidyverse design principles.
One notable difference between the two packages is that the
ale package does not and will not implement partial
dependency plots (PDP). The package is focused exclusively on
accumulated local effects (ALE); users who need PDPs may use
{ALEPlot}
or other implementations.
In each section here, we cover an example from {ALEPlot}
and then reimplement it with ale
.
We begin by loading the necessary libraries.
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
Simulated data with numeric outcomes (ALEPlot Example 2)
We begin with the second code example directly from the
{ALEPlot}
package. (We skip the first example because it is
a subset of the second, simply without interactions.) Here is the code
from the example to create a simulated dataset and train a neural
network on it:
## R code for Example 2
## Load relevant packages
library(ALEPlot)
## Generate some data and fit a neural network supervised learning model
set.seed(0) # not in the original, but added for reproducibility
n = 5000
x1 <- runif(n, min = 0, max = 1)
x2 <- runif(n, min = 0, max = 1)
x3 <- runif(n, min = 0, max = 1)
x4 <- runif(n, min = 0, max = 1)
y = 4*x1 + 3.87*x2^2 + 2.97*exp(-5+10*x3)/(1+exp(-5+10*x3))+
13.86*(x1-0.5)*(x2-0.5)+ rnorm(n, 0, 1)
DAT <- data.frame(y, x1, x2, x3, x4)
nnet.DAT <- nnet::nnet(
y~., data = DAT, linout = T, skip = F, size = 6,
decay = 0.1, maxit = 1000, trace = F
)
For the demonstration, x1
has a linear relationship with
y
, x2
and x3
have non-linear
relationships, and x4
is a random variable with no
relationship with y
. x1
and x2
interact with each other in their relationship with y
.
ALEPlot code
To create ALE data and plots, {ALEPlot}
requires the
creation of a custom prediction function:
## Define the predictive function
yhat <- function(X.model, newdata) as.numeric(predict(X.model, newdata,
type = "raw"))
Now the {ALEPlot}
function can be called to create the
ALE data and plot it. The function returns a specially formatted list
with the ALE data; it can be saved for subsequent custom plotting.
## Calculate and plot the ALE main effects of x1, x2, x3, and x4
ALE.1 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 1, K = 500,
NA.plot = TRUE)
ALE.2 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 2, K = 500,
NA.plot = TRUE)
ALE.3 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 3, K = 500,
NA.plot = TRUE)
ALE.4 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 4, K = 500,
NA.plot = TRUE)
In the {ALEPlot}
implementation, calling the function
automatically prints a plot. While this provides some convenience if
that is what the user wants, it is not so convenient if the user does
not want to print a plot at the very point of ALE creation. It is
particularly inconvenient for script building. Although it is possible
to configure R to suspend graphic output before the
{ALEPlot}
is called and then restart it after the function
call, this is not so straightforward—the function itself does not give
any option to control this behaviour.
ALE interactions can also be calculated and plotted:
## Calculate and plot the ALE second-order effects of {x1, x2} and {x1, x4}
ALE.12 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = c(1,2), K = 100,
NA.plot = TRUE)
If the output of the {ALEPlot}
has been saved to
variables, then its contents can be plotted with finer user control
using the generic R plot
method:
## Manually plot the ALE main effects on the same scale for easier comparison
## of the relative importance of the four predictor variables
par(mfrow = c(3,2))
plot(ALE.1$x.values, ALE.1$f.values, type="l", xlab="x1",
ylab="ALE_main_x1", xlim = c(0,1), ylim = c(-2,2), main = "(a)")
plot(ALE.2$x.values, ALE.2$f.values, type="l", xlab="x2",
ylab="ALE_main_x2", xlim = c(0,1), ylim = c(-2,2), main = "(b)")
plot(ALE.3$x.values, ALE.3$f.values, type="l", xlab="x3",
ylab="ALE_main_x3", xlim = c(0,1), ylim = c(-2,2), main = "(c)")
plot(ALE.4$x.values, ALE.4$f.values, type="l", xlab="x4",
ylab="ALE_main_x4", xlim = c(0,1), ylim = c(-2,2), main = "(d)")
## Manually plot the ALE second-order effects of {x1, x2} and {x1, x4}
image(ALE.12$x.values[[1]], ALE.12$x.values[[2]], ALE.12$f.values, xlab = "x1",
ylab = "x2", main = "(e)")
contour(ALE.12$x.values[[1]], ALE.12$x.values[[2]], ALE.12$f.values, add=TRUE,
drawlabels=TRUE)
image(ALE.14$x.values[[1]], ALE.14$x.values[[2]], ALE.14$f.values, xlab = "x1",
ylab = "x4", main = "(f)")
contour(ALE.14$x.values[[1]], ALE.14$x.values[[2]], ALE.14$f.values, add=TRUE,
drawlabels=TRUE)
{ale}
package equivalent
Now we demonstrate the same functionality with the ale package. We will work with the same model on the same data, so we will not create them again.
Before starting, we recommend that you enable progress bars to see how long procedures will take. Simply run the following code at the beginning of your R session:
# Run this in an R console; it will not work directly within an R Markdown or Quarto block
progressr::handlers(global = TRUE)
progressr::handlers('cli')
If you forget to do that, the ale package will do it automatically for you with a notification message.
To create the model, we invoke the ale which returns a list with various ALE elements.
Here are some notable differences compared to
ALEPlot
:
- In tidyverse style, the first element is the data and the second is the model.
- Unlike
{ALEPlot}
that functions on only one variable at a time, ale generates ALE data for multiple variables in a dataset at once. By default, it generates ALE elements for all the predictor variables in the dataset that it is given; the user can specify a single variable or any subset of variables. We will cover more details in another vignette, but for our purposes here, we note thedata
element that returns a list of the ALE data for each variable and theplots
element returns a list ofggplot
plots. -
ale creates a default generic predict function that
matches most standard R models. When the prediction type is not the
default “response”, as in our case, the user can set the desired type
with the
pred_type
argument. However, for more complex or non-standard prediction functions, ale supports custom functions with thepred_fun
argument.
Since the plots are saved as a list, they can easily be printed out all at once:
# Print plots
nn_plots <- plot(nn_ale)
nn_1D_plots <- nn_plots$distinct$y$plots[[1]]
patchwork::wrap_plots(nn_1D_plots, ncol = 2)
The ale package plots have various features that enhance interpretability:
- The outcome y is displayed on its full original scale.
- A median band that shows the middle 5 percentile of the y values is displayed. The idea is that any ALE values outside this band are at least somewhat significant.
- Similarly, there are 25% and 75% percentile markers to show the middle 50% of the y values. Any ALE y value beyond these bands indicates that the x variable is so strong that it alone at the values indicated can shift the y value by that much.
- Rug plots indicate the distribution of the data so that outliers are not over-interpreted.
It might not be clear that the previous plots display exactly the
same data as those shown above from ALEPlot
. To make the
comparison clearer, we can plot the ALEs on a zero-centred scale:
# Zero-centred ALE plots
nn_plots_zero <- plot(nn_ale, relative_y = 'zero')
nn_1D_plots_zero <- nn_plots_zero$distinct$y$plots[[1]]
patchwork::wrap_plots(nn_1D_plots_zero)
With these zero-centred plots, the full range of y values and the rug plots give some context that aids interpretation. (If the rugs look slightly different, it is because they are randomly jittered to avoid overplotting.)
The ale also produces interaction plots; see the introductory vignette for details on how they are specified and created.
# Create and plot interactions
nn_ale_2D <- ale(DAT, nnet.DAT, pred_type = "raw", complete_d = 2)
# Print plots
nn_plots <- plot(nn_ale_2D)
nn_2D_plots <- nn_plots$distinct$y$plots[[2]]
nn_2D_plots |>
# extract list of x1 ALE outputs
purrr::walk(\(it.x1) {
# plot all x2 plots in each .x1 element
patchwork::wrap_plots(it.x1, ncol = 2, nrow = 2) |>
print()
})
These interaction plots are heat maps that indicate the interaction
regions that are above or below the average value of y with colours.
Grey indicates no meaningful interaction; blue indicates a positive
interaction effect; red indicates a negative effect. We find these
easier to interpret than the contour maps from ALEPlot
,
especially since the colours in each plot are on the same scale and so
the plots are directly comparable with each other.
The range of outcome (y) values is divided into quantiles, deciles by
default. However, the middle quantiles are modified. Rather than showing
the middle 10% or 20% of values, it is much narrow: it shows the middle
5%. (This value is based on the notion of alpha of 0.05 for confidence
intervals; it can be customized with the median_band_pct
argument.)
The legend shows the midpoint y value of each quantile, which is usually the mean of the boundaries of the quantile. The exception is the special middle quantile, whose displayed midpoint value is the median of the entire dataset.
The interpretation of these interaction plots is that in any given region, the interaction between x1 and x2 increases (blue) or decreases (red) y by the amount indicated over and above the separate individual direct effects of x1 and x2 shown in the one-way ALE plots above. It is not an indication of the total effect of both variables together but rather of the additional effect of their interaction- beyond their individual effects. Thus, only the x1-x2 interaction shows any effect. For the interactions with x3, even though x3 indeed has a strong effect on y as we see in the one-way ALE plot above, it has no additional effect in interaction with the other variables, and so its interaction plots are entirely grey.
Real data with binary outcomes (ALEPlot Example 3)
The next code example from the {ALEPlot}
package
analyzes a real dataset with a binary outcome variable. Whereas the
{ALEPlot}
has the user load a CSV file that might not be
readily available, we make that dataset available as the census dataset.
We load it here with the adjustments necessary to run the
{ALEPlot}
example.
## R code for Example 3
## Load relevant packages
library(ALEPlot)
library(gbm, quietly = TRUE)
#> Loaded gbm 2.2.2
#> This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
## Read data and fit a boosted tree supervised learning model
data(census, package = 'ale') # load ale package version of the data
data <-
census |>
as.data.frame() |> # ALEPlot is not compatible with the tibble format
select(age:native_country, higher_income) |> # Rearrange columns to match ALEPlot order
na.omit(data)
Although gradient boosted trees generally perform quite well, they
are rather slow. Rather than having you wait for it to run, the code
here downloads a pretrained GBM model. However, the code used to
generate it is provided in comments so that you can see it and run it
yourself if you want to. Note that the model calls is based on
data[,-c(3,4)]
, which drops the third and fourth variables
(fnlwgt
and education
, respectively).
# # To generate the code, uncomment the following lines.
# # But GBM training is slow, so this vignette loads a pre-created model object.
# set.seed(0)
# gbm.data <- gbm(higher_income ~ ., data= data[,-c(3,4)],
# distribution = "bernoulli", n.trees=6000, shrinkage=0.02,
# interaction.depth=3)
# saveRDS(gbm.data, file.choose())
gbm.data <- url('https://github.com/tripartio/ale/raw/main/download/gbm.data_model.rds') |>
readRDS()
gbm.data
#> gbm(formula = higher_income ~ ., distribution = "bernoulli",
#> data = data[, -c(3, 4)], n.trees = 6000, interaction.depth = 3,
#> shrinkage = 0.02)
#> A gradient boosted model with bernoulli loss function.
#> 6000 iterations were performed.
#> There were 12 predictors of which 12 had non-zero influence.
ALEPlot code
As before, we create a custom prediction function and then call the
{ALEPlot}
function to generate the plots. The prediction
type here is “link”, which represents the log odds in the
gbm
package.
Creation of the ALE plots here is rather slow because the
gbm
predict function is slow. In this example, only
age
, education_num
(number of years of
education), and hours_per_week
are plotted, along with the
interaction between age
and
hours_per_week
.
## Define the predictive function; note the additional arguments for the
## predict function in gbm
yhat <- function(X.model, newdata) as.numeric(predict(X.model, newdata,
n.trees = 6000, type="link"))
## Calculate and plot the ALE main and interaction effects for x_1, x_3,
## x_11, and {x_1, x_11}
par(mfrow = c(2,2), mar = c(4,4,2,2)+ 0.1)
ALE.1=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=1, K=500,
NA.plot = TRUE)
ALE.3=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=3, K=500,
NA.plot = TRUE)
ALE.11=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=11, K=500,
NA.plot = TRUE)
ALE.1and11=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=c(1,11),
K=50, NA.plot = FALSE)
{ale}
package equivalent
Here is the analogous code using the ale package. In
this case, we also need to define a custom predict function because of
the particular n.trees = 6000
argument. To speed things up,
we provide a pretrained ale object. This is possible
because ale returns objects with data and plots bundled
together with no side effects (like automatic printing of created
plots). (It is probably possible to similarly cache
{ALEPlot}
ALE objects, but it is not quite as
straightforward.)
Log odds
We display all the plots because it is easy to do so with the
ale package but we focus on age
,
education_num
, and hours_per_week
for
comparison with ALEPlot. If the shapes of these plots look different, it
is because ale tries as much as possible to display plots
on the same y-axis coordinate scale for easy comparison across
plots.
# Custom predict function that returns log odds
yhat <- function(object, newdata, type) {
predict(object, newdata, type='link', n.trees = 6000) |> # return log odds
as.numeric()
}
# Generate ALE data for all variables
# # To generate the code, uncomment the following lines.
# # But it is very slow because it calculates ALE for all variables,
# # so this vignette loads a pre-created model object.
# gbm_ale_link <- ale(
# # data[,-c(3,4)], gbm.data,
# data, gbm.data,
# pred_fun = yhat,
# max_num_bins = 500,
# sample_size = 600 # technical issue: sample_size must be > max_num_bins + 1
# )
# saveRDS(gbm_ale_link, file.choose())
gbm_ale_link <- url('https://github.com/tripartio/ale/raw/main/download/gbm_ale_link.rds') |>
readRDS()
# Print plots
gbm_link_plots <- plot(gbm_ale_link)
gbm_1D_link_plots <- gbm_link_plots$distinct$higher_income$plots[[1]]
patchwork::wrap_plots(gbm_1D_link_plots, ncol = 2)
Now we generate ALE data for all two-way interactions and then plot
them. Again, note the interaction between age
and
hours_per_week
. The interaction is minimal except for the
extremely high cases of hours per week.
# # To generate the code, uncomment the following lines.
# # But it is very slow because it calculates ALE for all variables,
# # so this vignette loads a pre-created model object.
# gbm_ale_2D_link <- ale(
# # data[,-c(3,4)], gbm.data,
# data, gbm.data,
# complete_d = 2,
# pred_fun = yhat,
# max_num_bins = 500,
# sample_size = 600 # technical issue: sample_size must be > max_num_bins + 1
# )
# saveRDS(gbm_ale_2D_link, file.choose())
gbm_ale_2D_link <- url('https://github.com/tripartio/ale/raw/main/download/gbm_ale_2D_link.rds') |>
readRDS()
# Print plots
gbm_link_plots <- plot(gbm_ale_2D_link)
gbm_link_2D_plots <- gbm_link_plots$distinct$higher_income$plots[[2]]
gbm_link_2D_plots |>
# extract list of x1 ALE outputs
purrr::walk(\(it.x1) {
# plot all x2 plots in each .x1 element
patchwork::wrap_plots(it.x1, ncol = 2) |>
print()
})
In some of these plots, we can see some white spots. These are the
interaction zones where there is no data in the dataset to calculate the
existence of an interaction. For example, let’s focus on the
interactions of age
with education_num
:
gbm_link_2D_plots$age$education_num
Here, the grey zones in the majority of the plot indicate that there are minimal interaction effects for most of the data range. However, in the small interacting zone of people younger than 30 years old or so with 14 to 16 years of education, we see that the likelihood of higher income is around 0.9 times lower than average. For the several white zones, there is no data in the dataset to support an estimate. For example, there is no one 35 to 45 years old with 15 years of education and there is no one 49 to 60 years old with 14 years of education; so, the model can say nothing about such interactions.
Predicted probabilities
Log odds are not necessarily the most interpretable way to express probabilities (though we will show shortly that they are sometimes uniquely valuable). So, we repeat the ALE creation using the “response” prediction type for probabilities and the default median centring of the plots.
As we can see, the shapes of the plots are similar, but the y axes are more easily interpretable as the probability (from 0 to 1) that a census respondent is in the higher income category. The median of around 10% or so indicates the median prediction of the GBM model: half of the respondents were predicted to have higher than a 10% likelihood of being higher income and half were predicted to have lower likelihood. The y-axis rug plots indicate that the predictions were generally rather extreme, either relatively close to 0 or 1, with few predictions in the middle.
# Custom predict function that returns predicted probabilities
yhat <- function(object, newdata, type) {
as.numeric(
predict(
object, newdata, n.trees = 6000,
type = "response" # return predicted probabilities
)
)
}
# Generate ALE data for all variables
# # To generate the code, uncomment the following lines.
# # But it is slow because it calculates ALE for all variables,
# # so this vignette loads a pre-created model object.
# gbm_ale_prob <- ale(
# # data[,-c(3,4)], gbm.data,
# data, gbm.data,
# pred_fun = yhat,
# max_num_bins = 500,
# sample_size = 600 # technical issue: sample_size must be > max_num_bins + 1
# )
# saveRDS(gbm_ale_prob, file.choose())
gbm_ale_prob <- url('https://github.com/tripartio/ale/raw/main/download/gbm_ale_prob.rds') |>
readRDS()
# Print plots
gbm_prob_plots <- plot(gbm_ale_prob)
gbm_1D_prob_plots <- gbm_prob_plots$distinct$higher_income$plots[[1]]
patchwork::wrap_plots(gbm_1D_prob_plots, ncol = 2)
Finally, we again generate two-way interactions, this time based on probabilities instead of on log odds. However, probabilities might not be the best choice for indicating interactions because, as we see from the rugs in the one-way ALE plots, the GBM model heavily concentrates its probabilities in the extremes near 0 and 1. Thus, the plots’ suggestions of strong interactions are likely exaggerated. In this case, the log odds ALEs shown above are probably more relevant.
# To generate the code, uncomment the following lines.
# # But it is slow because it calculates ALE for all variables,
# # so this vignette loads a pre-created model object.
# gbm_ale_2D_prob <- ale(
# # data[,-c(3,4)], gbm.data,
# data, gbm.data,
# complete_d = 2,
# pred_fun = yhat,
# max_num_bins = 500,
# sample_size = 600 # technical issue: sample_size must be > max_num_bins + 1
# )
# saveRDS(gbm_ale_2D_prob, file.choose())
gbm_ale_2D_prob <- url('https://github.com/tripartio/ale/raw/main/download/gbm_ale_2D_prob.rds') |>
readRDS()
# Print plots
gbm_prob_plots <- plot(gbm_ale_2D_prob)
gbm_prob_2D_plots <- gbm_prob_plots$distinct$higher_income$plots[[2]]
gbm_prob_2D_plots |>
# extract list of x1 ALE outputs
purrr::walk(\(it.x1) {
# plot all x2 plots in each .x1 element
patchwork::wrap_plots(it.x1, ncol = 2) |>
print()
})