Skip to content

Commit

Permalink
Properly import code from odin2
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Dec 19, 2024
1 parent 00d45a7 commit 9f04c9a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 75 deletions.
127 changes: 59 additions & 68 deletions R/dsl-differentiate-expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ derivative <- list(

## Assume odin's indexes for now
idx <- c("i", "j", "k", "l", "i5", "i6", "i7", "i8")[seq_along(index)]
i <- Map(is_same, index, lapply(idx, as.name))
i <- Map(maths$is_same, index, lapply(idx, as.name))

if (any(vlapply(i, isFALSE))) {
return(0)
Expand Down Expand Up @@ -465,6 +465,64 @@ maths <- local({
ret
}
}
as_sum_of_parts <- function(expr) {
if (rlang::is_call(expr, c("-", "+"), 2)) {
if (rlang::is_call(expr, "-")) {
parts <- lapply(expr[-1], as_sum_of_parts)
uminus <- monty::monty_differentiation()$uminus
parts[[2]] <- lapply(parts[[2]], uminus)
unlist(parts, FALSE)
} else {
unlist(lapply(expr[-1], as_sum_of_parts), FALSE)
}
} else {
list(expr)
}
}
factorise_parts <- function(parts) {
f <- function(el) {
if (is.numeric(el)) {
list(el, 1, "")
} else if (rlang::is_call(el, "-", 1)) {
ret <- f(el[[2]])
ret[[1]] <- -1 * ret[[1]]
ret
} else {
list(1, el, rlang::hash(el))
}
}
parts <- lapply(parts, f)
id <- vcapply(parts, "[[", 3)
ret <- lapply(unname(split(parts, id)), function(el) {
n <- sum(vnapply(el, "[[", 1))
times(n, el[[1]][[2]])
})
plus_fold(ret)
}
factorise <- function(x) {
factorise_parts(as_sum_of_parts(x))
}
is_same <- function(a, b) {
if (is.numeric(a) && is.numeric(b)) {
return(a == b)
}
if (identical(a, b)) {
return(TRUE)
}
if (!is.recursive(a) && !is.recursive(b)) {
return(call("==", a, b))
}

a_parts <- as_sum_of_parts(a)
b_parts <- lapply(as_sum_of_parts(b), uminus)
ab <- factorise_parts(c(a_parts, b_parts))

if (is.numeric(ab)) {
return(ab == 0)
}

call("==", ab, 0)
}
rewrite <- function(expr) {
if (is.recursive(expr)) {
fn <- as.character(expr[[1]])
Expand Down Expand Up @@ -494,70 +552,3 @@ maths <- local({
}
as.list(environment())
})


is_same <- function(a, b) {
if (is.numeric(a) && is.numeric(b)) {
return(a == b)
}
if (identical(a, b)) {
return(TRUE)
}
if (!is.recursive(a) && !is.recursive(b)) {
return(call("==", a, b))
}

a_parts <- expr_to_sum_of_parts(a)
b_parts <- lapply(expr_to_sum_of_parts(b), maths$uminus)
ab <- expr_factorise_parts(c(a_parts, b_parts))

if (is.numeric(ab)) {
return(ab == 0)
}

call("==", ab, 0)
}


## Duplicated from odin2:
expr_to_sum_of_parts <- function(expr) {
if (rlang::is_call(expr, c("-", "+"), 2)) {
if (rlang::is_call(expr, "-")) {
parts <- lapply(expr[-1], expr_to_sum_of_parts)
uminus <- monty::monty_differentiation()$maths$uminus
parts[[2]] <- lapply(parts[[2]], uminus)
unlist(parts, FALSE)
} else {
unlist(lapply(expr[-1], expr_to_sum_of_parts), FALSE)
}
} else {
list(expr)
}
}


expr_factorise <- function(x) {
expr_factorise_parts(expr_to_sum_of_parts(x))
}


expr_factorise_parts <- function(parts) {
f <- function(el) {
if (is.numeric(el)) {
list(el, 1, "")
} else if (rlang::is_call(el, "-", 1)) {
ret <- f(el[[2]])
ret[[1]] <- -1 * ret[[1]]
ret
} else {
list(1, el, rlang::hash(el))
}
}
parts <- lapply(parts, f)
id <- vcapply(parts, "[[", 3)
ret <- lapply(unname(split(parts, id)), function(el) {
n <- sum(vnapply(el, "[[", 1))
maths$times(n, el[[1]][[2]])
})
maths$plus_fold(ret)
}
39 changes: 32 additions & 7 deletions tests/testthat/test-dsl-differentiation.R
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,36 @@ test_that("differentiate expressions with arrays", {


test_that("test sameness", {
expect_true(is_same(1, 1))
expect_false(is_same(1, 0))
expect_true(is_same(quote(i), quote(i)))
expect_true(is_same(quote(j), quote(j)))
expect_equal(is_same(quote(i), quote(j)), quote(i == j))
expect_equal(is_same(quote(i), quote(2)), quote(i == 2))
expect_false(is_same(quote(i), quote(i + 1)))
expect_true(maths$is_same(1, 1))
expect_false(maths$is_same(1, 0))
expect_true(maths$is_same(quote(i), quote(i)))
expect_true(maths$is_same(quote(j), quote(j)))
expect_equal(maths$is_same(quote(i), quote(j)), quote(i == j))
expect_equal(maths$is_same(quote(i), quote(2)), quote(i == 2))
expect_false(maths$is_same(quote(i), quote(i + 1)))
})


test_that("decompose an expression into sum of parts", {
expect_equal(maths$as_sum_of_parts(1), list(1))
expect_equal(maths$as_sum_of_parts(quote(x)), list(quote(x)))
expect_equal(maths$as_sum_of_parts(quote(x + y)), list(quote(x), quote(y)))
expect_equal(maths$as_sum_of_parts(quote(x + y + z)),
list(quote(x), quote(y), quote(z)))
expect_equal(maths$as_sum_of_parts(quote(x + 2 * y + z)),
list(quote(x), quote(2 * y), quote(z)))
expect_equal(maths$as_sum_of_parts(quote(x - y)), list(quote(x), quote(-y)))
expect_equal(maths$as_sum_of_parts(quote(x - y - z)),
list(quote(x), quote(-y), quote(-z)))
})


## Lots that this does not cover yet, it's limited to support what
## tends to happen in odin index calculations which are necessarily
## simple
test_that("factorise an expression", {
expect_equal(maths$factorise(quote(1)), quote(1))
expect_equal(maths$factorise(quote(a)), quote(a))
expect_equal(maths$factorise(quote(a + a)), quote(2 * a))
expect_equal(maths$factorise(quote(1 + 2)), quote(3))
})

0 comments on commit 9f04c9a

Please sign in to comment.