From 25d83a36c51be6771af9e201e87a18fb15be21d2 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 20 Oct 2022 20:58:51 -0700 Subject: [PATCH] ARROW-17980: [C++] As-of-Join Substrait extension (#17) * ARROW-17980: [C++] As-of-Join Substrait extension * add missing file * add missing proto * CI fixes * distinct keys per input table * CI fixes * resolve conflict * fix typo * ARROW-17980: Change extensions package from arrow::substrait to arrow::substrait_ext * ARROW-17980: Remove more instances of ::substrait Co-authored-by: Yaron Gvili --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 27 +++ cpp/proto/substrait/extension_rels.proto | 36 ++++ .../arrow/compute/exec/asof_join_benchmark.cc | 12 +- cpp/src/arrow/compute/exec/asof_join_node.cc | 107 ++++++++-- cpp/src/arrow/compute/exec/asof_join_node.h | 37 ++++ .../arrow/compute/exec/asof_join_node_test.cc | 27 ++- cpp/src/arrow/compute/exec/options.h | 41 ++-- cpp/src/arrow/engine/CMakeLists.txt | 1 + cpp/src/arrow/engine/substrait/options.cc | 111 ++++++++++ cpp/src/arrow/engine/substrait/options.h | 17 +- .../engine/substrait/relation_internal.cc | 30 +++ .../engine/substrait/relation_internal.h | 2 +- cpp/src/arrow/engine/substrait/serde_test.cc | 200 ++++++++++++++++++ cpp/src/arrow/engine/substrait/type_fwd.h | 1 + .../arrow/engine/substrait/type_internal.cc | 85 ++++---- 15 files changed, 643 insertions(+), 91 deletions(-) create mode 100644 cpp/proto/substrait/extension_rels.proto create mode 100644 cpp/src/arrow/compute/exec/asof_join_node.h create mode 100644 cpp/src/arrow/engine/substrait/options.cc diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index ef4026379518d..9b6cc4865f3d9 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1724,6 +1724,10 @@ macro(build_substrait) # Note: not all protos in Substrait actually matter to plan # consumption. No need to build the ones we don't need. set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type) + set(ARROW_SUBSTRAIT_PROTOS extension_rels) + set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto") + message("SOURCE DIR IS ${SOURCE_DIR} AND ${CMAKE_SOURCE_DIR} AND ${ARROW_SUBSTRAIT_PROTOS_DIR}" + ) externalproject_add(substrait_ep CONFIGURE_COMMAND "" @@ -1774,6 +1778,29 @@ macro(build_substrait) list(APPEND SUBSTRAIT_SOURCES "${SUBSTRAIT_PROTO_GEN}.cc") endforeach() + message("SOURCE DIR2 IS ${SOURCE_DIR} AND ${CMAKE_SOURCE_DIR} AND ${ARROW_SUBSTRAIT_PROTOS_DIR}" + ) + foreach(ARROW_SUBSTRAIT_PROTO ${ARROW_SUBSTRAIT_PROTOS}) + set(ARROW_SUBSTRAIT_PROTO_GEN + "${SUBSTRAIT_CPP_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.pb") + foreach(EXT h cc) + set_source_files_properties("${ARROW_SUBSTRAIT_PROTO_GEN}.${EXT}" + PROPERTIES COMPILE_OPTIONS + "${SUBSTRAIT_SUPPRESSED_FLAGS}" + GENERATED TRUE + SKIP_UNITY_BUILD_INCLUSION TRUE) + list(APPEND SUBSTRAIT_PROTO_GEN_ALL "${ARROW_SUBSTRAIT_PROTO_GEN}.${EXT}") + endforeach() + add_custom_command(OUTPUT "${ARROW_SUBSTRAIT_PROTO_GEN}.cc" + "${ARROW_SUBSTRAIT_PROTO_GEN}.h" + COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${SUBSTRAIT_LOCAL_DIR}/proto" + "-I${ARROW_SUBSTRAIT_PROTOS_DIR}" + "--cpp_out=${SUBSTRAIT_CPP_DIR}" + "${ARROW_SUBSTRAIT_PROTOS_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.proto" + DEPENDS ${PROTO_DEPENDS} substrait_ep) + + list(APPEND SUBSTRAIT_SOURCES "${ARROW_SUBSTRAIT_PROTO_GEN}.cc") + endforeach() add_custom_target(substrait_gen ALL DEPENDS ${SUBSTRAIT_PROTO_GEN_ALL}) diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto new file mode 100644 index 0000000000000..518412969f511 --- /dev/null +++ b/cpp/proto/substrait/extension_rels.proto @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +syntax = "proto3"; + +package arrow.substrait_ext; + +import "substrait/algebra.proto"; + +option csharp_namespace = "Arrow.Substrait"; +option go_package = "github.com/apache/arrow/substrait"; +option java_multiple_files = true; +option java_package = "io.arrow.substrait"; + +message AsOfJoinRel { + repeated AsOfJoinKeys input_keys = 1; + int64 tolerance = 2; + + message AsOfJoinKeys { + .substrait.Expression on = 1; + repeated .substrait.Expression by = 2; + } +} diff --git a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc index d510774aaf4bc..3c6b78d29f110 100644 --- a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc @@ -106,9 +106,19 @@ static void TableJoinOverhead(benchmark::State& state, benchmark::Counter(static_cast(default_memory_pool()->max_memory())); } +AsofJoinNodeOptions GetRepeatedOptions(size_t repeat, FieldRef on_key, + std::vector by_key, int64_t tolerance) { + std::vector input_keys(repeat); + for (size_t i = 0; i < repeat; i++) { + input_keys[i] = {on_key, by_key}; + } + return AsofJoinNodeOptions(input_keys, tolerance); +} + static void AsOfJoinOverhead(benchmark::State& state) { int64_t tolerance = 0; - AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, {kKeyCol}, tolerance); + AsofJoinNodeOptions options = + GetRepeatedOptions(int(state.range(4)), kTimeCol, {kKeyCol}, tolerance); TableJoinOverhead( state, TableGenerationProperties{int(state.range(0)), int(state.range(1)), diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index aef652e96627e..bf1bfdfc60cd3 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/compute/exec/asof_join_node.h" + #include #include #include @@ -951,17 +953,16 @@ class AsofJoinNode : public ExecNode { } static arrow::Result> MakeOutputSchema( - const std::vector& inputs, + const std::vector> input_schema, const std::vector& indices_of_on_key, const std::vector>& indices_of_by_key) { std::vector> fields; - size_t n_by = indices_of_by_key[0].size(); + size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size(); const DataType* on_key_type = NULLPTR; std::vector by_key_type(n_by, NULLPTR); // Take all non-key, non-time RHS fields - for (size_t j = 0; j < inputs.size(); ++j) { - const auto& input_schema = inputs[j]->output_schema(); + for (size_t j = 0; j < input_schema.size(); ++j) { const auto& on_field_ix = indices_of_on_key[j]; const auto& by_field_ix = indices_of_by_key[j]; @@ -969,10 +970,10 @@ class AsofJoinNode : public ExecNode { return Status::Invalid("Missing join key on table ", j); } - const auto& on_field = input_schema->fields()[on_field_ix]; + const auto& on_field = input_schema[j]->fields()[on_field_ix]; std::vector by_field(n_by); for (size_t k = 0; k < n_by; k++) { - by_field[k] = input_schema->fields()[by_field_ix[k]].get(); + by_field[k] = input_schema[j]->fields()[by_field_ix[k]].get(); } if (on_key_type == NULLPTR) { @@ -992,8 +993,8 @@ class AsofJoinNode : public ExecNode { } } - for (int i = 0; i < input_schema->num_fields(); ++i) { - const auto field = input_schema->field(i); + for (int i = 0; i < input_schema[j]->num_fields(); ++i) { + const auto field = input_schema[j]->field(i); if (i == on_field_ix) { ARROW_RETURN_NOT_OK(is_valid_on_field(field)); // Only add on field from the left table @@ -1030,6 +1031,56 @@ class AsofJoinNode : public ExecNode { return match.indices()[0]; } + static Result GetByKeySize( + const std::vector& input_keys) { + size_t n_by = 0; + for (size_t i = 0; i < input_keys.size(); ++i) { + const auto& by_key = input_keys[i].by_key; + if (i == 0) { + n_by = by_key.size(); + } else if (n_by != by_key.size()) { + return Status::Invalid("inconsistent size of by-key across inputs"); + } + } + return n_by; + } + + static Result> GetIndicesOfOnKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + size_t n_input = input_schema.size(); + std::vector indices_of_on_key(n_input); + for (size_t i = 0; i < n_input; ++i) { + const auto& on_key = input_keys[i].on_key; + ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], + FindColIndex(*input_schema[i], on_key, "on")); + } + return indices_of_on_key; + } + + static Result>> GetIndicesOfByKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(input_keys)); + size_t n_input = input_schema.size(); + std::vector> indices_of_by_key( + n_input, std::vector(n_by)); + for (size_t i = 0; i < n_input; ++i) { + for (size_t k = 0; k < n_by; k++) { + const auto& by_key = input_keys[i].by_key; + ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], + FindColIndex(*input_schema[i], by_key[k], "by")); + } + } + return indices_of_by_key; + } + static arrow::Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; @@ -1040,24 +1091,21 @@ class AsofJoinNode : public ExecNode { join_options.tolerance); } - size_t n_input = inputs.size(), n_by = join_options.by_key.size(); + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys)); + size_t n_input = inputs.size(); std::vector input_labels(n_input); - std::vector indices_of_on_key(n_input); - std::vector> indices_of_by_key( - n_input, std::vector(n_by)); + std::vector> input_schema(n_input); for (size_t i = 0; i < n_input; ++i) { input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i); - const Schema& input_schema = *inputs[i]->output_schema(); - ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], - FindColIndex(input_schema, join_options.on_key, "on")); - for (size_t k = 0; k < n_by; k++) { - ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], - FindColIndex(input_schema, join_options.by_key[k], "by")); - } + input_schema[i] = inputs[i]->output_schema(); } - - ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, - MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key)); + ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, + GetIndicesOfOnKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, + GetIndicesOfByKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr output_schema, + MakeOutputSchema(input_schema, indices_of_on_key, indices_of_by_key)); std::vector> key_hashers; for (size_t i = 0; i < n_input; i++) { @@ -1173,5 +1221,20 @@ void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { } } // namespace internal +namespace asofjoin { + +Result> MakeOutputSchema( + const std::vector>& input_schema, + const std::vector& input_keys) { + ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, + AsofJoinNode::GetIndicesOfOnKey(input_schema, input_keys)); + ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, + AsofJoinNode::GetIndicesOfByKey(input_schema, input_keys)); + return AsofJoinNode::MakeOutputSchema(input_schema, indices_of_on_key, + indices_of_by_key); +} + +} // namespace asofjoin + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node.h b/cpp/src/arrow/compute/exec/asof_join_node.h new file mode 100644 index 0000000000000..27777090d3d47 --- /dev/null +++ b/cpp/src/arrow/compute/exec/asof_join_node.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/options.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { +namespace asofjoin { + +using AsofJoinKeys = AsofJoinNodeOptions::Keys; + +ARROW_EXPORT Result> MakeOutputSchema( + const std::vector>& input_schema, + const std::vector& input_keys); + +} // namespace asofjoin +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index bde8d1fb0053c..7aec42f41c1ba 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -128,6 +128,15 @@ void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { ASSERT_OK(builder.Finish(&empty)); } +AsofJoinNodeOptions GetRepeatedOptions(size_t repeat, FieldRef on_key, + std::vector by_key, int64_t tolerance) { + std::vector input_keys(repeat); + for (size_t i = 0; i < repeat; i++) { + input_keys[i] = {on_key, by_key}; + } + return AsofJoinNodeOptions(input_keys, tolerance); +} + // mutates by copying from_key into to_key and changing from_key to zero Result MutateByKey(BatchesWithSchema& batches, std::string from_key, std::string to_key, bool replace_key = false, @@ -246,7 +255,7 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \ const FieldRef time, by_key_type key, const int64_t tolerance) { \ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \ - AsofJoinNodeOptions(time, {key}, tolerance)); \ + GetRepeatedOptions(3, time, {key}, tolerance)); \ } EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) @@ -298,7 +307,7 @@ void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema, int64_t tolerance, const std::string& expected_error_str) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {"key"}, tolerance), + GetRepeatedOptions(2, "time", {"key"}, tolerance), expected_error_str); } @@ -321,27 +330,27 @@ void DoRunMissingKeysTest(const std::shared_ptr& l_schema, void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("invalid_time", {"key"}, 0), + GetRepeatedOptions(2, "invalid_time", {"key"}, 0), "Bad join key on table : No match"); } void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {"invalid_key"}, 0), + GetRepeatedOptions(2, "time", {"invalid_key"}, 0), "Bad join key on table : No match"); } void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { - DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, {"key"}, 0), + DoRunInvalidPlanTest(l_schema, r_schema, GetRepeatedOptions(2, {0, "time"}, {"key"}, 0), "Bad join key on table : No match"); } void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunInvalidPlanTest(l_schema, r_schema, - AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0), + GetRepeatedOptions(2, "time", {FieldRef{0, 1}}, 0), "Bad join key on table : No match"); } @@ -402,7 +411,7 @@ void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, const std::shared_ptr& l_schema, const std::shared_ptr& r_schema) { DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema, - AsofJoinNodeOptions("time", {"key"}, 1000), + GetRepeatedOptions(2, "time", {"key"}, 1000), "out-of-order on-key values"); } @@ -499,7 +508,7 @@ struct BasicTest { ASSERT_OK_AND_ASSIGN(exp_nokey_batches, MutateByKey(exp_nokey_batches, "key", "key2", true, true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, - AsofJoinNodeOptions("time", {"key2"}, tolerance)); + GetRepeatedOptions(3, "time", {"key2"}, tolerance)); }); } static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); } @@ -512,7 +521,7 @@ struct BasicTest { ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key", false, false, true)); CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches, - AsofJoinNodeOptions("time", {}, tolerance)); + GetRepeatedOptions(3, "time", {}, tolerance)); }); } static void DoMutateEmptyKey(BasicTest& basic_tests) { diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 68178ea9f2113..edd4776e6345a 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -393,22 +393,35 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} - - /// \brief "on" key for the join. + /// \brief Keys for one input table of the AsofJoin operation /// - /// All inputs tables must be sorted by the "on" key. Must be a single field of a common - /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff - /// left_on - tolerance <= right_on <= left_on. - /// Currently, the "on" key must be of an integer, date, or timestamp type. - FieldRef on_key; - /// \brief "by" key for the join. + /// The keys must be consistent across the input tables: + /// Each "on" key must refer to a field of the same type and units across the tables. + /// Each "by" key must refer to a list of fields of the same types across the tables. + struct Keys { + /// \brief "on" key for the join. + /// + /// The input table must be sorted by the "on" key. Must be a single field of a common + /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff + /// left_on - tolerance <= right_on <= left_on. + /// Currently, the "on" key must be of an integer, date, or timestamp type. + FieldRef on_key; + /// \brief "by" key for the join. + /// + /// The input table must have each field of the "by" key. Exact equality is used for + /// each field of the "by" key. + /// Currently, each field of the "by" key must be of an integer, date, timestamp, or + /// base-binary type. + std::vector by_key; + }; + + AsofJoinNodeOptions(std::vector input_keys, int64_t tolerance) + : input_keys(std::move(input_keys)), tolerance(tolerance) {} + + /// \brief AsofJoin keys per input table. /// - /// All input tables must have the "by" key. Exact equality - /// is used for the "by" key. - /// Currently, the "by" key must be of an integer, date, timestamp, or base-binary type - std::vector by_key; + /// See `Keys` for details. + std::vector input_keys; /// \brief Tolerance for inexact "on" key matching. Must be non-negative. /// /// The tolerance is interpreted in the same units as the "on" key. diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index a8d5be90af872..91cbdc1dcd0f8 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -23,6 +23,7 @@ set(ARROW_SUBSTRAIT_SRCS substrait/expression_internal.cc substrait/extension_set.cc substrait/extension_types.cc + substrait/options.cc substrait/plan_internal.cc substrait/relation_internal.cc substrait/serde.cc diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc new file mode 100644 index 0000000000000..6614814771b4d --- /dev/null +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include + +#include "arrow/engine/substrait/options.h" + +#include +#include "arrow/compute/exec/asof_join_node.h" +#include "arrow/compute/exec/options.h" +#include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/relation_internal.h" +#include "substrait/extension_rels.pb.h" + +namespace arrow { +namespace engine { + +class DefaultExtensionProvider : public ExtensionProvider { + public: + Result MakeRel(const std::vector& inputs, + const google::protobuf::Any& rel, + const ExtensionSet& ext_set) override { + if (rel.Is()) { + substrait_ext::AsOfJoinRel as_of_join_rel; + rel.UnpackTo(&as_of_join_rel); + return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set); + } + return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ", + rel.DebugString()); + } + + private: + Result MakeAsOfJoinRel( + const std::vector& inputs, + const substrait_ext::AsOfJoinRel& as_of_join_rel, const ExtensionSet& ext_set) { + if (inputs.size() < 2) { + return Status::Invalid("substrait::AsOfJoinNode too few input tables: ", + inputs.size()); + } + if (static_cast(as_of_join_rel.input_keys_size()) != inputs.size()) { + return Status::Invalid("substrait::AsOfJoinNode mismatched number of inputs"); + } + + size_t n_input = inputs.size(), i = 0; + std::vector input_keys(n_input); + for (const auto& keys : as_of_join_rel.input_keys()) { + // on-key + if (!keys.has_on()) { + return Status::Invalid("substrait::AsOfJoinNode missing on-key for input ", i); + } + ARROW_ASSIGN_OR_RAISE(auto on_key_expr, FromProto(keys.on(), ext_set, {})); + if (on_key_expr.field_ref() == NULLPTR) { + return Status::NotImplemented( + "substrait::AsOfJoinNode non-field-ref on-key for input ", i); + } + const FieldRef& on_key = *on_key_expr.field_ref(); + + // by-key + std::vector by_key; + for (const auto& by_item : keys.by()) { + ARROW_ASSIGN_OR_RAISE(auto by_key_expr, FromProto(by_item, ext_set, {})); + if (by_key_expr.field_ref() == NULLPTR) { + return Status::NotImplemented( + "substrait::AsOfJoinNode non-field-ref by-key for input ", i); + } + by_key.push_back(*by_key_expr.field_ref()); + } + + input_keys[i] = {std::move(on_key), std::move(by_key)}; + ++i; + } + + // schema + int64_t tolerance = as_of_join_rel.tolerance(); + std::vector> input_schema(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + input_schema[i] = inputs[i].output_schema; + } + ARROW_ASSIGN_OR_RAISE(auto schema, + compute::asofjoin::MakeOutputSchema(input_schema, input_keys)); + compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys), tolerance}; + + // declaration + std::vector input_decls(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + input_decls[i] = inputs[i].declaration; + } + return DeclarationInfo{ + compute::Declaration("asofjoin", input_decls, std::move(asofjoin_node_opts)), + std::move(schema)}; + } +}; + +std::shared_ptr ExtensionProvider::kDefaultExtensionProvider = + std::make_shared(); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index 014842f4d8f1b..2f46abbbfb883 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -23,7 +23,11 @@ #include #include +#include + #include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/type_fwd.h" +#include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" namespace arrow { @@ -65,9 +69,18 @@ using NamedTableProvider = std::function(const std::vector&)>; static NamedTableProvider kDefaultNamedTableProvider; +class ARROW_ENGINE_EXPORT ExtensionProvider { + public: + static std::shared_ptr kDefaultExtensionProvider; + virtual ~ExtensionProvider() = default; + virtual Result MakeRel(const std::vector& inputs, + const google::protobuf::Any& rel, + const ExtensionSet& ext_set) = 0; +}; + /// Options that control the conversion between Substrait and Acero representations of a /// plan. -struct ConversionOptions { +struct ARROW_ENGINE_EXPORT ConversionOptions { /// \brief How strictly the converter should adhere to the structure of the input. ConversionStrictness strictness = ConversionStrictness::BEST_EFFORT; /// \brief A custom strategy to be used for providing named tables @@ -75,6 +88,8 @@ struct ConversionOptions { /// The default behavior will return an invalid status if the plan has any /// named table relations. NamedTableProvider named_table_provider = kDefaultNamedTableProvider; + std::shared_ptr extension_provider = + ExtensionProvider::kDefaultExtensionProvider; }; } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index b5de5c1a4bacd..d06cc7ef6f23f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -609,6 +609,36 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::move(aggregate_schema)); } + case substrait::Rel::RelTypeCase::kExtensionLeaf: { + const auto& ext = rel.extension_leaf(); + ARROW_ASSIGN_OR_RAISE( + auto ext_leaf_decl, + conversion_options.extension_provider->MakeRel({}, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_leaf_decl), ext_leaf_decl.output_schema); + } + case substrait::Rel::RelTypeCase::kExtensionSingle: { + const auto& ext = rel.extension_single(); + ARROW_ASSIGN_OR_RAISE(DeclarationInfo input, + FromProto(ext.input(), ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + auto ext_single_decl, + conversion_options.extension_provider->MakeRel({input}, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_single_decl), ext_single_decl.output_schema); + } + case substrait::Rel::RelTypeCase::kExtensionMulti: { + const auto& ext = rel.extension_multi(); + std::vector inputs; + for (const auto& input : ext.inputs()) { + ARROW_ASSIGN_OR_RAISE(auto input_info, + FromProto(input, ext_set, conversion_options)); + inputs.push_back(std::move(input_info)); + } + ARROW_ASSIGN_OR_RAISE( + auto ext_multi_decl, + conversion_options.extension_provider->MakeRel(inputs, ext.detail(), ext_set)); + return ProcessEmit(ext, std::move(ext_multi_decl), ext_multi_decl.output_schema); + } + default: break; } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 514f3f97fc053..c724149078342 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -32,7 +32,7 @@ namespace arrow { namespace engine { /// Information resulting from converting a Substrait relation. -struct DeclarationInfo { +struct ARROW_ENGINE_EXPORT DeclarationInfo { /// The compute declaration produced thus far. compute::Declaration declaration; diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index a74b491f872f5..09d260aadf60b 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -20,6 +20,7 @@ #include #include +#include "arrow/compute/exec/asof_join_node.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/file_base.h" @@ -2254,6 +2255,16 @@ TEST(SubstraitRoundTrip, BasicPlanEndToEnd) { EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } +NamedTableProvider ProvideMadeTable( + std::function>(const std::vector&)> make) { + return [make](const std::vector& names) -> Result { + ARROW_ASSIGN_OR_RAISE(auto table, make(names)); + std::shared_ptr options = + std::make_shared(table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; +} + TEST(SubstraitRoundTrip, ProjectRel) { #ifdef _WIN32 GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; @@ -3648,5 +3659,194 @@ TEST(Substrait, NestedEmitProjectWithMultiFieldExpressions) { buf, {}, conversion_options); } +TEST(Substrait, PlanWithExtension) { + // This demos an extension relation + std::string substrait_json = R"({ + "extensionUris": [], + "extensions": [], + "relations": [{ + "root": { + "input": { + "extension_multi": { + "common": { + "emit": { + "outputMapping": [0, 1, 2, 3] + } + }, + "inputs": [ + { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value1"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T1"] + } + } + }, + { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value2"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T2"] + } + } + } + ], + "detail": { + "@type": "/arrow.substrait_ext.AsOfJoinRel", + "input_keys" : [ + { + "on": { + "selection": { + "directReference": { + "structField": { + "field": 0, + } + }, + "rootReference": {} + } + }, + "by": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1, + } + }, + "rootReference": {} + } + } + ] + }, + { + "on": { + "selection": { + "directReference": { + "structField": { + "field": 0, + } + }, + "rootReference": {} + } + }, + "by": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1, + } + }, + "rootReference": {} + } + } + ] + } + ], + "tolerance": 1000 + } + } + }, + "names": ["time", "key", "value1", "value2"] + } + }], + "expectedTypeUrls": [] + })"; + + std::vector> input_schema = { + schema({field("time", int32()), field("key", int32()), field("value1", float64())}), + schema( + {field("time", int32()), field("key", int32()), field("value2", float64())})}; + NamedTableProvider table_provider = ProvideMadeTable( + [&input_schema]( + const std::vector& names) -> Result> { + if (names.size() != 1) { + return Status::Invalid("Multiple test table names"); + } + if (names[0] == "T1") { + return TableFromJSON(input_schema[0], + {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 3.1]]"}); + } + if (names[0] == "T2") { + return TableFromJSON(input_schema[1], + {"[[1, 1, 1.2], [3, 2, 2.2], [5, 2, 3.2]]"}); + } + return Status::Invalid("Unknown test table name ", names[0]); + }); + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + ASSERT_OK_AND_ASSIGN( + auto out_schema, + compute::asofjoin::MakeOutputSchema( + input_schema, {{FieldRef(0), {FieldRef(1)}}, {FieldRef(0), {FieldRef(1)}}})); + auto expected_table = TableFromJSON( + out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"}); + CheckRoundTripResult(std::move(out_schema), std::move(expected_table), + *compute::default_exec_context(), buf, {}, conversion_options); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_fwd.h b/cpp/src/arrow/engine/substrait/type_fwd.h index 235d9e82d1b87..6089d3f747a82 100644 --- a/cpp/src/arrow/engine/substrait/type_fwd.h +++ b/cpp/src/arrow/engine/substrait/type_fwd.h @@ -26,6 +26,7 @@ class ExtensionIdRegistry; class ExtensionSet; struct ConversionOptions; +struct DeclarationInfo; } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index 16032df67db63..50ed52fa937b7 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -36,7 +36,7 @@ namespace { template bool IsNullable(const TypeMessage& type) { // FIXME what can we do with NULLABILITY_UNSPECIFIED - return type.nullability() != ::substrait::Type::NULLABILITY_REQUIRED; + return type.nullability() != substrait::Type::NULLABILITY_REQUIRED; } template @@ -87,67 +87,67 @@ Result FieldsFromProto(int size, const Types& types, } // namespace Result, bool>> FromProto( - const ::substrait::Type& type, const ExtensionSet& ext_set, + const substrait::Type& type, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { switch (type.kind_case()) { - case ::substrait::Type::kBool: + case substrait::Type::kBool: return FromProtoImpl(type.bool_()); - case ::substrait::Type::kI8: + case substrait::Type::kI8: return FromProtoImpl(type.i8()); - case ::substrait::Type::kI16: + case substrait::Type::kI16: return FromProtoImpl(type.i16()); - case ::substrait::Type::kI32: + case substrait::Type::kI32: return FromProtoImpl(type.i32()); - case ::substrait::Type::kI64: + case substrait::Type::kI64: return FromProtoImpl(type.i64()); - case ::substrait::Type::kFp32: + case substrait::Type::kFp32: return FromProtoImpl(type.fp32()); - case ::substrait::Type::kFp64: + case substrait::Type::kFp64: return FromProtoImpl(type.fp64()); - case ::substrait::Type::kString: + case substrait::Type::kString: return FromProtoImpl(type.string()); - case ::substrait::Type::kBinary: + case substrait::Type::kBinary: return FromProtoImpl(type.binary()); - case ::substrait::Type::kTimestamp: + case substrait::Type::kTimestamp: return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); - case ::substrait::Type::kTimestampTz: + case substrait::Type::kTimestampTz: return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, TimestampTzTimezoneString()); - case ::substrait::Type::kDate: + case substrait::Type::kDate: return FromProtoImpl(type.date()); - case ::substrait::Type::kTime: + case substrait::Type::kTime: return FromProtoImpl(type.time(), TimeUnit::MICRO); - case ::substrait::Type::kIntervalYear: + case substrait::Type::kIntervalYear: return FromProtoImpl(type.interval_year(), interval_year); - case ::substrait::Type::kIntervalDay: + case substrait::Type::kIntervalDay: return FromProtoImpl(type.interval_day(), interval_day); - case ::substrait::Type::kUuid: + case substrait::Type::kUuid: return FromProtoImpl(type.uuid(), uuid); - case ::substrait::Type::kFixedChar: + case substrait::Type::kFixedChar: return FromProtoImpl(type.fixed_char(), fixed_char, type.fixed_char().length()); - case ::substrait::Type::kVarchar: + case substrait::Type::kVarchar: return FromProtoImpl(type.varchar(), varchar, type.varchar().length()); - case ::substrait::Type::kFixedBinary: + case substrait::Type::kFixedBinary: return FromProtoImpl(type.fixed_binary(), type.fixed_binary().length()); - case ::substrait::Type::kDecimal: { + case substrait::Type::kDecimal: { const auto& decimal = type.decimal(); return FromProtoImpl(decimal, decimal.precision(), decimal.scale()); } - case ::substrait::Type::kStruct: { + case substrait::Type::kStruct: { const auto& struct_ = type.struct_(); ARROW_ASSIGN_OR_RAISE( @@ -158,7 +158,7 @@ Result, bool>> FromProto( return FromProtoImpl(struct_, std::move(fields)); } - case ::substrait::Type::kList: { + case substrait::Type::kList: { const auto& list = type.list(); if (!list.has_type()) { @@ -173,7 +173,7 @@ Result, bool>> FromProto( list, field("item", std::move(type_nullable.first), type_nullable.second)); } - case ::substrait::Type::kMap: { + case substrait::Type::kMap: { const auto& map = type.map(); static const std::array kMissing = {"key and value", "value", "key", @@ -199,7 +199,7 @@ Result, bool>> FromProto( field("value", std::move(value_nullable.first), value_nullable.second)); } - case ::substrait::Type::kUserDefined: { + case substrait::Type::kUserDefined: { const auto& user_defined = type.user_defined(); uint32_t anchor = user_defined.type_reference(); ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor)); @@ -313,8 +313,7 @@ struct DataTypeToProtoImpl { for (const auto& field : t.fields()) { if (field->metadata() != nullptr) { - return Status::Invalid( - "::substrait::Type::Struct does not support field metadata"); + return Status::Invalid("substrait::Type::Struct does not support field metadata"); } ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set_, conversion_options_)); @@ -378,8 +377,8 @@ struct DataTypeToProtoImpl { template Sub* SetWithThen(void (::substrait::Type::*set_allocated_sub)(Sub*)) { auto sub = std::make_unique(); - sub->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE - : ::substrait::Type::NULLABILITY_REQUIRED); + sub->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE + : substrait::Type::NULLABILITY_REQUIRED); auto out = sub.get(); (type_->*set_allocated_sub)(sub.release()); @@ -394,37 +393,37 @@ struct DataTypeToProtoImpl { template Status EncodeUserDefined(const T& t) { ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(t)); - auto user_defined = std::make_unique<::substrait::Type::UserDefined>(); + auto user_defined = std::make_unique(); user_defined->set_type_reference(anchor); - user_defined->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE - : ::substrait::Type::NULLABILITY_REQUIRED); + user_defined->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE + : substrait::Type::NULLABILITY_REQUIRED); type_->set_allocated_user_defined(user_defined.release()); return Status::OK(); } Status NotImplemented(const DataType& t) { - return Status::NotImplemented("conversion to ::substrait::Type from ", t.ToString()); + return Status::NotImplemented("conversion to substrait::Type from ", t.ToString()); } Status operator()(const DataType& type) { return VisitTypeInline(type, this); } - ::substrait::Type* type_; + substrait::Type* type_; bool nullable_; ExtensionSet* ext_set_; const ConversionOptions& conversion_options_; }; } // namespace -Result> ToProto( +Result> ToProto( const DataType& type, bool nullable, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { - auto out = std::make_unique<::substrait::Type>(); + auto out = std::make_unique(); RETURN_NOT_OK( (DataTypeToProtoImpl{out.get(), nullable, ext_set, conversion_options})(type)); return std::move(out); } -Result> FromProto(const ::substrait::NamedStruct& named_struct, +Result> FromProto(const substrait::NamedStruct& named_struct, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { if (!named_struct.has_struct_()) { @@ -468,26 +467,26 @@ void ToProtoGetDepthFirstNames(const FieldVector& fields, } } // namespace -Result> ToProto( +Result> ToProto( const Schema& schema, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { if (schema.metadata()) { - return Status::Invalid("::substrait::NamedStruct does not support schema metadata"); + return Status::Invalid("substrait::NamedStruct does not support schema metadata"); } - auto named_struct = std::make_unique<::substrait::NamedStruct>(); + auto named_struct = std::make_unique(); auto names = named_struct->mutable_names(); names->Reserve(schema.num_fields()); ToProtoGetDepthFirstNames(schema.fields(), names); - auto struct_ = std::make_unique<::substrait::Type::Struct>(); + auto struct_ = std::make_unique(); auto types = struct_->mutable_types(); types->Reserve(schema.num_fields()); for (const auto& field : schema.fields()) { if (field->metadata() != nullptr) { - return Status::Invalid("::substrait::NamedStruct does not support field metadata"); + return Status::Invalid("substrait::NamedStruct does not support field metadata"); } ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set,