Skip to content

Commit

Permalink
Merge pull request #96 from mayer79/gwcolmeans-extension
Browse files Browse the repository at this point in the history
Add new argument mean_only to gwColMeans()
  • Loading branch information
mayer79 authored Oct 29, 2023
2 parents d7e9697 + 41a5651 commit 3dd2df2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 24 deletions.
19 changes: 15 additions & 4 deletions R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,15 @@ wcolMeans <- function(x, w = NULL) {
#' @param x A matrix-like object.
#' @param g Optional grouping variable.
#' @param w Optional case weights.
#' @param reorder Should groups be ordered, see [rowsum()]. Default is `TRUE`.
#' @returns A matrix with one row per group.
gwColMeans <- function(x, g = NULL, w = NULL, reorder = TRUE) {
#' @param reorder Should groups be ordered, see [rowsum()]. Default is `TRUE`.
#' @param mean_only If `TRUE`, only the means are returned. If `FALSE`, a list
#' with "mean", "num" and "denom" is returned.
#' @returns A matrix with one row per group (if `mean_only = TRUE`), or a list
#' with slots "mean", "num", and "denom".
#' @examples
#' with(iris, gwColMeans(Sepal.Width, g = Species, w = Sepal.Length))
#' with(iris, gwColMeans(Sepal.Width, g = Species, w = Sepal.Length, mean_only = FALSE))
gwColMeans <- function(x, g = NULL, w = NULL, reorder = TRUE, mean_only = TRUE) {
if (is.null(g)) {
return(rbind(wcolMeans(x, w = w)))
}
Expand All @@ -98,7 +104,12 @@ gwColMeans <- function(x, g = NULL, w = NULL, reorder = TRUE) {
num <- rowsum(x * w, group = g, reorder = reorder)
denom <- rowsum(w, group = g, reorder = reorder)
}
num / matrix(denom, nrow = nrow(num), ncol = ncol(num), byrow = FALSE)
out <- num / matrix(denom, nrow = nrow(num), ncol = ncol(num), byrow = FALSE)

if (mean_only) {
return(out)
}
list(mean = out, num = num, denom = denom)
}

#' Weighted Mean Centering
Expand Down
49 changes: 29 additions & 20 deletions backlog/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#' (calib <- calibration(fit, v = "Petal.Length", X = iris, y = "Sepal.Length"))
#' plot(calib)
#'
#' (calib <- calibration(fit, v = "Petal.Length", X = iris, y = "Sepal.Length", BY = "Species"))
#' plot(calib)
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' calib <- calibration(fit, v = "Petal.Width", X = iris, y = iris[1:2])
Expand All @@ -56,7 +59,9 @@ calibration <- function(object, ...) {
#' @describeIn calibration Default method.
#' @export
calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predict,
grid_size = 17L, pred = NULL,
BY = NULL, by_size = 4L,
grid_size = 17L,
pred = NULL,
n_max = 1000L, w = NULL, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
Expand All @@ -67,48 +72,52 @@ calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predic
if (!is.null(y)) {
y <- align_pred(prepare_y(y = y, X = X)[["y"]])
}

if (!is.null(w)) {
w <- prepare_w(w = w, X = X)[["w"]]
}

v_grouped <- approx_vector(X[[v]], m = grid_size)
if (!is.null(BY)) {
BY2 <- prepare_by(BY = BY, X = X, by_size = by_size)
BY <- BY2[["BY"]]
}
g <- v_grouped <- approx_vector(X[[v]], m = grid_size)
grid <- sort(unique(v_grouped), na.last = TRUE)

if (!is.null(BY)) {
g <- paste(BY, g, sep = ":")
}

# Average predicted
if (is.null(pred)) {
pred <- pred_fun(object, X, ...)
}
pred <- align_pred(pred)
avg_pred <- gwColMeans(pred, g = v_grouped, w = w)
avg_pred <- gwColMeans(pred, g = g, w = w)

# Average observed
avg_obs <- if (!is.null(y)) gwColMeans(y, g = v_grouped, w = w)
avg_obs <- if (!is.null(y)) gwColMeans(y, g = g, w = w)

# Exposure
ww <- if (is.null(w)) rep.int(1, NROW(X)) else w
exposure <- rowsum(ww, group = v_grouped)
exposure <- rowsum(ww, group = g)

# Partial dependence
if (n_max > 0L) {
if (nrow(X) > n_max) {
ix <- sample(nrow(X), n_max)
X <- X[ix, , drop = FALSE]
w <- if (!is.null(w)) w[ix]
}
pd <- pd_raw(
object = object, v = v, X = X, grid = grid, pred_fun = pred_fun, w = w, ...
)
rownames(pd) <- grid
} else {
pd <- NULL
}
pd <- partial_dep(
object = object,
v = v,
X = X,
grid = grid,
pred_fun = pred_fun,
BY = BY,
w = w,
...
)[["data"]]

out <- list(
v = v,
K = ncol(pred),
pred_names = colnames(pred),
grid = grid,
BY,
avg_obs = avg_obs,
avg_pred = avg_pred,
pd = pd,
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,23 @@ test_that("gwcolMeans() works", {
expect_equal(gwColMeans(x, g = g)[2L, ], wcolMeans(x[g == 2, ]))
expect_equal(gwColMeans(x, g = g, reorder = FALSE)[2L, ], wcolMeans(x[g == 1, ]))

g1 <- gwColMeans(x, g = g)
g2 <- gwColMeans(x, g = g, mean_only = FALSE)
g2_d <- matrix(g2$denom, nrow = 2, ncol = 2, byrow = FALSE)
expect_equal(g1, g2$num / g2_d)
expect_equal(g2$mean, g1)

# Grouped and weighted
expect_equal(
gwColMeans(x, g = g, w = w2)[2L, ],
wcolMeans(x[g == 2, ], w = w2[g == 2])
)

g1 <- gwColMeans(x, g = g, w = w2)
g2 <- gwColMeans(x, g = g, w = w2, mean_only = FALSE)
g2_d <- matrix(g2$denom, nrow = 2, ncol = 2, byrow = FALSE)
expect_equal(g1, g2$num / g2_d)
expect_equal(g2$mean, g1)
})

test_that("wcenter() works for matrices with > 1 columns", {
Expand Down

0 comments on commit 3dd2df2

Please sign in to comment.