diff --git a/DESCRIPTION b/DESCRIPTION index 0a825a6d..9cf19c1f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: monty Title: Monte Carlo Models -Version: 0.3.22 +Version: 0.3.23 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Wes", "Hinsley", role = "aut"), diff --git a/R/sample-manual.R b/R/sample-manual.R index c74340f4..d7896b22 100644 --- a/R/sample-manual.R +++ b/R/sample-manual.R @@ -193,7 +193,17 @@ monty_sample_manual_collect <- function(path, samples = NULL, prev <- sample_manual_collect_check_samples(inputs, samples, append) - observer <- inputs$model$observer + if (is.null(inputs$restart)) { + model <- inputs$model + sampler <- inputs$sampler + thinning_factor <- inputs$steps$thinning_factor + } else { + model <- inputs$restart$model + sampler <- inputs$restart$sampler + thinning_factor <- inputs$restart$thinning_factor + } + + observer <- model$observer res <- lapply(path$results, readRDS) samples <- combine_chains(res, observer) if (!is.null(prev)) { @@ -201,8 +211,9 @@ monty_sample_manual_collect <- function(path, samples = NULL, } if (restartable) { - samples$restart <- restart_data(res, inputs$model, inputs$sampler, NULL, - inputs$steps$thinning_factor) + runner <- NULL + samples$restart <- restart_data(res, model, sampler, runner, + thinning_factor) } samples } diff --git a/tests/testthat/test-sample-manual.R b/tests/testthat/test-sample-manual.R index 9f1aa5d0..5ea12f9c 100644 --- a/tests/testthat/test-sample-manual.R +++ b/tests/testthat/test-sample-manual.R @@ -355,3 +355,37 @@ test_that("can sample from models requiring restore", { res2 <- monty_sample_manual_collect(path) expect_equal(res2, res1) }) + + +test_that("can continue a manually sampled chain, twice", { + model <- ex_simple_gamma1() + sampler <- monty_sampler_random_walk(vcv = diag(1) * 0.01) + + path_a <- withr::local_tempdir() + path_b <- withr::local_tempdir() + path_c <- withr::local_tempdir() + + set.seed(1) + cmp_a <- monty_sample(model, sampler, 100, n_chains = 2, restartable = TRUE) + cmp_b <- monty_sample_continue(cmp_a, 50, restartable = TRUE) + cmp_c <- monty_sample_continue(cmp_b, 20, restartable = TRUE) + + set.seed(1) + monty_sample_manual_prepare(model, sampler, 100, path_a, n_chains = 2) + monty_sample_manual_run(1, path_a) + monty_sample_manual_run(2, path_a) + res_a <- monty_sample_manual_collect(path_a, restartable = TRUE) + + monty_sample_manual_prepare_continue(res_a, 50, path_b) + monty_sample_manual_run(1, path_b) + monty_sample_manual_run(2, path_b) + res_b <- monty_sample_manual_collect(path_b, samples = res_a, + restartable = TRUE) + + monty_sample_manual_prepare_continue(res_b, 20, path_c) + monty_sample_manual_run(1, path_c) + monty_sample_manual_run(2, path_c) + res_c <- monty_sample_manual_collect(path_c, samples = res_b) + + expect_equal(res_c$pars, cmp_c$pars) +})