From 17048e0f78da19388fdb04c25cb61900708fd765 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Fri, 7 Jun 2024 08:32:23 +0200 Subject: [PATCH 1/5] use factor levels of data not RF --- R/forde.R | 4 ++-- tests/testthat/test-return_types.R | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/R/forde.R b/R/forde.R index e063b11..d53973e 100644 --- a/R/forde.R +++ b/R/forde.R @@ -152,8 +152,8 @@ forde <- function( classes <- sapply(x, class) x <- suppressWarnings(prep_x(x)) factor_cols <- sapply(x, is.factor) - lvls <- arf$forest$covariate.levels[factor_cols] - if (!is.null(lvls)) { + if (any(factor_cols)) { + lvls <- lapply(x[, factor_cols, drop = FALSE], levels) names(lvls) <- colnames_x[factor_cols] lvl_df <- rbindlist(lapply(seq_along(lvls), function(j) { melt(as.data.table(lvls[j]), measure.vars = names(lvls)[j], diff --git a/tests/testthat/test-return_types.R b/tests/testthat/test-return_types.R index 31a95d6..a60b34e 100644 --- a/tests/testthat/test-return_types.R +++ b/tests/testthat/test-return_types.R @@ -80,6 +80,20 @@ test_that("FORGE returns same column types", { expect_equal(classes, classes_synth) }) +test_that("FORGE returns factors with same levels (and order of levels)", { + arf <- adversarial_rf(iris, num_trees = 2, verbose = FALSE, parallel = FALSE) + psi <- forde(arf, iris, parallel = FALSE) + x_synth <- forge(psi, n_synth = 10, parallel = FALSE) + expect_equal(levels(x_synth$Species), levels(iris$Species)) +}) + +test_that("EXPCT returns factors with same levels (and order of levels)", { + arf <- adversarial_rf(iris, num_trees = 2, verbose = FALSE, parallel = FALSE) + psi <- forde(arf, iris, parallel = FALSE) + x_synth <- expct(psi, parallel = FALSE) + expect_equal(levels(x_synth$Species), levels(iris$Species)) +}) + # test_that("MAP returns proper column types", { # n <- 50 # dat <- data.frame(numeric = rnorm(n), From 72f9ced2d459fcca2efa358af048638add3aa257 Mon Sep 17 00:00:00 2001 From: Jan Kapar Date: Tue, 11 Jun 2024 17:21:40 +0200 Subject: [PATCH 2/5] use rf-levels internally again --- R/forde.R | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/R/forde.R b/R/forde.R index d53973e..a90a9c0 100644 --- a/R/forde.R +++ b/R/forde.R @@ -328,13 +328,7 @@ forde <- function( all_na[!grepl('\\.5', min), min := min + 0.5] all_na[!grepl('\\.5', max), max := max + 0.5] all_na[, min := min + 0.5][, max := max - 0.5] - all_na <- rbindlist(lapply(seq_len(nrow(all_na)), function(i) { - data.table( - leaf = all_na[i, leaf], variable = all_na[i, variable], - level = all_na[i, seq(min, max)], - NA_share = all_na[i, NA_share] - ) - })) + all_na <- all_na[, .(level = seq(min, max), NA_share), by = .(leaf, variable)] all_na <- merge(all_na, lvl_df, by = c('variable', 'level')) all_na[, level := NULL][, tree := tree] setcolorder(all_na, colnames(dt)) From 9978ea898b1b87e3b64c9ae78c2316e296bbd43f Mon Sep 17 00:00:00 2001 From: Jan Kapar Date: Tue, 11 Jun 2024 17:27:19 +0200 Subject: [PATCH 3/5] use rf-levels internally again --- R/forde.R | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/R/forde.R b/R/forde.R index a90a9c0..5c70fe7 100644 --- a/R/forde.R +++ b/R/forde.R @@ -153,14 +153,15 @@ forde <- function( x <- suppressWarnings(prep_x(x)) factor_cols <- sapply(x, is.factor) if (any(factor_cols)) { - lvls <- lapply(x[, factor_cols, drop = FALSE], levels) - names(lvls) <- colnames_x[factor_cols] - lvl_df <- rbindlist(lapply(seq_along(lvls), function(j) { - melt(as.data.table(lvls[j]), measure.vars = names(lvls)[j], - value.name = 'val')[, level := .I] - })) + # Store levels used in rf (used for internal calculations with all-NA leaves) + lvls_rf <- arf$forest$covariate.levels[factor_cols] + lvl_df_rf <- data.table(variable = names(lvls_rf), val = lvls_rf)[ + , .(val = unlist(val), levels = seq_len(length(unlist(val)))), by = variable] + # Store levels used in data (used for forde output to post-process synthetic data) + lvl_df_data <- data.table(x)[, .(variable = names(.SD), val = lapply(.SD, levels)) ,.SDcols = factor_cols][ + , .(val = unlist(val), levels = seq_len(length(unlist(val)))), by = variable] } else { - lvl_df <- data.table() + lvl_df_rf <- lvl_df_data <- data.table() } names(factor_cols) <- colnames_x deci <- rep(NA_integer_, d) @@ -323,13 +324,19 @@ forde <- function( by = c('tree', 'leaf', 'variable'), sort = FALSE) all_na[!is.finite(min), min := 0.5] for (j in names(which(factor_cols))) { - all_na[!is.finite(max) & variable == j, max := lvl_df[variable == j, max(level)]] + all_na[!is.finite(max) & variable == j, max := lvl_df_rf[variable == j, max(level)]] } all_na[!grepl('\\.5', min), min := min + 0.5] all_na[!grepl('\\.5', max), max := max + 0.5] all_na[, min := min + 0.5][, max := max - 0.5] - all_na <- all_na[, .(level = seq(min, max), NA_share), by = .(leaf, variable)] - all_na <- merge(all_na, lvl_df, by = c('variable', 'level')) + all_na <- rbindlist(lapply(seq_len(nrow(all_na)), function(i) { + data.table( + leaf = all_na[i, leaf], variable = all_na[i, variable], + level = all_na[i, seq(min, max)], + NA_share = all_na[i, NA_share] + ) + })) + all_na <- merge(all_na, lvl_df_rf, by = c('variable', 'level')) all_na[, level := NULL][, tree := tree] setcolorder(all_na, colnames(dt)) dt <- rbind(dt, all_na) @@ -343,14 +350,14 @@ forde <- function( } else { # Define the range of each variable in each leaf dt <- unique(dt[, val_count := .N, by = .(f_idx, variable, val)]) - dt <- merge(dt, lvl_df[, .(k = .N), by = variable], by = "variable") + dt <- merge(dt, lvl_df_rf[, .(k = .N), by = variable], by = "variable") dt[!is.finite(min), min := 0.5][!is.finite(max), max := k + 0.5] dt[!grepl('\\.5', min), min := min + 0.5][!grepl('\\.5', max), max := max + 0.5] dt[, k := max - min] # Enumerate each possible leaf-variable-value combo tmp <- dt[, seq(min[1] + 0.5, max[1] - 0.5), by = .(f_idx, variable)] setnames(tmp, 'V1', 'level') - tmp <- merge(tmp, lvl_df, by = c('variable', 'level'), + tmp <- merge(tmp, lvl_df_rf, by = c('variable', 'level'), sort = FALSE)[, level := NULL] # Populate count, k tmp <- merge(tmp, unique(dt[, .(f_idx, variable, count, k)]), @@ -375,7 +382,7 @@ forde <- function( psi_cat <- foreach(tree = seq_len(num_trees), .combine = rbind) %do% psi_cat_fn(tree) } - lvl_df[, level := NULL] + lvl_df_rf[, level := NULL] setkey(psi_cat, f_idx, variable) setcolorder(psi_cat, c('f_idx', 'variable')) } else { @@ -391,7 +398,7 @@ forde <- function( 'meta' = data.table('variable' = colnames_x, 'class' = classes, 'family' = fifelse(factor_cols, 'multinom', family), 'decimals' = deci), - 'levels' = lvl_df, + 'levels' = lvl_df_data, 'input_class' = input_class ) return(psi) From a52c744477aeb1b944d3be2720d9a78c8fd68992 Mon Sep 17 00:00:00 2001 From: Jan Kapar Date: Tue, 11 Jun 2024 18:39:12 +0200 Subject: [PATCH 4/5] annoying things --- R/forde.R | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/R/forde.R b/R/forde.R index 5c70fe7..02e68e6 100644 --- a/R/forde.R +++ b/R/forde.R @@ -155,11 +155,11 @@ forde <- function( if (any(factor_cols)) { # Store levels used in rf (used for internal calculations with all-NA leaves) lvls_rf <- arf$forest$covariate.levels[factor_cols] - lvl_df_rf <- data.table(variable = names(lvls_rf), val = lvls_rf)[ - , .(val = unlist(val), levels = seq_len(length(unlist(val)))), by = variable] + lvl_df_rf <- data.table(variable = colnames_x[factor_cols], val = lvls_rf)[ + , .(val = unlist(val), level = seq_len(length(unlist(val)))), by = variable] # Store levels used in data (used for forde output to post-process synthetic data) - lvl_df_data <- data.table(x)[, .(variable = names(.SD), val = lapply(.SD, levels)) ,.SDcols = factor_cols][ - , .(val = unlist(val), levels = seq_len(length(unlist(val)))), by = variable] + lvl_df_data <- data.table(x)[, .(variable = colnames_x[factor_cols], val = lapply(.SD, levels)) ,.SDcols = factor_cols][ + , .(val = unlist(val)), by = variable] } else { lvl_df_rf <- lvl_df_data <- data.table() } @@ -329,13 +329,7 @@ forde <- function( all_na[!grepl('\\.5', min), min := min + 0.5] all_na[!grepl('\\.5', max), max := max + 0.5] all_na[, min := min + 0.5][, max := max - 0.5] - all_na <- rbindlist(lapply(seq_len(nrow(all_na)), function(i) { - data.table( - leaf = all_na[i, leaf], variable = all_na[i, variable], - level = all_na[i, seq(min, max)], - NA_share = all_na[i, NA_share] - ) - })) + all_na <- all_na[, .(level = seq(min, max), NA_share), by = .(leaf, variable)] all_na <- merge(all_na, lvl_df_rf, by = c('variable', 'level')) all_na[, level := NULL][, tree := tree] setcolorder(all_na, colnames(dt)) From 05323c5e0effe6ea8d015eea3857bca0d556d856 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Wed, 12 Jun 2024 09:48:38 +0200 Subject: [PATCH 5/5] update docs --- man/expct.Rd | 3 +++ 1 file changed, 3 insertions(+) diff --git a/man/expct.Rd b/man/expct.Rd index 536e68a..75f91a2 100644 --- a/man/expct.Rd +++ b/man/expct.Rd @@ -58,6 +58,9 @@ provide a data frame with condition events. This supports inequalities and inter Alternatively, users may directly input a pre-calculated posterior distribution over leaves, with columns \code{f_idx} and \code{wt}. This may be preferable for complex constraints. See Examples. + +Please note that results for continuous features which are both included in \code{query} and in +\code{evidence} with an interval condition are currently inconsistent. } \examples{ # Train ARF and estimate leaf parameters