Skip to content

Commit

Permalink
Merge pull request #81 from mayer79/approximate
Browse files Browse the repository at this point in the history
Implementation of quantile approximation
  • Loading branch information
mayer79 authored Oct 19, 2023
2 parents de0d70d + 5d1b584 commit 1dcdfab
Show file tree
Hide file tree
Showing 16 changed files with 578 additions and 462 deletions.
7 changes: 4 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

## Major changes

- `hstats()` has received an argument `quant_approx` to speed-up calculations by quantile binning. Dense numeric variables are replaced by midpoints of `quant_approx + 1` uniform quantiles. By default, the value is `NULL` (no approximation). Even relatively high values like 50 will bring a massive speed-up for dense features, mainly for the one-way calculations. Use this option when calculations are slow, or when you want to increase `n_max`.
- `hstats()`: `n_max` has been increased from 300 to 500 rows. This will make estimates of H statistics more stable at the price of longer run time. Reduce to 300 for the old behaviour.
- `hstats()`: By default, three-way interactions are not calculated anymore. Set `threeway_m` to 5 for the old behaviour.
- Revised plots: The colors and color palettes have changed and can (also) be controlled via global options. For instance, to change the fill color of all bars, set `options(hstats.fill = new value)`. Value labels are more clear, and there are more options. Varying color/fill scales now use viridis (inferno). This can be modified on the fly or via `options(hstats.viridis_args = list(...))`.
- `hstats()`: Three-way interactions are not anymore calculated by default. Set `threeway_m` to 5 for the old behaviour.
- Revised plots: The colors and color palettes have changed and can now also be controlled via global options. For instance, to change the fill color of all bars, set `options(hstats.fill = new value)`. Value labels are more clear, and there are more options. Varying color/fill scales now use viridis (inferno). This can be modified on the fly or via `options(hstats.viridis_args = list(...))`.
- "hstats_matrix" object: All statistics functions, e.g., `h2_pairwise()` or `perm_importance()`, now return a "hstats_matrix". The values are stored in `$M` and can be plotted via `plot()`. Other methods include: `dimnames()`, `rownames()`, `colnames()`, `dim()`, `nrow()`, `ncol()`, `head()`, `tail()`, and subsetting like a normal matrix. This allows, e.g, to select and plot only one column of the results.
- `perm_importance()`: The `perms` argument has been changed to `m_rep`.
- `print()` and `summary()` methods have been revised.
Expand All @@ -20,7 +21,7 @@
- `average_loss()` is more flexible regarding the group `BY` argument. It can also be a variable *name*. Non-discrete `BY` variables are now automatically binned. Like `partial_dep()`, binning is controlled by the `by_size = 4` argument.
- `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`.
- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column names of `w` and `y` if passed as name).
- Missing grid values: `partial_dep()` and `ice()` have received a `na.rm = TRUE` argument that controls if missing values are dropped during grid creation. The default is compatible with earlier releases.
- Missing grid values: `partial_dep()` and `ice()` have received a `na.rm` argument that controls if missing values are dropped during grid creation. The default `TRUE` is compatible with earlier releases.

# hstats 0.3.0

