diff --git a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp index be0b2b7edc35..679bd0b56cc1 100644 --- a/ydb/library/yql/core/type_ann/type_ann_blocks.cpp +++ b/ydb/library/yql/core/type_ann/type_ann_blocks.cpp @@ -116,19 +116,35 @@ IGraphTransformer::TStatus BlockExpandChunkedWrapper(const TExprNode::TPtr& inpu } TTypeAnnotationNode::TListType blockItemTypes; - if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { - return IGraphTransformer::TStatus::Error; - } + if (input->Head().GetTypeAnn()->GetKind() == ETypeAnnotationKind::Stream) { + if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } - auto flowItemTypes = input->Head().GetTypeAnn()->Cast()->GetItemType()->Cast()->GetItems(); - bool allScalars = AllOf(flowItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; }); - if (allScalars) { - output = input->HeadPtr(); - return IGraphTransformer::TStatus::Repeat; - } + auto streamItemTypes = input->Head().GetTypeAnn()->Cast()->GetItemType()->Cast()->GetItems(); + bool allScalars = AllOf(streamItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; }); + if (allScalars) { + output = input->HeadPtr(); + return IGraphTransformer::TStatus::Repeat; + } - input->SetTypeAnn(input->Head().GetTypeAnn()); - return IGraphTransformer::TStatus::Ok; + input->SetTypeAnn(input->Head().GetTypeAnn()); + return IGraphTransformer::TStatus::Ok; + } else { + if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) { + return IGraphTransformer::TStatus::Error; + } + + auto flowItemTypes = input->Head().GetTypeAnn()->Cast()->GetItemType()->Cast()->GetItems(); + bool allScalars = AllOf(flowItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; }); + if (allScalars) { + output = input->HeadPtr(); + return IGraphTransformer::TStatus::Repeat; + } + + input->SetTypeAnn(input->Head().GetTypeAnn()); + return IGraphTransformer::TStatus::Ok; + } } IGraphTransformer::TStatus BlockCoalesceWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) { diff --git a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp index a719c116f9ab..9e86008d66c3 100644 --- a/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp +++ b/ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp @@ -1120,6 +1120,52 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode { +using TBaseComputation = TMutableComputationNode; +class TExpanderState : public TComputationValue { +using TBase = TComputationValue; +public: + TExpanderState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, NUdf::TUnboxedValue&& stream, size_t width) + : TBase(memInfo), HolderFactory_(ctx.HolderFactory), State_(ctx.HolderFactory.Create(width)), Stream_(stream) {} + + NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) { + auto& s = *static_cast(State_.AsBoxed().Get()); + if (!s.Count) { + s.ClearValues(); + auto result = Stream_.WideFetch(s.Values.data(), width); + if (NUdf::EFetchStatus::Ok != result) { + return result; + } + s.FillArrays(); + } + + const auto sliceSize = s.Slice(); + for (size_t i = 0; i < width; ++i) { + output[i] = s.Get(sliceSize, HolderFactory_, i); + } + return NUdf::EFetchStatus::Ok; + } + +private: + const THolderFactory& HolderFactory_; + NUdf::TUnboxedValue State_; + NUdf::TUnboxedValue Stream_; +}; +public: + TBlockExpandChunkedStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, size_t width) + : TBaseComputation(mutables, EValueRepresentation::Boxed) + , Stream_(stream) + , Width_(width) {} + + NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { + return ctx.HolderFactory.Create(ctx, std::move(Stream_->GetValue(ctx)), Width_); + } + void RegisterDependencies() const override {} +private: + IComputationNode* const Stream_; + const size_t Width_; +}; + } // namespace IComputationNode* WrapToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) { @@ -1184,13 +1230,21 @@ IComputationNode* WrapReplicateScalar(TCallable& callable, const TComputationNod IComputationNode* WrapBlockExpandChunked(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount()); - - const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); - const auto wideComponents = GetWideComponents(flowType); - - const auto wideFlow = dynamic_cast(LocateNode(ctx.NodeLocator, callable, 0)); - MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); - return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size()); + if (callable.GetInput(0).GetStaticType()->IsStream()) { + const auto streamType = AS_TYPE(TStreamType, callable.GetInput(0).GetStaticType()); + const auto wideComponents = GetWideComponents(streamType); + const auto computation = dynamic_cast(LocateNode(ctx.NodeLocator, callable, 0)); + + MKQL_ENSURE(computation != nullptr, "Expected computation node"); + return new TBlockExpandChunkedStreamWrapper(ctx.Mutables, computation, wideComponents.size()); + } else { + const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType()); + const auto wideComponents = GetWideComponents(flowType); + + const auto wideFlow = dynamic_cast(LocateNode(ctx.NodeLocator, callable, 0)); + MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node"); + return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size()); + } } } diff --git a/ydb/library/yql/minikql/mkql_program_builder.cpp b/ydb/library/yql/minikql/mkql_program_builder.cpp index c12742d9a606..c57e118ab126 100644 --- a/ydb/library/yql/minikql/mkql_program_builder.cpp +++ b/ydb/library/yql/minikql/mkql_program_builder.cpp @@ -228,6 +228,24 @@ bool ReduceOptionalElements(const TType* type, const TArrayRef& test return multiOptional; } +std::vector ValidateBlockStreamType(const TType* streamType) { + const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType)); + MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column"); + std::vector streamItems; + streamItems.reserve(wideComponents.size()); + bool isScalar; + for (size_t i = 0; i < wideComponents.size(); ++i) { + auto blockType = AS_TYPE(TBlockType, wideComponents[i]); + isScalar = blockType->GetShape() == TBlockType::EShape::Scalar; + auto withoutBlock = blockType->GetItemType(); + streamItems.push_back(withoutBlock); + } + + MKQL_ENSURE(isScalar, "Last column should be scalar"); + MKQL_ENSURE(AS_TYPE(TDataType, streamItems.back())->GetSchemeType() == NUdf::TDataType::Id, "Expected Uint64"); + return streamItems; +} + std::vector ValidateBlockFlowType(const TType* flowType) { const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType)); MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column"); @@ -1550,10 +1568,14 @@ TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode flow, ui32 bitmapIndex) return TRuntimeNode(callableBuilder.Build(), false); } -TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode flow) { - ValidateBlockFlowType(flow.GetStaticType()); - TCallableBuilder callableBuilder(Env, __func__, flow.GetStaticType()); - callableBuilder.Add(flow); +TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode comp) { + if (comp.GetStaticType()->IsStream()) { + ValidateBlockStreamType(comp.GetStaticType()); + } else { + ValidateBlockFlowType(comp.GetStaticType()); + } + TCallableBuilder callableBuilder(Env, __func__, comp.GetStaticType()); + callableBuilder.Add(comp); return TRuntimeNode(callableBuilder.Build(), false); } diff --git a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp index 2e9b71350a47..c0f666745dd5 100644 --- a/ydb/library/yql/providers/dq/opt/dqs_opt.cpp +++ b/ydb/library/yql/providers/dq/opt/dqs_opt.cpp @@ -94,15 +94,18 @@ namespace NYql::NDqs { YQL_CLOG(INFO, ProviderDq) << "DqsRewritePhyBlockReadOnDqIntegration"; return Build(ctx, node->Pos()) - .Input(ctx.Builder(node->Pos()) - .Callable("BlockExpandChunked").Add(0, Build(ctx, node->Pos()) - .Input(Build(ctx, node->Pos()) - .Input(readWideWrap.Input()) - .Flags(readWideWrap.Flags()) - .Token(readWideWrap.Token()) - .Done()) - .Done().Ptr()) - .Seal().Build()) + .Input( + Build(ctx, node->Pos()) + .Input(ctx.Builder(node->Pos()).Callable("BlockExpandChunked") + .Add(0, Build(ctx, node->Pos()) + .Input(readWideWrap.Input()) + .Flags(readWideWrap.Flags()) + .Token(readWideWrap.Token()) + .Done().Ptr()) + .Seal().Build() + ) + .Done() + ) .Done().Ptr(); }, ctx, optSettings); });