Skip to contents

An ALE S7 object contains ALE data and statistics. For details, see vignette('ale-intro') or the details and examples below.

Usage

ALE(
  model,
  x_cols = list(d1 = TRUE),
  data = NULL,
  y_col = NULL,
  ...,
  exclude_cols = NULL,
  parallel = future::availableCores(logical = FALSE, omit = 1),
  model_packages = NULL,
  output_stats = TRUE,
  output_conf = TRUE,
  output_boot_data = FALSE,
  pred_fun = function(object, newdata, type = pred_type) {
     stats::predict(object =
    object, newdata = newdata, type = type)
 },
  pred_type = "response",
  p_values = NULL,
  p_alpha = c(0.01, 0.05),
  max_num_bins = 10,
  boot_it = 0,
  boot_alpha = 0.05,
  boot_centre = "mean",
  seed = 0,
  y_type = NULL,
  median_band_pct = c(0.05, 0.5),
  sample_size = 500,
  .bins = NULL,
  silent = FALSE
)

Arguments

model

model object. Required. Model for which ALE should be calculated. May be any kind of R object that can make predictions from data.

x_cols, exclude_cols

character, list, or formula. Columns names from data requested in one of the special x_cols formats for which ALE data is to be calculated. Defaults to 1D ALE for all columns in data except y_col. See details in the documentation for resolve_x_cols().

data

dataframe. Dataset from which to create predictions for the ALE. It should normally be the same dataset on which model was trained. If not provided, ALE() will try to detect it automatically. For non-standard models, data should be provided.

y_col

character(1). Name of the outcome target label (y) variable. If not provided, ALE() will try to detect it automatically. For non-standard models, y_col should be provided. For survival models, set y_col to the name of the binary event column; in that case, pred_type should also be specified.

...

not used. Inserted to require explicit naming of subsequent arguments.

parallel

non-negative integer(1). Number of parallel threads (workers or tasks) for parallel execution of the function. Set parallel = 0 to disable parallel processing. See details.

model_packages

character. Character vector of names of packages that model depends on that might not be obvious with parallel processing. If you get weird error messages when parallel processing is enabled (which is the default) but they are resolved by setting parallel = 0, you might need to specify model_packages. See details.

output_stats

logical(1). If TRUE (default), return ALE statistics.

output_conf

logical(1). If TRUE (default), return ALE confidence regions. If output_stats is FALSE, output_conf is ignored since confidence regions cannot be produced without ALE statistics.

output_boot_data

logical(1). If TRUE, return the raw ALE data for each bootstrap iteration. Default is FALSE.

pred_fun, pred_type

function,character(1). pred_fun is a function that returns a vector of predicted values of type pred_type from model on data. See details.

p_values

instructions for calculating p-values and to determine the median band. If NULL (default), no p-values are calculated and median_band_pct is used to determine the median band. To calculate p-values, an ALEpDist() object must be provided here. If p_values is set to 'auto', this ALE() function will try to automatically create the p-values distribution; this only works with standard R model types. An error message will be given if p-values cannot be generated. Any other input provided to this argument will result in an error. For more details about creating p-values, see documentation for ALEpDist(). Note that p-values will not be generated if output_stats is FALSE.

p_alpha

numeric length 2 from 0 to 1. Alpha for "confidence interval" ranges for printing bands around the median for single-variable plots. These are the default values used if p_values are provided. If p_values are not provided, then median_band_pct is used instead. The inner band range will be the median value of y ± p_alpha[2] of the relevant ALE statistic (usually ALE range or normalized ALE range). For plots with a second outer band, its range will be the median ± p_alpha[1]. For example, in the ALE plots, for the default p_alpha = c(0.01, 0.05), the inner band will be the median ± ALE minimum or maximum at p = 0.05 and the outer band will be the median ± ALE minimum or maximum at p = 0.01.

max_num_bins

positive integer(1). Maximum number of bins for numeric x_cols variables. The number of bins is eventually the lower of the number of unique values of a numeric variable and max_num_bins.

boot_it

