Skip to content

Commit

Permalink
Merge pull request #12 from mrc-ide/mrc-5113
Browse files Browse the repository at this point in the history
Implement restartable chains
  • Loading branch information
richfitz authored Mar 4, 2024
2 parents b078f65 + c9c852e commit 9d6c55e
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 10 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export(mcstate_rng_pointer)
export(mcstate_runner_parallel)
export(mcstate_runner_serial)
export(mcstate_sample)
export(mcstate_sample_continue)
export(mcstate_sampler_random_walk)
importFrom(R6,R6Class)
useDynLib(mcstate2, .registration = TRUE)
3 changes: 2 additions & 1 deletion R/runner.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ mcstate_run_chain <- function(pars, model, sampler, n_steps, rng) {
## surface to the user in the final object (or summarise them in
## some particular way with no guarantees about the format). We
## might hold things like start and stop times here in future.
internal <- list(used_r_rng = !identical(get_r_rng_state(), r_rng_state))
internal <- list(used_r_rng = !identical(get_r_rng_state(), r_rng_state),
rng_state = rng$state())

list(pars = history_pars,
density = history_density,
Expand Down
117 changes: 109 additions & 8 deletions R/sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@
##' `parallel` package to run chains in parallel. If you only run
##' one chain then this argument is best left alone.
##'
##' @param restartable Logical, indicating if the chains should be
##' restartable. This will add additional data to the chains
##' object.
##'
##' @return A list of parameters and densities.
##'
##' @export
mcstate_sample <- function(model, sampler, n_steps, initial = NULL,
n_chains = 1L, runner = NULL) {
n_chains = 1L, runner = NULL,
restartable = FALSE) {
if (!inherits(model, "mcstate_model")) {
cli::cli_abort("Expected 'model' to be an 'mcstate_model'",
arg = "model")
Expand All @@ -55,7 +60,58 @@ mcstate_sample <- function(model, sampler, n_steps, initial = NULL,
rng <- initial_rng(n_chains)
pars <- initial_parameters(initial, model, rng, environment())
res <- runner$run(pars, model, sampler, n_steps, rng)
combine_chains(res)

samples <- combine_chains(res)
if (restartable) {
samples$restart <- restart_data(res, model, sampler, runner)
}
samples
}


##' Continue (restart) chains started by [mcstate_sample]. Requires
##' that the original chains were run with `restartable = TRUE`.
##' Running chains this way will result in the final state being
##' exactly the same as running for the total (original + continued)
##' number of steps in a single push.
##'
##' @title Continue sampling
##'
##' @param samples A `mcstate_samples` object created by
##' [mcstate_sample()]
##'
##' @param n_steps The number of new steps to run
##'
##' @inheritParams mcstate_sample
##'
##' @return A list of parameters and densities
##' @export
mcstate_sample_continue <- function(samples, n_steps, restartable = FALSE) {
if (!inherits(samples, "mcstate_samples")) {
cli::cli_abort("Expected 'samples' to be an 'mcstate_samples' object")
}
if (is.null(samples$restart)) {
cli::cli_abort(
c("Your chains are not restartable",
i = paste("To work with 'mcstate_sample_continue', you must",
"use the argument 'restartable = TRUE' when calling",
"mcstate_sample()")))
}

rng <- lapply(samples$restart$rng_state,
function(s) mcstate_rng$new(seed = s))
model <- samples$restart$model
pars <- samples$restart$pars
sampler <- samples$restart$sampler
runner <- samples$restart$runner

res <- runner$run(pars, model, sampler, n_steps, rng)
samples <- append_chains(samples, combine_chains(res))

if (restartable) {
samples$restart <- restart_data(res, model, sampler, runner)
}
samples
}


Expand Down Expand Up @@ -156,16 +212,61 @@ combine_chains <- function(res) {
"runner")))
}

ret <- list(pars = pars,
density = density,
details = details,
chain = chain)
class(ret) <- "mcstate_samples"
ret
samples <- list(pars = pars,
density = density,
details = details,
chain = chain)
class(samples) <- "mcstate_samples"
samples
}


