An ALE
S7 object contains ALE data and statistics. For details, see vignette('ale-intro')
or the details and examples below.
Usage
ALE(
model,
x_cols = list(d1 = TRUE),
data = NULL,
y_col = NULL,
...,
exclude_cols = NULL,
parallel = future::availableCores(logical = FALSE, omit = 1),
model_packages = NULL,
output_stats = TRUE,
output_conf = TRUE,
output_boot_data = FALSE,
pred_fun = function(object, newdata, type = pred_type) {
stats::predict(object =
object, newdata = newdata, type = type)
},
pred_type = "response",
p_values = NULL,
p_alpha = c(0.01, 0.05),
max_num_bins = 10,
boot_it = 0,
boot_alpha = 0.05,
boot_centre = "mean",
seed = 0,
y_type = NULL,
median_band_pct = c(0.05, 0.5),
sample_size = 500,
.bins = NULL,
silent = FALSE
)
Arguments
- model
model object. Required. Model for which ALE should be calculated. May be any kind of R object that can make predictions from data.
- x_cols, exclude_cols
character, list, or formula. Columns names from
data
requested in one of the specialx_cols
formats for which ALE data is to be calculated. Defaults to 1D ALE for all columns indata
excepty_col
. See details in the documentation forresolve_x_cols()
.- data
dataframe. Dataset from which to create predictions for the ALE. It should normally be the same dataset on which
model
was trained. If not provided,ALE()
will try to detect it automatically. For non-standard models,data
should be provided.- y_col
character(1). Name of the outcome target label (y) variable. If not provided,
ALE()
will try to detect it automatically. For non-standard models,y_col
should be provided. For survival models, sety_col
to the name of the binary event column; in that case,pred_type
should also be specified.- ...
not used. Inserted to require explicit naming of subsequent arguments.
- parallel
non-negative integer(1). Number of parallel threads (workers or tasks) for parallel execution of the function. Set
parallel = 0
to disable parallel processing. See details.- model_packages
character. Character vector of names of packages that
model
depends on that might not be obvious with parallel processing. If you get weird error messages when parallel processing is enabled (which is the default) but they are resolved by settingparallel = 0
, you might need to specifymodel_packages
. See details.- output_stats
logical(1). If
TRUE
(default), return ALE statistics.- output_conf
logical(1). If
TRUE
(default), return ALE confidence regions. Ifoutput_stats
isFALSE
,output_conf
is ignored since confidence regions cannot be produced without ALE statistics.- output_boot_data
logical(1). If
TRUE
, return the raw ALE data for each bootstrap iteration. Default isFALSE
.- pred_fun, pred_type
function,character(1).
pred_fun
is a function that returns a vector of predicted values of typepred_type
frommodel
ondata
. See details.- p_values
instructions for calculating p-values and to determine the median band. If
NULL
(default), no p-values are calculated andmedian_band_pct
is used to determine the median band. To calculate p-values, anALEpDist()
object must be provided here. Ifp_values
is set to 'auto', thisALE()
function will try to automatically create the p-values distribution; this only works with standard R model types. An error message will be given if p-values cannot be generated. Any other input provided to this argument will result in an error. For more details about creating p-values, see documentation forALEpDist()
. Note that p-values will not be generated ifoutput_stats
isFALSE
.- p_alpha
numeric length 2 from 0 to 1. Alpha for "confidence interval" ranges for printing bands around the median for single-variable plots. These are the default values used if
p_values
are provided. Ifp_values
are not provided, thenmedian_band_pct
is used instead. The inner band range will be the median value of y ±p_alpha[2]
of the relevant ALE statistic (usually ALE range or normalized ALE range). For plots with a second outer band, its range will be the median ±p_alpha[1]
. For example, in the ALE plots, for the defaultp_alpha = c(0.01, 0.05)
, the inner band will be the median ± ALE minimum or maximum at p = 0.05 and the outer band will be the median ± ALE minimum or maximum at p = 0.01.- max_num_bins
positive integer(1). Maximum number of bins for numeric
x_cols
variables. The number of bins is eventually the lower of the number of unique values of a numeric variable andmax_num_bins
.- boot_it
non-negative integer(1). Number of bootstrap iterations for data-only bootstrapping on ALE data. This is appropriate for models that have been developed with cross-validation. For models that have not been validated, full-model bootstrapping should be used instead with the
ModelBoot
class. See details there. The defaultboot_it = 0
turns off bootstrapping.- boot_alpha
numeric(1) from 0 to 1. Alpha for percentile-based confidence interval range for the bootstrap intervals; the bootstrap confidence intervals will be the lowest and highest
(1 - 0.05) / 2
percentiles. For example, ifboot_alpha = 0.05
(default), the intervals will be from the 2.5 and 97.5 percentiles.- boot_centre
character(1) in c('mean', 'median'). When bootstrapping, the main estimate for the ALE y value is considered to be
boot_centre
. Regardless of the value specified here, both the mean and median will be available.- seed
integer(1). Random seed. Supply this between runs to assure that identical random ALE data is generated each time when bootstrapping. Without bootstrapping, ALE is a deterministic algorithm that should result in identical results each time regardless of the seed specified.
- y_type
character(1) in c('binary', 'numeric', 'categorical', 'ordinal'). Datatype of the y (outcome) variable. Normally determined automatically; only provide if an error message for a complex non-standard model requires it.
- median_band_pct
numeric length 2 from 0 to 1. Alpha for "confidence interval" ranges for printing bands around the median for single-variable plots. These are the default values used if
p_values
are not provided. Ifp_values
are provided, thenmedian_band_pct
is ignored. The inner band range will be the median value of y ±median_band_pct[1]/2
. For plots with a second outer band, its range will be the median ±median_band_pct[2]/2
. For example, for the defaultmedian_band_pct = c(0.05, 0.5)
, the inner band will be the median ± 2.5% and the outer band will be the median ± 25%.- sample_size
non-negative integer(1). Size of the sample of
data
to be returned with theALE
object. This is primarily used for rug plots. See themin_rug_per_interval
argument.- .bins
Internal. List of ALE bin and n count vectors. If provided, these vectors will be used to set the intervals of the ALE x axis for each variable. By default (NULL),
ALE()
automatically calculates the bins..bins
is normally used in advanced analyses where the bins from a previous analysis are reused for subsequent analyses (for example, for full model bootstrapping withModelBoot()
).- silent
logical(1), default
FALSE.
IfTRUE
, do not display any non-essential messages during execution (such as progress bars). Regardless, any warnings and errors will always display. See details for how to enable progress bars.
Properties
- distinct
Stores the optional ALE data, ALE statistics, and bootstrap data for one or more categories.
- params
The parameters used to calculate the ALE data. These include most of the arguments used to construct the
ALE
object. These are either the values provided by the user or used by default if the user did not change them but also includes several objects that are created within the constructor. These extra objects are described here, as well as those parameters that are stored differently from the form in the arguments:* `max_d`: the highest dimension of ALE data present. If only 1D ALE is present, then `max_d == 1`. If even one 2D ALE element is present (even with no 1D), then `max_d == 2`. * `requested_x_cols`,`ordered_x_cols`: `requested_x_cols` is the resolved list of `x_cols` as requested by the user (that is, `x_cols` minus `exclude_cols`). `ordered_x_cols` is the same set of `x_cols` but arranged in the internal storage order. * `y_cats`: categories for categorical classification models. For non-categorical models, this is the same as `y_col`. * `y_type`: high-level datatype of the y outcome variable. * `y_summary`: summary statistics of y values used for the ALE calculation. These statistics are based on the actual values of `y_col` unless if `y_type` is a probability or other value that is constrained in the `[0, 1]` range. In that case, `y_summary` is based on the predicted values of `y_col` by applying `model` to the `data`. `y_summary` is a named numeric vector. Most of the elements are the percentile of the y values. E.g., the '5%' element is the 5th percentile of y values. The following elements have special meanings: * The first element is named either `p` or `q` and its value is always 0. The value is not used; only the name of the element is meaningful. `p` means that the following special `y_summary` elements are based on the provided `ALEpDist` object. `q` means that quantiles were calculated based on `median_band_pct` because `p_values` was not provided. * `min`, `mean`, `max`: the minimum, mean, and maximum y values, respectively. Note that the median is `50%`, the 50th percentile. * `med_lo_2`, `med_lo`, `med_hi`, `med_hi_2`: `med_lo` and `med_hi` are the inner lower and upper confidence intervals of y values with respect to the median (`50%`); `med_lo_2` and `med_hi_2` are the outer confidence intervals. See the documentation for the `p_alpha` and `median_band_pct` arguments to understand how these are determined. * `model`: same as `ALE@params$model` (see documentation there). * `data`: same as `ALE@params$model` (see documentation there).
Custom predict function
The calculation of ALE requires modifying several values of the original data
. Thus, ALE()
needs direct access to the predict
function for the model
. By default, ALE()
uses a generic default predict
function of the form predict(object, newdata, type)
with the default prediction type of 'response'
. If, however, the desired prediction values are not generated with that format, the user must specify what they want. Very often, the only modification needed is to change the prediction type to some other value by setting the pred_type
argument (e.g., to 'prob'
to generated classification probabilities). But if the desired predictions need a different function signature, then the user must create a custom prediction function and pass it to pred_fun
. The requirements for this custom function are:
It must take three required arguments and nothing else:
object
: a modelnewdata
: a dataframe or compatible table type such as a tibble or data.tabletype
: a string; it should usually be specified astype = pred_type
These argument names are according to the R convention for the genericstats::predict()
function.
It must return a vector or matrix of numeric values as the prediction.
You can see an example below of a custom prediction function.
Note: survival
models probably do not need a custom prediction function but y_col
must be set to the name of the binary event column and pred_type
must be set to the desired prediction type.
ALE statistics
For details about the ALE-based statistics (ALED, ALER, NALED, and NALER), see vignette('ale-statistics')
.
Parallel processing
Parallel processing using the {furrr}
framework is enabled by default. By default, it will use all the available physical CPU cores (minus the core being used for the current R session) with the setting parallel = future::availableCores(logical = FALSE, omit = 1)
. Note that only physical cores are used (not logical cores or "hyperthreading") because machine learning can only take advantage of the floating point processors on physical cores, which are absent from logical cores. Trying to use logical cores will not speed up processing and might actually slow it down with useless data transfer. If you will dedicate the entire computer to running this function (and you don't mind everything else becoming very slow while it runs), you may use all cores by setting parallel = future::availableCores(logical = FALSE)
. To disable parallel processing, set parallel = 0
.
#' The {ale}
package should be able to automatically recognize and load most packages that are needed, but with parallel processing enabled (which is the default), some packages might not be properly loaded. This problem might be indicated if you get a strange error message that mentions something somewhere about "progress interrupted" or "future", especially if you see such errors after the progress bars begin displaying (assuming you did not disable progress bars with silent = TRUE
). In that case, first try disabling parallel processing with parallel = 0
. If that resolves the problem, then to get faster parallel processing to work, try adding the package names needed for the model
to this argument, e.g., model_packages = c('tidymodels', 'mgcv')
.
Progress bars
Progress bars are implemented with the {progressr}
package, which lets the user fully control progress bars. To disable progress bars, set silent = TRUE
. The first time a function is called in the {ale}
package that requires progress bars, it checks if the user has activated the necessary {progressr}
settings. If not, the {ale}
package automatically enables {progressr}
progress bars with the cli
handler and prints a message notifying the user.
If you like the default progress bars and you want to make them permanent, you can add the following lines of code to your .Rprofile
configuration file and they will become your defaults for every R session; you will not see the message again:
This would apply not only to the {ale}
package but to any package that uses {progressr}
progress bars. For more details on formatting progress bars to your liking, see the introduction to the {progressr}
package.
References
Okoli, Chitu. 2023. “Statistical Inference Using Machine Learning and Classical Techniques Based on Accumulated Local Effects (ALE).” arXiv. https://arxiv.org/abs/2310.09877.
Examples
# Sample 1000 rows from the ggplot2::diamonds dataset (for a simple example)
set.seed(0)
diamonds_sample <- ggplot2::diamonds[sample(nrow(ggplot2::diamonds), 1000), ]
# Create a GAM model with flexible curves to predict diamond price
# Smooth all numeric variables and include all other variables
gam_diamonds <- mgcv::gam(
price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) +
cut + color + clarity,
data = diamonds_sample
)
summary(gam_diamonds)
#>
#> Family: gaussian
#> Link function: identity
#>
#> Formula:
#> price ~ s(carat) + s(depth) + s(table) + s(x) + s(y) + s(z) +
#> cut + color + clarity
#>
#> Parametric coefficients:
#> Estimate Std. Error t value Pr(>|t|)
#> (Intercept) 3421.412 74.903 45.678 < 2e-16 ***
#> cut.L 261.339 171.630 1.523 0.128170
#> cut.Q 53.684 129.990 0.413 0.679710
#> cut.C -71.942 103.804 -0.693 0.488447
#> cut^4 -8.657 80.614 -0.107 0.914506
#> color.L -1778.903 113.669 -15.650 < 2e-16 ***
#> color.Q -482.225 104.675 -4.607 4.64e-06 ***
#> color.C 58.724 95.983 0.612 0.540807
#> color^4 125.640 87.111 1.442 0.149548
#> color^5 -241.194 81.913 -2.945 0.003314 **
#> color^6 -49.305 74.435 -0.662 0.507883
#> clarity.L 4141.841 226.713 18.269 < 2e-16 ***
#> clarity.Q -2367.820 217.185 -10.902 < 2e-16 ***
#> clarity.C 1026.214 180.295 5.692 1.67e-08 ***
#> clarity^4 -602.066 137.258 -4.386 1.28e-05 ***
#> clarity^5 408.336 105.344 3.876 0.000113 ***
#> clarity^6 -82.379 88.434 -0.932 0.351815
#> clarity^7 4.017 78.816 0.051 0.959362
#> ---
#> Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#>
#> Approximate significance of smooth terms:
#> edf Ref.df F p-value
#> s(carat) 7.503 8.536 4.114 3.65e-05 ***
#> s(depth) 1.486 1.874 0.601 0.614753
#> s(table) 2.929 3.738 1.294 0.240011
#> s(x) 8.897 8.967 3.323 0.000542 ***
#> s(y) 3.875 5.118 11.075 < 2e-16 ***
#> s(z) 9.000 9.000 2.648 0.004938 **
#> ---
#> Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#>
#> R-sq.(adj) = 0.94 Deviance explained = 94.3%
#> GCV = 9.7669e+05 Scale est. = 9.262e+05 n = 1000
# \donttest{
# Simple ALE without bootstrapping
ale_gam_diamonds <- ALE(gam_diamonds)
#> Warning: Could not recover model data from environment. Please make sure your
#> data is available in your workspace.
#> Trying to retrieve data from the model frame now.
#> Error in (function (.x, .f, ..., .progress = FALSE) { map_("list", .x, .f, ..., .progress = .progress)})(.x = list("cut", "color", "clarity"), .f = function (...) { { ...furrr_chunk_seeds_i <- ...furrr_chunk_seeds_env[["i"]] ...furrr_chunk_seeds_env[["i"]] <- ...furrr_chunk_seeds_i + 1L assign(x = ".Random.seed", value = ...furrr_chunk_seeds[[...furrr_chunk_seeds_i]], envir = globalenv(), inherits = FALSE) } NULL ...furrr_out <- ...furrr_fn(...) ...furrr_out}): ℹ In index: 1.
#> Caused by error in `map()`:
#> ℹ In index: 1.
#> Caused by error in `predict.gam()`:
#> ! newdata is a model.frame: it should contain all required variables
# Plot the ALE data
plot(ale_gam_diamonds)
#> Error: object 'ale_gam_diamonds' not found
# Bootstrapped ALE
# This can be slow, since bootstrapping runs the algorithm boot_it times
# Create ALE with 100 bootstrap samples
ale_gam_diamonds_boot <- ALE(
gam_diamonds,
boot_it = 100
)
#> Warning: Could not recover model data from environment. Please make sure your
#> data is available in your workspace.
#> Trying to retrieve data from the model frame now.
#> Error in (function (.x, .f, ..., .progress = FALSE) { map_("list", .x, .f, ..., .progress = .progress)})(.x = list("cut", "color", "clarity"), .f = function (...) { { ...furrr_chunk_seeds_i <- ...furrr_chunk_seeds_env[["i"]] ...furrr_chunk_seeds_env[["i"]] <- ...furrr_chunk_seeds_i + 1L assign(x = ".Random.seed", value = ...furrr_chunk_seeds[[...furrr_chunk_seeds_i]], envir = globalenv(), inherits = FALSE) } NULL ...furrr_out <- ...furrr_fn(...) ...furrr_out}): ℹ In index: 1.
#> Caused by error in `map()`:
#> ℹ In index: 1.
#> Caused by error in `predict.gam()`:
#> ! newdata is a model.frame: it should contain all required variables
# Bootstrapped ALEs print with confidence intervals
plot(ale_gam_diamonds_boot)
#> Error: object 'ale_gam_diamonds_boot' not found
# If the predict function you want is non-standard, you may define a
# custom predict function. It must return a single numeric vector.
custom_predict <- function(object, newdata, type = pred_type) {
predict(object, newdata, type = type, se.fit = TRUE)$fit
}
ale_gam_diamonds_custom <- ALE(
gam_diamonds,
pred_fun = custom_predict, pred_type = 'link'
)
#> Warning: Could not recover model data from environment. Please make sure your
#> data is available in your workspace.
#> Trying to retrieve data from the model frame now.
#> Error in (function (.x, .f, ..., .progress = FALSE) { map_("list", .x, .f, ..., .progress = .progress)})(.x = list("cut", "color", "clarity"), .f = function (...) { { ...furrr_chunk_seeds_i <- ...furrr_chunk_seeds_env[["i"]] ...furrr_chunk_seeds_env[["i"]] <- ...furrr_chunk_seeds_i + 1L assign(x = ".Random.seed", value = ...furrr_chunk_seeds[[...furrr_chunk_seeds_i]], envir = globalenv(), inherits = FALSE) } NULL ...furrr_out <- ...furrr_fn(...) ...furrr_out}): ℹ In index: 1.
#> Caused by error in `map()`:
#> ℹ In index: 1.
#> Caused by error in `predict.gam()`:
#> ! newdata is a model.frame: it should contain all required variables
# Plot the ALE data
plot(ale_gam_diamonds_custom)
#> Error: object 'ale_gam_diamonds_custom' not found
# How to retrieve specific types of ALE data from an ALE object.
ale_diamonds_with__boot_data <- ALE(
gam_diamonds,
# For detailed options for x_cols, see examples at resolve_x_cols()
x_cols = ~ carat + cut + clarity + color:depth + x:y,
output_boot_data = TRUE,
boot_it = 10 # just for demonstration
)
#> Warning: Could not recover model data from environment. Please make sure your
#> data is available in your workspace.
#> Trying to retrieve data from the model frame now.
#> Error in (function (.x, .f, ..., .progress = FALSE) { map_("list", .x, .f, ..., .progress = .progress)})(.x = list(c("color", "depth"), c("x", "y")), .f = function (...) { { ...furrr_chunk_seeds_i <- ...furrr_chunk_seeds_env[["i"]] ...furrr_chunk_seeds_env[["i"]] <- ...furrr_chunk_seeds_i + 1L assign(x = ".Random.seed", value = ...furrr_chunk_seeds[[...furrr_chunk_seeds_i]], envir = globalenv(), inherits = FALSE) } NULL ...furrr_out <- ...furrr_fn(...) ...furrr_out}): ℹ In index: 1.
#> Caused by error in `map()`:
#> ℹ In index: 1.
#> Caused by error in `predict.gam()`:
#> ! newdata is a model.frame: it should contain all required variables
# See ?get.ALE for details on the various kinds of data that may be retrieved.
get(ale_diamonds_with__boot_data, ~ carat + color:depth) # default ALE data
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, what = 'boot_data')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'estimate')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'aled')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'all')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'conf_regions')
#> Error: object 'ale_diamonds_with__boot_data' not found
get(ale_diamonds_with__boot_data, stats = 'conf_sig')
#> Error: object 'ale_diamonds_with__boot_data' not found
# }