Expand Down
136 changes: 131 additions & 5 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,19 @@
#' @param threeway_m Like `pairwise_m`, but controls the feature count for
#' three-way interactions. Cannot be larger than `pairwise_m`.
#' To save computation time, the default is 0.
#' @param quant_approx Integer. Dense numeric variables in `X` are replaced by midpoints
#' of `quant_approx + 1` uniform quantiles. By default, the value is `NULL`
#' (no approximation). Even relatively high values like 50 will bring a massive
#' speed-up for dense features, mainly for one-way statistics.
#' Note that the quantiles are calculated after subsampling to `n_max` rows.
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param verbose Should a progress bar be shown? The default is `TRUE`.
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`,
#' for instance `type = "response"` in a [glm()] model, or `reshape = TRUE` in a
#' multiclass XGBoost model.
#' @returns
#' An object of class "hstats" containing these elements:
#' - `X`: Input `X` (sampled to `n_max` rows).
#' - `X`: Input `X` (sampled to `n_max` rows, after optional quantile approximation).
#' - `w`: Case weight vector `w` (sampled to `n_max` values), or `NULL`.
#' - `v`: Vector of column names in `X` for which overall
#' H statistics have been calculated.
Expand Down Expand Up @@ -136,7 +141,7 @@ hstats <- function(object, ...) {
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict, n_max = 500L,
w = NULL, pairwise_m = 5L, threeway_m = 0L,
eps = 1e-10, verbose = TRUE, ...) {
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -173,6 +178,11 @@ hstats.default <- function(object, X, v = NULL,
}
}

# Quantile approximation to speedup things for dense features
if (!is.null(quant_approx)) {
X <- approx_matrix_or_df(X = X, v = v, m = quant_approx)
}

# Predictions ("F" in Friedman and Popescu) always calculated (cheap)
f <- wcenter(align_pred(pred_fun(object, X, ...)), w = w)
mean_f2 <- wcolMeans(f^2, w = w) # A vector
Expand Down Expand Up @@ -266,7 +276,7 @@ hstats.default <- function(object, X, v = NULL,
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
n_max = 500L, w = NULL, pairwise_m = 5L, threeway_m = 0L,
eps = 1e-10, verbose = TRUE, ...) {
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
hstats.default(
object = object,
X = X,
Expand All @@ -276,6 +286,7 @@ hstats.ranger <- function(object, X, v = NULL,
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
verbose = verbose,
...
Expand All @@ -287,7 +298,7 @@ hstats.ranger <- function(object, X, v = NULL,
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
n_max = 500L, w = NULL, pairwise_m = 5L, threeway_m = 0L,
eps = 1e-10, verbose = TRUE, ...) {
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand All @@ -300,6 +311,7 @@ hstats.Learner <- function(object, X, v = NULL,
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
verbose = verbose,
...
Expand All @@ -313,7 +325,7 @@ hstats.explainer <- function(object, X = object[["data"]],
pred_fun = object[["predict_function"]],
n_max = 500L, w = object[["weights"]],
pairwise_m = 5L, threeway_m = 0L,
eps = 1e-10, verbose = TRUE, ...) {
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
hstats.default(
object = object[["model"]],
X = X,
Expand All @@ -323,6 +335,7 @@ hstats.explainer <- function(object, X = object[["data"]],
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
verbose = verbose,
...
Expand Down Expand Up @@ -461,3 +474,116 @@ plot.hstats <- function(x, which = 1:3, normalize = TRUE, squared = TRUE,
}
p
}

# Helper functions used only by hstats()

#' Pairwise or 3-Way Partial Dependencies
#'
#' Calculates centered partial dependence functions for selected pairwise or three-way
#' situations.
#'
#' @noRd
#' @keywords internal
#'
#' @param v Vector of column names to calculate `way` order interactions.
#' @inheritParams hstats
#' @param way Pairwise (`way = 2`) or three-way (`way = 3`) interactions.
#' @param verb Verbose (`TRUE`/`FALSE`).
#'
#' @returns
#' A list with a named list of feature combinations (pairs or triples), and
#' corresponding centered partial dependencies.
mway <- function(object, v, X, pred_fun = stats::predict, w = NULL,
way = 2L, verb = TRUE, ...) {
combs <- utils::combn(v, way, simplify = FALSE)
n_combs <- length(combs)
F_way <- vector("list", length = n_combs)
names(F_way) <- names(combs) <- sapply(combs, paste, collapse = ":")

if (verb) {
cat(way, "way calculations...\n", sep = "-")
pb <- utils::txtProgressBar(max = n_combs, style = 3)
}

for (i in seq_len(n_combs)) {
z <- combs[[i]]
F_way[[i]] <- wcenter(
pd_raw(object, v = z, X = X, grid = X[, z], pred_fun = pred_fun, w = w, ...),
w = w
)
if (verb) {
utils::setTxtProgressBar(pb, i)
}
}
if (verb) {
cat("\n")
}
list(combs, F_way)
}

#' Get Feature Names
#'
#' This function takes the unsorted and unnormalized H2_j matrix and extracts the top
#' m feature names (unsorted). If H2_j has multiple columns, this is done per column and
#' then the union is returned.
#'
#' @noRd
#' @keywords internal
#'
#' @param H Unnormalized, unsorted H2_j values.
#' @param m Number of features to pick per column.
#'
#' @returns A vector of the union of the m column-wise most important features.
get_v <- function(H, m) {
v <- rownames(H)
selector <- function(vv) names(utils::head(sort(-vv[vv > 0]), m))
if (NCOL(H) == 1L) {
v_cand <- selector(drop(H))
} else {
v_cand <- Reduce(union, lapply(asplit(H, MARGIN = 2L), FUN = selector))
}
v[v %in% v_cand]
}

#' Approximate Vector
#'
#' Internal function. Approximates values by the average of the two closest quantiles.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A vector or factor.
#' @param m Number of unique values.
#' @returns An approximation of `x` (or `x` if non-numeric or discrete).
approx_vector <- function(x, m = 25L) {
if (!is.numeric(x) || length(unique(x)) <= m) {
return(x)
}
p <- seq(0, 1, length.out = m + 1L)
q <- unique(stats::quantile(x, probs = p, names = FALSE, na.rm = TRUE))
mids <- (q[-length(q)] + q[-1L]) / 2
return(mids[findInterval(x, q, rightmost.closed = TRUE)])
}

#' Approximate df or Matrix
#'
#' Internal function. Calls `approx_vector()` to each column in matrix or data.frame.
#'
#' @noRd
#' @keywords internal
#'
#' @param X A matrix or data.frame.
#' @param m Number of unique values.
#' @returns An approximation of `X` (or `X` if non-numeric or discrete).
approx_matrix_or_df <- function(X, v = colnames(X), m = 25L) {
stopifnot(
m >= 2L,
is.data.frame(X) || is.matrix(X)
)
if (is.data.frame(X)) {
X[v] <- lapply(X[v], FUN = approx_vector, m = m)
} else { # Matrix
X[, v] <- apply(X[, v, drop = FALSE], MARGIN = 2L, FUN = approx_vector, m = m)
}
return(X)
}
72 changes: 72 additions & 0 deletions R/pd_raw.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,75 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict,
}
return(list(pred = pred, grid_pred = grid_pred))
}

# Helper functions used only within pd_raw()

#' Compresses X
#'
#' @description
#' Internal function to remove duplicated rows in `X` based on columns not in `v`.
#' Compensation is done by summing corresponding case weights `w`.
#' Currently implemented only for the case when there is a single non-`v` column in `X`.
#' Can later be generalized to multiple columns via [paste()].
#'
#' Notes:
#' - This function is important for interaction calculations.
#' - The initial check for having a single non-`v` column is very cheap.
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams pd_raw
#' @returns A list with `X` and `w`, potentially compressed.
.compress_X <- function(X, v, w = NULL) {
not_v <- setdiff(colnames(X), v)
if (length(not_v) != 1L) {
return(list(X = X, w = w)) # No optimization implemented for this case
}
x_not_v <- if (is.data.frame(X)) X[[not_v]] else X[, not_v]
X_dup <- duplicated(x_not_v)
if (!any(X_dup)) {
return(list(X = X, w = w)) # No optimization done
}

# Compress
if (is.null(w)) {
w <- rep(1.0, times = nrow(X))
}
list(
X = X[!X_dup, , drop = FALSE],
w = c(rowsum(w, group = x_not_v, reorder = FALSE)) # warning if missing in x_not_v
)
}

#' Compresses Grid
#'
#' Internal function used to remove duplicated grid rows. Re-indexing to original grid
#' rows needs to be later, but this function provides the re-index vector to do so.
#' Further note that checking for uniqueness can be costly for higher-dimensional grids.
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams pd_raw
#' @returns
#' A list with `grid` (possibly compressed) and the optional `reindex` vector
#' used to map compressed grid values back to the original grid rows. The original
#' grid equals the compressed grid at indices `reindex`.
.compress_grid <- function(grid) {
ugrid <- unique(grid)
if (NROW(ugrid) == NROW(grid)) {
# No optimization done
return(list(grid = grid, reindex = NULL))
}
out <- list(grid = ugrid)
if (NCOL(grid) >= 2L) { # Non-vector case
grid <- apply(grid, MARGIN = 1L, FUN = paste, collapse = "_:_")
ugrid <- apply(ugrid, MARGIN = 1L, FUN = paste, collapse = "_:_")
if (anyDuplicated(ugrid)) {
stop("String '_:_' found in grid values at unlucky position.")
}
}
out[["reindex"]] <- match(grid, ugrid)
out
}
70 changes: 0 additions & 70 deletions R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,76 +48,6 @@ wrowmean <- function(x, ngroups = 1L, w = NULL) {
out
}

#' Compresses X
#'
#' @description
#' Internal function to remove duplicated rows in `X` based on columns not in `v`.
#' Compensation is done by summing corresponding case weights `w`.
#' Currently implemented only for the case when there is a single non-`v` column in `X`.
#' Can later be generalized to multiple columns via [paste()].
#'
#' Notes:
#' - This function is important for interaction calculations.
#' - The initial check for having a single non-`v` column is very cheap.
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams pd_raw
#' @returns A list with `X` and `w`, potentially compressed.
.compress_X <- function(X, v, w = NULL) {
not_v <- setdiff(colnames(X), v)
if (length(not_v) != 1L) {
return(list(X = X, w = w)) # No optimization implemented for this case
}
x_not_v <- if (is.data.frame(X)) X[[not_v]] else X[, not_v]
X_dup <- duplicated(x_not_v)
if (!any(X_dup)) {
return(list(X = X, w = w)) # No optimization done
}

# Compress
if (is.null(w)) {
w <- rep(1.0, times = nrow(X))
}
list(
X = X[!X_dup, , drop = FALSE],
w = c(rowsum(w, group = x_not_v, reorder = FALSE)) # warning if missing in x_not_v
)
}

#' Compresses Grid
#'
#' Internal function used to remove duplicated grid rows. Re-indexing to original grid
#' rows needs to be later, but this function provides the re-index vector to do so.
#' Further note that checking for uniqueness can be costly for higher-dimensional grids.
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams pd_raw
#' @returns
#' A list with `grid` (possibly compressed) and the optional `reindex` vector
#' used to map compressed grid values back to the original grid rows. The original
#' grid equals the compressed grid at indices `reindex`.
.compress_grid <- function(grid) {
ugrid <- unique(grid)
if (NROW(ugrid) == NROW(grid)) {
# No optimization done
return(list(grid = grid, reindex = NULL))
}
out <- list(grid = ugrid)
if (NCOL(grid) >= 2L) { # Non-vector case
grid <- apply(grid, MARGIN = 1L, FUN = paste, collapse = "_:_")
ugrid <- apply(ugrid, MARGIN = 1L, FUN = paste, collapse = "_:_")
if (anyDuplicated(ugrid)) {
stop("String '_:_' found in grid values at unlucky position.")
}
}
out[["reindex"]] <- match(grid, ugrid)
out
}

#' Weighted Version of colMeans()
#'
#' Internal function used to calculate column-wise weighted means.
Expand Down
Loading

0 comments on commit 1dcdfab

Please sign in to comment.