Skip to content

Commit

Permalink
YQL-18053: Add block implementation for Member callable (#3461)
Browse files Browse the repository at this point in the history
  • Loading branch information
igormunkin authored Apr 15, 2024
1 parent 26b8ef3 commit 502820a
Show file tree
Hide file tree
Showing 19 changed files with 315 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5468,7 +5468,7 @@ bool CollectBlockRewrites(const TMultiExprType* multiInputType, bool keepInputCo

TExprNode::TListType funcArgs;
std::string_view arrowFunctionName;
if (node->IsList() || node->IsCallable({"And", "Or", "Xor", "Not", "Coalesce", "Exists", "If", "Just", "Nth", "ToPg", "FromPg", "PgResolvedCall", "PgResolvedOp"}))
if (node->IsList() || node->IsCallable({"And", "Or", "Xor", "Not", "Coalesce", "Exists", "If", "Just", "Member", "Nth", "ToPg", "FromPg", "PgResolvedCall", "PgResolvedOp"}))
{
if (node->IsCallable() && !IsSupportedAsBlockType(node->Pos(), *node->GetTypeAnn(), ctx, types)) {
return true;
Expand Down
68 changes: 68 additions & 0 deletions ydb/library/yql/core/type_ann/type_ann_blocks.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "type_ann_blocks.h"
#include "type_ann_impl.h"
#include "type_ann_list.h"
#include "type_ann_wide.h"
#include "type_ann_pg.h"
Expand Down Expand Up @@ -429,6 +430,73 @@ IGraphTransformer::TStatus BlockAsTupleWrapper(const TExprNode::TPtr& input, TEx
return IGraphTransformer::TStatus::Ok;
}

IGraphTransformer::TStatus BlockMemberWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

auto& child = input->Head();
if (!EnsureBlockOrScalarType(child, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

bool isScalar;
const TTypeAnnotationNode* blockItemType = GetBlockItemType(*child.GetTypeAnn(), isScalar);
const TTypeAnnotationNode* resultType;
if (IsNull(*blockItemType)) {
resultType = blockItemType;
} else {
const TStructExprType* structType;
bool isOptional;
if (blockItemType->GetKind() == ETypeAnnotationKind::Optional) {
auto itemType = blockItemType->Cast<TOptionalExprType>()->GetItemType();
if (!EnsureStructType(child.Pos(), *itemType, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

structType = itemType->Cast<TStructExprType>();
isOptional = true;
} else {
if (!EnsureStructType(child.Pos(), *blockItemType, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

structType = blockItemType->Cast<TStructExprType>();
isOptional = false;
}

if (!EnsureComputableType(input->Head().Pos(), *structType, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

if (!EnsureAtom(input->Tail(), ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

auto memberName = input->Tail().Content();
auto pos = FindOrReportMissingMember(memberName, input->Pos(), *structType, ctx.Expr);
if (!pos) {
return IGraphTransformer::TStatus::Error;
}

resultType = structType->GetItems()[*pos]->GetItemType();
if (isOptional && !resultType->IsOptionalOrNull()) {
resultType = ctx.Expr.MakeType<TOptionalExprType>(resultType);
}
}

if (isScalar) {
resultType = ctx.Expr.MakeType<TScalarExprType>(resultType);
} else {
resultType = ctx.Expr.MakeType<TBlockExprType>(resultType);
}

input->SetTypeAnn(resultType);
return IGraphTransformer::TStatus::Ok;
}


IGraphTransformer::TStatus BlockNthWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Y_UNUSED(output);
if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
Expand Down
1 change: 1 addition & 0 deletions ydb/library/yql/core/type_ann/type_ann_blocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace NTypeAnnImpl {
IGraphTransformer::TStatus BlockJustWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockAsTupleWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockNthWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockMemberWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockToPgWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockFromPgWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx);
IGraphTransformer::TStatus BlockFuncWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TExtContext& ctx);
Expand Down
1 change: 1 addition & 0 deletions ydb/library/yql/core/type_ann/type_ann_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12239,6 +12239,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
Functions["BlockIf"] = &BlockIfWrapper;
Functions["BlockJust"] = &BlockJustWrapper;
Functions["BlockAsTuple"] = &BlockAsTupleWrapper;
Functions["BlockMember"] = &BlockMemberWrapper;
Functions["BlockNth"] = &BlockNthWrapper;
Functions["BlockToPg"] = &BlockToPgWrapper;
Functions["BlockFromPg"] = &BlockFromPgWrapper;
Expand Down
119 changes: 119 additions & 0 deletions ydb/library/yql/minikql/comp_nodes/mkql_block_getelem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include "mkql_block_getelem.h"

#include <ydb/library/yql/minikql/computation/mkql_block_impl.h>
#include <ydb/library/yql/minikql/mkql_node_cast.h>
#include <ydb/library/yql/minikql/mkql_node_builder.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

class TBlockGetElementExec {
public:
TBlockGetElementExec(const std::shared_ptr<arrow::DataType>& returnArrowType, ui32 index, bool isOptional, bool needExternalOptional)
: ReturnArrowType(returnArrowType)
, Index(index)
, IsOptional(isOptional)
, NeedExternalOptional(needExternalOptional)
{}

arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
arrow::Datum inputDatum = batch.values[0];
if (inputDatum.is_scalar()) {
if (inputDatum.scalar()->is_valid) {
const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*inputDatum.scalar());
*res = arrow::Datum(structScalar.value[Index]);
} else {
*res = arrow::Datum(arrow::MakeNullScalar(ReturnArrowType));
}
} else {
const auto& array = inputDatum.array();
auto child = array->child_data[Index];
if (NeedExternalOptional) {
auto newArrayData = arrow::ArrayData::Make(ReturnArrowType, array->length, { array->buffers[0] });
newArrayData->child_data.push_back(child);
*res = arrow::Datum(newArrayData);
} else if (!IsOptional || !array->buffers[0]) {
*res = arrow::Datum(child);
} else {
auto newArrayData = child->Copy();
if (!newArrayData->buffers[0]) {
newArrayData->buffers[0] = array->buffers[0];
} else {
auto buffer = AllocateBitmapWithReserve(array->length + array->offset, ctx->memory_pool());
arrow::internal::BitmapAnd(child->GetValues<uint8_t>(0, 0), child->offset, array->GetValues<uint8_t>(0, 0), array->offset, array->length, array->offset, buffer->mutable_data());
newArrayData->buffers[0] = buffer;
}

newArrayData->SetNullCount(arrow::kUnknownNullCount);
*res = arrow::Datum(newArrayData);
}
}

return arrow::Status::OK();
}

private:
const std::shared_ptr<arrow::DataType> ReturnArrowType;
const ui32 Index;
const bool IsOptional;
const bool NeedExternalOptional;
};

std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockGetElementKernel(const TVector<TType*>& argTypes, TType* resultType,
ui32 index, bool isOptional, bool needExternalOptional) {
std::shared_ptr<arrow::DataType> returnArrowType;
MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type");
auto exec = std::make_shared<TBlockGetElementExec>(returnArrowType, index, isOptional, needExternalOptional);
auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
[exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
return exec->Exec(ctx, batch, res);
});

kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
return kernel;
}

TType* GetElementType(const TStructType* structType, ui32 index) {
MKQL_ENSURE(index < structType->GetMembersCount(), "Bad member index");
return structType->GetMemberType(index);
}

TType* GetElementType(const TTupleType* tupleType, ui32 index) {
MKQL_ENSURE(index < tupleType->GetElementsCount(), "Bad tuple index");
return tupleType->GetElementType(index);
}

template<typename ObjectType>
IComputationNode* WrapBlockGetElement(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected two args.");
auto inputObject = callable.GetInput(0);
auto blockType = AS_TYPE(TBlockType, inputObject.GetStaticType());
bool isOptional;
auto objectType = AS_TYPE(ObjectType, UnpackOptional(blockType->GetItemType(), isOptional));
auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1));
auto index = indexData->AsValue().Get<ui32>();
auto childType = GetElementType(objectType, index);
bool needExternalOptional = isOptional && childType->IsVariant();

auto objectNode = LocateNode(ctx.NodeLocator, callable, 0);

TComputationNodePtrVector argsNodes = { objectNode };
TVector<TType*> argsTypes = { blockType };
auto kernel = MakeBlockGetElementKernel(argsTypes, callable.GetType()->GetReturnType(), index, isOptional, needExternalOptional);
return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}

} // namespace

IComputationNode* WrapBlockMember(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
return WrapBlockGetElement<TStructType>(callable, ctx);
}

IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
return WrapBlockGetElement<TTupleType>(callable, ctx);
}

} // namespace NMiniKQL
} // namespace NKikimr
11 changes: 11 additions & 0 deletions ydb/library/yql/minikql/comp_nodes/mkql_block_getelem.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once
#include <ydb/library/yql/minikql/computation/mkql_computation_node.h>

namespace NKikimr {
namespace NMiniKQL {

IComputationNode* WrapBlockMember(TCallable& callable, const TComputationNodeFactoryContext& ctx);
IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx);

} // namespace NMiniKQL
} // namespace NKikimr
86 changes: 0 additions & 86 deletions ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,58 +66,6 @@ class TBlockAsTupleExec {
const std::shared_ptr<arrow::DataType> ReturnArrowType;
};

class TBlockNthExec {
public:
TBlockNthExec(const std::shared_ptr<arrow::DataType>& returnArrowType, ui32 index, bool isOptional, bool needExternalOptional)
: ReturnArrowType(returnArrowType)
, Index(index)
, IsOptional(isOptional)
, NeedExternalOptional(needExternalOptional)
{}

arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
arrow::Datum inputDatum = batch.values[0];
if (inputDatum.is_scalar()) {
if (inputDatum.scalar()->is_valid) {
const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*inputDatum.scalar());
*res = arrow::Datum(structScalar.value[Index]);
} else {
*res = arrow::Datum(arrow::MakeNullScalar(ReturnArrowType));
}
} else {
const auto& array = inputDatum.array();
auto child = array->child_data[Index];
if (NeedExternalOptional) {
auto newArrayData = arrow::ArrayData::Make(ReturnArrowType, array->length, { array->buffers[0] });
newArrayData->child_data.push_back(child);
*res = arrow::Datum(newArrayData);
} else if (!IsOptional || !array->buffers[0]) {
*res = arrow::Datum(child);
} else {
auto newArrayData = child->Copy();
if (!newArrayData->buffers[0]) {
newArrayData->buffers[0] = array->buffers[0];
} else {
auto buffer = AllocateBitmapWithReserve(array->length + array->offset, ctx->memory_pool());
arrow::internal::BitmapAnd(child->GetValues<uint8_t>(0, 0), child->offset, array->GetValues<uint8_t>(0, 0), array->offset, array->length, array->offset, buffer->mutable_data());
newArrayData->buffers[0] = buffer;
}

newArrayData->SetNullCount(arrow::kUnknownNullCount);
*res = arrow::Datum(newArrayData);
}
}

return arrow::Status::OK();
}

private:
const std::shared_ptr<arrow::DataType> ReturnArrowType;
const ui32 Index;
const bool IsOptional;
const bool NeedExternalOptional;
};

