Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hist2 #99

Merged
merged 4 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ average_loss.default <- function(object, X, y,

# Real work
L <- as.matrix(loss(y, pred_fun(object, X, ...)))
M <- gwColMeans(L, g = BY, w = w)
M <- gwColMeans(L, g = BY, w = w)[["M"]]

if (agg_cols && ncol(M) > 1L) {
M <- cbind(rowSums(M))
Expand Down
35 changes: 16 additions & 19 deletions R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,36 @@ wcolMeans <- function(x, w = NULL) {

#' Grouped wcolMeans()
#'
#' Internal function used to calculate grouped column-wise weighted means.
#'
#' Internal function used to calculate grouped column-wise weighted means along with
#' corresponding (weighted) counts.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A matrix-like object.
#' @param g Optional grouping variable.
#' @param w Optional case weights.
#' @param reorder Should groups be ordered, see [rowsum()]. Default is `TRUE`.
#' @param mean_only If `TRUE`, only the means are returned. If `FALSE`, a list
#' with "mean", "num" and "denom" is returned.
#' @returns A matrix with one row per group (if `mean_only = TRUE`), or a list
#' with slots "mean", "num", and "denom".
#' @returns A list with two elements: "M" represents a matrix of grouped (column)
#' means, and "w" is a vector of corresponding group counts/weights.
#' @examples
#' with(iris, gwColMeans(Sepal.Width, g = Species, w = Sepal.Length))
#' with(iris, gwColMeans(Sepal.Width, g = Species, w = Sepal.Length, mean_only = FALSE))
gwColMeans <- function(x, g = NULL, w = NULL, reorder = TRUE, mean_only = TRUE) {
gwColMeans <- function(x, g = NULL, w = NULL, reorder = TRUE) {
if (is.null(g)) {
return(rbind(wcolMeans(x, w = w)))
M <- rbind(wcolMeans(x, w = w))
denom <- if (is.null(w)) NROW(x) else sum(w)
return(list(M = M, w = denom))
}

# Now the interesting case
if (is.null(w)) {
num <- rowsum(x, group = g, reorder = reorder)
denom <- rowsum(rep.int(1, NROW(x)), group = g, reorder = reorder)
w <- rep.int(1, NROW(x))
} else {
num <- rowsum(x * w, group = g, reorder = reorder)
denom <- rowsum(w, group = g, reorder = reorder)
}
out <- num / matrix(denom, nrow = nrow(num), ncol = ncol(num), byrow = FALSE)

if (mean_only) {
return(out)
x <- x * w # w is correctly recycled over columns
}
list(mean = out, num = num, denom = denom)
num <- rowsum(x, group = g, reorder = reorder)
denom <- as.numeric(rowsum(w, group = g, reorder = reorder))
list(M = num / denom, w = denom)
}

#' Weighted Mean Centering
Expand Down
17 changes: 10 additions & 7 deletions R/utils_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#' If `strategy = "quantile"`, the evaluation points are quantiles over a regular grid
#' of probabilities from `trim[1]` to `trim[2]`.
#'
#' All quantiles are calculated via the inverse of the ECDF, i.e., via
#' Quantiles are calculated via the inverse of the ECDF, i.e., via
#' `stats::quantile(..., type = 1`).
#'
#' @param z A vector or factor.
Expand All @@ -24,7 +24,7 @@
#' of grid values. Set to `0:1` for no trimming.
#' @param strategy How to find grid values of non-discrete numeric columns?
#' Either "uniform" or "quantile", see description of [univariate_grid()].
#' @param na.rm Should missing values be dropped from grid? Default is `TRUE`.
#' @param na.rm Should missing values be dropped from the grid? Default is `TRUE`.
#' @returns A vector or factor of evaluation points.
#' @seealso [multivariate_grid()]
#' @export
Expand All @@ -33,8 +33,8 @@
#' univariate_grid(rev(iris$Species)) # Same
#'
#' x <- iris$Sepal.Width
#' univariate_grid(x, grid_size = 5) # Quantile binning
#' univariate_grid(x, grid_size = 5, strategy = "uniform") # Uniform
#' univariate_grid(x, grid_size = 5) # Uniform binning
#' univariate_grid(x, grid_size = 5, strategy = "quantile") # Quantile
univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE) {
strategy <- match.arg(strategy)
Expand All @@ -50,9 +50,12 @@ univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99),
g <- stats::quantile(z, probs = p, names = FALSE, type = 1L, na.rm = TRUE)
out <- unique(g)
} else {
# strategy = "uniform" (could use range() if trim = 0:1)
r <- stats::quantile(z, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
# pretty(r, n = grid_size) # Until version 0.2.0
# strategy = "uniform"
if (trim[1L] == 0 && trim[2L] == 1) {
r <- range(z, na.rm = TRUE)
} else {
r <- stats::quantile(z, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
}
out <- seq(r[1L], r[2L], length.out = grid_size)
}
if (!na.rm && anyNA(z)) {
Expand Down
39 changes: 39 additions & 0 deletions backlog/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,42 @@ plot.calibration <- function(x,
}
p
}

#' Histogram Bin Construction
#'
#' Creates histogram of vector/factor `x`. In the discrete case, no binning is done.
#' Otherwise, the values are optionally trimmed and then passed to [hist()]. Compared
#' with [hist()], the function also returns the binned values of `x`.
#'
#' @param x A vector or factor to be binned.
#' @inheritParams hist
#' @inheritParams univariate_grid
#' @returns A list with binned "x", vector of "breaks", bin midpoints "grid", and a
#' logical flag "discrete" indicating whether the values have not been binned.
#' @seealso See [calibration()] for examples.
hist2 <- function(x, breaks = 17L, trim = c(0.01, 0.99),
include.lowest = TRUE, right = TRUE, na.rm = TRUE) {
g <- unique(x)
if (!is.numeric(x) || (length(breaks) == 1L && is.numeric(breaks) && length(g) <= breaks)) {
g <- sort(g, na.last = if (na.rm) NA else TRUE)
return(list(x = x, breaks = g, grid = g, discrete = TRUE))
}

# Trim outliers before histogram construction?
if (trim[1L] == 0 && trim[2L] == 1) {
xx <- x
} else {
r <- stats::quantile(x, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
xx <- x[x >= r[1L] & x <= r[2L]]
}
h <- hist(xx, breaks = breaks, include.lowest = include.lowest, right = right)
b <- h$breaks
ix <- findInterval(
x, vec = b, left.open = right, rightmost.closed = include.lowest, all.inside = TRUE
)
g <- h$mids
if (!na.rm && anyNA(x)) {
g <- c(g, NA)
}
list(x = g[ix], breaks = b, grid = g, discrete = FALSE)
}
2 changes: 1 addition & 1 deletion man/H2.Rd → man/h2.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/H2_overall.Rd → man/h2_overall.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/H2_pairwise.Rd → man/h2_pairwise.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/H2_threeway.Rd → man/h2_threeway.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/ice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/multivariate_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/partial_dep.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/univariate_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 21 additions & 21 deletions tests/testthat/test_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,31 @@ test_that("gwcolMeans() works", {
w2 <- 1:6

# Ungrouped
expect_equal(gwColMeans(x), rbind(wcolMeans(x)))
expect_equal(gwColMeans(x, w = w1), rbind(wcolMeans(x, w = w1)))
expect_equal(gwColMeans(x, w = w2), rbind(wcolMeans(x, w = w2)))
r <- gwColMeans(x)
expect_equal(r$M, rbind(wcolMeans(x)))
expect_equal(r$w, nrow(x))

# Grouped
expect_equal(gwColMeans(x, g = g)[2L, ], wcolMeans(x[g == 2, ]))
expect_equal(gwColMeans(x, g = g, reorder = FALSE)[2L, ], wcolMeans(x[g == 1, ]))
r <- gwColMeans(x, w = w1)
expect_equal(r$M, rbind(wcolMeans(x, w = w1)))
expect_equal(r$w, sum(w1))

g1 <- gwColMeans(x, g = g)
g2 <- gwColMeans(x, g = g, mean_only = FALSE)
g2_d <- matrix(g2$denom, nrow = 2, ncol = 2, byrow = FALSE)
expect_equal(g1, g2$num / g2_d)
expect_equal(g2$mean, g1)
r <- gwColMeans(x, w = w2)
expect_equal(r$M, rbind(wcolMeans(x, w = w2)))
expect_equal(r$w, sum(w2))

# Grouped and weighted
expect_equal(
gwColMeans(x, g = g, w = w2)[2L, ],
wcolMeans(x[g == 2, ], w = w2[g == 2])
)
# Grouped
r <- gwColMeans(x, g = g)
expect_equal(r$M[2L, ], wcolMeans(x[g == 2, ]))
expect_equal(r$w, c(4, 2))

g1 <- gwColMeans(x, g = g, w = w2)
g2 <- gwColMeans(x, g = g, w = w2, mean_only = FALSE)
g2_d <- matrix(g2$denom, nrow = 2, ncol = 2, byrow = FALSE)
expect_equal(g1, g2$num / g2_d)
expect_equal(g2$mean, g1)
r <- gwColMeans(x, g = g, reorder = FALSE)
expect_equal(r$M[2L, ], wcolMeans(x[g == 1, ]))
expect_equal(r$w, c(2, 4))

# Grouped and weighted
r <- gwColMeans(x, g = g, w = w2)
expect_equal(r$M[2L, ], wcolMeans(x[g == 2, ], w = w2[g == 2]))
expect_equal(r$w, c(sum(3:6), sum(1:2)))
})

test_that("wcenter() works for matrices with > 1 columns", {
Expand Down