Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stops augsynth from segfaulting with na values or unbalanced panels #64

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 70 additions & 52 deletions R/augsynth.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


#' Fit Augmented SCM
#'
#'
#' @param form outcome ~ treatment | auxillary covariates
#' @param unit Name of unit column
#' @param time Name of time column
Expand All @@ -14,11 +14,11 @@
#' ridge=Ridge regression (allows for standard errors),
#' none=No outcome model,
#' en=Elastic Net, RF=Random Forest, GSYN=gSynth,
#' mcp=MCPanel,
#' mcp=MCPanel,
#' cits=Comparitive Interuppted Time Series
#' causalimpact=Bayesian structural time series with CausalImpact
#' @param scm Whether the SCM weighting function is used
#' @param fixedeff Whether to include a unit fixed effect, default F
#' @param fixedeff Whether to include a unit fixed effect, default F
#' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted
#' @param ... optional arguments for outcome model
#'
Expand All @@ -43,33 +43,51 @@ single_augsynth <- function(form, unit, time, t_int, data,
unit <- enquo(unit)
time <- enquo(time)

## validate input data

# Check for NA values
if( sum(is.na(data %>%
select(!!unit, !!time, any_of(as.character(form)))
)
) > 0 ) {
stop("Missing values detected.")
}

# Check whether there are omitted rows
full_data <- data %>%
tidyr::expand({{unit}}, {{time}})

if( nrow(data) != nrow(full_data) ) {
stop("There are missing rows in the input data set. Panel must be balanced.")
}

## format data
outcome <- terms(formula(form, rhs=1))[[2]]
trt <- terms(formula(form, rhs=1))[[3]]

wide <- format_data(outcome, trt, unit, time, t_int, data)
synth_data <- do.call(format_synth, wide)

treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit)
control_units <- data %>% filter(!(!!unit %in% treated_units)) %>%
control_units <- data %>% filter(!(!!unit %in% treated_units)) %>%
distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit)
## add covariates
if(length(form)[2] == 2) {
Z <- extract_covariates(form, unit, time, t_int, data, cov_agg)
} else {
Z <- NULL
}

# fit augmented SCM
augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc,
augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc,
scm, fixedeff, ...)

# add some extra data
augsynth$data$time <- data %>% distinct(!!time) %>%
arrange(!!time) %>% pull(!!time)
augsynth$call <- call_name
augsynth$t_int <- t_int
augsynth$t_int <- t_int

augsynth$weights <- matrix(augsynth$weights)
rownames(augsynth$weights) <- control_units

