Skip to content

Commit

Permalink
Support min/max
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Oct 8, 2024
1 parent e4e3d9f commit ba0a37e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
8 changes: 4 additions & 4 deletions R/generate_dust_sexp.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ generate_dust_sexp_reduce <- function(expr, dat, options) {
dim <- paste0(
if (isFALSE(options$shared_exists)) "dim_" else "shared.dim.",
target)
stopifnot(fn == "sum")
stopifnot(fn %in% c("sum", "prod", "min", "max"))
if (is.null(index)) {
sprintf("dust2::array::sum<real_type>(%s, %s)", target_str, dim)
sprintf("dust2::array::%s<real_type>(%s, %s)", fn, target_str, dim)
} else {
index_str <- paste(vcapply(index, function(el) {
if (el$type == "single") {
Expand All @@ -217,7 +217,7 @@ generate_dust_sexp_reduce <- function(expr, dat, options) {
generate_dust_sexp(from, dat, options),
generate_dust_sexp(to, dat, options))
}), collapse = ", ")
sprintf("dust2::array::sum<real_type>(%s, %s, %s)",
target_str, dim, index_str)
sprintf("dust2::array::%s<real_type>(%s, %s, %s)",
fn, target_str, dim, index_str)
}
}
18 changes: 18 additions & 0 deletions tests/testthat/test-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -1660,3 +1660,21 @@ test_that("Generate conditional debug", {
" }",
"}"))
})


test_that("support min/max", {
dat <- odin_parse({
update(x) <- min(a) + max(b, c)
initial(x) <- 0
a[] <- i
dim(a) <- 10
b <- 20
c <- 30
})
dat <- generate_prepare(dat)
expect_equal(
generate_dust_system_update(dat),
c(method_args$update,
" state_next[0] = dust2::array::min<real_type>(shared.a, shared.dim.a) + std::max(shared.b, shared.c);",
"}"))
})

0 comments on commit ba0a37e

Please sign in to comment.