Skip to contents

The {ALEPlot} package is the reference implementation of the brilliant idea of accumulated local effects (ALE) by Daniel Apley and Jingyu Zhu. However, it has not been updated since 2018. The {ale} package attempts to rewrite and extend the original base work.

In developing the ale package, we must ensure that we correctly implement the original ALE algorithm while extending it. Indeed, some permanent unit tests call {ALEPlot} to make sure that ale provides identical results for identical inputs. We thought that presenting some of these comparisons as a vignette might be helpful. We focus here on the examples from the {ALEPlot} package so that results are directly comparable.

Other than its extensions for ALE-based statistics, here are some of the main points in which ale differs from {ALEPlot} where it provides otherwise similar functionality:

  • It uses ggplot2 instead of base R graphics. We consider ggplot2 to be a more modern and versatile graphics system.
  • It saves plots as ggplot objects to a “plots” element of the return value; it does not automatically print the plot to the screen. As we show, this lets the user manipulate the plots more flexibly.
  • In the plot, the Y outcome variable is displayed by default on its full absolute scale, centred on the median or mean, not on a scale relative to zero. (This option can be controlled with the relative_y argument to [plot()], as we demonstrate.) We believe that such plots more easily interpretable.
  • In addition, there are numerous design choices to simply the function interface based on tidyverse design principles.

One notable difference between the two packages is that the ale package does not and will not implement partial dependency plots (PDP). The package is focused exclusively on accumulated local effects (ALE); users who need PDPs may use {ALEPlot} or other implementations.

In each section here, we cover an example from {ALEPlot} and then reimplement it with ale.

We begin by loading the necessary libraries.

library(dplyr)
#> Error in get(paste0(generic, ".", class), envir = get_method_env()) : 
#>   object 'type_sum.accel' not found
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

Simulated data with numeric outcomes (ALEPlot Example 2)

We begin with the second code example directly from the {ALEPlot} package. (We skip the first example because it is a subset of the second, simply without interactions.) Here is the code from the example to create a simulated dataset and train a neural network on it:

## R code for Example 2
## Load relevant packages
library(ALEPlot)

## Generate some data and fit a neural network supervised learning model
set.seed(0)  # not in the original, but added for reproducibility
n = 5000
x1 <- runif(n, min = 0, max = 1)
x2 <- runif(n, min = 0, max = 1)
x3 <- runif(n, min = 0, max = 1)
x4 <- runif(n, min = 0, max = 1)
y = 4*x1 + 3.87*x2^2 + 2.97*exp(-5+10*x3)/(1+exp(-5+10*x3))+
13.86*(x1-0.5)*(x2-0.5)+ rnorm(n, 0, 1)
DAT <- data.frame(y, x1, x2, x3, x4)
nnet.DAT <- nnet::nnet(
  y~., data = DAT, linout = T, skip = F, size = 6,
  decay = 0.1, maxit = 1000, trace = F
)

For the demonstration, x1 has a linear relationship with y, x2 and x3 have non-linear relationships, and x4 is a random variable with no relationship with y. x1 and x2 interact with each other in their relationship with y.

ALEPlot code

To create ALE data and plots, {ALEPlot} requires the creation of a custom prediction function:

## Define the predictive function
yhat <- function(X.model, newdata) as.numeric(predict(X.model, newdata,
type = "raw"))

Now the {ALEPlot} function can be called to create the ALE data and plot it. The function returns a specially formatted list with the ALE data; it can be saved for subsequent custom plotting.

## Calculate and plot the ALE main effects of x1, x2, x3, and x4
ALE.1 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 1, K = 500,
NA.plot = TRUE)

ALE.2 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 2, K = 500,
NA.plot = TRUE)

ALE.3 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 3, K = 500,
NA.plot = TRUE)

ALE.4 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = 4, K = 500,
NA.plot = TRUE)

In the {ALEPlot} implementation, calling the function automatically prints a plot. While this provides some convenience if that is what the user wants, it is not so convenient if the user does not want to print a plot at the very point of ALE creation. It is particularly inconvenient for script building. Although it is possible to configure R to suspend graphic output before the {ALEPlot} is called and then restart it after the function call, this is not so straightforward—the function itself does not give any option to control this behaviour.