non-negative integer(1). Number of bootstrap iterations for data-only bootstrapping on ALE data. This is appropriate for models that have been developed with cross-validation. For models that have not been validated, full-model bootstrapping should be used instead with the ModelBoot class. See details there. The default boot_it = 0 turns off bootstrapping.

boot_alpha

numeric(1) from 0 to 1. Alpha for percentile-based confidence interval range for the bootstrap intervals; the bootstrap confidence intervals will be the lowest and highest (1 - 0.05) / 2 percentiles. For example, if boot_alpha = 0.05 (default), the intervals will be from the 2.5 and 97.5 percentiles.

boot_centre

character(1) in c('mean', 'median'). When bootstrapping, the main estimate for the ALE y value is considered to be boot_centre. Regardless of the value specified here, both the mean and median will be available.

seed

integer(1). Random seed. Supply this between runs to assure that identical random ALE data is generated each time when bootstrapping. Without bootstrapping, ALE is a deterministic algorithm that should result in identical results each time regardless of the seed specified.

y_type

character(1) in c('binary', 'numeric', 'categorical', 'ordinal'). Datatype of the y (outcome) variable. Normally determined automatically; only provide if an error message for a complex non-standard model requires it.

median_band_pct

numeric length 2 from 0 to 1. Alpha for "confidence interval" ranges for printing bands around the median for single-variable plots. These are the default values used if p_values are not provided. If p_values are provided, then median_band_pct is ignored. The inner band range will be the median value of y ± median_band_pct[1]/2. For plots with a second outer band, its range will be the median ± median_band_pct[2]/2. For example, for the default median_band_pct = c(0.05, 0.5), the inner band will be the median ± 2.5% and the outer band will be the median ± 25%.

sample_size

non-negative integer(1). Size of the sample of data to be returned with the ALE object. This is primarily used for rug plots. See the min_rug_per_interval argument.

.bins

Internal. List of ALE bin and n count vectors. If provided, these vectors will be used to set the intervals of the ALE x axis for each variable. By default (NULL), ALE() automatically calculates the bins. .bins is normally used in advanced analyses where the bins from a previous analysis are reused for subsequent analyses (for example, for full model bootstrapping with ModelBoot()).

silent

logical(1), default FALSE. If TRUE, do not display any non-essential messages during execution (such as progress bars). Regardless, any warnings and errors will always display. See details for how to enable progress bars.

Value

An object of class ALE with properties distinct and params.

Properties

distinct

Stores the optional ALE data, ALE statistics, and bootstrap data for one or more categories.

params

The parameters used to calculate the ALE data. These include most of the arguments used to construct the ALE object. These are either the values provided by the user or used by default if the user did not change them but also includes several objects that are created within the constructor. These extra objects are described here, as well as those parameters that are stored differently from the form in the arguments:

* `max_d`: the highest dimension of ALE data present. If only 1D ALE is present, then `max_d == 1`. If even one 2D ALE element is present (even with no 1D), then `max_d == 2`.
* `requested_x_cols`,`ordered_x_cols`: `requested_x_cols` is the resolved list of `x_cols` as requested by the user (that is, `x_cols` minus `exclude_cols`). `ordered_x_cols` is the same set of `x_cols` but arranged in the internal storage order.
* `y_cats`: categories for categorical classification models. For non-categorical models, this is the same as `y_col`.
* `y_type`: high-level datatype of the y outcome variable.
* `y_summary`: summary statistics of y values used for the ALE calculation. These statistics are based on the actual values of `y_col` unless if `y_type` is a probability or other value that is constrained in the `[0, 1]` range. In that case, `y_summary` is based on the predicted values of `y_col` by applying `model` to the `data`. `y_summary` is a named numeric vector. Most of the elements are the percentile of the y values. E.g., the '5%' element is the 5th percentile of y values. The following elements have special meanings:
* The first element is named either `p` or `q` and its value is always 0.
  The value is not used; only the name of the element is meaningful.
  `p` means that the following special `y_summary` elements are based on
  the provided `ALEpDist` object. `q` means that quantiles were calculated
  based on `median_band_pct` because `p_values` was not provided.
