diff --git a/velox/functions/sparksql/registration/RegisterSpecialForm.cpp b/velox/functions/sparksql/registration/RegisterSpecialForm.cpp index d9f12abe4f80..356e40fdf34c 100644 --- a/velox/functions/sparksql/registration/RegisterSpecialForm.cpp +++ b/velox/functions/sparksql/registration/RegisterSpecialForm.cpp @@ -18,6 +18,7 @@ #include "velox/expression/SpecialFormRegistry.h" #include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h" #include "velox/functions/sparksql/specialforms/DecimalRound.h" +#include "velox/functions/sparksql/specialforms/GetStructField.h" #include "velox/functions/sparksql/specialforms/MakeDecimal.h" #include "velox/functions/sparksql/specialforms/SparkCastExpr.h" @@ -44,6 +45,8 @@ void registerSpecialFormGeneralFunctions(const std::string& prefix) { "cast", std::make_unique()); registerFunctionCallToSpecialForm( "try_cast", std::make_unique()); + registerFunctionCallToSpecialForm( + "get_struct_field", std::make_unique()); } } // namespace sparksql } // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/specialforms/CMakeLists.txt b/velox/functions/sparksql/specialforms/CMakeLists.txt index e141e0074bc8..0b573f4d6312 100644 --- a/velox/functions/sparksql/specialforms/CMakeLists.txt +++ b/velox/functions/sparksql/specialforms/CMakeLists.txt @@ -16,6 +16,7 @@ velox_add_library( velox_functions_spark_specialforms AtLeastNNonNulls.cpp DecimalRound.cpp + GetStructField.cpp MakeDecimal.cpp SparkCastExpr.cpp SparkCastHooks.cpp) diff --git a/velox/functions/sparksql/specialforms/GetStructField.cpp b/velox/functions/sparksql/specialforms/GetStructField.cpp new file mode 100644 index 000000000000..c9682ffc667d --- /dev/null +++ b/velox/functions/sparksql/specialforms/GetStructField.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/specialforms/GetStructField.h" +#include "expression/Expr.h" +#include "vector/ComplexVector.h" +#include "vector/ConstantVector.h" +#include "velox/expression/PeeledEncoding.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::functions::sparksql { + +void GetStructFieldExpr::evalSpecialForm( + const SelectivityVector& rows, + EvalCtx& context, + VectorPtr& result) { + VectorPtr input; + VectorPtr ordinalVector; + inputs_[0]->eval(rows, context, input); + inputs_[1]->eval(rows, context, ordinalVector); + VELOX_USER_CHECK(ordinalVector->isConstantEncoding()); + auto ordinal = + ordinalVector->asUnchecked>()->valueAt(0); + auto resultType = std::const_pointer_cast(type_); + + LocalSelectivityVector remainingRows(context, rows); + context.deselectErrors(*remainingRows); + + LocalDecodedVector decoded(context, *input, *remainingRows); + + auto* rawNulls = decoded->nulls(); + if (rawNulls) { + remainingRows->deselectNulls( + rawNulls, remainingRows->begin(), remainingRows->end()); + } + + VectorPtr localResult; + if (!remainingRows->hasSelections()) { + localResult = + BaseVector::createNullConstant(resultType, rows.end(), context.pool()); + } else { + auto rowData = decoded->base()->as(); + if (decoded->isIdentityMapping()) { + localResult = rowData->childAt(ordinal); + } else { + localResult = + decoded->wrap(rowData->childAt(ordinal), *input, decoded->size()); + } + } + + context.moveOrCopyResult(localResult, *remainingRows, result); + context.releaseVector(localResult); + + VELOX_CHECK_NOT_NULL(result); + if (rawNulls || context.errors()) { + EvalCtx::addNulls( + rows, remainingRows->asRange().bits(), context, resultType, result); + } + + context.releaseVector(input); + context.releaseVector(ordinalVector); +} + +TypePtr GetStructFieldCallToSpecialForm::resolveType( + const std::vector& /*argTypes*/) { + VELOX_FAIL("get_struct_field function does not support type resolution."); +} + +ExprPtr GetStructFieldCallToSpecialForm::constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& /*config*/) { + VELOX_USER_CHECK_EQ(args.size(), 2, "get_struct_field expects two argument."); + + VELOX_USER_CHECK_EQ( + args[0]->type()->kind(), + TypeKind::ROW, + "The first argument of get_struct_field should be of row type."); + + VELOX_USER_CHECK_EQ( + args[1]->type()->kind(), + TypeKind::INTEGER, + "The second argument of get_struct_field should be of integer type."); + + return std::make_shared( + type, std::move(args), trackCpuUsage); +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/specialforms/GetStructField.h b/velox/functions/sparksql/specialforms/GetStructField.h new file mode 100644 index 000000000000..e220c0f2e49b --- /dev/null +++ b/velox/functions/sparksql/specialforms/GetStructField.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/FunctionCallToSpecialForm.h" +#include "velox/expression/SpecialForm.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::functions::sparksql { + +class GetStructFieldExpr : public SpecialForm { + public: + /// @param type The target type of the returned expression + /// @param expr The struct expression to get by field ordinal + /// @param trackCpuUsage Whether to track CPU usage + GetStructFieldExpr( + TypePtr type, + std::vector&& expr, + bool trackCpuUsage) + : SpecialForm(type, expr, "get_struct_field", false, trackCpuUsage) {} + + void evalSpecialForm( + const SelectivityVector& rows, + EvalCtx& context, + VectorPtr& result) override; + + void computePropagatesNulls() override { + propagatesNulls_ = inputs_[0]->propagatesNulls(); + } +}; + +class GetStructFieldCallToSpecialForm : public exec::FunctionCallToSpecialForm { + public: + // Throws not supported exception. + TypePtr resolveType(const std::vector& argTypes) override; + + /// @brief Returns an expression for get_struct_field special form. The + /// expression is a regular expression based on a custom VectorFunction + /// implementation. + /// @param type Result type. + /// @param args Two inputs. First input should be a RowVector. Second + /// input is the ordinal to get from RowVector, and must be constant + /// INTEGER. + exec::ExprPtr constructSpecialForm( + const TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const core::QueryConfig& config) override; + + static constexpr const char* kGetStructField = "get_struct_field"; +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 39087bd8adb5..1ccf8e3151d3 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -34,6 +34,7 @@ add_executable( DecimalUtilTest.cpp ElementAtTest.cpp GetJsonObjectTest.cpp + GetStructFieldTest.cpp HashTest.cpp InTest.cpp JsonObjectKeysTest.cpp diff --git a/velox/functions/sparksql/tests/GetStructFieldTest.cpp b/velox/functions/sparksql/tests/GetStructFieldTest.cpp new file mode 100644 index 000000000000..887dbf02243f --- /dev/null +++ b/velox/functions/sparksql/tests/GetStructFieldTest.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox::test; + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class GetStructFieldTest : public SparkFunctionBaseTest { + protected: + void testGetStructField( + const VectorPtr& input, + int ordinal, + const VectorPtr& expected) { + auto batchSize = expected->size(); + auto ordinalVector = makeConstant(ordinal, batchSize); + std::vector inputs = { + std::make_shared(input->type(), "c0"), + std::make_shared(INTEGER(), "c1")}; + auto resultType = expected->type(); + auto expr = std::make_shared( + resultType, std::move(inputs), "get_struct_field"); + auto result = + evaluate(expr, makeRowVector({"c0", "c1"}, {input, ordinalVector})); + ::facebook::velox::test::assertEqualVectors(expected, result); + } +}; + +TEST_F(GetStructFieldTest, simpleType) { + auto col0 = makeFlatVector({1, 2}); + auto col1 = makeFlatVector({"hello", "world"}); + auto col2 = makeNullableFlatVector({std::nullopt, 12}); + auto data = makeRowVector({col0, col1, col2}); + + // Get int field + testGetStructField(data, 0, col0); + + // Get string field + testGetStructField(data, 1, col1); + + // Get int field with null + testGetStructField(data, 2, col2); +} + +TEST_F(GetStructFieldTest, complexType) { + auto col0 = makeArrayVector({{1, 2}, {3, 4}}); + auto col1 = makeMapVector( + {{{"a", 0}, {"b", 1}}, {{"c", 3}, {"d", 4}}}); + auto col2 = makeRowVector( + {makeArrayVector( + {{100, 101, 102}, {200, 201, 202}, {300, 301, 302}}), + makeFlatVector({"a", "b", "c"})}); + auto data = makeRowVector({col0, col1, col2}); + + // Get array field + testGetStructField(data, 0, col0); + + // Get map field + testGetStructField(data, 1, col1); + + // Get row field + testGetStructField(data, 2, col2); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test