From 713b61d55d432b507ae8ecc10abc9d7dae07eb77 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Mon, 20 Jan 2025 23:06:09 +0800 Subject: [PATCH 1/5] add get_struct_field --- .../registration/RegisterSpecialForm.cpp | 3 + .../sparksql/specialforms/CMakeLists.txt | 1 + .../sparksql/specialforms/GetStructField.cpp | 105 ++++++++++++++++++ .../sparksql/specialforms/GetStructField.h | 67 +++++++++++ velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../sparksql/tests/GetStructFieldTest.cpp | 88 +++++++++++++++ 6 files changed, 265 insertions(+) create mode 100644 velox/functions/sparksql/specialforms/GetStructField.cpp create mode 100644 velox/functions/sparksql/specialforms/GetStructField.h create mode 100644 velox/functions/sparksql/tests/GetStructFieldTest.cpp diff --git a/velox/functions/sparksql/registration/RegisterSpecialForm.cpp b/velox/functions/sparksql/registration/RegisterSpecialForm.cpp index d9f12abe4f80..356e40fdf34c 100644 --- a/velox/functions/sparksql/registration/RegisterSpecialForm.cpp +++ b/velox/functions/sparksql/registration/RegisterSpecialForm.cpp @@ -18,6 +18,7 @@ #include "velox/expression/SpecialFormRegistry.h" #include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h" #include "velox/functions/sparksql/specialforms/DecimalRound.h" +#include "velox/functions/sparksql/specialforms/GetStructField.h" #include "velox/functions/sparksql/specialforms/MakeDecimal.h" #include "velox/functions/sparksql/specialforms/SparkCastExpr.h" @@ -44,6 +45,8 @@ void registerSpecialFormGeneralFunctions(const std::string& prefix) { "cast", std::make_unique()); registerFunctionCallToSpecialForm( "try_cast", std::make_unique()); + registerFunctionCallToSpecialForm( + "get_struct_field", std::make_unique()); } } // namespace sparksql } // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/specialforms/CMakeLists.txt b/velox/functions/sparksql/specialforms/CMakeLists.txt index e141e0074bc8..0b573f4d6312 100644 --- a/velox/functions/sparksql/specialforms/CMakeLists.txt +++ b/velox/functions/sparksql/specialforms/CMakeLists.txt @@ -16,6 +16,7 @@ velox_add_library( velox_functions_spark_specialforms AtLeastNNonNulls.cpp DecimalRound.cpp + GetStructField.cpp MakeDecimal.cpp SparkCastExpr.cpp SparkCastHooks.cpp) diff --git a/velox/functions/sparksql/specialforms/GetStructField.cpp b/velox/functions/sparksql/specialforms/GetStructField.cpp new file mode 100644 index 000000000000..909d8c4865f1 --- /dev/null +++ b/velox/functions/sparksql/specialforms/GetStructField.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/specialforms/GetStructField.h" +#include +#include "expression/Expr.h" +#include "vector/ComplexVector.h" +#include "vector/ConstantVector.h" +#include "velox/expression/PeeledEncoding.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::functions::sparksql { + +void GetStructFieldExpr::evalSpecialForm( + const SelectivityVector& rows, + EvalCtx& context, + VectorPtr& result) { + VectorPtr input; + VectorPtr ordinalVector; + inputs_[0]->eval(rows, context, input); + inputs_[1]->eval(rows, context, ordinalVector); + VELOX_USER_CHECK(ordinalVector->isConstantEncoding()); + auto ordinal = + ordinalVector->asUnchecked>()->valueAt(0); + auto resultType = std::const_pointer_cast(type_); + + LocalSelectivityVector remainingRows(context, rows); + context.deselectErrors(*remainingRows); + + LocalDecodedVector decoded(context, *input, *remainingRows); + + auto* rawNulls = decoded->nulls(); + if (rawNulls) { + remainingRows->deselectNulls( + rawNulls, remainingRows->begin(), remainingRows->end()); + } + + VectorPtr localResult; + if (!remainingRows->hasSelections()) { + localResult = + BaseVector::createNullConstant(resultType, rows.end(), context.pool()); + } else { + auto rowData = decoded->base()->as(); + if (decoded->isIdentityMapping()) { + localResult = rowData->childAt(ordinal); + } else { + localResult = + decoded->wrap(rowData->childAt(ordinal), *input, decoded->size()); + } + } + + context.moveOrCopyResult(localResult, *remainingRows, result); + context.releaseVector(localResult); + + VELOX_CHECK_NOT_NULL(result); + if (rawNulls || context.errors()) { + EvalCtx::addNulls( + rows, remainingRows->asRange().bits(), context, resultType, result); + } + + context.releaseVector(input); + context.releaseVector(ordinalVector); +} + +TypePtr GetStructFieldCallToSpecialForm::resolveType( + const std::vector& /*argTypes*/) { + VELOX_FAIL("get_struct_field function does not support type resolution."); +} + +ExprPtr GetStructFieldCallToSpecialForm::constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& /*config*/) { + VELOX_USER_CHECK_EQ(args.size(), 2, "get_struct_field expects two argument."); + + VELOX_USER_CHECK_EQ( + args[0]->type()->kind(), + TypeKind::ROW, + "The first argument of get_struct_field should be of row type."); + + VELOX_USER_CHECK_EQ( + args[1]->type()->kind(), + TypeKind::INTEGER, + "The second argument of get_struct_field should be of integer type."); + + return std::make_shared( + type, std::move(args), trackCpuUsage); +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/specialforms/GetStructField.h b/velox/functions/sparksql/specialforms/GetStructField.h new file mode 100644 index 000000000000..e220c0f2e49b --- /dev/null +++ b/velox/functions/sparksql/specialforms/GetStructField.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/FunctionCallToSpecialForm.h" +#include "velox/expression/SpecialForm.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::functions::sparksql { + +class GetStructFieldExpr : public SpecialForm { + public: + /// @param type The target type of the returned expression + /// @param expr The struct expression to get by field ordinal + /// @param trackCpuUsage Whether to track CPU usage + GetStructFieldExpr( + TypePtr type, + std::vector&& expr, + bool trackCpuUsage) + : SpecialForm(type, expr, "get_struct_field", false, trackCpuUsage) {} + + void evalSpecialForm( + const SelectivityVector& rows, + EvalCtx& context, + VectorPtr& result) override; + + void computePropagatesNulls() override { + propagatesNulls_ = inputs_[0]->propagatesNulls(); + } +}; + +class GetStructFieldCallToSpecialForm : public exec::FunctionCallToSpecialForm { + public: + // Throws not supported exception. + TypePtr resolveType(const std::vector& argTypes) override; + + /// @brief Returns an expression for get_struct_field special form. The + /// expression is a regular expression based on a custom VectorFunction + /// implementation. + /// @param type Result type. + /// @param args Two inputs. First input should be a RowVector. Second + /// input is the ordinal to get from RowVector, and must be constant + /// INTEGER. + exec::ExprPtr constructSpecialForm( + const TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const core::QueryConfig& config) override; + + static constexpr const char* kGetStructField = "get_struct_field"; +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 39087bd8adb5..1ccf8e3151d3 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -34,6 +34,7 @@ add_executable( DecimalUtilTest.cpp ElementAtTest.cpp GetJsonObjectTest.cpp + GetStructFieldTest.cpp HashTest.cpp InTest.cpp JsonObjectKeysTest.cpp diff --git a/velox/functions/sparksql/tests/GetStructFieldTest.cpp b/velox/functions/sparksql/tests/GetStructFieldTest.cpp new file mode 100644 index 000000000000..469cfac9614f --- /dev/null +++ b/velox/functions/sparksql/tests/GetStructFieldTest.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "functions/sparksql/specialforms/GetStructField.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox::test; + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class GetStructFieldTest : public SparkFunctionBaseTest { + protected: + void GetStructFieldSimple( + const VectorPtr& parameter, + const VectorPtr& index, + const TypePtr& inputType, + const TypePtr& resultType, + const VectorPtr& expected = nullptr) { + core::TypedExprPtr data = + std::make_shared(inputType, "c0"); + core::TypedExprPtr ordinal = + std::make_shared(INTEGER(), "c1"); + auto getStructFieldExpr = std::make_shared( + resultType, + std::vector{data, ordinal}, + "get_struct_field"); + auto result = evaluate( + getStructFieldExpr, makeRowVector({"c0", "c1"}, {parameter, index})); + if (expected) { + ::facebook::velox::test::assertEqualVectors(expected, result); + } + } +}; + +TEST_F(GetStructFieldTest, simpleInteger) { + auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), BIGINT(), BIGINT()}); + auto data = makeRowVector( + {"k1", "k2", "a"}, + {makeNullableFlatVector({12}), + makeNullableFlatVector({2}), + makeNullableFlatVector({1})}); + auto result = makeNullableFlatVector({2}); + auto index = makeConstant(1, 1); + GetStructFieldSimple(data, index, dataType, INTEGER(), result); +} + +TEST_F(GetStructFieldTest, simpleVarchar) { + auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), VARCHAR(), BIGINT()}); + auto data = makeRowVector( + {"k1", "k2", "a"}, + {makeNullableFlatVector({12}), + makeNullableFlatVector({"Milly"}), + makeNullableFlatVector({1})}); + auto result = makeNullableFlatVector({"Milly"}); + auto index = makeConstant(1, 1); + GetStructFieldSimple(data, index, dataType, VARCHAR(), result); +} + +TEST_F(GetStructFieldTest, simpleNull) { + auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), BIGINT(), BIGINT()}); + auto data = makeRowVector( + {"k1", "k2", "a"}, + {makeNullableFlatVector({12}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({1})}); + auto result = makeNullableFlatVector({std::nullopt}); + auto index = makeConstant(1, 1); + GetStructFieldSimple(data, index, dataType, INTEGER(), result); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test From 6de6bd5f7f82122c9944f1fc94bf1a11f9e6fc8c Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Fri, 7 Feb 2025 10:14:29 +0800 Subject: [PATCH 2/5] address comments --- .../sparksql/specialforms/GetStructField.cpp | 1 - .../sparksql/tests/GetStructFieldTest.cpp | 101 +++++++++--------- 2 files changed, 53 insertions(+), 49 deletions(-) diff --git a/velox/functions/sparksql/specialforms/GetStructField.cpp b/velox/functions/sparksql/specialforms/GetStructField.cpp index 909d8c4865f1..c9682ffc667d 100644 --- a/velox/functions/sparksql/specialforms/GetStructField.cpp +++ b/velox/functions/sparksql/specialforms/GetStructField.cpp @@ -15,7 +15,6 @@ */ #include "velox/functions/sparksql/specialforms/GetStructField.h" -#include #include "expression/Expr.h" #include "vector/ComplexVector.h" #include "vector/ConstantVector.h" diff --git a/velox/functions/sparksql/tests/GetStructFieldTest.cpp b/velox/functions/sparksql/tests/GetStructFieldTest.cpp index 469cfac9614f..ae2ba92a5420 100644 --- a/velox/functions/sparksql/tests/GetStructFieldTest.cpp +++ b/velox/functions/sparksql/tests/GetStructFieldTest.cpp @@ -15,7 +15,6 @@ */ #include -#include "functions/sparksql/specialforms/GetStructField.h" #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -26,62 +25,68 @@ namespace { class GetStructFieldTest : public SparkFunctionBaseTest { protected: - void GetStructFieldSimple( - const VectorPtr& parameter, - const VectorPtr& index, - const TypePtr& inputType, - const TypePtr& resultType, - const VectorPtr& expected = nullptr) { - core::TypedExprPtr data = - std::make_shared(inputType, "c0"); - core::TypedExprPtr ordinal = - std::make_shared(INTEGER(), "c1"); - auto getStructFieldExpr = std::make_shared( + void testGetStructField( + const VectorPtr& input, + int ordinal, + const VectorPtr& expected) { + auto batchSize = expected->size(); + auto ordinalVector = makeConstant(ordinal, batchSize); + std::vector inputs = { + std::make_shared(input->type(), "c0"), + std::make_shared(INTEGER(), "c1") + }; + auto resultType = expected->type(); + auto expr = std::make_shared( resultType, - std::vector{data, ordinal}, + std::move(inputs), "get_struct_field"); auto result = evaluate( - getStructFieldExpr, makeRowVector({"c0", "c1"}, {parameter, index})); - if (expected) { - ::facebook::velox::test::assertEqualVectors(expected, result); - } + expr, makeRowVector({"c0", "c1"}, {input, ordinalVector})); + ::facebook::velox::test::assertEqualVectors(expected, result); } }; -TEST_F(GetStructFieldTest, simpleInteger) { - auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), BIGINT(), BIGINT()}); - auto data = makeRowVector( - {"k1", "k2", "a"}, - {makeNullableFlatVector({12}), - makeNullableFlatVector({2}), - makeNullableFlatVector({1})}); - auto result = makeNullableFlatVector({2}); - auto index = makeConstant(1, 1); - GetStructFieldSimple(data, index, dataType, INTEGER(), result); -} +TEST_F(GetStructFieldTest, simpleType) { + auto col0 = makeFlatVector({1, 2}); + auto col1 = makeFlatVector({"hello", "world"}); + auto col2 = makeNullableFlatVector({std::nullopt, 12}); + auto data = makeRowVector({col0, col1, col2}); + + // Get int field + testGetStructField(data, 0, col0); + + // Get string field + testGetStructField(data, 1, col1); -TEST_F(GetStructFieldTest, simpleVarchar) { - auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), VARCHAR(), BIGINT()}); - auto data = makeRowVector( - {"k1", "k2", "a"}, - {makeNullableFlatVector({12}), - makeNullableFlatVector({"Milly"}), - makeNullableFlatVector({1})}); - auto result = makeNullableFlatVector({"Milly"}); - auto index = makeConstant(1, 1); - GetStructFieldSimple(data, index, dataType, VARCHAR(), result); + // Get int field with null + testGetStructField(data, 2, col2); } -TEST_F(GetStructFieldTest, simpleNull) { - auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), BIGINT(), BIGINT()}); - auto data = makeRowVector( - {"k1", "k2", "a"}, - {makeNullableFlatVector({12}), - makeNullableFlatVector({std::nullopt}), - makeNullableFlatVector({1})}); - auto result = makeNullableFlatVector({std::nullopt}); - auto index = makeConstant(1, 1); - GetStructFieldSimple(data, index, dataType, INTEGER(), result); +TEST_F(GetStructFieldTest, complexType) { + auto col0 = makeArrayVector({ + {1, 2}, + {3, 4} + }); + auto col1 = makeMapVector({ + {{"a", 0}, {"b", 1}}, + {{"c", 3}, {"d", 4}} + }); + auto col2 = makeRowVector({ + makeArrayVector({ + {100, 101}, {200, 202}, {300, 303} + }), + makeFlatVector({"a", "b"}) + }); + auto data = makeRowVector({col0, col1, col2}); + + // Get array field + testGetStructField(data, 0, col0); + + // Get map field + testGetStructField(data, 1, col1); + + // Get row field + testGetStructField(data, 2, col2); } } // namespace From 7629f7035b9c4fc947ad9df4e10528038786a425 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Fri, 7 Feb 2025 10:25:49 +0800 Subject: [PATCH 3/5] fix style --- .../sparksql/tests/GetStructFieldTest.cpp | 31 ++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/velox/functions/sparksql/tests/GetStructFieldTest.cpp b/velox/functions/sparksql/tests/GetStructFieldTest.cpp index ae2ba92a5420..ee44552c26c0 100644 --- a/velox/functions/sparksql/tests/GetStructFieldTest.cpp +++ b/velox/functions/sparksql/tests/GetStructFieldTest.cpp @@ -33,15 +33,12 @@ class GetStructFieldTest : public SparkFunctionBaseTest { auto ordinalVector = makeConstant(ordinal, batchSize); std::vector inputs = { std::make_shared(input->type(), "c0"), - std::make_shared(INTEGER(), "c1") - }; + std::make_shared(INTEGER(), "c1")}; auto resultType = expected->type(); auto expr = std::make_shared( - resultType, - std::move(inputs), - "get_struct_field"); - auto result = evaluate( - expr, makeRowVector({"c0", "c1"}, {input, ordinalVector})); + resultType, std::move(inputs), "get_struct_field"); + auto result = + evaluate(expr, makeRowVector({"c0", "c1"}, {input, ordinalVector})); ::facebook::velox::test::assertEqualVectors(expected, result); } }; @@ -63,20 +60,12 @@ TEST_F(GetStructFieldTest, simpleType) { } TEST_F(GetStructFieldTest, complexType) { - auto col0 = makeArrayVector({ - {1, 2}, - {3, 4} - }); - auto col1 = makeMapVector({ - {{"a", 0}, {"b", 1}}, - {{"c", 3}, {"d", 4}} - }); - auto col2 = makeRowVector({ - makeArrayVector({ - {100, 101}, {200, 202}, {300, 303} - }), - makeFlatVector({"a", "b"}) - }); + auto col0 = makeArrayVector({{1, 2}, {3, 4}}); + auto col1 = makeMapVector( + {{{"a", 0}, {"b", 1}}, {{"c", 3}, {"d", 4}}}); + auto col2 = makeRowVector( + {makeArrayVector({{100, 101}, {200, 202}, {300, 303}}), + makeFlatVector({"a", "b"})}); auto data = makeRowVector({col0, col1, col2}); // Get array field From 414d3ed49ac8277f6e53e1c3b621156df781a81e Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Sat, 8 Feb 2025 14:33:44 +0800 Subject: [PATCH 4/5] fix ut --- velox/functions/sparksql/tests/GetStructFieldTest.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/velox/functions/sparksql/tests/GetStructFieldTest.cpp b/velox/functions/sparksql/tests/GetStructFieldTest.cpp index ee44552c26c0..29e0d423c640 100644 --- a/velox/functions/sparksql/tests/GetStructFieldTest.cpp +++ b/velox/functions/sparksql/tests/GetStructFieldTest.cpp @@ -63,9 +63,9 @@ TEST_F(GetStructFieldTest, complexType) { auto col0 = makeArrayVector({{1, 2}, {3, 4}}); auto col1 = makeMapVector( {{{"a", 0}, {"b", 1}}, {{"c", 3}, {"d", 4}}}); - auto col2 = makeRowVector( - {makeArrayVector({{100, 101}, {200, 202}, {300, 303}}), - makeFlatVector({"a", "b"})}); + auto col2 = makeRowVector({ + makeArrayVector({{100, 101, 102}, {200, 201, 202}, {300, 301, 302}}), + makeFlatVector({"a", "b", "c"})}); auto data = makeRowVector({col0, col1, col2}); // Get array field From 437ea5c8407518d302c645d22612e5d09d312236 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Sat, 8 Feb 2025 14:56:58 +0800 Subject: [PATCH 5/5] fix --- velox/functions/sparksql/tests/GetStructFieldTest.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/velox/functions/sparksql/tests/GetStructFieldTest.cpp b/velox/functions/sparksql/tests/GetStructFieldTest.cpp index 29e0d423c640..887dbf02243f 100644 --- a/velox/functions/sparksql/tests/GetStructFieldTest.cpp +++ b/velox/functions/sparksql/tests/GetStructFieldTest.cpp @@ -63,9 +63,10 @@ TEST_F(GetStructFieldTest, complexType) { auto col0 = makeArrayVector({{1, 2}, {3, 4}}); auto col1 = makeMapVector( {{{"a", 0}, {"b", 1}}, {{"c", 3}, {"d", 4}}}); - auto col2 = makeRowVector({ - makeArrayVector({{100, 101, 102}, {200, 201, 202}, {300, 301, 302}}), - makeFlatVector({"a", "b", "c"})}); + auto col2 = makeRowVector( + {makeArrayVector( + {{100, 101, 102}, {200, 201, 202}, {300, 301, 302}}), + makeFlatVector({"a", "b", "c"})}); auto data = makeRowVector({col0, col1, col2}); // Get array field