Skip to contents

ale() is the central function that manages the creation of ALE data and plots for one-way ALE. For two-way interactions, see ale_ixn(). This function calls ale_core (a non-exported function) that manages the ALE data and plot creation in detail. For details, see the introductory vignette for this package or the details and examples below.

Usage

ale(
  data,
  model,
  x_cols = NULL,
  y_col = NULL,
  ...,
  parallel = parallel::detectCores(logical = FALSE) - 1,
  model_packages = as.character(NA),
  output = c("plots", "data", "stats", "conf_regions"),
  pred_fun = function(object, newdata, type = pred_type) {
     stats::predict(object =
    object, newdata = newdata, type = type)
 },
  pred_type = "response",
  p_values = NULL,
  p_alpha = c(0.01, 0.05),
  x_intervals = 100,
  boot_it = 0,
  seed = 0,
  boot_alpha = 0.05,
  boot_centre = "mean",
  relative_y = "median",
  y_type = NULL,
  median_band_pct = c(0.05, 0.5),
  rug_sample_size = 500,
  min_rug_per_interval = 1,
  ale_xs = NULL,
  ale_ns = NULL,
  compact_plots = FALSE,
  silent = FALSE
)

Arguments

data

dataframe. Dataset from which to create predictions for the ALE.

model

model object. Model for which ALE should be calculated. May be any kind of R object that can make predictions from data.

x_cols

character. Vector of column names from data for which one-way ALE data is to be calculated (that is, simple ALE without interactions). If not provided, ALE will be created for all columns in data except y_col.

y_col

character length 1. Name of the outcome target label (y) variable. If not provided, ale() will try to detect it automatically. For non-standard models, y_col should be provided. For survival models, set y_col to the name of the binary event column; in that case, pred_type should also be specified.

...

not used. Inserted to require explicit naming of subsequent arguments.

parallel

non-negative integer length 1. Number of parallel threads (workers or tasks) for parallel execution of the function. See details.

model_packages

character. Character vector of names of packages that model depends on that might not be obvious. The {ale} package should be able to automatically recognize and load most packages that are needed, but with parallel processing enabled (which is the default), some packages might not be properly loaded. If you get a strange error message that mentions something somewhere about 'future', try adding the package for your model to this vector, especially if you see such errors after the progress bars begin displaying (assuming you did not disable progress bars with silent = TRUE).

output

character in c('plots', 'data', 'stats', 'conf_regions'). Vector of types of results to return. 'plots' will return an ALE plot; 'data' will return the source ALE data; 'stats' will return ALE statistics. Each option must be listed to return the specified component. By default, all are returned.

pred_fun, pred_type

function,character length 1. pred_fun is a function that returns a vector of predicted values of type pred_type from model on data. See details.

p_values

instructions for calculating p-values and to determine the median band. If NULL (default), no p-values are calculated and median_band_pct is used to determine the median band. To calculate p-values, an object generated by the create_p_funs() function must be provided here. If p_values is set to 'auto', this ale() function will try to automatically create the p-values function; this only works with standard R model types. Any error message will be given if p-values cannot be generated. Any other input provided to this argument will result in an error. For more details about creating p-values, see documentation for create_p_funs(). Note that p-values will not be generated if 'stats' are not included as an option in the output argument.

p_alpha

numeric length 2 from 0 to 1. Alpha for "confidence interval" ranges for printing bands around the median for single-variable plots. These are the default values used if p_values are provided. If p_values are not provided, then median_band_pct is used instead. The inner band range will be the median value of y ± p_alpha[2] of the relevant ALE statistic (usually ALE range or normalized ALE range). For plots with a second outer band, its range will be the median ± p_alpha[1]. For example, in the ALE plots, for the default p_alpha = c(0.01, 0.05), the inner band will be the median ± ALE minimum or maximum at p = 0.05 and the outer band will be the median ± ALE minimum or maximum at p = 0.01.

x_intervals

positive integer length 1. Maximum number of intervals on the x-axis for the ALE data for each column in x_cols. The number of intervals that the algorithm generates might eventually be fewer than what the user specifies if the data values for a given x value do not support that many intervals.

boot_it

non-negative integer length 1. Number of bootstrap iterations for the ALE values. If boot_it = 0 (default), then ALE will be calculated on the entire dataset with no bootstrapping.

seed

integer length 1. Random seed. Supply this between runs to assure that identical random ALE data is generated each time

boot_alpha

