Skip to content

Commit

Permalink
add get_struct_field
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Jan 20, 2025
1 parent 915f1d8 commit eeebd22
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 0 deletions.
3 changes: 3 additions & 0 deletions velox/functions/sparksql/registration/RegisterSpecialForm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "velox/expression/SpecialFormRegistry.h"
#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h"
#include "velox/functions/sparksql/specialforms/DecimalRound.h"
#include "velox/functions/sparksql/specialforms/GetStructField.h"
#include "velox/functions/sparksql/specialforms/MakeDecimal.h"
#include "velox/functions/sparksql/specialforms/SparkCastExpr.h"

Expand All @@ -44,6 +45,8 @@ void registerSpecialFormGeneralFunctions(const std::string& prefix) {
"cast", std::make_unique<SparkCastCallToSpecialForm>());
registerFunctionCallToSpecialForm(
"try_cast", std::make_unique<SparkTryCastCallToSpecialForm>());
registerFunctionCallToSpecialForm(
"get_struct_field", std::make_unique<GetStructFieldCallToSpecialForm>());
}
} // namespace sparksql
} // namespace facebook::velox::functions
1 change: 1 addition & 0 deletions velox/functions/sparksql/specialforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ velox_add_library(
velox_functions_spark_specialforms
AtLeastNNonNulls.cpp
DecimalRound.cpp
GetStructField.cpp
MakeDecimal.cpp
SparkCastExpr.cpp
SparkCastHooks.cpp)
Expand Down
105 changes: 105 additions & 0 deletions velox/functions/sparksql/specialforms/GetStructField.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/specialforms/GetStructField.h"
#include <type/Type.h>
#include "expression/Expr.h"
#include "vector/ComplexVector.h"
#include "vector/ConstantVector.h"
#include "velox/expression/PeeledEncoding.h"

using namespace facebook::velox::exec;

namespace facebook::velox::functions::sparksql {

void GetStructFieldExpr::evalSpecialForm(
const SelectivityVector& rows,
EvalCtx& context,
VectorPtr& result) {
VectorPtr input;
VectorPtr ordinalVector;
inputs_[0]->eval(rows, context, input);
inputs_[1]->eval(rows, context, ordinalVector);
VELOX_USER_CHECK(ordinalVector->isConstantEncoding());
auto ordinal =
ordinalVector->asUnchecked<ConstantVector<int32_t>>()->valueAt(0);
auto resultType = std::const_pointer_cast<const Type>(type_);

LocalSelectivityVector remainingRows(context, rows);
context.deselectErrors(*remainingRows);

LocalDecodedVector decoded(context, *input, *remainingRows);

auto* rawNulls = decoded->nulls();
if (rawNulls) {
remainingRows->deselectNulls(
rawNulls, remainingRows->begin(), remainingRows->end());
}

VectorPtr localResult;
if (!remainingRows->hasSelections()) {
localResult =
BaseVector::createNullConstant(resultType, rows.end(), context.pool());
} else {
auto rowData = decoded->base()->as<RowVector>();
if (decoded->isIdentityMapping()) {
localResult = rowData->childAt(ordinal);
} else {
localResult =
decoded->wrap(rowData->childAt(ordinal), *input, decoded->size());
}
}

context.moveOrCopyResult(localResult, *remainingRows, result);
context.releaseVector(localResult);

VELOX_CHECK_NOT_NULL(result);
if (rawNulls || context.errors()) {
EvalCtx::addNulls(
rows, remainingRows->asRange().bits(), context, resultType, result);
}

context.releaseVector(input);
context.releaseVector(ordinalVector);
}

TypePtr GetStructFieldCallToSpecialForm::resolveType(
const std::vector<TypePtr>& /*argTypes*/) {
VELOX_FAIL("get_struct_field function does not support type resolution.");
}

ExprPtr GetStructFieldCallToSpecialForm::constructSpecialForm(
const TypePtr& type,
std::vector<ExprPtr>&& args,
bool trackCpuUsage,
const core::QueryConfig& /*config*/) {
VELOX_USER_CHECK_EQ(args.size(), 2, "get_struct_field expects two argument.");

VELOX_USER_CHECK_EQ(
args[0]->type()->kind(),
TypeKind::ROW,
"The first argument of get_struct_field should be of row type.");

VELOX_USER_CHECK_EQ(
args[1]->type()->kind(),
TypeKind::INTEGER,
"The second argument of get_struct_field should be of integer type.");

return std::make_shared<GetStructFieldExpr>(
type, std::move(args), trackCpuUsage);
}

} // namespace facebook::velox::functions::sparksql
67 changes: 67 additions & 0 deletions velox/functions/sparksql/specialforms/GetStructField.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/FunctionCallToSpecialForm.h"
#include "velox/expression/SpecialForm.h"

