From b209499588e3da67ab439d07b5e252579ca15d26 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 21 Jul 2021 09:45:05 +0100 Subject: [PATCH] Allow if/else rewriting --- DESCRIPTION | 2 +- R/opt.R | 21 ++++++++++++++++++++- tests/testthat/test-parse2-rewrite.R | 21 +++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 390c7918..20d71e41 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: odin Title: ODE Generation and Integration -Version: 1.2.1 +Version: 1.2.2 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Thibaut", "Jombart", role = "ctb"), diff --git a/R/opt.R b/R/opt.R index ce2c28a3..5c856c63 100644 --- a/R/opt.R +++ b/R/opt.R @@ -12,9 +12,10 @@ static_eval <- function(expr) { return(expr) } - fn <- expr[[1]] if (is_call(expr, "+") || is_call(expr, "*")) { expr <- static_eval_assoc(expr) + } else if (is_call(expr, "if")) { + expr <- static_eval_if(expr) } else { expr[-1] <- lapply(expr[-1], static_eval) } @@ -76,6 +77,24 @@ static_eval_assoc <- function(expr) { } +static_eval_if <- function(expr) { + args <- lapply(expr[-1], static_eval) + + cond <- args[[1L]] + if (is.recursive(cond) && all(vlapply(cond[-1L], is.numeric))) { + cond <- eval(cond, baseenv()) + } + + if (!is.recursive(cond)) { + expr <- if (as.logical(cond)) args[[2L]] else args[[3L]] + } else { + expr[-1L] <- args + } + + expr +} + + order_args <- function(args) { i <- viapply(args, function(x) is.language(x) + is.recursive(x)) args[order(-i, vcapply(args, deparse_str))] diff --git a/tests/testthat/test-parse2-rewrite.R b/tests/testthat/test-parse2-rewrite.R index f70f0fcb..ea545115 100644 --- a/tests/testthat/test-parse2-rewrite.R +++ b/tests/testthat/test-parse2-rewrite.R @@ -151,3 +151,24 @@ test_that("collapse complex constants into expressions", { dat$equations$deriv_x$rhs$value, list("+", "a", 8)) }) + + +test_that("collapse if/else expressions", { + code <- c( + "a <- user()", + "b <- if (a == 1) 2 else 3", + "initial(x) <- 1", + "deriv(x) <- x * b") + + ir1 <- odin_parse_(code, + options = odin_options(rewrite_constants = TRUE, + substitutions = list(a = 1))) + ir2 <- odin_parse_(code, + options = odin_options(rewrite_constants = TRUE, + substitutions = list(a = 2))) + + expect_equal(ir_deserialise(ir1)$equations$deriv_x$rhs$value, + list("*", "x", 2)) + expect_equal(ir_deserialise(ir2)$equations$deriv_x$rhs$value, + list("*", "x", 3)) +})