ALE interactions can also be calculated and plotted:

## Calculate and plot the ALE second-order effects of {x1, x2} and {x1, x4}
ALE.12 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = c(1,2), K = 100,
NA.plot = TRUE)

ALE.14 = ALEPlot(DAT[,2:5], nnet.DAT, pred.fun = yhat, J = c(1,4), K = 100,
NA.plot = TRUE)

If the output of the {ALEPlot} has been saved to variables, then its contents can be plotted with finer user control using the generic R plot method:

## Manually plot the ALE main effects on the same scale for easier comparison
## of the relative importance of the four predictor variables
par(mfrow = c(3,2))
plot(ALE.1$x.values, ALE.1$f.values, type="l", xlab="x1",
ylab="ALE_main_x1", xlim = c(0,1), ylim = c(-2,2), main = "(a)")
plot(ALE.2$x.values, ALE.2$f.values, type="l", xlab="x2",
ylab="ALE_main_x2", xlim = c(0,1), ylim = c(-2,2), main = "(b)")
plot(ALE.3$x.values, ALE.3$f.values, type="l", xlab="x3",
ylab="ALE_main_x3", xlim = c(0,1), ylim = c(-2,2), main = "(c)")
plot(ALE.4$x.values, ALE.4$f.values, type="l", xlab="x4",
ylab="ALE_main_x4", xlim = c(0,1), ylim = c(-2,2), main = "(d)")
## Manually plot the ALE second-order effects of {x1, x2} and {x1, x4}
image(ALE.12$x.values[[1]], ALE.12$x.values[[2]], ALE.12$f.values, xlab = "x1",
ylab = "x2", main = "(e)")
contour(ALE.12$x.values[[1]], ALE.12$x.values[[2]], ALE.12$f.values, add=TRUE,
drawlabels=TRUE)
image(ALE.14$x.values[[1]], ALE.14$x.values[[2]], ALE.14$f.values, xlab = "x1",
ylab = "x4", main = "(f)")
contour(ALE.14$x.values[[1]], ALE.14$x.values[[2]], ALE.14$f.values, add=TRUE,
drawlabels=TRUE)

{ale} package equivalent

Now we demonstrate the same functionality with the ale package. We will work with the same model on the same data, so we will not create them again.

Before starting, we recommend that you enable progress bars to see how long procedures will take. Simply run the following code at the beginning of your R session:

# Run this in an R console; it will not work directly within an R Markdown or Quarto block
progressr::handlers(global = TRUE)
progressr::handlers('cli')

If you forget to do that, the ale package will do it automatically for you with a notification message.

To create the model, we invoke the ale which returns a list with various ALE elements.

library(ale)

nn_ale <- ale(DAT, nnet.DAT, pred_type = "raw")

Here are some notable differences compared to ALEPlot:

  • In tidyverse style, the first element is the data and the second is the model.
  • Unlike {ALEPlot} that functions on only one variable at a time, ale generates ALE data for multiple variables in a dataset at once. By default, it generates ALE elements for all the predictor variables in the dataset that it is given; the user can specify a single variable or any subset of variables. We will cover more details in another vignette, but for our purposes here, we note the data element that returns a list of the ALE data for each variable and the plots element returns a list of ggplot plots.
  • ale creates a default generic predict function that matches most standard R models. When the prediction type is not the default “response”, as in our case, the user can set the desired type with the pred_type argument. However, for more complex or non-standard prediction functions, ale supports custom functions with the pred_fun argument.

Since the plots are saved as a list, they can easily be printed out all at once:

# Print plots
nn_plots <- plot(nn_ale)
nn_1D_plots <- nn_plots$distinct$y$plots[[1]]
patchwork::wrap_plots(nn_1D_plots, ncol = 2)

The ale package plots have various features that enhance interpretability:

  • The outcome y is displayed on its full original scale.
  • A median band that shows the middle 5 percentile of the y values is displayed. The idea is that any ALE values outside this band are at least somewhat significant.
  • Similarly, there are 25% and 75% percentile markers to show the middle 50% of the y values. Any ALE y value beyond these bands indicates that the x variable is so strong that it alone at the values indicated can shift the y value by that much.
  • Rug plots indicate the distribution of the data so that outliers are not over-interpreted.

