Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Spark get_struct_field function #12166

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions velox/functions/sparksql/registration/RegisterSpecialForm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "velox/expression/SpecialFormRegistry.h"
#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h"
#include "velox/functions/sparksql/specialforms/DecimalRound.h"
#include "velox/functions/sparksql/specialforms/GetStructField.h"
#include "velox/functions/sparksql/specialforms/MakeDecimal.h"
#include "velox/functions/sparksql/specialforms/SparkCastExpr.h"

Expand All @@ -44,6 +45,8 @@ void registerSpecialFormGeneralFunctions(const std::string& prefix) {
"cast", std::make_unique<SparkCastCallToSpecialForm>());
registerFunctionCallToSpecialForm(
"try_cast", std::make_unique<SparkTryCastCallToSpecialForm>());
registerFunctionCallToSpecialForm(
"get_struct_field", std::make_unique<GetStructFieldCallToSpecialForm>());
}
} // namespace sparksql
} // namespace facebook::velox::functions
1 change: 1 addition & 0 deletions velox/functions/sparksql/specialforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ velox_add_library(
velox_functions_spark_specialforms
AtLeastNNonNulls.cpp
DecimalRound.cpp
GetStructField.cpp
MakeDecimal.cpp
SparkCastExpr.cpp
SparkCastHooks.cpp)
Expand Down
104 changes: 104 additions & 0 deletions velox/functions/sparksql/specialforms/GetStructField.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/specialforms/GetStructField.h"
#include "expression/Expr.h"
#include "vector/ComplexVector.h"
#include "vector/ConstantVector.h"
#include "velox/expression/PeeledEncoding.h"

using namespace facebook::velox::exec;

namespace facebook::velox::functions::sparksql {

void GetStructFieldExpr::evalSpecialForm(
const SelectivityVector& rows,
EvalCtx& context,
VectorPtr& result) {
VectorPtr input;
VectorPtr ordinalVector;
inputs_[0]->eval(rows, context, input);
inputs_[1]->eval(rows, context, ordinalVector);
VELOX_USER_CHECK(ordinalVector->isConstantEncoding());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we can move the checks inside constructSpecialForm and extract the constant value there. Some example in decimal round:

VELOX_USER_CHECK_EQ(
args[1]->type()->kind(),
TypeKind::INTEGER,
"The second argument of decimal_round should be of integer type.");
auto constantExpr = std::dynamic_pointer_cast<exec::ConstantExpr>(args[1]);
VELOX_USER_CHECK_NOT_NULL(
constantExpr,
"The second argument of decimal_round should be constant expression.");
VELOX_USER_CHECK(
constantExpr->value()->isConstantEncoding(),
"The second argument of decimal_round should be wrapped in constant vector.");
auto constantVector =
constantExpr->value()->asUnchecked<ConstantVector<int32_t>>();
VELOX_USER_CHECK(
!constantVector->isNullAt(0),
"The second argument of decimal_round is non-nullable.");
scale = constantVector->valueAt(0);

auto ordinal =
ordinalVector->asUnchecked<ConstantVector<int32_t>>()->valueAt(0);
auto resultType = std::const_pointer_cast<const Type>(type_);

LocalSelectivityVector remainingRows(context, rows);
context.deselectErrors(*remainingRows);

LocalDecodedVector decoded(context, *input, *remainingRows);

auto* rawNulls = decoded->nulls();
if (rawNulls) {
remainingRows->deselectNulls(
rawNulls, remainingRows->begin(), remainingRows->end());
}

VectorPtr localResult;
if (!remainingRows->hasSelections()) {
localResult =
BaseVector::createNullConstant(resultType, rows.end(), context.pool());
} else {
auto rowData = decoded->base()->as<RowVector>();
if (decoded->isIdentityMapping()) {
localResult = rowData->childAt(ordinal);
} else {
localResult =
decoded->wrap(rowData->childAt(ordinal), *input, decoded->size());
}
}

context.moveOrCopyResult(localResult, *remainingRows, result);
context.releaseVector(localResult);

VELOX_CHECK_NOT_NULL(result);
if (rawNulls || context.errors()) {
EvalCtx::addNulls(
rows, remainingRows->asRange().bits(), context, resultType, result);
}

context.releaseVector(input);
context.releaseVector(ordinalVector);
}

TypePtr GetStructFieldCallToSpecialForm::resolveType(
const std::vector<TypePtr>& /*argTypes*/) {
VELOX_FAIL("get_struct_field function does not support type resolution.");
}

ExprPtr GetStructFieldCallToSpecialForm::constructSpecialForm(
const TypePtr& type,
std::vector<ExprPtr>&& args,
bool trackCpuUsage,
const core::QueryConfig& /*config*/) {
VELOX_USER_CHECK_EQ(args.size(), 2, "get_struct_field expects two argument.");

VELOX_USER_CHECK_EQ(
args[0]->type()->kind(),
TypeKind::ROW,
"The first argument of get_struct_field should be of row type.");

VELOX_USER_CHECK_EQ(
args[1]->type()->kind(),
TypeKind::INTEGER,
"The second argument of get_struct_field should be of integer type.");

return std::make_shared<GetStructFieldExpr>(
type, std::move(args), trackCpuUsage);
}

} // namespace facebook::velox::functions::sparksql
67 changes: 67 additions & 0 deletions velox/functions/sparksql/specialforms/GetStructField.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/FunctionCallToSpecialForm.h"
#include "velox/expression/SpecialForm.h"

