diff --git a/NEWS.md b/NEWS.md index 7e78e7519..cb66cb0a5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # Version 7.11.0.9000 +## Bug fixes + +* Restrict static transforms so they only use the upstream part of the plan (#1199, #1200, @bart1). # Version 7.11.0 diff --git a/R/igraph.R b/R/igraph.R index 85eca2591..a0c94cb60 100644 --- a/R/igraph.R +++ b/R/igraph.R @@ -18,6 +18,15 @@ downstream_nodes <- function(graph, from) { ) } +upstream_nodes <- function(graph, from) { + nbhd_vertices( + graph = graph, + vertices = from, + mode = "in", + order = igraph::gorder(graph) + ) +} + nbhd_graph <- function(graph, vertices, mode, order) { vertices <- nbhd_vertices( graph = graph, diff --git a/R/transform_plan.R b/R/transform_plan.R index bcbb42205..8564d95f1 100644 --- a/R/transform_plan.R +++ b/R/transform_plan.R @@ -277,12 +277,15 @@ transform_plan_ <- function( plan$transform <- lapply(plan$transform, parse_transform) graph <- dsl_graph(plan) order <- igraph::topo_sort(graph)$name + subplans <- split(plan, f = plan$target) for (target in order) { - index <- which(target == plan$target) - rows <- transform_row(index, plan, graph, max_expand) - plan <- sub_in_plan(plan, rows, index) - old_cols(plan) <- old_cols + upstream_plan <- dsl_upstream_plan(target, graph, subplans) + index <- which(target == upstream_plan$target) + old_cols(upstream_plan) <- old_cols + subplans[[target]] <- transform_row(index, upstream_plan, graph, max_expand) } + plan <- drake_bind_rows(subplans) + old_cols(plan) <- old_cols plan <- dsl_trace(plan = plan, trace = trace) old_cols(plan) <- plan$transform <- NULL plan <- dsl_tidy_eval(plan = plan, tidy_eval = tidy_eval, envir = envir) @@ -290,6 +293,11 @@ transform_plan_ <- function( plan } +dsl_upstream_plan <- function(target, graph, subplans) { + upstream_targets <- upstream_nodes(graph, target) + drake_bind_rows(subplans[upstream_targets]) +} + dsl_trace <- function(plan, trace) { if (!trace) { keep <- as.character(intersect(colnames(plan), old_cols(plan))) diff --git a/tests/testthat/test-7-dsl.R b/tests/testthat/test-7-dsl.R index 6fee3738d..430ba077e 100644 --- a/tests/testthat/test-7-dsl.R +++ b/tests/testthat/test-7-dsl.R @@ -3078,3 +3078,127 @@ test_with_dir("NAs removed from old grouping vars grid (#1010)", { ) equivalent_plans(out, exp) }) + +test_with_dir("static transforms use only upstream part of plan (#1199)", { + skip_on_cran() + radars <- c("radar1", "radar2") + seasons <- c("season1", "season2") + months <- c(1, 2) + radar_seasons <- expand.grid( + radar = radars, + season = seasons, + stringsAsFactors = FALSE + ) + out <- drake_plan( + data = target( + get_data(radar, month), + transform = cross(radar = !!radars, month = !!months) + ), + to_cross = target( + list(data), + transform = combine(data, .by = radar) + ), + problem = target( + list(to_cross, season), + transform = cross(to_cross, season = !!seasons) + ), + separate = target( + list(radar, season), + transform = map(.data = !!radar_seasons) + ), + trace = TRUE + ) + exp <- drake_plan( + data_radar1_1 = target( + command = get_data("radar1", 1), + radar = "\"radar1\"", + month = "1", + data = "data_radar1_1" + ), + data_radar2_1 = target( + command = get_data("radar2", 1), + radar = "\"radar2\"", + month = "1", + data = "data_radar2_1" + ), + data_radar1_2 = target( + command = get_data("radar1", 2), + radar = "\"radar1\"", + month = "2", + data = "data_radar1_2" + ), + data_radar2_2 = target( + command = get_data("radar2", 2), + radar = "\"radar2\"", + month = "2", + data = "data_radar2_2" + ), + problem_season1_to_cross_radar1 = target( + command = list(to_cross_radar1, "season1"), + radar = "\"radar1\"", + season = "\"season1\"", + separate = "separate_radar1_season1", + to_cross = "to_cross_radar1", + problem = "problem_season1_to_cross_radar1" + ), + problem_season2_to_cross_radar1 = target( + command = list(to_cross_radar1, "season2"), + radar = "\"radar1\"", + season = "\"season2\"", + separate = "separate_radar1_season2", + to_cross = "to_cross_radar1", + problem = "problem_season2_to_cross_radar1" + ), + problem_season1_to_cross_radar2 = target( + command = list(to_cross_radar2, "season1"), + radar = "\"radar1\"", + season = "\"season1\"", + separate = "separate_radar1_season1", + to_cross = "to_cross_radar2", + problem = "problem_season1_to_cross_radar2" + ), + problem_season2_to_cross_radar2 = target( + command = list(to_cross_radar2, "season2"), + radar = "\"radar1\"", + season = "\"season2\"", + separate = "separate_radar1_season2", + to_cross = "to_cross_radar2", + problem = "problem_season2_to_cross_radar2" + ), + separate_radar1_season1 = target( + command = list("radar1", "season1"), + radar = "\"radar1\"", + season = "\"season1\"", + separate = "separate_radar1_season1" + ), + separate_radar2_season1 = target( + command = list("radar2", "season1"), + radar = "\"radar2\"", + season = "\"season1\"", + separate = "separate_radar2_season1" + ), + separate_radar1_season2 = target( + command = list("radar1", "season2"), + radar = "\"radar1\"", + season = "\"season2\"", + separate = "separate_radar1_season2" + ), + separate_radar2_season2 = target( + command = list("radar2", "season2"), + radar = "\"radar2\"", + season = "\"season2\"", + separate = "separate_radar2_season2" + ), + to_cross_radar1 = target( + command = list(data_radar1_1, data_radar1_2), + radar = "\"radar1\"", + to_cross = "to_cross_radar1" + ), + to_cross_radar2 = target( + command = list(data_radar2_1, data_radar2_2), + radar = "\"radar2\"", + to_cross = "to_cross_radar2" + ) + ) + equivalent_plans(out, exp) +})