It might not be clear that the previous plots display exactly the same data as those shown above from ALEPlot. To make the comparison clearer, we can plot the ALEs on a zero-centred scale:

# Zero-centred ALE plots
nn_plots_zero <- plot(nn_ale, relative_y = 'zero')
nn_1D_plots_zero <- nn_plots_zero$distinct$y$plots[[1]]
patchwork::wrap_plots(nn_1D_plots_zero)

With these zero-centred plots, the full range of y values and the rug plots give some context that aids interpretation. (If the rugs look slightly different, it is because they are randomly jittered to avoid overplotting.)

The ale also produces interaction plots; see the introductory vignette for details on how they are specified and created.

# Create and plot interactions
nn_ale_2D <- ale(DAT, nnet.DAT, pred_type = "raw", complete_d = 2)

# Print plots
nn_plots <- plot(nn_ale_2D)
nn_2D_plots <- nn_plots$distinct$y$plots[[2]]
nn_2D_plots |>
  # extract list of x1 ALE outputs
  purrr::walk(\(it.x1) {  
    # plot all x2 plots in each .x1 element
    patchwork::wrap_plots(it.x1, ncol = 2, nrow = 2) |> 
      print()
  })

These interaction plots are heat maps that indicate the interaction regions that are above or below the average value of y with colours. Grey indicates no meaningful interaction; blue indicates a positive interaction effect; red indicates a negative effect. We find these easier to interpret than the contour maps from ALEPlot, especially since the colours in each plot are on the same scale and so the plots are directly comparable with each other.

The range of outcome (y) values is divided into quantiles, deciles by default. However, the middle quantiles are modified. Rather than showing the middle 10% or 20% of values, it is much narrow: it shows the middle 5%. (This value is based on the notion of alpha of 0.05 for confidence intervals; it can be customized with the median_band_pct argument.)

The legend shows the midpoint y value of each quantile, which is usually the mean of the boundaries of the quantile. The exception is the special middle quantile, whose displayed midpoint value is the median of the entire dataset.

The interpretation of these interaction plots is that in any given region, the interaction between x1 and x2 increases (blue) or decreases (red) y by the amount indicated over and above the separate individual direct effects of x1 and x2 shown in the one-way ALE plots above. It is not an indication of the total effect of both variables together but rather of the additional effect of their interaction- beyond their individual effects. Thus, only the x1-x2 interaction shows any effect. For the interactions with x3, even though x3 indeed has a strong effect on y as we see in the one-way ALE plot above, it has no additional effect in interaction with the other variables, and so its interaction plots are entirely grey.

Real data with binary outcomes (ALEPlot Example 3)

The next code example from the {ALEPlot} package analyzes a real dataset with a binary outcome variable. Whereas the {ALEPlot} has the user load a CSV file that might not be readily available, we make that dataset available as the census dataset. We load it here with the adjustments necessary to run the {ALEPlot} example.

## R code for Example 3
## Load relevant packages
library(ALEPlot)
library(gbm, quietly = TRUE)
#> Loaded gbm 2.2.2
#> This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3

## Read data and fit a boosted tree supervised learning model
data(census, package = 'ale')  # load ale package version of the data
data <-  
  census |> 
  as.data.frame() |>   # ALEPlot is not compatible with the tibble format
  select(age:native_country, higher_income) |>  # Rearrange columns to match ALEPlot order
  na.omit(data)

Although gradient boosted trees generally perform quite well, they are rather slow. Rather than having you wait for it to run, the code here downloads a pretrained GBM model. However, the code used to generate it is provided in comments so that you can see it and run it yourself if you want to. Note that the model calls is based on data[,-c(3,4)], which drops the third and fourth variables (fnlwgt and education, respectively).

# # To generate the code, uncomment the following lines.
# # But GBM training is slow, so this vignette loads a pre-created model object.
# set.seed(0)
# gbm.data <- gbm(higher_income ~ ., data= data[,-c(3,4)],
#                 distribution = "bernoulli", n.trees=6000, shrinkage=0.02,
#                 interaction.depth=3)
# saveRDS(gbm.data, file.choose())
gbm.data <- url('https://github.com/tripartio/ale/raw/main/download/gbm.data_model.rds') |> 
  readRDS()

