Skip to content

Commit

Permalink
Merge pull request #24 from bips-hb/expct
Browse files Browse the repository at this point in the history
Vectorize expct function
  • Loading branch information
mnwright authored May 27, 2024
2 parents 9ff0f5c + 203e8b9 commit 4d9c6b0
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 92 deletions.
209 changes: 151 additions & 58 deletions R/expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,32 @@
#' @param params Circuit parameters learned via \code{\link{forde}}.
#' @param query Optional character vector of variable names. Estimates will be
#' computed for each. If \code{NULL}, all variables other than those in
#' \code{evidence} will be estimated.
#' \code{evidence} will be estimated. If evidence contains \code{NA}s, those
#' variables will be estimated and a full dataset is returned.
#' @param evidence Optional set of conditioning events. This can take one of
#' three forms: (1) a partial sample, i.e. a single row of data with some but
#' not all columns; (2) a data frame of conditioning events, which allows for
#' inequalities; or (3) a posterior distribution over leaves. See Details.
#'
#' three forms: (1) a partial sample, i.e. a single row of data with
#' some but not all columns; (2) a data frame of conditioning events,
#' which allows for inequalities and intervals; or (3) a posterior distribution over leaves;
#' see Details and Examples.
#' @param evidence_row_mode Interpretation of rows in multi-row evidence. If \code{'separate'},
#' each row in \code{evidence} is a separate conditioning event for which \code{n_synth} synthetic samples
#' are generated. If \code{'or'}, the rows are combined with a logical or; see Examples.
#' @param stepsize Stepsize defining number of evidence rows handled in one for each step.
#' Defaults to nrow(evidence)/num_registered_workers for \code{parallel == TRUE}.
#' @param parallel Compute in parallel? Must register backend beforehand, e.g.
#' via \code{doParallel}.
#'
#' @details
#' This function computes expected values for any subset of features, optionally
#' conditioned on some event(s).
#'
#' There are three methods for (optionally) encoding conditioning events via the
#' \code{evidence} argument. The first is to provide a partial sample, where
#' some columns from the training data are missing or set to \code{NA}. The second is to
#' provide a data frame with condition events. This supports inequalities and intervals.
#' Alternatively, users may directly input a pre-calculated posterior
#' distribution over leaves, with columns \code{f_idx} and \code{wt}. This may
#' be preferable for complex constraints. See Examples.
#'
#' @return
#' A one row data frame with values for all query variables.
Expand Down Expand Up @@ -44,6 +59,12 @@
#' # Compute expectations for all features other than Species
#' expct(psi, evidence = evi)
#'
#' # Condition on first two data rows with some missing values
#' evi <- iris[1:2,]
#' evi[1, 1] <- NA_real_
#' evi[1, 5] <- NA_character_
#' evi[2, 2] <- NA_real_
#' x_synth <- expct(psi, evidence = evi)
#'
#' @seealso
#' \code{\link{adversarial_rf}}, \code{\link{forde}}, \code{\link{lik}}
Expand All @@ -57,77 +78,149 @@
expct <- function(
params,
query = NULL,
evidence = NULL) {
evidence = NULL,
evidence_row_mode = c("separate", "or"),
stepsize = 0,
parallel = TRUE) {

evidence_row_mode <- match.arg(evidence_row_mode)

# To avoid data.table check issues
variable <- tree <- f_idx <- cvg <- wt <- V1 <- value <- val <- family <-
mu <- sigma <- obs <- prob <- f_idx_uncond <- . <- NULL

# Prep evidence
conj <- FALSE
if (!is.null(evidence) && !(ncol(evidence) == 2 && all(c("f_idx", "wt") %in% colnames(evidence)))) {
evidence_variable <- prep_cond(evidence, params, "or")$variable
conj <- TRUE
}
mu <- sigma <- obs <- prob <- f_idx_uncond <- step <- c_idx <- idx <-
NA_share <- . <- NULL

# Prepare evidence and stepsize
if (is.null(evidence)) {
step_no <- 1
} else {
evidence <- as.data.table(evidence)
if (ncol(evidence) == 2 && all(colnames(evidence) == c("f_idx", "wt"))) {
stepsize <- nrow(evidence)
step_no <- 1
} else if (parallel & evidence_row_mode == "separate") {
# For "separate", parallelize in forge (not in cforde)
if (stepsize == 0) {
stepsize <- ceiling(nrow(evidence)/foreach::getDoParWorkers())
}
stepsize_cforde <- 0
parallel_cforde = FALSE
step_no <- ceiling(nrow(evidence)/stepsize)
} else {
# For "or", parallelize in cforde (not in forge)
if (stepsize == 0) {
stepsize <- nrow(evidence)
}
stepsize_cforde <- stepsize
parallel_cforde <- parallel
stepsize <- nrow(evidence)
step_no <- 1
}
}

# Check query
if (is.null(query)) {
if (isTRUE(conj)) {
query <- setdiff(params$meta$variable, evidence_variable)
} else {
if (any(is.na(evidence))) {
query <- params$meta$variable
if (!is.null(evidence)) {
warning('Computing expectations for all variables. To avoid this ',
'for conditioning variables, consider passing evidence in the ',
'form of a partial sample or data frame of events.')
}
} else {
query <- setdiff(params$meta$variable, colnames(evidence))
}
} else if (any(!query %in% params$meta$variable)) {
err <- setdiff(query, params$meta$variable)
stop('Unrecognized feature(s) in query: ', err)
}
factor_cols <- params$meta[variable %in% query, family == 'multinom']

# PMF over leaves
if (is.null(evidence)) {
num_trees <- params$forest[, max(tree)]
omega <- params$forest[, .(f_idx, cvg)]
omega[, wt := cvg / num_trees]
omega[, cvg := NULL]
} else if (conj) {
omega <- cforde(params, evidence, "or")$forest[, .(f_idx = f_idx_uncond, wt = cvg)]
} else {
omega <- evidence
}
omega <- omega[wt > 0]

psi_cnt <- psi_cat <- NULL
# Continuous data
if (any(!factor_cols)) {
tmp <- merge(params$cnt[variable %in% query], omega, by = 'f_idx', sort = FALSE)
# tmp[, expct := truncnorm::etruncnorm(min, max, mu, sigma)]
# psi_cnt <- tmp[, crossprod(wt, expct), by = variable]
psi_cnt <- tmp[, crossprod(wt, mu), by = variable]
psi_cnt <- dcast(psi_cnt, . ~ variable, value.var = 'V1')[, . := NULL]
# Run in parallel for each step
par_fun <- function(step_) {

# Prepare the event space
if (is.null(evidence) || ( ncol(evidence) == 2 && all(colnames(evidence) == c("f_idx", "wt")))) {
cparams <- NULL
} else {
# Call cforde with part of the evidence for this step
index_start <- (step_-1)*stepsize + 1
index_end <- min(step_*stepsize, nrow(evidence))
evidence_part <- evidence[index_start:index_end,]
cparams <- cforde(params, evidence_part, evidence_row_mode, stepsize_cforde, parallel_cforde)
}

# omega contains the weight (wt) for each leaf (f_idx) for each condition (c_idx)
if (is.null(cparams)) {
if (is.null(evidence)) {
num_trees <- params$forest[, max(tree)]
omega <- params$forest[, .(f_idx, f_idx_uncond = f_idx, cvg)]
omega[, `:=` (c_idx = 1, wt = cvg / num_trees)]
omega[, cvg := NULL]
} else {
omega <- copy(evidence)
omega[, f_idx_uncond := f_idx]
omega[, c_idx := 1]
}
} else {
omega <- cparams$forest[, .(c_idx, f_idx, f_idx_uncond, wt = cvg)]
}
omega <- omega[wt > 0, ]
omega[, idx := .I]

synth_cnt <- synth_cat <- NULL
# Continuous data
if (any(!factor_cols)) {
if (is.null(cparams) || nrow(cparams$cnt) == 0){
psi_cond <- data.table()
} else {
psi_cond <- merge(omega, cparams$cnt[variable %in% query, -c("cvg_factor", "f_idx_uncond")], by = c('c_idx', 'f_idx'),
sort = FALSE, allow.cartesian = TRUE)[prob > 0,]
# draw sub-leaf areas (resulting from within-row or-conditions)
if(any(psi_cond[,prob != 1])) {
psi_cond[, I := .I]
psi_cond <- psi_cond[sort(c(psi_cond[prob == 1, I],
psi_cond[prob > 0 & prob < 1, fifelse(.N > 1, resample(I, 1, prob = prob), 0), by = .(variable, idx)][,V1])), -"I"]
}
psi_cond[, prob := NULL]
}
psi <- unique(rbind(psi_cond,
merge(omega, params$cnt[variable %in% query, ], by.x = 'f_idx_uncond', by.y = 'f_idx',
sort = FALSE, allow.cartesian = TRUE)[,val := NA_real_]), by = c("c_idx", "f_idx", "variable"))
psi[NA_share == 1, wt := 0]
cnt <- psi[is.na(val), val := sum(wt * mu)/sum(wt), by = .(c_idx, variable)]
cnt <- unique(cnt[, .(c_idx, variable, val)])
synth_cnt <- dcast(cnt, c_idx ~ variable, value.var = 'val')[, c_idx := NULL]
}


# Categorical data
if (any(factor_cols)) {
if (is.null(cparams) || nrow(cparams$cat) == 0) {
psi <- merge(omega, params$cat[variable %in% query, ], by.x = 'f_idx_uncond', by.y = 'f_idx', sort = FALSE, allow.cartesian = TRUE)
} else {
psi_cond <- merge(omega, cparams$cat[variable %in% query, -c("cvg_factor", "f_idx_uncond")], by = c('c_idx', 'f_idx'),
sort = FALSE, allow.cartesian = TRUE)
psi_uncond <- merge(omega, params$cat[variable %in% query, ], by.x = 'f_idx_uncond', by.y = 'f_idx',
sort = FALSE, allow.cartesian = TRUE)
psi_uncond_relevant <- psi_uncond[!psi_cond[,.(idx, variable)], on = .(idx, variable), all = FALSE]
psi <- rbind(psi_cond, psi_uncond_relevant)
}
psi[NA_share == 1, wt := 0]
psi[prob < 1, prob := sum(wt * prob)/sum(wt), by = .(c_idx, variable, val)]
cat <- setDT(psi)[, .SD[which.max.random(prob)], by = .(c_idx, variable)]
cat <- unique(cat[, .(c_idx, variable, val)])
synth_cat <- dcast(cat, c_idx ~ variable, value.var = 'val')[, c_idx := NULL]
}

# Create dataset with expectations
x_synth <- cbind(synth_cnt, synth_cat)
x_synth <- post_x(x_synth, params)

x_synth
}

# Categorical data
if (any(factor_cols)) {
tmp <- merge(params$cat[variable %in% query], omega, by = 'f_idx', sort = FALSE)
tmp <- tmp[, crossprod(prob, wt), by = .(variable, val)]
tmp <- tmp[order(match(variable, query[factor_cols]))]
vals <- tmp[tmp[, .I[which.max(V1)], by = variable]$V1]$val
psi_cat <- setDT(lapply(seq_along(vals), function(j) vals[j]))
setnames(psi_cat, query[factor_cols])
if (isTRUE(parallel)) {
x_synth_ <- foreach(step = 1:step_no, .combine = "rbind") %dopar% par_fun(step)
} else {
x_synth_ <- foreach(step = 1:step_no, .combine = "rbind") %do% par_fun(step)
}

# Clean up, export
out <- cbind(psi_cnt, psi_cat)
out <- post_x(out, params)
return(out)
return(x_synth_)
}




13 changes: 9 additions & 4 deletions R/forde.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@ forde <- function(
x <- suppressWarnings(prep_x(x))
factor_cols <- sapply(x, is.factor)
lvls <- arf$forest$covariate.levels[factor_cols]
lvl_df <- rbindlist(lapply(seq_along(lvls), function(j) {
melt(as.data.table(lvls[j]), measure.vars = names(lvls)[j],
value.name = 'val')[, level := .I]
}))
if (!is.null(lvls)) {
names(lvls) <- colnames_x[factor_cols]
lvl_df <- rbindlist(lapply(seq_along(lvls), function(j) {
melt(as.data.table(lvls[j]), measure.vars = names(lvls)[j],
value.name = 'val')[, level := .I]
}))
} else {
lvl_df <- data.table()
}
names(factor_cols) <- colnames_x
deci <- rep(NA_integer_, d)
if (any(!factor_cols)) {
Expand Down
69 changes: 47 additions & 22 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#' Adaptive column renaming
#'
#' This function renames columns in case the input data.frame includes any
#' This function renames columns in case the input colnames includes any
#' colnames required by internal functions (e.g., \code{"y"}).
#'
#' @param df Input data.frame.
#' @param cn Column names.
#' @param old_name Name of column to be renamed.
#'

col_rename <- function(df, old_name) {
col_rename <- function(cn, old_name) {
k <- 1L
converged <- FALSE
while (!isTRUE(converged)) {
new_name <- paste0(old_name, k)
if (!new_name %in% colnames(df)) {
if (!new_name %in% cn) {
converged <- TRUE
} else {
k <- k + 1L
Expand All @@ -21,6 +21,35 @@ col_rename <- function(df, old_name) {
return(new_name)
}

#' Rename all problematic columns with col_rename().
#'
#' @param cn Old column names.
#'
#' @return New columns names.

col_rename_all <- function(cn) {

if ('y' %in% cn) {
cn[which(cn == 'y')] <- col_rename(cn, 'y')
}
if ('obs' %in% cn) {
cn[which(cn == 'obs')] <- col_rename(cn, 'obs')
}
if ('tree' %in% cn) {
cn[which(cn == 'tree')] <- col_rename(cn, 'tree')
}
if ('leaf' %in% cn) {
cn[which(cn == 'leaf')] <- col_rename(cn, 'leaf')
}
if ('cnt' %in% cn) {
cn[which(cn == 'cnt')] <- col_rename(cn, 'cnt')
}
if ('N' %in% cn) {
cn[which(cn == 'N')] <- col_rename(cn, 'N')
}
cn
}

#' Safer version of sample()
#'
#' @param x A vector of one or more elements from which to choose.
Expand All @@ -32,6 +61,19 @@ resample <- function(x, ...) {
x[sample.int(length(x), ...)]
}

#' which.max() with random at ties
#'
#' @param x A numeric vector.
#'
#' @return Index of maximum value in x, with random tie-breaking.

which.max.random <- function(x) {
if (all(is.na(x))) {
return(NA)
}
which(rank(x, ties.method = "random", na.last = FALSE) == length(x))
}

#' Preprocess input data
#'
#' This function prepares input data for ARFs.
Expand Down Expand Up @@ -74,24 +116,7 @@ prep_x <- function(x) {
}
}
# Rename annoying columns
if ('y' %in% colnames(x)) {
colnames(x)[which(colnames(x) == 'y')] <- col_rename(x, 'y')
}
if ('obs' %in% colnames(x)) {
colnames(x)[which(colnames(x) == 'obs')] <- col_rename(x, 'obs')
}
if ('tree' %in% colnames(x)) {
colnames(x)[which(colnames(x) == 'tree')] <- col_rename(x, 'tree')
}
if ('leaf' %in% colnames(x)) {
colnames(x)[which(colnames(x) == 'leaf')] <- col_rename(x, 'leaf')
}
if ('cnt' %in% colnames(x)) {
colnames(x)[which(colnames(x) == 'cnt')] <- col_rename(x, 'cnt')
}
if ('N' %in% colnames(x)) {
colnames(x)[which(colnames(x) == 'N')] <- col_rename(x, 'N')
}
colnames(x) <- col_rename_all(colnames(x))
return(x)
}

Expand Down
6 changes: 3 additions & 3 deletions man/col_rename.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4d9c6b0

Please sign in to comment.