numeric length 1 from 0 to 1. Alpha for percentile-based confidence interval range for the bootstrap intervals; the bootstrap confidence intervals will be the lowest and highest (1 - 0.05) / 2 percentiles. For example, if boot_alpha = 0.05 (default), the intervals will be from the 2.5 and 97.5 percentiles.

boot_centre

character length 1 in c('mean', 'median'). When bootstrapping, the main estimate for ale_y is considered to be boot_centre. Regardless of the value specified here, both the mean and median will be available.

relative_y

character length 1 in c('median', 'mean', 'zero'). The ale_y values will be adjusted relative to this value. 'median' is the default. 'zero' will maintain the default of ALEPlot::ALEPlot(), which is not shifted.

y_type

character length 1. Datatype of the y (outcome) variable. Must be one of c('binary', 'numeric', 'multinomial', 'ordinal'). Normally determined automatically; only provide for complex non-standard models that require it.

median_band_pct

numeric length 2 from 0 to 1. Alpha for "confidence interval" ranges for printing bands around the median for single-variable plots. These are the default values used if p_values are not provided. If p_values are provided, then median_band_pct is ignored. The inner band range will be the median value of y ± median_band_pct[1]/2. For plots with a second outer band, its range will be the median ± median_band_pct[2]/2. For example, for the default median_band_pct = c(0.05, 0.5), the inner band will be the median ± 2.5% and the outer band will be the median ± 25%.

rug_sample_size, min_rug_per_interval

single non-negative integer length 1. Rug plots are normally down-sampled otherwise they are too slow. rug_sample_size specifies the size of this sample. To prevent down-sampling, set to Inf. To suppress rug plots, set to 0. When down-sampling, the rug plots maintain representativeness of the data by guaranteeing that each of the x_intervals intervals will retain at least min_rug_per_interval elements; usually set to just 1 or 2.

ale_xs, ale_ns

list of ale_x and ale_n vectors. If provided, these vectors will be used to set the intervals of the ALE x axis for each variable. By default (NULL), the function automatically calculates the ale_x intervals. ale_xs is normally used in advanced analyses where the ale_x intervals from a previous analysis are reused for subsequent analyses (for example, for full model bootstrapping; see the model_bootstrap() function).

compact_plots

logical length 1, default FALSE. When output includes 'plots', the returned ggplot objects each include the environments of the plots. This lets the user modify the plots with all the flexibility of ggplot, but it can result in very large return objects (sometimes even hundreds of megabytes large). To compact the plots to their bare minimum, set compact_plots = TRUE. However, returned plots will not be easily modifiable, so this should only be used if you do not want to subsequently modify the plots.

silent

logical length 1, default FALSE. If TRUE, do not display any non-essential messages during execution (such as progress bars). Regardless, any warnings and errors will always display. See details for how to enable progress bars.

Value