gbm.data
#> gbm(formula = higher_income ~ ., distribution = "bernoulli", 
#>     data = data[, -c(3, 4)], n.trees = 6000, interaction.depth = 3, 
#>     shrinkage = 0.02)
#> A gradient boosted model with bernoulli loss function.
#> 6000 iterations were performed.
#> There were 12 predictors of which 12 had non-zero influence.

ALEPlot code

As before, we create a custom prediction function and then call the {ALEPlot} function to generate the plots. The prediction type here is “link”, which represents the log odds in the gbm package.

Creation of the ALE plots here is rather slow because the gbm predict function is slow. In this example, only age, education_num (number of years of education), and hours_per_week are plotted, along with the interaction between age and hours_per_week.

## Define the predictive function; note the additional arguments for the
## predict function in gbm
yhat <- function(X.model, newdata) as.numeric(predict(X.model, newdata,
n.trees = 6000, type="link"))

## Calculate and plot the ALE main and interaction effects for x_1, x_3,
## x_11, and {x_1, x_11}
par(mfrow = c(2,2), mar = c(4,4,2,2)+ 0.1)
ALE.1=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=1, K=500,
NA.plot = TRUE)
ALE.3=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=3, K=500,
NA.plot = TRUE)
ALE.11=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=11, K=500,
NA.plot = TRUE)
ALE.1and11=ALEPlot(data[,-c(3,4,15)], gbm.data, pred.fun=yhat, J=c(1,11),
K=50, NA.plot = FALSE)

{ale} package equivalent

Here is the analogous code using the ale package. In this case, we also need to define a custom predict function because of the particular n.trees = 6000 argument. To speed things up, we provide a pretrained ale object. This is possible because ale returns objects with data and plots bundled together with no side effects (like automatic printing of created plots). (It is probably possible to similarly cache {ALEPlot} ALE objects, but it is not quite as straightforward.)

Log odds

We display all the plots because it is easy to do so with the ale package but we focus on age, education_num, and hours_per_week for comparison with ALEPlot. If the shapes of these plots look different, it is because ale tries as much as possible to display plots on the same y-axis coordinate scale for easy comparison across plots.

# Custom predict function that returns log odds
yhat <- function(object, newdata, type) {
    predict(object, newdata, type='link', n.trees = 6000) |>  # return log odds
    as.numeric()
}

# Generate ALE data for all variables

# # To generate the code, uncomment the following lines.
# # But it is very slow because it calculates ALE for all variables,
# # so this vignette loads a pre-created model object.
# gbm_ale_link <- ale(
#   # data[,-c(3,4)], gbm.data,
#   data, gbm.data,
#   pred_fun = yhat,
#   max_num_bins = 500,
#   sample_size = 600  # technical issue: sample_size must be > max_num_bins + 1
# )
# saveRDS(gbm_ale_link, file.choose())
gbm_ale_link <- url('https://github.com/tripartio/ale/raw/main/download/gbm_ale_link.rds') |> 
  readRDS()

# Print plots
gbm_link_plots <- plot(gbm_ale_link)
gbm_1D_link_plots <- gbm_link_plots$distinct$higher_income$plots[[1]]
patchwork::wrap_plots(gbm_1D_link_plots, ncol = 2)

Now we generate ALE data for all two-way interactions and then plot them. Again, note the interaction between age and hours_per_week. The interaction is minimal except for the extremely high cases of hours per week.

# # To generate the code, uncomment the following lines.
# # But it is very slow because it calculates ALE for all variables,
# # so this vignette loads a pre-created model object.
# gbm_ale_2D_link <- ale(
#   # data[,-c(3,4)], gbm.data,
#   data, gbm.data,
#   complete_d = 2,
#   pred_fun = yhat,
#   max_num_bins = 500,
#   sample_size = 600  # technical issue: sample_size must be > max_num_bins + 1
# )
# saveRDS(gbm_ale_2D_link, file.choose())
gbm_ale_2D_link <- url('https://github.com/tripartio/ale/raw/main/download/gbm_ale_2D_link.rds') |> 
  readRDS()