## This is absolutely terrible, but it will get there.
append_chains <- function(prev, curr) {
n_chains <- length(prev$restart$rng_state)
i <- split(seq_along(prev$chain), prev$chain)
j <- split(seq_along(curr$chain) + length(prev$chain), curr$chain)
j <- lapply(j, function(x) x[-1])
k <- unlist(rbind(i, j))

pars <- rbind(prev$pars, curr$pars)[k, , drop = FALSE]
density <- c(prev$density, curr$density)[k]
if (!is.null(prev$details) || !is.null(curr$details)) {
## This needs to wait until hmc or adaptive sampling are merged to
## work with.
cli::cli_abort("Can't yet merge chains with details")
} else {
details <- NULL
}
chain <- c(prev$chain, curr$chain)[k]

samples <- list(pars = pars,
density = density,
details = details,
chain = chain)
class(samples) <- "mcstate_samples"
samples
}


initial_rng <- function(n_chains, seed = NULL) {
lapply(mcstate_rng_distributed_state(n_nodes = n_chains, seed = seed),
function(s) mcstate_rng$new(seed = s))
}


restart_data <- function(res, model, sampler, runner) {
## TODO: thisis not actually enough; we also need the state from any
## stateful sampler (so that's the case for the adaptive sampler and
## for hmc with debug enabled)
n_pars <- ncol(res[[1]]$pars)
pars <- vapply(res, function(x) x$pars[nrow(x$pars), ], numeric(n_pars))
if (n_pars == 1) {
pars <- matrix(pars, ncol = 1)
}
list(rng_state = lapply(res, function(x) x$internal$rng_state),
pars = pars,
model = model,
sampler = sampler,
runner = runner)
}
7 changes: 6 additions & 1 deletion man/mcstate_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions man/mcstate_sample_continue.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions tests/testthat/test-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,63 @@ test_that("need a direct sample function in order to start sampling", {
mcstate_sample(model2, sampler, 10),
"'initial' must be provided with this model")
})


test_that("can continue chains", {
model <- ex_simple_gamma1()
sampler <- mcstate_sampler_random_walk(vcv = diag(1) * 0.01)

set.seed(1)
res1 <- mcstate_sample(model, sampler, 100, 1, n_chains = 3)

set.seed(1)
res2a <- mcstate_sample(model, sampler, 50, 1, n_chains = 3,
restartable = TRUE)
res2b <- mcstate_sample_continue(res2a, 50)

expect_equal(res2b, res1)
})


test_that("can continue continuable chains", {
model <- ex_simple_gamma1()
sampler <- mcstate_sampler_random_walk(vcv = diag(1) * 0.01)

set.seed(1)
res1 <- mcstate_sample(model, sampler, 30, 1, n_chains = 3,
restartable = TRUE)

set.seed(1)
res2a <- mcstate_sample(model, sampler, 10, 1, n_chains = 3,
restartable = TRUE)
res2b <- mcstate_sample_continue(res2a, 10, restartable = TRUE)
res2c <- mcstate_sample_continue(res2b, 10, restartable = TRUE)

expect_equal(res2c, res1)
})


test_that("can't restart chains that don't have restart information", {
model <- ex_simple_gamma1()
sampler <- mcstate_sampler_random_walk(vcv = diag(1) * 0.01)
res <- mcstate_sample(model, sampler, 5, 1, n_chains = 3)
expect_error(mcstate_sample_continue(res, 50),
"Your chains are not restartable")
})


test_that("continuing requires that we have a samples object", {
model <- ex_simple_gamma1()
expect_error(mcstate_sample_continue(model, 50),
"Expected 'samples' to be an 'mcstate_samples' object")
})


test_that("can't append chains that have details", {
model <- ex_simple_gamma1()
sampler <- mcstate_sampler_random_walk(vcv = diag(1) * 0.01)
res <- mcstate_sample(model, sampler, 5, 1, restartable = TRUE)
res$details <- list()
expect_error(mcstate_sample_continue(res, 5),
"Can't yet merge chains with details")
})

0 comments on commit 9d6c55e

Please sign in to comment.