list with the following elements:

  • data: a list whose elements, named by each requested x variable, are each a tibble with the following columns:

    • ale_x: the values of each of the ALE x intervals or categories.

    • ale_n: the number of rows of data in each ale_x interval or category.

    • ale_y: the ALE function value calculated for that interval or category. For bootstrapped ALE, this is the same as ale_y_mean by default or ale_y_median if the boot_centre = 'median' argument is specified. Regardless, both ale_y_mean and ale_y_median are returned as columns here.

    • ale_y_lo, ale_y_hi: the lower and upper confidence intervals, respectively, for the bootstrapped ale_y value. Note: regardless what options are requested in the output argument, this data element is always returned.

  • stats: if stats are requested in the output argument (as is the default), returns a list. If not requested, returns NULL. The returned list provides ALE statistics of the data element duplicated and presented from various perspectives in the following elements:

    • by_term: a list named by each requested x variable, each of whose elements is a tibble with the following columns:

      • statistic: the ALE statistic specified in the row (see the by_statistic element below).

      • estimate: the bootstrapped mean or median of the statistic, depending on the boot_centre argument to the ale() function. Regardless, both mean and median are returned as columns here.

      • conf.low, conf.high: the lower and upper confidence intervals, respectively, for the bootstrapped estimate.

    • by_statistic: list named by each of the following ALE statistics: aled, aler_min, aler_max, naled, naler_min, naler_max. See vignette('ale-statistics') for details.

    • estimate: a tibble whose data consists of the estimate values from the by_term element above. The columns are term (the variable name) and the statistic for which the estimate is given: aled, aler_min, aler_max, naled, naler_min, naler_max.

    • effects_plot: a ggplot object which is the ALE effects plot for all the x variables.

  • plots: if plots are requested in the output argument (as is the default), returns a list whose elements, named by each requested x variable, are each a ggplot object of the ALE y values plotted against the x variable intervals. If plots is not included in output, this element is NULL.

  • conf_regions: if conf_regions are requested in the output argument (as is the default), returns a list. If not requested, returns NULL. The returned list provides summaries of the confidence regions of the relevant ALE statistics of the data element. The list has the following elements:

    • by_term: a list named by each requested x variable, each of whose elements is a tibble with the relevant data for the confidence regions. (See vignette('ale-statistics') for details about confidence regions.)

    • significant: a tibble that summarizes the by_term to only show confidence regions that are statistically significant. Its columns are those from by_term plus a term column to specify which x variable is indicated by the respective row.

    • sig_criterion: a length-one character vector that reports which values were used to determine statistical significance: if p_values was provided to the ale() function, it will be used; otherwise, median_band_pct will be used.

  • Various values echoed from the original call to the ale() function, provided to document the key elements used to calculate the ALE data, statistics, and plots: y_col, x_cols, boot_it, seed, boot_alpha, boot_centre, relative_y, y_type, median_band_pct, rug_sample_size. These are either the values provided by the user or used by default if the user did not change them.

  • y_summary: summary statistics of y values used for the ALE calculation. These statistics are based on the actual values of y_col unless if y_type is a probability or other value that is constrained in the [0, 1] range. In that case, y_summary is based on the predicted values of y_col by applying model to the data. y_summary is a named numeric vector. Most of the elements are the percentile of the y values. E.g., the '5%' element is the 5th percentile of y values. The following elements have special meanings:

    • The first element is named either p or q and its value is always 0. The value is not used; only the name of the element is meaningful. p means that the following special y_summary elements are based on the provided p_values object. q means that quantiles were calculated based on median_band_pct because p_values was not provided.

    • min, mean, max: the minimum, mean, and maximum y values, respectively. Note that the median is 50%, the 50th percentile.

    • med_lo_2, med_lo, med_hi, med_hi_2: med_lo and med_hi are the inner lower and upper confidence intervals of y values with respect to the median (50%); med_lo_2 and med_hi_2 are the outer confidence intervals. See the documentation for the p_alpha and median_band_pct arguments to understand how these are determined.

Details

ale_core.R

Core functions for the ale package: ale, ale_ixn, and ale_core

Custom predict function

The calculation of ALE requires modifying several values of the original data. Thus, ale() needs direct access to a predict function that work on model. By default, ale() uses a generic default predict function of the form predict(object, newdata, type) with the default prediction type of 'response'. If, however, the desired prediction values are not generated with that format, the user must specify what they want. Most of the time, the only modification needed is to change the prediction type to some other value by setting the pred_type argument (e.g., to 'prob' to generated classification probabilities). But if the desired predictions need a different function signature, then the user must create a custom prediction function and pass it to pred_fun. The requirements for this custom function are:

  • It must take three required arguments and nothing else:

    • object: a model

    • newdata: a dataframe or compatible table type

    • type: a string; it should usually be specified as type = pred_type These argument names are according to the R convention for the generic stats::predict function.

  • It must return a vector of numeric values as the prediction.

You can see an example below of a custom prediction function.

Note: survival models probably do not need a custom prediction function but y_col must be set to the name of the binary event column and pred_type must be set to the desired prediction type.

ALE statistics

For details about the ALE-based statistics (ALED, ALER, NALED, and NALER), see vignette('ale-statistics').

Parallel processing

Parallel processing using the {furrr} library is enabled by default. By default, it will use all the available physical CPU cores (minus the core being used for the current R session) with the setting parallel = parallel::detectCores(logical = FALSE) - 1. Note that only physical cores are used (not logical cores or "hyperthreading") because machine learning can only take advantage of the floating point processors on physical cores, which are absent from logical cores. Trying to use logical cores will not speed up processing and might actually slow it down with useless data transfer. If you will dedicate the entire computer to running this function (and you don't mind everything else becoming very slow while it runs), you may use all cores by setting parallel = parallel::detectCores(logical = FALSE). To disable parallel processing, set parallel = 0.

Progress bars

Progress bars are implemented with the {progressr} package, which lets the user fully control progress bars. To disable progress bars, set silent = TRUE. The first time a function is called in the {ale} package that requires progress bars, it checks if the user has activated the necessary {progressr} settings. If not, the {ale} package automatically enables {progressr} progress bars with the cli handler and prints a message notifying the user.

If you like the default progress bars and you want to make them permanent, then you can add the following lines of code to your .Rprofile configuration file and they will become your defaults for every R session; you will not see the message again:

progressr::handlers(global = TRUE)
progressr::handlers('cli')

For more details on formatting progress bars to your liking, see the introduction to the {progressr} package.

References

Okoli, Chitu. 2023. “Statistical Inference Using Machine Learning and Classical Techniques Based on Accumulated Local Effects (ALE).” arXiv. https://arxiv.org/abs/2310.09877.

Examples

set.seed(0)
diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]