* `min`, `mean`, `max`: the minimum, mean, and maximum y values, respectively. Note that the median is `50%`, the 50th percentile.
* `med_lo_2`, `med_lo`, `med_hi`, `med_hi_2`: `med_lo` and `med_hi` are the inner lower and upper confidence intervals of y values with respect to the median (`50%`); `med_lo_2` and `med_hi_2` are the outer confidence intervals. See the documentation for the `p_alpha` and `median_band_pct` arguments to understand how these are determined.
* `model`: same as `ALE@params$model` (see documentation there).
* `data`: same as `ALE@params$model` (see documentation there).

Custom predict function

The calculation of ALE requires modifying several values of the original data. Thus, ALE() needs direct access to the predict function for the model. By default, ALE() uses a generic default predict function of the form predict(object, newdata, type) with the default prediction type of 'response'. If, however, the desired prediction values are not generated with that format, the user must specify what they want. Very often, the only modification needed is to change the prediction type to some other value by setting the pred_type argument (e.g., to 'prob' to generated classification probabilities). But if the desired predictions need a different function signature, then the user must create a custom prediction function and pass it to pred_fun. The requirements for this custom function are:

  • It must take three required arguments and nothing else:

    • object: a model

    • newdata: a dataframe or compatible table type such as a tibble or data.table

    • type: a string; it should usually be specified as type = pred_type These argument names are according to the R convention for the generic stats::predict() function.

  • It must return a vector or matrix of numeric values as the prediction.

You can see an example below of a custom prediction function.

Note: survival models probably do not need a custom prediction function but y_col must be set to the name of the binary event column and pred_type must be set to the desired prediction type.

ALE statistics

For details about the ALE-based statistics (ALED, ALER, NALED, and NALER), see vignette('ale-statistics').

Parallel processing

Parallel processing using the {furrr} framework is enabled by default. By default, it will use all the available physical CPU cores (minus the core being used for the current R session) with the setting parallel = future::availableCores(logical = FALSE, omit = 1). Note that only physical cores are used (not logical cores or "hyperthreading") because machine learning can only take advantage of the floating point processors on physical cores, which are absent from logical cores. Trying to use logical cores will not speed up processing and might actually slow it down with useless data transfer. If you will dedicate the entire computer to running this function (and you don't mind everything else becoming very slow while it runs), you may use all cores by setting parallel = future::availableCores(logical = FALSE). To disable parallel processing, set parallel = 0.

#' The {ale} package should be able to automatically recognize and load most packages that are needed, but with parallel processing enabled (which is the default), some packages might not be properly loaded. This problem might be indicated if you get a strange error message that mentions something somewhere about "progress interrupted" or "future", especially if you see such errors after the progress bars begin displaying (assuming you did not disable progress bars with silent = TRUE). In that case, first try disabling parallel processing with parallel = 0. If that resolves the problem, then to get faster parallel processing to work, try adding the package names needed for the model to this argument, e.g., model_packages = c('tidymodels', 'mgcv').

Progress bars

Progress bars are implemented with the {progressr} package, which lets the user fully control progress bars. To disable progress bars, set silent = TRUE. The first time a function is called in the {ale} package that requires progress bars, it checks if the user has activated the necessary {progressr} settings. If not, the {ale} package automatically enables {progressr} progress bars with the cli handler and prints a message notifying the user.

If you like the default progress bars and you want to make them permanent, you can add the following lines of code to your .Rprofile configuration file and they will become your defaults for every R session; you will not see the message again:

progressr::handlers(global = TRUE)
progressr::handlers('cli')

This would apply not only to the {ale} package but to any package that uses {progressr} progress bars. For more details on formatting progress bars to your liking, see the introduction to the {progressr} package.

References

Okoli, Chitu. 2023. “Statistical Inference Using Machine Learning and Classical Techniques Based on Accumulated Local Effects (ALE).” arXiv. https://arxiv.org/abs/2310.09877.

Examples

# Sample 1000 rows from the ggplot2::diamonds dataset (for a simple example)
set.seed(0)
diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]