using namespace facebook::velox::exec;

namespace facebook::velox::functions::sparksql {

class GetStructFieldExpr : public SpecialForm {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you need to extend class FunctionCallToSpecialForm rather than SpecialForm, so as the implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are different, one is for function register, the other is for function iteself

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can implement the get_struct_field as a vector function then register it as a special form. What's the specific consideration here to implement it by extending SpecialForm?

public:
/// @param type The target type of the returned expression
/// @param expr The struct expression to get by field ordinal
/// @param trackCpuUsage Whether to track CPU usage
GetStructFieldExpr(
TypePtr type,
std::vector<ExprPtr>&& expr,
bool trackCpuUsage)
: SpecialForm(type, expr, "get_struct_field", false, trackCpuUsage) {}

void evalSpecialForm(
const SelectivityVector& rows,
EvalCtx& context,
VectorPtr& result) override;

void computePropagatesNulls() override {
propagatesNulls_ = inputs_[0]->propagatesNulls();
}
};

class GetStructFieldCallToSpecialForm : public exec::FunctionCallToSpecialForm {
public:
// Throws not supported exception.
TypePtr resolveType(const std::vector<TypePtr>& argTypes) override;

/// @brief Returns an expression for get_struct_field special form. The
/// expression is a regular expression based on a custom VectorFunction
/// implementation.
/// @param type Result type.
/// @param args Two inputs. First input should be a RowVector. Second
/// input is the ordinal to get from RowVector, and must be constant
/// INTEGER.
exec::ExprPtr constructSpecialForm(
const TypePtr& type,
std::vector<exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const core::QueryConfig& config) override;

static constexpr const char* kGetStructField = "get_struct_field";
};

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ add_executable(
DecimalUtilTest.cpp
ElementAtTest.cpp
GetJsonObjectTest.cpp
GetStructFieldTest.cpp
HashTest.cpp
InTest.cpp
JsonObjectKeysTest.cpp
Expand Down
83 changes: 83 additions & 0 deletions velox/functions/sparksql/tests/GetStructFieldTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>

#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"
#include "velox/vector/tests/utils/VectorTestBase.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions::sparksql::test {
namespace {

class GetStructFieldTest : public SparkFunctionBaseTest {
WangGuangxin marked this conversation as resolved.
Show resolved Hide resolved
protected:
void testGetStructField(
const VectorPtr& input,
int ordinal,
const VectorPtr& expected) {
auto batchSize = expected->size();
auto ordinalVector = makeConstant<int32_t>(ordinal, batchSize);
std::vector<core::TypedExprPtr> inputs = {
std::make_shared<const core::FieldAccessTypedExpr>(input->type(), "c0"),
std::make_shared<const core::FieldAccessTypedExpr>(INTEGER(), "c1")};
auto resultType = expected->type();
auto expr = std::make_shared<const core::CallTypedExpr>(
resultType, std::move(inputs), "get_struct_field");
auto result =
evaluate(expr, makeRowVector({"c0", "c1"}, {input, ordinalVector}));
::facebook::velox::test::assertEqualVectors(expected, result);
}
};

TEST_F(GetStructFieldTest, simpleType) {
auto col0 = makeFlatVector<int32_t>({1, 2});
auto col1 = makeFlatVector<std::string>({"hello", "world"});
auto col2 = makeNullableFlatVector<int32_t>({std::nullopt, 12});
auto data = makeRowVector({col0, col1, col2});

// Get int field
testGetStructField(data, 0, col0);

// Get string field
testGetStructField(data, 1, col1);

// Get int field with null
testGetStructField(data, 2, col2);
}

TEST_F(GetStructFieldTest, complexType) {
auto col0 = makeArrayVector<int32_t>({{1, 2}, {3, 4}});
auto col1 = makeMapVector<std::string, int32_t>(
{{{"a", 0}, {"b", 1}}, {{"c", 3}, {"d", 4}}});
auto col2 = makeRowVector(
{makeArrayVector<int32_t>(
{{100, 101, 102}, {200, 201, 202}, {300, 301, 302}}),
makeFlatVector<std::string>({"a", "b", "c"})});
auto data = makeRowVector({col0, col1, col2});

// Get array field
testGetStructField(data, 0, col0);

// Get map field
testGetStructField(data, 1, col1);

// Get row field
testGetStructField(data, 2, col2);
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test
Loading