# Create a GAM model with flexible curves to predict diamond price
# Smooth all numeric variables and include all other variables
gam_diamonds <- mgcv::gam(
  price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) +
    cut + color + clarity,
  data = diamonds_sample
)
summary(gam_diamonds)
#> 
#> Family: gaussian 
#> Link function: identity 
#> 
#> Formula:
#> price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) + 
#>     cut + color + clarity
#> 
#> Parametric coefficients:
#>              Estimate Std. Error t value Pr(>|t|)    
#> (Intercept)  3421.412     74.903  45.678  < 2e-16 ***
#> cut.L         261.339    171.630   1.523 0.128170    
#> cut.Q          53.684    129.990   0.413 0.679710    
#> cut.C         -71.942    103.804  -0.693 0.488447    
#> cut^4          -8.657     80.614  -0.107 0.914506    
#> color.L     -1778.903    113.669 -15.650  < 2e-16 ***
#> color.Q      -482.225    104.675  -4.607 4.64e-06 ***
#> color.C        58.724     95.983   0.612 0.540807    
#> color^4       125.640     87.111   1.442 0.149548    
#> color^5      -241.194     81.913  -2.945 0.003314 ** 
#> color^6       -49.305     74.435  -0.662 0.507883    
#> clarity.L    4141.841    226.713  18.269  < 2e-16 ***
#> clarity.Q   -2367.820    217.185 -10.902  < 2e-16 ***
#> clarity.C    1026.214    180.295   5.692 1.67e-08 ***
#> clarity^4    -602.066    137.258  -4.386 1.28e-05 ***
#> clarity^5     408.336    105.344   3.876 0.000113 ***
#> clarity^6     -82.379     88.434  -0.932 0.351815    
#> clarity^7       4.017     78.816   0.051 0.959362    
#> ---
#> Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#> 
#> Approximate significance of smooth terms:
#>            edf Ref.df      F  p-value    
#> s(carat) 7.503  8.536  4.114 3.65e-05 ***
#> s(depth) 1.486  1.874  0.601 0.614753    
#> s(table) 2.929  3.738  1.294 0.240011    
#> s(x)     8.897  8.967  3.323 0.000542 ***
#> s(y)     3.875  5.118 11.075  < 2e-16 ***
#> s(z)     9.000  9.000  2.648 0.004938 ** 
#> ---
#> Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#> 
#> R-sq.(adj) =   0.94   Deviance explained = 94.3%
#> GCV = 9.7669e+05  Scale est. = 9.262e+05  n = 1000


# \donttest{

# Simple ALE without bootstrapping
ale_gam_diamonds <- ale(
  diamonds_sample, gam_diamonds,
  parallel = 2  # CRAN limit (delete this line on your own computer)
)

# Plot the ALE data
ale_gam_diamonds$plots |>
  patchwork::wrap_plots()


# Bootstrapped ALE
# This can be slow, since bootstrapping runs the algorithm boot_it times

# Create ALE with 100 bootstrap samples
ale_gam_diamonds_boot <- ale(
  diamonds_sample, gam_diamonds, boot_it = 100,
  parallel = 2  # CRAN limit (delete this line on your own computer)
)

# Bootstrapped ALEs print with confidence intervals
ale_gam_diamonds_boot$plots |>
  patchwork::wrap_plots()



# If the predict function you want is non-standard, you may define a
# custom predict function. It must return a single numeric vector.
custom_predict <- function(object, newdata, type = pred_type) {
  predict(object, newdata, type = type, se.fit = TRUE)$fit
}

ale_gam_diamonds_custom <- ale(
  diamonds_sample, gam_diamonds,
  pred_fun = custom_predict, pred_type = 'link',
  parallel = 2  # CRAN limit (delete this line on your own computer)
)

# Plot the ALE data
ale_gam_diamonds_custom$plots |>
  patchwork::wrap_plots()


# }