# Create a GAM model with flexible curves to predict diamond price
# Smooth all numeric variables and include all other variables
gam_diamonds <- mgcv::gam(
  price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) +
    cut + color + clarity,
  data = diamonds_sample
)
summary(gam_diamonds)
#> 
#> Family: gaussian 
#> Link function: identity 
#> 
#> Formula:
#> price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) + 
#>     cut + color + clarity
#> 
#> Parametric coefficients:
#>              Estimate Std. Error t value Pr(>|t|)    
#> (Intercept)  3421.412     74.903  45.678  < 2e-16 ***
#> cut.L         261.339    171.630   1.523 0.128170    
#> cut.Q          53.684    129.990   0.413 0.679710    
#> cut.C         -71.942    103.804  -0.693 0.488447    
#> cut^4          -8.657     80.614  -0.107 0.914506    
#> color.L     -1778.903    113.669 -15.650  < 2e-16 ***
#> color.Q      -482.225    104.675  -4.607 4.64e-06 ***
#> color.C        58.724     95.983   0.612 0.540807    
#> color^4       125.640     87.111   1.442 0.149548    
#> color^5      -241.194     81.913  -2.945 0.003314 ** 
#> color^6       -49.305     74.435  -0.662 0.507883    
#> clarity.L    4141.841    226.713  18.269  < 2e-16 ***
#> clarity.Q   -2367.820    217.185 -10.902  < 2e-16 ***
#> clarity.C    1026.214    180.295   5.692 1.67e-08 ***
#> clarity^4    -602.066    137.258  -4.386 1.28e-05 ***
#> clarity^5     408.336    105.344   3.876 0.000113 ***
#> clarity^6     -82.379     88.434  -0.932 0.351815    
#> clarity^7       4.017     78.816   0.051 0.959362    
#> ---
#> Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#> 
#> Approximate significance of smooth terms:
#>            edf Ref.df      F  p-value    
#> s(carat) 7.503  8.536  4.114 3.65e-05 ***
#> s(depth) 1.486  1.874  0.601 0.614753    
#> s(table) 2.929  3.738  1.294 0.240011    
#> s(x)     8.897  8.967  3.323 0.000542 ***
#> s(y)     3.875  5.118 11.075  < 2e-16 ***
#> s(z)     9.000  9.000  2.648 0.004938 ** 
#> ---
#> Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#> 
#> R-sq.(adj) =   0.94   Deviance explained = 94.3%
#> GCV = 9.7669e+05  Scale est. = 9.262e+05  n = 1000


# \donttest{

# Simple ALE without bootstrapping
ale_gam_diamonds <- ALE(gam_diamonds)
#> Warning: Could not recover model data from environment. Please make sure your
#>   data is available in your workspace.
#>   Trying to retrieve data from the model frame now.
#> Error in (function (.x, .f, ..., .progress = FALSE) {    map_("list", .x, .f, ..., .progress = .progress)})(.x = list("cut", "color", "clarity"), .f = function (...) {    {        ...furrr_chunk_seeds_i <- ...furrr_chunk_seeds_env[["i"]]        ...furrr_chunk_seeds_env[["i"]] <- ...furrr_chunk_seeds_i +             1L        assign(x = ".Random.seed", value = ...furrr_chunk_seeds[[...furrr_chunk_seeds_i]],             envir = globalenv(), inherits = FALSE)    }    NULL    ...furrr_out <- ...furrr_fn(...)    ...furrr_out}):  In index: 1.
#> Caused by error in `map()`:
#>  In index: 1.
#> Caused by error in `predict.gam()`:
#> ! newdata is a model.frame: it should contain all required variables

# Plot the ALE data
plot(ale_gam_diamonds)
#> Error: object 'ale_gam_diamonds' not found

# Bootstrapped ALE
# This can be slow, since bootstrapping runs the algorithm boot_it times

