-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add Spark get_struct_field function #12166
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "velox/functions/sparksql/specialforms/GetStructField.h" | ||
#include "expression/Expr.h" | ||
#include "vector/ComplexVector.h" | ||
#include "vector/ConstantVector.h" | ||
#include "velox/expression/PeeledEncoding.h" | ||
|
||
using namespace facebook::velox::exec; | ||
|
||
namespace facebook::velox::functions::sparksql { | ||
|
||
void GetStructFieldExpr::evalSpecialForm( | ||
const SelectivityVector& rows, | ||
EvalCtx& context, | ||
VectorPtr& result) { | ||
VectorPtr input; | ||
VectorPtr ordinalVector; | ||
inputs_[0]->eval(rows, context, input); | ||
inputs_[1]->eval(rows, context, ordinalVector); | ||
VELOX_USER_CHECK(ordinalVector->isConstantEncoding()); | ||
auto ordinal = | ||
ordinalVector->asUnchecked<ConstantVector<int32_t>>()->valueAt(0); | ||
auto resultType = std::const_pointer_cast<const Type>(type_); | ||
|
||
LocalSelectivityVector remainingRows(context, rows); | ||
context.deselectErrors(*remainingRows); | ||
|
||
LocalDecodedVector decoded(context, *input, *remainingRows); | ||
|
||
auto* rawNulls = decoded->nulls(); | ||
if (rawNulls) { | ||
remainingRows->deselectNulls( | ||
rawNulls, remainingRows->begin(), remainingRows->end()); | ||
} | ||
|
||
VectorPtr localResult; | ||
if (!remainingRows->hasSelections()) { | ||
localResult = | ||
BaseVector::createNullConstant(resultType, rows.end(), context.pool()); | ||
} else { | ||
auto rowData = decoded->base()->as<RowVector>(); | ||
if (decoded->isIdentityMapping()) { | ||
localResult = rowData->childAt(ordinal); | ||
} else { | ||
localResult = | ||
decoded->wrap(rowData->childAt(ordinal), *input, decoded->size()); | ||
} | ||
} | ||
|
||
context.moveOrCopyResult(localResult, *remainingRows, result); | ||
context.releaseVector(localResult); | ||
|
||
VELOX_CHECK_NOT_NULL(result); | ||
if (rawNulls || context.errors()) { | ||
EvalCtx::addNulls( | ||
rows, remainingRows->asRange().bits(), context, resultType, result); | ||
} | ||
|
||
context.releaseVector(input); | ||
context.releaseVector(ordinalVector); | ||
} | ||
|
||
TypePtr GetStructFieldCallToSpecialForm::resolveType( | ||
const std::vector<TypePtr>& /*argTypes*/) { | ||
VELOX_FAIL("get_struct_field function does not support type resolution."); | ||
} | ||
|
||
ExprPtr GetStructFieldCallToSpecialForm::constructSpecialForm( | ||
const TypePtr& type, | ||
std::vector<ExprPtr>&& args, | ||
bool trackCpuUsage, | ||
const core::QueryConfig& /*config*/) { | ||
VELOX_USER_CHECK_EQ(args.size(), 2, "get_struct_field expects two argument."); | ||
|
||
VELOX_USER_CHECK_EQ( | ||
args[0]->type()->kind(), | ||
TypeKind::ROW, | ||
"The first argument of get_struct_field should be of row type."); | ||
|
||
VELOX_USER_CHECK_EQ( | ||
args[1]->type()->kind(), | ||
TypeKind::INTEGER, | ||
"The second argument of get_struct_field should be of integer type."); | ||
|
||
return std::make_shared<GetStructFieldExpr>( | ||
type, std::move(args), trackCpuUsage); | ||
} | ||
|
||
} // namespace facebook::velox::functions::sparksql |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include "velox/expression/FunctionCallToSpecialForm.h" | ||
#include "velox/expression/SpecialForm.h" | ||
|
||
using namespace facebook::velox::exec; | ||
|
||
namespace facebook::velox::functions::sparksql { | ||
|
||
class GetStructFieldExpr : public SpecialForm { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you need to extend class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are different, one is for function register, the other is for function iteself There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we can implement the get_struct_field as a vector function then register it as a special form. What's the specific consideration here to implement it by extending |
||
public: | ||
/// @param type The target type of the returned expression | ||
/// @param expr The struct expression to get by field ordinal | ||
/// @param trackCpuUsage Whether to track CPU usage | ||
GetStructFieldExpr( | ||
TypePtr type, | ||
std::vector<ExprPtr>&& expr, | ||
bool trackCpuUsage) | ||
: SpecialForm(type, expr, "get_struct_field", false, trackCpuUsage) {} | ||
|
||
void evalSpecialForm( | ||
const SelectivityVector& rows, | ||
EvalCtx& context, | ||
VectorPtr& result) override; | ||
|
||
void computePropagatesNulls() override { | ||
propagatesNulls_ = inputs_[0]->propagatesNulls(); | ||
} | ||
}; | ||
|
||
class GetStructFieldCallToSpecialForm : public exec::FunctionCallToSpecialForm { | ||
public: | ||
// Throws not supported exception. | ||
TypePtr resolveType(const std::vector<TypePtr>& argTypes) override; | ||
|
||
/// @brief Returns an expression for get_struct_field special form. The | ||
/// expression is a regular expression based on a custom VectorFunction | ||
/// implementation. | ||
/// @param type Result type. | ||
/// @param args Two inputs. First input should be a RowVector. Second | ||
/// input is the ordinal to get from RowVector, and must be constant | ||
/// INTEGER. | ||
exec::ExprPtr constructSpecialForm( | ||
const TypePtr& type, | ||
std::vector<exec::ExprPtr>&& compiledChildren, | ||
bool trackCpuUsage, | ||
const core::QueryConfig& config) override; | ||
|
||
static constexpr const char* kGetStructField = "get_struct_field"; | ||
}; | ||
|
||
} // namespace facebook::velox::functions::sparksql |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <gtest/gtest.h> | ||
|
||
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" | ||
#include "velox/vector/tests/utils/VectorTestBase.h" | ||
|
||
using namespace facebook::velox::test; | ||
|
||
namespace facebook::velox::functions::sparksql::test { | ||
namespace { | ||
|
||
class GetStructFieldTest : public SparkFunctionBaseTest { | ||
WangGuangxin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
protected: | ||
void testGetStructField( | ||
const VectorPtr& input, | ||
int ordinal, | ||
const VectorPtr& expected) { | ||
auto batchSize = expected->size(); | ||
auto ordinalVector = makeConstant<int32_t>(ordinal, batchSize); | ||
std::vector<core::TypedExprPtr> inputs = { | ||
std::make_shared<const core::FieldAccessTypedExpr>(input->type(), "c0"), | ||
std::make_shared<const core::FieldAccessTypedExpr>(INTEGER(), "c1")}; | ||
auto resultType = expected->type(); | ||
auto expr = std::make_shared<const core::CallTypedExpr>( | ||
resultType, std::move(inputs), "get_struct_field"); | ||
auto result = | ||
evaluate(expr, makeRowVector({"c0", "c1"}, {input, ordinalVector})); | ||
::facebook::velox::test::assertEqualVectors(expected, result); | ||
} | ||
}; | ||
|
||
TEST_F(GetStructFieldTest, simpleType) { | ||
auto col0 = makeFlatVector<int32_t>({1, 2}); | ||
auto col1 = makeFlatVector<std::string>({"hello", "world"}); | ||
auto col2 = makeNullableFlatVector<int32_t>({std::nullopt, 12}); | ||
auto data = makeRowVector({col0, col1, col2}); | ||
|
||
// Get int field | ||
testGetStructField(data, 0, col0); | ||
|
||
// Get string field | ||
testGetStructField(data, 1, col1); | ||
|
||
// Get int field with null | ||
testGetStructField(data, 2, col2); | ||
} | ||
|
||
TEST_F(GetStructFieldTest, complexType) { | ||
auto col0 = makeArrayVector<int32_t>({{1, 2}, {3, 4}}); | ||
auto col1 = makeMapVector<std::string, int32_t>( | ||
{{{"a", 0}, {"b", 1}}, {{"c", 3}, {"d", 4}}}); | ||
auto col2 = makeRowVector( | ||
{makeArrayVector<int32_t>( | ||
{{100, 101, 102}, {200, 201, 202}, {300, 301, 302}}), | ||
makeFlatVector<std::string>({"a", "b", "c"})}); | ||
auto data = makeRowVector({col0, col1, col2}); | ||
|
||
// Get array field | ||
testGetStructField(data, 0, col0); | ||
|
||
// Get map field | ||
testGetStructField(data, 1, col1); | ||
|
||
// Get row field | ||
testGetStructField(data, 2, col2); | ||
} | ||
|
||
} // namespace | ||
} // namespace facebook::velox::functions::sparksql::test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose we can move the checks inside
constructSpecialForm
and extract the constant value there. Some example in decimal round:velox/velox/functions/sparksql/specialforms/DecimalRound.cpp
Lines 229 to 245 in 04bfdff