Skip to content

Commit

Permalink
Merge pull request #99 from mayer79/hist2
Browse files Browse the repository at this point in the history
Hist2
  • Loading branch information
mayer79 authored Oct 30, 2023
2 parents 98a3f15 + 590a001 commit df418d7
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 59 deletions.
2 changes: 1 addition & 1 deletion R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ average_loss.default <- function(object, X, y,

# Real work
L <- as.matrix(loss(y, pred_fun(object, X, ...)))
M <- gwColMeans(L, g = BY, w = w)
M <- gwColMeans(L, g = BY, w = w)[["M"]]

if (agg_cols && ncol(M) > 1L) {
M <- cbind(rowSums(M))
Expand Down
35 changes: 16 additions & 19 deletions R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,36 @@ wcolMeans <- function(x, w = NULL) {

#' Grouped wcolMeans()
#'
#' Internal function used to calculate grouped column-wise weighted means.
#'
#' Internal function used to calculate grouped column-wise weighted means along with
#' corresponding (weighted) counts.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A matrix-like object.
#' @param g Optional grouping variable.
#' @param w Optional case weights.
#' @param reorder Should groups be ordered, see [rowsum()]. Default is `TRUE`.
#' @param mean_only If `TRUE`, only the means are returned. If `FALSE`, a list
#' with "mean", "num" and "denom" is returned.
#' @returns A matrix with one row per group (if `mean_only = TRUE`), or a list
#' with slots "mean", "num", and "denom".
#' @returns A list with two elements: "M" represents a matrix of grouped (column)
#' means, and "w" is a vector of corresponding group counts/weights.
#' @examples
#' with(iris, gwColMeans(Sepal.Width, g = Species, w = Sepal.Length))
#' with(iris, gwColMeans(Sepal.Width, g = Species, w = Sepal.Length, mean_only = FALSE))
gwColMeans <- function(x, g = NULL, w = NULL, reorder = TRUE, mean_only = TRUE) {
gwColMeans <- function(x, g = NULL, w = NULL, reorder = TRUE) {
if (is.null(g)) {
return(rbind(wcolMeans(x, w = w)))
M <- rbind(wcolMeans(x, w = w))
denom <- if (is.null(w)) NROW(x) else sum(w)
return(list(M = M, w = denom))
}

# Now the interesting case
if (is.null(w)) {
num <- rowsum(x, group = g, reorder = reorder)
denom <- rowsum(rep.int(1, NROW(x)), group = g, reorder = reorder)
w <- rep.int(1, NROW(x))
} else {
num <- rowsum(x * w, group = g, reorder = reorder)
denom <- rowsum(w, group = g, reorder = reorder)
}
out <- num / matrix(denom, nrow = nrow(num), ncol = ncol(num), byrow = FALSE)

if (mean_only) {
return(out)
x <- x * w # w is correctly recycled over columns
}
list(mean = out, num = num, denom = denom)
num <- rowsum(x, group = g, reorder = reorder)
denom <- as.numeric(rowsum(w, group = g, reorder = reorder))
list(M = num / denom, w = denom)
}

#' Weighted Mean Centering
Expand Down
17 changes: 10 additions & 7 deletions R/utils_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#' If `strategy = "quantile"`, the evaluation points are quantiles over a regular grid
#' of probabilities from `trim[1]` to `trim[2]`.
#'
#' All quantiles are calculated via the inverse of the ECDF, i.e., via
#' Quantiles are calculated via the inverse of the ECDF, i.e., via
#' `stats::quantile(..., type = 1`).
#'
#' @param z A vector or factor.
Expand All @@ -24,7 +24,7 @@
#' of grid values. Set to `0:1` for no trimming.
#' @param strategy How to find grid values of non-discrete numeric columns?
#' Either "uniform" or "quantile", see description of [univariate_grid()].
#' @param na.rm Should missing values be dropped from grid? Default is `TRUE`.
#' @param na.rm Should missing values be dropped from the grid? Default is `TRUE`.
#' @returns A vector or factor of evaluation points.
#' @seealso [multivariate_grid()]
#' @export
Expand All @@ -33,8 +33,8 @@
#' univariate_grid(rev(iris$Species)) # Same
#'
#' x <- iris$Sepal.Width
#' univariate_grid(x, grid_size = 5) # Quantile binning
#' univariate_grid(x, grid_size = 5, strategy = "uniform") # Uniform
#' univariate_grid(x, grid_size = 5) # Uniform binning
#' univariate_grid(x, grid_size = 5, strategy = "quantile") # Quantile
univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE) {
strategy <- match.arg(strategy)
Expand All @@ -50,9 +50,12 @@ univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99),
g <- stats::quantile(z, probs = p, names = FALSE, type = 1L, na.rm = TRUE)
out <- unique(g)
} else {
# strategy = "uniform" (could use range() if trim = 0:1)
r <- stats::quantile(z, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
# pretty(r, n = grid_size) # Until version 0.2.0
# strategy = "uniform"
if (trim[1L] == 0 && trim[2L] == 1) {
r <- range(z, na.rm = TRUE)
} else {
r <- stats::quantile(z, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
}
out <- seq(r[1L], r[2L], length.out = grid_size)
}
if (!na.rm && anyNA(z)) {
Expand Down
39 changes: 39 additions & 0 deletions backlog/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,42 @@ plot.calibration <- function(x,
}
p
}

#' Histogram Bin Construction
#'
#' Creates histogram of vector/factor `x`. In the discrete case, no binning is done.
#' Otherwise, the values are optionally trimmed and then passed to [hist()]. Compared
#' with [hist()], the function also returns the binned values of `x`.
#'
#' @param x A vector or factor to be binned.
#' @inheritParams hist
#' @inheritParams univariate_grid
#' @returns A list with binned "x", vector of "breaks", bin midpoints "grid", and a
#' logical flag "discrete" indicating whether the values have not been binned.
#' @seealso See [calibration()] for examples.
hist2 <- function(x, breaks = 17L, trim = c(0.01, 0.99),
include.lowest = TRUE, right = TRUE, na.rm = TRUE) {
g <- unique(x)
if (!is.numeric(x) || (length(breaks) == 1L && is.numeric(breaks) && length(g) <= breaks)) {
g <- sort(g, na.last = if (na.rm) NA else TRUE)
return(list(x = x, breaks = g, grid = g, discrete = TRUE))
}

# Trim outliers before histogram construction?
if (trim[1L] == 0 && trim[2L] == 1) {
xx <- x
} else {
r <- stats::quantile(x, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
xx <- x[x >= r[1L] & x <= r[2L]]
}
h <- hist(xx, breaks = breaks, include.lowest = include.lowest, right = right)
b <- h$breaks
ix <- findInterval(
x, vec = b, left.open = right, rightmost.closed = include.lowest, all.inside = TRUE
)
g <- h$mids
if (!na.rm && anyNA(x)) {
g <- c(g, NA)
}
list(x = g[ix], breaks = b, grid = g, discrete = FALSE)
}
2 changes: 1 addition & 1 deletion man/H2.Rd → man/h2.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/H2_overall.Rd → man/h2_overall.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/H2_pairwise.Rd → man/h2_pairwise.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/H2_threeway.Rd → man/h2_threeway.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/ice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/multivariate_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/partial_dep.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/univariate_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 21 additions & 21 deletions tests/testthat/test_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,31 @@ test_that("gwcolMeans() works", {
w2 <- 1:6

# Ungrouped
expect_equal(gwColMeans(x), rbind(wcolMeans(x)))
expect_equal(gwColMeans(x, w = w1), rbind(wcolMeans(x, w = w1)))
expect_equal(gwColMeans(x, w = w2), rbind(wcolMeans(x, w = w2)))
r <- gwColMeans(x)
expect_equal(r$M, rbind(wcolMeans(x)))
expect_equal(r$w, nrow(x))

# Grouped
expect_equal(gwColMeans(x, g = g)[2L, ], wcolMeans(x[g == 2, ]))
expect_equal(gwColMeans(x, g = g, reorder = FALSE)[2L, ], wcolMeans(x[g == 1, ]))
r <- gwColMeans(x, w = w1)
expect_equal(r$M, rbind(wcolMeans(x, w = w1)))
expect_equal(r$w, sum(w1))

g1 <- gwColMeans(x, g = g)
g2 <- gwColMeans(x, g = g, mean_only = FALSE)
g2_d <- matrix(g2$denom, nrow = 2, ncol = 2, byrow = FALSE)
expect_equal(g1, g2$num / g2_d)
expect_equal(g2$mean, g1)
r <- gwColMeans(x, w = w2)
expect_equal(r$M, rbind(wcolMeans(x, w = w2)))
expect_equal(r$w, sum(w2))

# Grouped and weighted
expect_equal(
gwColMeans(x, g = g, w = w2)[2L, ],
wcolMeans(x[g == 2, ], w = w2[g == 2])
)
# Grouped
r <- gwColMeans(x, g = g)
expect_equal(r$M[2L, ], wcolMeans(x[g == 2, ]))
expect_equal(r$w, c(4, 2))

g1 <- gwColMeans(x, g = g, w = w2)
g2 <- gwColMeans(x, g = g, w = w2, mean_only = FALSE)
g2_d <- matrix(g2$denom, nrow = 2, ncol = 2, byrow = FALSE)
expect_equal(g1, g2$num / g2_d)
expect_equal(g2$mean, g1)
r <- gwColMeans(x, g = g, reorder = FALSE)
expect_equal(r$M[2L, ], wcolMeans(x[g == 1, ]))
expect_equal(r$w, c(2, 4))

# Grouped and weighted
r <- gwColMeans(x, g = g, w = w2)
expect_equal(r$M[2L, ], wcolMeans(x[g == 2, ], w = w2[g == 2]))
expect_equal(r$w, c(sum(3:6), sum(1:2)))
})

test_that("wcenter() works for matrices with > 1 columns", {
Expand Down

0 comments on commit df418d7

Please sign in to comment.