std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockAsTupleKernel(const TVector<TType*>& argTypes, TType* resultType) {
std::shared_ptr<arrow::DataType> returnArrowType;
MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type");
Expand All @@ -131,20 +79,6 @@ std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockAsTupleKernel(const TVect
return kernel;
}

std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockNthKernel(const TVector<TType*>& argTypes, TType* resultType, ui32 index,
bool isOptional, bool needExternalOptional) {
std::shared_ptr<arrow::DataType> returnArrowType;
MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type");
auto exec = std::make_shared<TBlockNthExec>(returnArrowType, index, isOptional, needExternalOptional);
auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
[exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
return exec->Exec(ctx, batch, res);
});

kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
return kernel;
}

} // namespace

IComputationNode* WrapBlockAsTuple(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
Expand All @@ -159,25 +93,5 @@ IComputationNode* WrapBlockAsTuple(TCallable& callable, const TComputationNodeFa
return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}

IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() == 2U, "Expected two args.");
auto input = callable.GetInput(0U);
auto blockType = AS_TYPE(TBlockType, input.GetStaticType());
bool isOptional;
auto tupleType = AS_TYPE(TTupleType, UnpackOptional(blockType->GetItemType(), isOptional));
auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1U));
auto index = indexData->AsValue().Get<ui32>();
MKQL_ENSURE(index < tupleType->GetElementsCount(), "Bad tuple index");
auto childType = tupleType->GetElementType(index);
bool needExternalOptional = isOptional && childType->IsVariant();