# Create ALE with 100 bootstrap samples
ale_gam_diamonds_boot <- ALE(
  gam_diamonds,
  boot_it = 100
)
#> Warning: Could not recover model data from environment. Please make sure your
#>   data is available in your workspace.
#>   Trying to retrieve data from the model frame now.
#> Error in (function (.x, .f, ..., .progress = FALSE) {    map_("list", .x, .f, ..., .progress = .progress)})(.x = list("cut", "color", "clarity"), .f = function (...) {    {        ...furrr_chunk_seeds_i <- ...furrr_chunk_seeds_env[["i"]]        ...furrr_chunk_seeds_env[["i"]] <- ...furrr_chunk_seeds_i +             1L        assign(x = ".Random.seed", value = ...furrr_chunk_seeds[[...furrr_chunk_seeds_i]],             envir = globalenv(), inherits = FALSE)    }    NULL    ...furrr_out <- ...furrr_fn(...)    ...furrr_out}):  In index: 1.
#> Caused by error in `map()`:
#>  In index: 1.
#> Caused by error in `predict.gam()`:
#> ! newdata is a model.frame: it should contain all required variables

# Bootstrapped ALEs print with confidence intervals
plot(ale_gam_diamonds_boot)
#> Error: object 'ale_gam_diamonds_boot' not found


# If the predict function you want is non-standard, you may define a
# custom predict function. It must return a single numeric vector.
custom_predict <- function(object, newdata, type = pred_type) {
  predict(object, newdata, type = type, se.fit = TRUE)$fit
}

ale_gam_diamonds_custom <- ALE(
  gam_diamonds,
  pred_fun = custom_predict, pred_type = 'link'
)
#> Warning: Could not recover model data from environment. Please make sure your
#>   data is available in your workspace.
#>   Trying to retrieve data from the model frame now.
#> Error in (function (.x, .f, ..., .progress = FALSE) {    map_("list", .x, .f, ..., .progress = .progress)})(.x = list("cut", "color", "clarity"), .f = function (...) {    {        ...furrr_chunk_seeds_i <- ...furrr_chunk_seeds_env[["i"]]        ...furrr_chunk_seeds_env[["i"]] <- ...furrr_chunk_seeds_i +             1L        assign(x = ".Random.seed", value = ...furrr_chunk_seeds[[...furrr_chunk_seeds_i]],             envir = globalenv(), inherits = FALSE)    }    NULL    ...furrr_out <- ...furrr_fn(...)    ...furrr_out}):  In index: 1.
#> Caused by error in `map()`:
#>  In index: 1.
#> Caused by error in `predict.gam()`:
#> ! newdata is a model.frame: it should contain all required variables

# Plot the ALE data
plot(ale_gam_diamonds_custom)
#> Error: object 'ale_gam_diamonds_custom' not found


# How to retrieve specific types of ALE data from an ALE object.
ale_diamonds_with__boot_data <- ALE(
  gam_diamonds,
  # For detailed options for x_cols, see examples at resolve_x_cols()
  x_cols = ~ carat + cut + clarity + color:depth + x:y,
  output_boot_data = TRUE,
  boot_it = 10  # just for demonstration
)
#> Warning: Could not recover model data from environment. Please make sure your
#>   data is available in your workspace.
#>   Trying to retrieve data from the model frame now.
#> Error in (function (.x, .f, ..., .progress = FALSE) {    map_("list", .x, .f, ..., .progress = .progress)})(.x = list(c("color", "depth"), c("x", "y")), .f = function (...) {    {        ...furrr_chunk_seeds_i <- ...furrr_chunk_seeds_env[["i"]]        ...furrr_chunk_seeds_env[["i"]] <- ...furrr_chunk_seeds_i +             1L        assign(x = ".Random.seed", value = ...furrr_chunk_seeds[[...furrr_chunk_seeds_i]],             envir = globalenv(), inherits = FALSE)    }    NULL    ...furrr_out <- ...furrr_fn(...)    ...furrr_out}):  In index: 1.
#> Caused by error in `map()`:
#>  In index: 1.
#> Caused by error in `predict.gam()`:
#> ! newdata is a model.frame: it should contain all required variables

# See ?get.ALE for details on the various kinds of data that may be retrieved.
get(ale_diamonds_with__boot_data, ~ carat + color:depth)  # default ALE data
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, what = 'boot_data')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'estimate')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'aled')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'all')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'conf_regions')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'conf_sig')
#> Error: object 'ale_diamonds_with__boot_data' not found
# }