# Print plots
gbm_link_plots <- plot(gbm_ale_2D_link)
gbm_link_2D_plots <- gbm_link_plots$distinct$higher_income$plots[[2]]
gbm_link_2D_plots |>
  # extract list of x1 ALE outputs
  purrr::walk(\(it.x1) {  
    # plot all x2 plots in each .x1 element
    patchwork::wrap_plots(it.x1, ncol = 2) |> 
      print()
  })

In some of these plots, we can see some white spots. These are the interaction zones where there is no data in the dataset to calculate the existence of an interaction. For example, let’s focus on the interactions of age with education_num:

gbm_link_2D_plots$age$education_num

Here, the grey zones in the majority of the plot indicate that there are minimal interaction effects for most of the data range. However, in the small interacting zone of people younger than 30 years old or so with 14 to 16 years of education, we see that the likelihood of higher income is around 0.9 times lower than average. For the several white zones, there is no data in the dataset to support an estimate. For example, there is no one 35 to 45 years old with 15 years of education and there is no one 49 to 60 years old with 14 years of education; so, the model can say nothing about such interactions.

Predicted probabilities

Log odds are not necessarily the most interpretable way to express probabilities (though we will show shortly that they are sometimes uniquely valuable). So, we repeat the ALE creation using the “response” prediction type for probabilities and the default median centring of the plots.

As we can see, the shapes of the plots are similar, but the y axes are more easily interpretable as the probability (from 0 to 1) that a census respondent is in the higher income category. The median of around 10% or so indicates the median prediction of the GBM model: half of the respondents were predicted to have higher than a 10% likelihood of being higher income and half were predicted to have lower likelihood. The y-axis rug plots indicate that the predictions were generally rather extreme, either relatively close to 0 or 1, with few predictions in the middle.

# Custom predict function that returns predicted probabilities
yhat <- function(object, newdata, type) {
  as.numeric(
    predict(
      object, newdata,  n.trees = 6000, 
      type = "response"  # return predicted probabilities
    )
  )
}

# Generate ALE data for all variables

# # To generate the code, uncomment the following lines.
# # But it is slow because it calculates ALE for all variables,
# # so this vignette loads a pre-created model object.
# gbm_ale_prob <- ale(
#   # data[,-c(3,4)], gbm.data,
#   data, gbm.data,
#   pred_fun = yhat,
#   max_num_bins = 500,
#   sample_size = 600  # technical issue: sample_size must be > max_num_bins + 1
# )
# saveRDS(gbm_ale_prob, file.choose())
gbm_ale_prob <- url('https://github.com/tripartio/ale/raw/main/download/gbm_ale_prob.rds') |> 
  readRDS()

# Print plots
gbm_prob_plots <- plot(gbm_ale_prob)
gbm_1D_prob_plots <- gbm_prob_plots$distinct$higher_income$plots[[1]]
patchwork::wrap_plots(gbm_1D_prob_plots, ncol = 2)

Finally, we again generate two-way interactions, this time based on probabilities instead of on log odds. However, probabilities might not be the best choice for indicating interactions because, as we see from the rugs in the one-way ALE plots, the GBM model heavily concentrates its probabilities in the extremes near 0 and 1. Thus, the plots’ suggestions of strong interactions are likely exaggerated. In this case, the log odds ALEs shown above are probably more relevant.

# To generate the code, uncomment the following lines.
# # But it is slow because it calculates ALE for all variables,
# # so this vignette loads a pre-created model object.
# gbm_ale_2D_prob <- ale(
#   # data[,-c(3,4)], gbm.data,
#   data, gbm.data,
#   complete_d = 2,
#   pred_fun = yhat,
#   max_num_bins = 500,
#   sample_size = 600  # technical issue: sample_size must be > max_num_bins + 1
# )
# saveRDS(gbm_ale_2D_prob, file.choose())
gbm_ale_2D_prob <- url('https://github.com/tripartio/ale/raw/main/download/gbm_ale_2D_prob.rds') |> 
  readRDS()

# Print plots
gbm_prob_plots <- plot(gbm_ale_2D_prob)
gbm_prob_2D_plots <- gbm_prob_plots$distinct$higher_income$plots[[2]]
gbm_prob_2D_plots |>
  # extract list of x1 ALE outputs
  purrr::walk(\(it.x1) {  
    # plot all x2 plots in each .x1 element
    patchwork::wrap_plots(it.x1, ncol = 2) |> 
      print()
  })