Expand All @@ -86,9 +104,9 @@ single_augsynth <- function(form, unit, time, t_int, data,
#' @param fixedeff Whether to de-mean synth
#' @param V V matrix for Synth, default NULL
#' @param ... Extra args for outcome model
#'
#'
#' @noRd
#'
#'
fit_augsynth_internal <- function(wide, synth_data, Z, progfunc,
scm, fixedeff, V = NULL, ...) {

Expand Down Expand Up @@ -119,23 +137,23 @@ fit_augsynth_internal <- function(wide, synth_data, Z, progfunc,
} else if(progfunc == "none") {
## Just SCM
augsynth <- do.call(fit_ridgeaug_formatted,
c(list(wide_data = fit_wide,
c(list(wide_data = fit_wide,
synth_data = fit_synth_data,
Z = Z, ridge = F, scm = T, V = V, ...)))
} else {
## Other outcome models
progfuncs = c("ridge", "none", "en", "rf", "gsyn", "mcp",
"cits", "causalimpact", "seq2seq")
if (progfunc %in% progfuncs) {
augsynth <- fit_augsyn(fit_wide, fit_synth_data,
augsynth <- fit_augsyn(fit_wide, fit_synth_data,
progfunc, scm, ...)
} else {
stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq', 'None'")
}

}

augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0),
augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0),
augsynth$mhat)
augsynth$data <- wide
augsynth$data$Z <- Z
Expand Down Expand Up @@ -169,13 +187,13 @@ predict.augsynth <- function(object, att = F, ...) {
# att <- F
# }
augsynth <- object

X <- augsynth$data$X
y <- augsynth$data$y
comb <- cbind(X, y)
trt <- augsynth$data$trt
mhat <- augsynth$mhat

m1 <- colMeans(mhat[trt==1,,drop=F])

resid <- (comb[trt==0,,drop=F] - mhat[trt==0,drop=F])
Expand All @@ -198,7 +216,7 @@ predict.augsynth <- function(object, att = F, ...) {
#' @export
print.augsynth <- function(x, ...) {
augsynth <- x

## straight from lm
cat("\nCall:\n", paste(deparse(augsynth$call), sep="\n", collapse="\n"), "\n\n", sep="")

Expand All @@ -214,7 +232,7 @@ print.augsynth <- function(x, ...) {

#' Plot function for augsynth
#' @importFrom graphics plot
#'
#'
#' @param x Augsynth object to be plotted
#' @param inf Boolean, whether to get confidence intervals around the point estimates
#' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects
Expand All @@ -228,22 +246,22 @@ plot.augsynth <- function(x, inf = T, cv = F, ...) {
# }

augsynth <- x

if (cv == T) {
errors = data.frame(lambdas = augsynth$lambdas,
errors = augsynth$lambda_errors,
errors_se = augsynth$lambda_errors_se)
p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) +
ggplot2::geom_point(size = 2) +
ggplot2::geom_point(size = 2) +
ggplot2::geom_errorbar(
ggplot2::aes(ymin = errors,
ymax = errors + errors_se),
width=0.2, size = 0.5)
width=0.2, size = 0.5)
p <- p + ggplot2::labs(title = bquote("Cross Validation MSE over " ~ lambda),
x = expression(lambda), y = "Cross Validation MSE",
x = expression(lambda), y = "Cross Validation MSE",
parse = TRUE)
p <- p + ggplot2::scale_x_log10()

# find minimum and min + 1se lambda to plot
min_lambda <- choose_lambda(augsynth$lambdas,
augsynth$lambda_errors,
Expand All @@ -257,7 +275,7 @@ plot.augsynth <- function(x, inf = T, cv = F, ...) {
min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda)

p <- p + ggplot2::geom_point(
ggplot2::aes(x = min_lambda,
ggplot2::aes(x = min_lambda,
y = augsynth$lambda_errors[min_lambda_index]),
color = "gold")
p + ggplot2::geom_point(
Expand Down Expand Up @@ -299,8 +317,8 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) {
# } else {
# inf_type <- "conformal"
# }


summ <- list()

t0 <- ncol(augsynth$data$X)
Expand Down Expand Up @@ -382,8 +400,8 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) {
} else {
summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w
}


summ$inf_type <- if(inf) inf_type else "None"
class(summ) <- "summary.augsynth"
return(summ)
Expand All @@ -395,7 +413,7 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", ...) {
#' @export
print.summary.augsynth <- function(x, ...) {
summ <- x

## straight from lm
cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="")

Expand All @@ -405,7 +423,7 @@ print.summary.augsynth <- function(x, ...) {
att_est <- summ$att$Estimate
t_total <- length(att_est)
t_int <- summ$att %>% filter(Time <= summ$t_int) %>% nrow()

att_pre <- att_est[1:(t_int-1)]
att_post <- att_est[t_int:t_total]

Expand All @@ -420,14 +438,14 @@ print.summary.augsynth <- function(x, ...) {
se_avg <- summ$average_att$Std.Error

out_msg <- paste("Average ATT Estimate (Jackknife Std. Error): ",
format(round(att_post,3), nsmall=3),
format(round(att_post,3), nsmall=3),
" (",
format(round(se_avg,3)), ")\n")
inf_type <- "Jackknife over units"
} else if(summ$inf_type == "conformal") {
p_val <- summ$average_att$p_val
out_msg <- paste("Average ATT Estimate (p Value for Joint Null): ",
format(round(att_post,3), nsmall=3),
format(round(att_post,3), nsmall=3),
" (",
format(round(p_val,3)), ")\n")
inf_type <- "Conformal inference"
Expand All @@ -442,7 +460,7 @@ print.summary.augsynth <- function(x, ...) {
}


out_msg <- paste(out_msg,
out_msg <- paste(out_msg,
"L2 Imbalance: ",
format(round(summ$l2_imbalance,3), nsmall=3), "\n",
"Percent improvement from uniform weights: ",
Expand All @@ -452,16 +470,16 @@ print.summary.augsynth <- function(x, ...) {

out_msg <- paste(out_msg,
"Covariate L2 Imbalance: ",
format(round(summ$covariate_l2_imbalance,3),
format(round(summ$covariate_l2_imbalance,3),
nsmall=3),
"\n",
"Percent improvement from uniform weights: ",
format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100),
format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100),
"%\n\n",
sep="")

}
out_msg <- paste(out_msg,
out_msg <- paste(out_msg,
"Avg Estimated Bias: ",
format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n",
"Inference type: ",
Expand All @@ -471,30 +489,30 @@ print.summary.augsynth <- function(x, ...) {
cat(out_msg)

if(summ$inf_type == "jackknife") {
out_att <- summ$att[t_int:t_final,] %>%
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, Std.Error)
} else if(summ$inf_type == "conformal") {
out_att <- summ$att[t_int:t_final,] %>%
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, lower_bound, upper_bound, p_val)
names(out_att) <- c("Time", "Estimate",
names(out_att) <- c("Time", "Estimate",
paste0((1 - summ$alpha) * 100, "% CI Lower Bound"),
paste0((1 - summ$alpha) * 100, "% CI Upper Bound"),
paste0("p Value"))
} else if(summ$inf_type == "jackknife+") {
out_att <- summ$att[t_int:t_final,] %>%
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate, lower_bound, upper_bound)
names(out_att) <- c("Time", "Estimate",
names(out_att) <- c("Time", "Estimate",
paste0((1 - summ$alpha) * 100, "% CI Lower Bound"),
paste0((1 - summ$alpha) * 100, "% CI Upper Bound"))
} else {
out_att <- summ$att[t_int:t_final,] %>%
out_att <- summ$att[t_int:t_final,] %>%
select(Time, Estimate)
}
out_att %>%
mutate_at(vars(-Time), ~ round(., 3)) %>%
print(row.names = F)


}

#' Plot function for summary function for augsynth
Expand All @@ -509,7 +527,7 @@ plot.summary.augsynth <- function(x, inf = T, ...) {
# } else {
# inf <- T
# }

p <- summ$att %>%
ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate))
if(inf) {
Expand All @@ -526,15 +544,15 @@ plot.summary.augsynth <- function(x, inf = T, ...) {
}
p + ggplot2::geom_line() +
ggplot2::geom_vline(xintercept=summ$t_int, lty=2) +
ggplot2::geom_hline(yintercept=0, lty=2) +
ggplot2::geom_hline(yintercept=0, lty=2) +
ggplot2::theme_bw()

}



#' augsynth
#'
#'
#' @description A package implementing the Augmented Synthetic Controls Method
#' @docType package
#' @name augsynth-package
Expand All @@ -545,9 +563,9 @@ plot.summary.augsynth <- function(x, inf = T, ...) {
#' @import tidyr
#' @importFrom stats terms
#' @importFrom stats formula
#' @importFrom stats update
#' @importFrom stats delete.response
#' @importFrom stats model.matrix
#' @importFrom stats model.frame
#' @importFrom stats update
#' @importFrom stats delete.response
#' @importFrom stats model.matrix
#' @importFrom stats model.frame
#' @importFrom stats na.omit
NULL
Loading