using namespace facebook::velox::exec;

namespace facebook::velox::functions::sparksql {

class GetStructFieldExpr : public SpecialForm {
public:
/// @param type The target type of the returned expression
/// @param expr The struct expression to get by field ordinal
/// @param trackCpuUsage Whether to track CPU usage
GetStructFieldExpr(
TypePtr type,
std::vector<ExprPtr>&& expr,
bool trackCpuUsage)
: SpecialForm(type, expr, "get_struct_field", false, trackCpuUsage) {}

void evalSpecialForm(
const SelectivityVector& rows,
EvalCtx& context,
VectorPtr& result) override;

void computePropagatesNulls() override {
propagatesNulls_ = inputs_[0]->propagatesNulls();
}
};

class GetStructFieldCallToSpecialForm : public exec::FunctionCallToSpecialForm {
public:
// Throws not supported exception.
TypePtr resolveType(const std::vector<TypePtr>& argTypes) override;

/// @brief Returns an expression for get_struct_field special form. The
/// expression is a regular expression based on a custom VectorFunction
/// implementation.
/// @param type Result type.
/// @param args Two inputs. First input should be a RowVector. Second
/// input is the ordinal to get from RowVector, and must be constant
/// INTEGER.
exec::ExprPtr constructSpecialForm(
const TypePtr& type,
std::vector<exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const core::QueryConfig& config) override;

static constexpr const char* kGetStructField = "get_struct_field";
};

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ add_executable(
DecimalUtilTest.cpp
ElementAtTest.cpp
GetJsonObjectTest.cpp
GetStructFieldTest.cpp
HashTest.cpp
InTest.cpp
JsonObjectKeysTest.cpp
Expand Down
88 changes: 88 additions & 0 deletions velox/functions/sparksql/tests/GetStructFieldTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>

#include "functions/sparksql/specialforms/GetStructField.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"
#include "velox/vector/tests/utils/VectorTestBase.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions::sparksql::test {
namespace {

class GetStructFieldTest : public SparkFunctionBaseTest {
protected:
void GetStructFieldSimple(
const VectorPtr& parameter,
const VectorPtr& index,
const TypePtr& inputType,
const TypePtr& resultType,
const VectorPtr& expected = nullptr) {
core::TypedExprPtr data =
std::make_shared<const core::FieldAccessTypedExpr>(inputType, "c0");
core::TypedExprPtr ordinal =
std::make_shared<const core::FieldAccessTypedExpr>(INTEGER(), "c1");
auto getStructFieldExpr = std::make_shared<const core::CallTypedExpr>(
resultType,
std::vector<core::TypedExprPtr>{data, ordinal},
"get_struct_field");
auto result = evaluate(
getStructFieldExpr, makeRowVector({"c0", "c1"}, {parameter, index}));
if (expected) {
::facebook::velox::test::assertEqualVectors(expected, result);
}
}
};

TEST_F(GetStructFieldTest, simpleInteger) {
auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), BIGINT(), BIGINT()});
auto data = makeRowVector(
{"k1", "k2", "a"},
{makeNullableFlatVector<int32_t>({12}),
makeNullableFlatVector<int32_t>({2}),
makeNullableFlatVector<int32_t>({1})});
auto result = makeNullableFlatVector<int32_t>({2});
auto index = makeConstant<int32_t>(1, 1);
GetStructFieldSimple(data, index, dataType, INTEGER(), result);
}

TEST_F(GetStructFieldTest, simpleVarchar) {
auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), VARCHAR(), BIGINT()});
auto data = makeRowVector(
{"k1", "k2", "a"},
{makeNullableFlatVector<int32_t>({12}),
makeNullableFlatVector<std::string>({"Milly"}),
makeNullableFlatVector<int32_t>({1})});
auto result = makeNullableFlatVector<std::string>({"Milly"});
auto index = makeConstant<int32_t>(1, 1);
GetStructFieldSimple(data, index, dataType, VARCHAR(), result);
}

TEST_F(GetStructFieldTest, simpleNull) {
auto dataType = ROW({"k1", "k2", "a"}, {BIGINT(), BIGINT(), BIGINT()});
auto data = makeRowVector(
{"k1", "k2", "a"},
{makeNullableFlatVector<int32_t>({12}),
makeNullableFlatVector<int32_t>({std::nullopt}),
makeNullableFlatVector<int32_t>({1})});
auto result = makeNullableFlatVector<int32_t>({std::nullopt});
auto index = makeConstant<int32_t>(1, 1);
GetStructFieldSimple(data, index, dataType, INTEGER(), result);
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit eeebd22

Please sign in to comment.