auto tuple = LocateNode(ctx.NodeLocator, callable, 0);

TComputationNodePtrVector argsNodes = { tuple };
TVector<TType*> argsTypes = { blockType };
auto kernel = MakeBlockNthKernel(argsTypes, callable.GetType()->GetReturnType(), index, isOptional, needExternalOptional);
return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
}

}
}
1 change: 0 additions & 1 deletion ydb/library/yql/minikql/comp_nodes/mkql_block_tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ namespace NKikimr {
namespace NMiniKQL {

IComputationNode* WrapBlockAsTuple(TCallable& callable, const TComputationNodeFactoryContext& ctx);
IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx);

}
}
2 changes: 2 additions & 0 deletions ydb/library/yql/minikql/comp_nodes/mkql_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mkql_block_agg.h"
#include "mkql_block_coalesce.h"
#include "mkql_block_exists.h"
#include "mkql_block_getelem.h"
#include "mkql_block_if.h"
#include "mkql_block_just.h"
#include "mkql_block_logical.h"
Expand Down Expand Up @@ -297,6 +298,7 @@ struct TCallableComputationNodeBuilderFuncMapFiller {
{"BlockJust", &WrapBlockJust},
{"BlockCompress", &WrapBlockCompress},
{"BlockAsTuple", &WrapBlockAsTuple},
{"BlockMember", &WrapBlockMember},
{"BlockNth", &WrapBlockNth},
{"BlockExpandChunked", &WrapBlockExpandChunked},
{"BlockCombineAll", &WrapBlockCombineAll},
Expand Down
1 change: 1 addition & 0 deletions ydb/library/yql/minikql/comp_nodes/ya.make.inc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ SET(ORIG_SOURCES
mkql_block_agg_sum.cpp
mkql_block_coalesce.cpp
mkql_block_exists.cpp
mkql_block_getelem.cpp
mkql_block_if.cpp
mkql_block_just.cpp
mkql_block_logical.cpp
Expand Down
18 changes: 18 additions & 0 deletions ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,24 @@ TRuntimeNode TProgramBuilder::BlockExists(TRuntimeNode data) {
return TRuntimeNode(callableBuilder.Build(), false);
}

TRuntimeNode TProgramBuilder::BlockMember(TRuntimeNode structObj, const std::string_view& memberName) {
auto blockType = AS_TYPE(TBlockType, structObj.GetStaticType());
bool isOptional;
const auto type = AS_TYPE(TStructType, UnpackOptional(blockType->GetItemType(), isOptional));

const auto memberIndex = type->GetMemberIndex(memberName);
auto memberType = type->GetMemberType(memberIndex);
if (isOptional && !memberType->IsOptional() && !memberType->IsNull() && !memberType->IsPg()) {
memberType = NewOptionalType(memberType);
}

auto returnType = NewBlockType(memberType, blockType->GetShape());
TCallableBuilder callableBuilder(Env, __func__, returnType);
callableBuilder.Add(structObj);
callableBuilder.Add(NewDataLiteral<ui32>(memberIndex));
return TRuntimeNode(callableBuilder.Build(), false);
}

TRuntimeNode TProgramBuilder::BlockNth(TRuntimeNode tuple, ui32 index) {
auto blockType = AS_TYPE(TBlockType, tuple.GetStaticType());
bool isOptional;
Expand Down
Loading

0 comments on commit 502820a

Please sign in to comment.