Skip to content

Commit

Permalink
Merge pull request #234 from mrc-ide/mrc-2529
Browse files Browse the repository at this point in the history
Allow if/else rewriting
  • Loading branch information
weshinsley authored Aug 3, 2021
2 parents b824dd2 + b209499 commit 8c9dff7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.2.1
Version: 1.2.2
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
21 changes: 20 additions & 1 deletion R/opt.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ static_eval <- function(expr) {
return(expr)
}

fn <- expr[[1]]
if (is_call(expr, "+") || is_call(expr, "*")) {
expr <- static_eval_assoc(expr)
} else if (is_call(expr, "if")) {
expr <- static_eval_if(expr)
} else {
expr[-1] <- lapply(expr[-1], static_eval)
}
Expand Down Expand Up @@ -76,6 +77,24 @@ static_eval_assoc <- function(expr) {
}


static_eval_if <- function(expr) {
args <- lapply(expr[-1], static_eval)

cond <- args[[1L]]
if (is.recursive(cond) && all(vlapply(cond[-1L], is.numeric))) {
cond <- eval(cond, baseenv())
}

if (!is.recursive(cond)) {
expr <- if (as.logical(cond)) args[[2L]] else args[[3L]]
} else {
expr[-1L] <- args
}

expr
}


order_args <- function(args) {
i <- viapply(args, function(x) is.language(x) + is.recursive(x))
args[order(-i, vcapply(args, deparse_str))]
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-parse2-rewrite.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,24 @@ test_that("collapse complex constants into expressions", {
dat$equations$deriv_x$rhs$value,
list("+", "a", 8))
})


test_that("collapse if/else expressions", {
code <- c(
"a <- user()",
"b <- if (a == 1) 2 else 3",
"initial(x) <- 1",
"deriv(x) <- x * b")

ir1 <- odin_parse_(code,
options = odin_options(rewrite_constants = TRUE,
substitutions = list(a = 1)))
ir2 <- odin_parse_(code,
options = odin_options(rewrite_constants = TRUE,
substitutions = list(a = 2)))

expect_equal(ir_deserialise(ir1)$equations$deriv_x$rhs$value,
list("*", "x", 2))
expect_equal(ir_deserialise(ir2)$equations$deriv_x$rhs$value,
list("*", "x", 3))
})

0 comments on commit 8c9dff7

Please sign in to comment.