Skip to content

Commit

Permalink
Merge pull request #102 from mayer79/hstats_profile
Browse files Browse the repository at this point in the history
Update NEWS
  • Loading branch information
mayer79 authored Nov 8, 2023
2 parents 2d36df7 + 096e969 commit b9bf808
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 41 deletions.
6 changes: 3 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Enhancements

- {hstats} now also work for factor predictions. The levels are represented by one-hot-encoded columns.
- {hstats} now also work for factor predictions. The levels are represented by one-hot-encoded columns ([PR#101](https://github.com/mayer79/hstats/pull/101)).
- The plot method of a two-dimensional PDP has recieved the option `d2_geom = "line"`. Instead of a heatmap of the two features, one of the features is moved to color grouping. Combined with `swap_dim = TRUE`, you can swap the role of the two `v` variables without recalculating anything. The idea was proposed by [Roel Verbelen](https://github.com/RoelVerbelen) in [issue #91](https://github.com/mayer79/hstats/issues/91), see also [issue #94](https://github.com/mayer79/hstats/issues/94).

## Bug fixes
Expand All @@ -11,8 +11,8 @@

## Other changes

- Much faster one-hot-encoding, thanks to Mathias Ambühl.
- Most functions are slightly faster.
- Much faster one-hot-encoding, thanks to Mathias Ambühl ([PR#101](https://github.com/mayer79/hstats/pull/101)).
- Most functions are slightly faster ([PR#101](https://github.com/mayer79/hstats/pull/101)).
- Add unit tests to compare against {iml}.
- Made all examples "tibble" and "data.table" friendly.
- Revised input checks in loss functions (relevant for `perm_importance()` and `average_loss()`).
Expand Down
15 changes: 5 additions & 10 deletions R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ wcolMeans <- function(x, w = NULL) {
#' @noRd
#' @keywords internal
#'
#' @param x Factor.
#' @param x Factor-like.
#' @returns Named vector.
colMeans_factor <- function(x) {
x <- as.factor(x)
Expand Down Expand Up @@ -133,9 +133,6 @@ wrowmean <- function(x, ngroups = 1L, w = NULL) {
}

# General version
if (!is.matrix(x)) {
x <- as.matrix(x)
}
wrowmean_matrix(x, ngroups = ngroups, w = w)
}

Expand All @@ -146,13 +143,11 @@ wrowmean <- function(x, ngroups = 1L, w = NULL) {
#' @noRd
#' @keywords internal
#'
#' @param x Factor.
#' @param x Factor-like.
#' @param ngroups Number of subsequent, equals sized groups.
#' @returns Matrix with column names.
rowmean_factor <- function(x, ngroups = 1L) {
if (!is.factor(x)) {
stop("x must be a factor.")
}
x <- as.factor(x)
lev <- levels(x)
n_bg <- length(x) %/% ngroups
dim(x) <- c(n_bg, ngroups)
Expand Down Expand Up @@ -192,13 +187,13 @@ wrowmean_vector <- function(x, ngroups = 1L, w = NULL) {
#' @noRd
#' @keywords internal
#'
#' @param x Matrix.
#' @param x Matrix-like.
#' @param ngroups Number of subsequent, equals sized groups.
#' @param w Optional vector of case weights of length `NROW(x) / ngroups`.
#' @returns Matrix.
wrowmean_matrix <- function(x, ngroups = 1L, w = NULL) {
if (!is.matrix(x)) {
stop("x must be a matrix.")
x <- as.matrix(x)
}
n_bg <- nrow(x) %/% ngroups
g <- rep_each(ngroups, each = n_bg) # rep(seq_len(ngroups), each = n_bg)
Expand Down
57 changes: 30 additions & 27 deletions backlog/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ calibration <- function(object, ...) {

#' @describeIn calibration Default method.
#' @export
calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predict,
calibration.default <- function(object, v, X, y = NULL,
pred_fun = stats::predict,
BY = NULL, by_size = 4L,
grid_size = 17L,
breaks = 17L, trim = c(0.01, 0.99),
include.lowest = TRUE,
right = TRUE, na.rm = TRUE,
pred = NULL,
n_max = 1000L, w = NULL, ...) {
stopifnot(
Expand All @@ -70,11 +73,7 @@ calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predic
)

if (!is.null(y)) {
y <- prepare_y(y = y, X = X)[["y"]]
if (is.factor(y) || is.character(y)) {
y <- stats::model.matrix(~ as.factor(y) + 0)
}
y <- align_pred(y)
y <- prepare_y(y = y, X = X, ohe = TRUE)[["y"]]
}
if (!is.null(w)) {
w <- prepare_w(w = w, X = X)[["w"]]
Expand All @@ -83,49 +82,53 @@ calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predic
BY2 <- prepare_by(BY = BY, X = X, by_size = by_size)
BY <- BY2[["BY"]]
}
g <- v_grouped <- approx_vector(X[[v]], m = grid_size)
grid <- sort(unique(v_grouped), na.last = TRUE)

h <- hist2(
X[[v]],
breaks = breaks,
trim = trim,
include.lowest = include.lowest,
right = right,
na.rm = TRUE
)

if (!is.null(BY)) {
g <- paste(BY, g, sep = ":")
g <- paste(BY, h$x, sep = ":")
} else {
g <- h$x
}

# Average predicted
if (is.null(pred)) {
pred <- pred_fun(object, X, ...)
}
pred <- align_pred(pred)
tmp <- gwColMeans(pred, g = g, w = w, mean_only = FALSE)
avg_pred <- tmp[["mean"]]

# Exposure
exposure <- tmp[["denom"]]

pred <- prepare_pred(pred, ohe = TRUE)
pr <- gwColMeans(pred, g = g, w = w)

# Average observed
avg_obs <- if (!is.null(y)) gwColMeans(y, g = g, w = w)
avg_obs <- if (!is.null(y)) gwColMeans(y, g = g, w = w)$M

# Partial dependence
pd <- partial_dep(
object = object,
v = v,
X = X,
grid = grid,
grid = h$grid,
pred_fun = pred_fun,
BY = BY,
w = w,
...
)[["data"]]
)

out <- list(
v = v,
K = ncol(pred),
pred_names = colnames(pred),
grid = grid,
K = ncol(pr$M),
pred_names = colnames(pr$M),
grid = h[-1L],
BY,
avg_obs = avg_obs,
avg_pred = avg_pred,
pd = pd,
exposure = exposure
avg_pred = pr$M,
pd = pd[["data"]],
exposure = pr$w
)
return(structure(out, class = "calibration"))
}
Expand Down
8 changes: 8 additions & 0 deletions backlog/colMeans_factors.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,11 @@ gcolMeans_factor <- function(x, g = NULL) {
out
}

wrowmean_matrix2 <- function(x, ngroups = 1L, w = NULL) {
if (!is.matrix(x)) {
stop("x must be a matrix.")
}
dim(x) <- c(nrow(x)/ngroups, ngroups, ncol(x))
out <- colMeans(aperm(x, c(1, 3, 2)))
t.default(out)
}
1 change: 0 additions & 1 deletion tests/testthat/test_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ test_that("rowmean_factor() works for factor input", {
x <- factor(c("C", "A", "C", "C", "A", "A"))
out <- rowmean_factor(x, ngroups = 2L)

expect_error(rowmean_factor(1:3))
expect_true(is.matrix(out))
expect_equal(out, cbind(A = c(1/3, 2/3), C = c(2/3, 1/3)))
expect_equal(out, wrowmean_matrix(fdummy(x), ngroups = 2L))
Expand Down

0 comments on commit b9bf808

Please sign in to comment.