Skip to content

Commit

Permalink
Merge pull request #10 from mrc-ide/mrc-5106
Browse files Browse the repository at this point in the history
Improve interface around model specification
  • Loading branch information
richfitz authored Mar 1, 2024
2 parents 4a92bdf + 28d1658 commit b078f65
Show file tree
Hide file tree
Showing 8 changed files with 437 additions and 103 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

export(mcstate_model)
export(mcstate_model_properties)
export(mcstate_rng)
export(mcstate_rng_distributed_pointer)
export(mcstate_rng_distributed_state)
Expand Down
255 changes: 209 additions & 46 deletions R/model.R
Original file line number Diff line number Diff line change
@@ -1,63 +1,226 @@
##' Create a basic `mcstate` model. Currently nothing here is
##' validated, and it's likely that users will never actually use this
##' directly. Contains data and methods that define a basic model
##' object, so that we can implement samplers against. Not all models
##' will support everything here, and we'll add additional
##' fields/traits over time to advertise what a model can do. For
##' example, models will need to advertise that they are capable of
##' being differentiated, or that they are stochastic in order to be
##' used with different methods.
##' Describe properties of a model. Use of this function is optional,
##' but you can pass the return value of this as the `properties`
##' argument of mcstate_model to enforce that your model does actually have these
##' properties.
##'
##' @title Create basic model
##' @title Describe model properties
##'
##' @param parameters Names of the parameters. Every parameter is
##' named, and for now every parameter is a scalar. We might relax
##' this later to support an `odin`-style structured parameter list,
##' but that might just generate a suitable vector of parameter
##' names perhaps? In any case, once we start doing inference it's
##' naturally in the R^n, and here n is defined as the length of
##' this vector of names.
##' @param has_gradient Logical, indicating if the model has a
##' `gradient` method. Use `NULL` (the default) to detect this from
##' the model.
##'
##' @param direct_sample A function to sample directly from the
##' parameter space, given an [mcstate_rng] object to sample from.
##' In the case where a model returns a posterior (e.g., in Bayesian
##' inference), this is assumed to be sampling from the prior.
##' We'll use this for generating initial conditions for MCMC where
##' those are not given, and possibly other uses.
##'
##' @param density Compute the model density for a vector of parameter
##' values; this is the posterior probability in the case of
##' Bayesian inference, but it could be anything really. Models can
##' return `-Inf` if things are impossible, and we'll try and cope
##' gracefully with that wherever possible.
##'
##' @param gradient Compute the gradient of `density` with respect to
##' the parameter vector; takes a parameter vector and returns a
##' vector the same length. For efficiency, the model may want to
##' be stateful so that gradients can be efficiently calculated
##' after a density calculation, or density after gradient, where
##' these are called with the same parameters.
##'
##' @param domain Information on the parameter domain. This is a two
##' @param has_direct_sample Logical, indicating if the model has a
##' `direct_sample` method. Use `NULL` (the default) to detect this
##' from the model.
##'
##' @return A list of class `mcstate_model_properties` which should
##' not be modified.
##'
##' @export
mcstate_model_properties <- function(has_gradient = NULL,
has_direct_sample = NULL) {
ret <- list(has_gradient = has_gradient,
has_direct_sample = has_direct_sample)
class(ret) <- "mcstate_model_properties"
ret
}

