From 9f04c9ac1b2715665ede13ec9c658c706a4c5aea Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 19 Dec 2024 08:29:22 +0000 Subject: [PATCH] Properly import code from odin2 --- R/dsl-differentiate-expr.R | 127 ++++++++++------------ tests/testthat/test-dsl-differentiation.R | 39 +++++-- 2 files changed, 91 insertions(+), 75 deletions(-) diff --git a/R/dsl-differentiate-expr.R b/R/dsl-differentiate-expr.R index 780a2e0b..cbd05f8b 100644 --- a/R/dsl-differentiate-expr.R +++ b/R/dsl-differentiate-expr.R @@ -153,7 +153,7 @@ derivative <- list( ## Assume odin's indexes for now idx <- c("i", "j", "k", "l", "i5", "i6", "i7", "i8")[seq_along(index)] - i <- Map(is_same, index, lapply(idx, as.name)) + i <- Map(maths$is_same, index, lapply(idx, as.name)) if (any(vlapply(i, isFALSE))) { return(0) @@ -465,6 +465,64 @@ maths <- local({ ret } } + as_sum_of_parts <- function(expr) { + if (rlang::is_call(expr, c("-", "+"), 2)) { + if (rlang::is_call(expr, "-")) { + parts <- lapply(expr[-1], as_sum_of_parts) + uminus <- monty::monty_differentiation()$uminus + parts[[2]] <- lapply(parts[[2]], uminus) + unlist(parts, FALSE) + } else { + unlist(lapply(expr[-1], as_sum_of_parts), FALSE) + } + } else { + list(expr) + } + } + factorise_parts <- function(parts) { + f <- function(el) { + if (is.numeric(el)) { + list(el, 1, "") + } else if (rlang::is_call(el, "-", 1)) { + ret <- f(el[[2]]) + ret[[1]] <- -1 * ret[[1]] + ret + } else { + list(1, el, rlang::hash(el)) + } + } + parts <- lapply(parts, f) + id <- vcapply(parts, "[[", 3) + ret <- lapply(unname(split(parts, id)), function(el) { + n <- sum(vnapply(el, "[[", 1)) + times(n, el[[1]][[2]]) + }) + plus_fold(ret) + } + factorise <- function(x) { + factorise_parts(as_sum_of_parts(x)) + } + is_same <- function(a, b) { + if (is.numeric(a) && is.numeric(b)) { + return(a == b) + } + if (identical(a, b)) { + return(TRUE) + } + if (!is.recursive(a) && !is.recursive(b)) { + return(call("==", a, b)) + } + + a_parts <- as_sum_of_parts(a) + b_parts <- lapply(as_sum_of_parts(b), uminus) + ab <- factorise_parts(c(a_parts, b_parts)) + + if (is.numeric(ab)) { + return(ab == 0) + } + + call("==", ab, 0) + } rewrite <- function(expr) { if (is.recursive(expr)) { fn <- as.character(expr[[1]]) @@ -494,70 +552,3 @@ maths <- local({ } as.list(environment()) }) - - -is_same <- function(a, b) { - if (is.numeric(a) && is.numeric(b)) { - return(a == b) - } - if (identical(a, b)) { - return(TRUE) - } - if (!is.recursive(a) && !is.recursive(b)) { - return(call("==", a, b)) - } - - a_parts <- expr_to_sum_of_parts(a) - b_parts <- lapply(expr_to_sum_of_parts(b), maths$uminus) - ab <- expr_factorise_parts(c(a_parts, b_parts)) - - if (is.numeric(ab)) { - return(ab == 0) - } - - call("==", ab, 0) -} - - -## Duplicated from odin2: -expr_to_sum_of_parts <- function(expr) { - if (rlang::is_call(expr, c("-", "+"), 2)) { - if (rlang::is_call(expr, "-")) { - parts <- lapply(expr[-1], expr_to_sum_of_parts) - uminus <- monty::monty_differentiation()$maths$uminus - parts[[2]] <- lapply(parts[[2]], uminus) - unlist(parts, FALSE) - } else { - unlist(lapply(expr[-1], expr_to_sum_of_parts), FALSE) - } - } else { - list(expr) - } -} - - -expr_factorise <- function(x) { - expr_factorise_parts(expr_to_sum_of_parts(x)) -} - - -expr_factorise_parts <- function(parts) { - f <- function(el) { - if (is.numeric(el)) { - list(el, 1, "") - } else if (rlang::is_call(el, "-", 1)) { - ret <- f(el[[2]]) - ret[[1]] <- -1 * ret[[1]] - ret - } else { - list(1, el, rlang::hash(el)) - } - } - parts <- lapply(parts, f) - id <- vcapply(parts, "[[", 3) - ret <- lapply(unname(split(parts, id)), function(el) { - n <- sum(vnapply(el, "[[", 1)) - maths$times(n, el[[1]][[2]]) - }) - maths$plus_fold(ret) -} diff --git a/tests/testthat/test-dsl-differentiation.R b/tests/testthat/test-dsl-differentiation.R index 03c4786b..5b1327e7 100644 --- a/tests/testthat/test-dsl-differentiation.R +++ b/tests/testthat/test-dsl-differentiation.R @@ -473,11 +473,36 @@ test_that("differentiate expressions with arrays", { test_that("test sameness", { - expect_true(is_same(1, 1)) - expect_false(is_same(1, 0)) - expect_true(is_same(quote(i), quote(i))) - expect_true(is_same(quote(j), quote(j))) - expect_equal(is_same(quote(i), quote(j)), quote(i == j)) - expect_equal(is_same(quote(i), quote(2)), quote(i == 2)) - expect_false(is_same(quote(i), quote(i + 1))) + expect_true(maths$is_same(1, 1)) + expect_false(maths$is_same(1, 0)) + expect_true(maths$is_same(quote(i), quote(i))) + expect_true(maths$is_same(quote(j), quote(j))) + expect_equal(maths$is_same(quote(i), quote(j)), quote(i == j)) + expect_equal(maths$is_same(quote(i), quote(2)), quote(i == 2)) + expect_false(maths$is_same(quote(i), quote(i + 1))) +}) + + +test_that("decompose an expression into sum of parts", { + expect_equal(maths$as_sum_of_parts(1), list(1)) + expect_equal(maths$as_sum_of_parts(quote(x)), list(quote(x))) + expect_equal(maths$as_sum_of_parts(quote(x + y)), list(quote(x), quote(y))) + expect_equal(maths$as_sum_of_parts(quote(x + y + z)), + list(quote(x), quote(y), quote(z))) + expect_equal(maths$as_sum_of_parts(quote(x + 2 * y + z)), + list(quote(x), quote(2 * y), quote(z))) + expect_equal(maths$as_sum_of_parts(quote(x - y)), list(quote(x), quote(-y))) + expect_equal(maths$as_sum_of_parts(quote(x - y - z)), + list(quote(x), quote(-y), quote(-z))) +}) + + +## Lots that this does not cover yet, it's limited to support what +## tends to happen in odin index calculations which are necessarily +## simple +test_that("factorise an expression", { + expect_equal(maths$factorise(quote(1)), quote(1)) + expect_equal(maths$factorise(quote(a)), quote(a)) + expect_equal(maths$factorise(quote(a + a)), quote(2 * a)) + expect_equal(maths$factorise(quote(1 + 2)), quote(3)) })