Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve list_simplify() performance + errors #942

Merged
merged 7 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions R/list-simplify.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#' Simplify a list to an atomic or S3 vector
#'
#' Simplification maintains a one-to-one correspondence between the input
#' and output, implying that each element of `x` must contain a vector of
#' length 1. If you don't want to maintain this correspondence, then you
#' probably want either [list_c()] or [list_flatten()].
#' and output, implying that each element of `x` must contain a one element
#' vector or a one-row data frame. If you don't want to maintain this
#' correspondence, then you probably want either [list_c()]/[list_rbind()] or
#' [list_flatten()].
#'
#' @param x A list.
#' @param strict What should happen if simplification fails? If `TRUE`,
Expand All @@ -16,8 +17,13 @@
#' @examples
#' list_simplify(list(1, 2, 3))
#'
#' try(list_simplify(list(1, 2, "x")))
#' # Only works when vectors are length one and have compatible types:
#' try(list_simplify(list(1, 2, 1:3)))
#' try(list_simplify(list(1, 2, "x")))
#'
#' # Unless you strict = FALSE, in which case you get the input back:
#' list_simplify(list(1, 2, 1:3), strict = FALSE)
#' list_simplify(list(1, 2, "x"), strict = FALSE)
list_simplify <- function(x, strict = TRUE, ptype = NULL) {
if (!is_bool(strict)) {
cli::cli_abort(
Expand All @@ -33,6 +39,7 @@ list_simplify <- function(x, strict = TRUE, ptype = NULL) {
list_simplify_internal <- function(x,
simplify = NA,
ptype = NULL,
error_arg = caller_arg(x),
error_call = caller_env()) {
if (length(simplify) > 1 || !is.logical(simplify)) {
cli::cli_abort(
Expand All @@ -57,38 +64,55 @@ list_simplify_internal <- function(x,
x,
strict = !is.na(simplify),
ptype = ptype,
error_arg = error_arg,
error_call = error_call
)
}

simplify_impl <- function(x,
strict = TRUE,
ptype = NULL,
error_arg = caller_arg(x),
error_call = caller_env()) {
vec_check_list(x, call = error_call)
vec_check_list(x, arg = error_arg, call = error_call)

can_simplify <- every(x, vec_is, size = 1)
# Handle the cases where we definitely can't simplify
if (strict) {
list_check_all_vectors(x, arg = error_arg, call = error_call)
size_one <- list_sizes(x) == 1L
can_simplify <- all(size_one)

if (can_simplify) {
tryCatch(
# TODO: use `error_call` when available
list_unchop(x, ptype = ptype),
vctrs_error_incompatible_type = function(err) {
if (strict || !is.null(ptype)) {
cnd_signal(err)
} else {
x
}
}
)
} else {
if (strict) {
if (!can_simplify) {
bad <- which(!size_one)[[1]]
cli::cli_abort(
"All elements must be length-1 vectors.",
c(
"All elements must be size 1.",
i = "`{error_arg}[[{bad}]]` is size {vec_size(x[[bad]])}."
),
call = error_call
)
} else {
x
}
} else {
can_simplify <- list_all_vectors(x) && all(list_sizes(x) == 1L)

if (!can_simplify) {
return(x)
}
}

names <- vec_names(x)
x <- vec_set_names(x, NULL)

# TODO: use `error_call` when available
out <- tryCatch(
list_unchop(x, ptype = ptype),
vctrs_error_incompatible_type = function(err) {
if (strict || !is.null(ptype)) {
cnd_signal(err)
} else {
x
}
}
)
vec_set_names(out, names)
}
11 changes: 10 additions & 1 deletion R/list-transpose.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,23 @@ list_transpose <- function(x, template = NULL, simplify = NA, ptype = NULL, defa
res <- map(x, idx, .default = default[[i]])
res <- list_simplify_internal(res,
simplify = simplify[[i]] %||% NA,
ptype = ptype[[i]]
ptype = ptype[[i]],
error_arg = result_index(idx)
)
out[[i]] <- res
}

out
}

result_index <- function(idx) {
if (is.character(idx)) {
paste0("result$", idx)
} else {
paste0("result[[", idx, "]]")
}
}

match_template <- function(x, template, error_arg = caller_arg(x), error_call = caller_env()) {
if (is.character(template)) {
if (is_bare_list(x) && is_named(x)) {
Expand Down
14 changes: 10 additions & 4 deletions man/list_simplify.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 24 additions & 12 deletions tests/testthat/_snaps/list-simplify.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

# strict simplification will error

Code
list_simplify(list(mean))
Condition
Error in `list_simplify()`:
! `x[[1]]` must be a vector, not a function.
Code
list_simplify(list(1, "a"))
Condition
Expand All @@ -25,23 +30,38 @@
list_simplify(list(1, 1:2))
Condition
Error in `list_simplify()`:
! All elements must be length-1 vectors.
! All elements must be size 1.
i `x[[2]]` is size 2.
Code
list_simplify(list(data.frame(x = 1), data.frame(x = 1:2)))
Condition
Error in `list_simplify()`:
! All elements must be size 1.
i `x[[2]]` is size 2.
Code
list_simplify(list(1, 2), ptype = character())
Condition
Error:
! Can't convert <double> to <character>.

# validates inputs
# list_simplify() validates inputs

Code
list_simplify_internal(1:5)
list_simplify(1:5)
Condition
Error:
Error in `list_simplify()`:
! `x` must be a list, not an integer vector.

---

Code
list_simplify(list(), strict = NA)
Condition
Error in `list_simplify()`:
! `strict` must be `TRUE` or `FALSE`, not `NA`.

# list_simplify_internal() validates inputs

Code
list_simplify_internal(list(), simplify = 1)
Condition
Expand All @@ -56,11 +76,3 @@
Error:
! Can't specify `ptype` when `simplify = FALSE`.

---

Code
list_simplify(list(), strict = NA)
Condition
Error in `list_simplify()`:
! `strict` must be `TRUE` or `FALSE`, not `NA`.

3 changes: 2 additions & 1 deletion tests/testthat/_snaps/list-transpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
list_transpose(list(list(x = 1), list(x = 2:3)), simplify = TRUE)
Condition
Error in `list_transpose()`:
! All elements must be length-1 vectors.
! All elements must be size 1.
i `result$x[[2]]` is size 2.

# can supply `simplify` globally or individually

Expand Down
17 changes: 13 additions & 4 deletions tests/testthat/test-list-simplify.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ test_that("simplifies using vctrs principles", {
expect_equal(list_simplify(x), data.frame(x = c(1, NA), y = c(NA, 2)))
})

test_that("only uses outer names", {
out <- list_simplify(list(a = 1, c(b = 1), c = c(d = 1)))
expect_named(out, c("a", "", "c"))
})

test_that("ptype is enforced", {
expect_equal(list_simplify(list(1, 2), ptype = double()), c(1, 2))
expect_snapshot(list_simplify(list(1, 2), ptype = character()), error = TRUE)
Expand All @@ -15,8 +20,10 @@ test_that("ptype is enforced", {

test_that("strict simplification will error", {
expect_snapshot(error = TRUE, {
list_simplify(list(mean))
list_simplify(list(1, "a"))
list_simplify(list(1, 1:2))
list_simplify(list(data.frame(x = 1), data.frame(x = 1:2)))
list_simplify(list(1, 2), ptype = character())
})
})
Expand All @@ -29,10 +36,12 @@ test_that("simplification requires length-1 vectors with common type", {

# argument checking -------------------------------------------------------

test_that("validates inputs", {
expect_snapshot(list_simplify_internal(1:5), error = TRUE)
test_that("list_simplify() validates inputs", {
expect_snapshot(list_simplify(1:5), error = TRUE)
expect_snapshot(list_simplify(list(), strict = NA), error = TRUE)
})

test_that("list_simplify_internal() validates inputs", {
expect_snapshot(list_simplify_internal(list(), simplify = 1), error = TRUE)
expect_snapshot(list_simplify_internal(list(), simplify = FALSE, ptype = integer()), error = TRUE)

expect_snapshot(list_simplify(list(), strict = NA), error = TRUE)
})