##' Create a basic `mcstate` model. This takes a user-supplied object
##' that minimally can compute a probability density (via a `density`
##' function) and information about parameters; with this we can
##' sample from the model using `MCMC` using [mcstate_sample]. We
##' don't imagine that many users will call this function directly,
##' but that this will be glue used by packages.
##'
##' The `model` argument can be a list or environment (something
##' indexable by `$`) and have elements:
##'
##' * `density`: A function that will compute some probability
##' density. It must take an argument representing a parameter
##' vector (a numeric vector) and return a single value. This is
##' the posterior probability density in Bayesian inference, but it
##' could be anything really. Models can return `-Inf` if things
##' are impossible, and we'll try and cope gracefully with that
##' wherever possible.
##'
##' * `parameters`: A character vector of parameter names. This
##' vector is the source of truth for the length of the parameter
##' vector.
##'
##' * `domain`: Information on the parameter domain. This is a two
##' column matrix with `length(parameters)` rows representing each
##' parameter. The parameter minimum and maximum bounds are given
##' as the first and second column. Infinite values (`-Inf` or
##' `Inf`) should be used where the parameter has infinite domain up
##' or down. Currently used to translate from a bounded to
##' unbounded space for HMC, but we might also use this for
##' reflecting proposals in MCMC too.
##' reflecting proposals in MCMC too. If not present we assume that
##' the model is valid everywhere (i.e., that all parameters are
##' valid from `-Inf` to `Inf`.
##'
##' @return An object of class `mcstate_model`, which can be used with
##' a sampler.
##' * `direct_sample`: A function to sample directly from the
##' parameter space, given an [mcstate_rng] object to sample from.
##' In the case where a model returns a posterior (e.g., in Bayesian
##' inference), this is assumed to be sampling from the prior.
##' We'll use this for generating initial conditions for MCMC where
##' those are not given, and possibly other uses. If not given then
##' when using [mcstate_sample()] the user will have to provide a
##' vector of initial states.
##'
##' * `gradient`: A function to compute the gradient of `density` with
##' respect to the parameter vector; takes a parameter vector and
##' returns a vector the same length. For efficiency, the model may
##' want to be stateful so that gradients can be efficiently
##' calculated after a density calculation, or density after
##' gradient, where these are called with the same parameters. This
##' function is optional (and may not be well defined or possible to
##' define).
##'
##' @title Create basic model
##'
##' @param model A list or environment with elements as described in
##' Details.
##'
##' @param properties Optionally, a [mcstate_model_properties] object,
##' used to enforce or clarify properties of the model.
##'
##' @return An object of class `mcstate_model`. This will have elements:
##'
##' * `model`: The model as provided
##' * `parameters`: The parameter name vector
##' * `domain`: The parameter domain matrix
##' * `direct_sample`: The `direct_sample` function, if provided by the model
##' * `gradient`: The `gradient` function, if provided by the model
##' * `properties`: A list of properties of the model;
##' see [mcstate_model_properties()]. Currently this contains:
##' * `has_gradient`: the model can compute its gradient
##' * `has_direct_sample`: the model can sample from parameters space
##'
##' @export
mcstate_model <- function(parameters, direct_sample, density, gradient,
domain) {
ret <- list(parameters = parameters,
direct_sample = direct_sample,
mcstate_model <- function(model, properties = NULL) {
call <- environment() # for nicer stack traces
parameters <- validate_model_parameters(model, call)
domain <- validate_model_domain(model, call)
density <- validate_model_density(model, call)

properties <- validate_model_properties(properties, call)
gradient <- validate_model_gradient(model, properties, call)
direct_sample <- validate_model_direct_sample(model, properties, call)

## Update properties based on what we found:
properties$has_gradient <- !is.null(gradient)
properties$has_direct_sample <- !is.null(direct_sample)

ret <- list(model = model,
parameters = parameters,
domain = domain,
density = density,
gradient = gradient,
domain = domain)
direct_sample = direct_sample,
properties = properties)
class(ret) <- "mcstate_model"
ret
}



validate_model_properties <- function(properties, call = NULL) {
if (is.null(properties)) {
return(mcstate_model_properties())
}

if (!inherits(properties, "mcstate_model_properties")) {
cli::cli_abort(
"Expected 'properties' to be a 'mcstate_model_properties' object",
arg = "properties", call = call)
}

properties
}


validate_model_parameters <- function(model, call = NULL) {
if (!is.character(model$parameters)) {
cli::cli_abort("Expected 'model$parameters' to be a character vector",
arg = "model", call = call)
}
model$parameters
}


validate_model_domain <- function(model, call = NULL) {
n_pars <- length(model$parameters)
if (is.null(model$domain)) {
domain <- cbind(rep(-Inf, n_pars), rep(Inf, n_pars))
} else {
domain <- model$domain
if (!is.matrix(domain)) {
cli::cli_abort("Expected 'model$domain' to be a matrix if non-NULL")
}
if (nrow(domain) != n_pars) {
cli::cli_abort(paste(
"Expected 'model$domain' to have {n_pars} row{?s},",
"but it had {nrow(domain)}"))
}
if (ncol(domain) != 2) {
cli::cli_abort(paste(
"Expected 'model$domain' to have 2 columns,",
"but it had {ncol(domain)}"))
}
}
domain
}


validate_model_density <- function(model, call = NULL) {
if (!is.function(model$density)) {
cli::cli_abort("Expected 'model$density' to be a function",
arg = "model", call = call)
}
model$density
}


validate_model_optional_method <- function(model, properties,
method_name, property_name,
call) {
if (isFALSE(properties[[property_name]])) {
return(NULL)
}
value <- model[[method_name]]
if (isTRUE(properties[[property_name]]) && !is.function(value)) {
cli::cli_abort(
paste("Did not find a function '{method_name}' within your model, but",
"your properties say that it should do"),
arg = "model", call = call)
}
if (!is.null(value) && !is.function(value)) {
cli::cli_abort(
"Expected 'model${method_name}' to be a function if non-NULL",
arg = "model", call = call)
}
value
}


validate_model_gradient <- function(model, properties, call) {
validate_model_optional_method(
model, properties, "gradient", "has_gradient", call)
}


validate_model_direct_sample <- function(model, properties, call) {
validate_model_optional_method(
model, properties, "direct_sample", "has_direct_sample", call)
}


require_direct_sample <- function(model, message, ...) {
if (!model$properties$has_direct_sample) {
cli::cli_abort(
c(message,
i = paste("This model does not provide 'direct_sample()', so we",
"cannot directly sample from its parameter space")),
...)
}
}
3 changes: 3 additions & 0 deletions R/sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ initial_parameters <- function(initial, model, rng, call = NULL) {
n_pars <- length(model$parameters)
n_chains <- length(rng)
if (is.null(initial)) {
require_direct_sample(model,
"'initial' must be provided with this model",
arg = "initial", call = environment())
## Really this would just be from the prior; we can't directly
## sample from the posterior!
initial <- lapply(rng, function(r) model$direct_sample(r))
Expand Down
Loading

0 comments on commit b078f65

Please sign in to comment.