Skip to content

Commit

Permalink
next iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
MrLolthe1st committed Feb 5, 2024
1 parent e8d78d0 commit b2b34c0
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 31 deletions.
38 changes: 27 additions & 11 deletions ydb/library/yql/core/type_ann/type_ann_blocks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,35 @@ IGraphTransformer::TStatus BlockExpandChunkedWrapper(const TExprNode::TPtr& inpu
}

TTypeAnnotationNode::TListType blockItemTypes;
if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}
if (input->Head().GetTypeAnn()->GetKind() == ETypeAnnotationKind::Stream) {
if (!EnsureWideStreamBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

auto flowItemTypes = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
bool allScalars = AllOf(flowItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; });
if (allScalars) {
output = input->HeadPtr();
return IGraphTransformer::TStatus::Repeat;
}
auto streamItemTypes = input->Head().GetTypeAnn()->Cast<TStreamExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
bool allScalars = AllOf(streamItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; });
if (allScalars) {
output = input->HeadPtr();
return IGraphTransformer::TStatus::Repeat;
}

input->SetTypeAnn(input->Head().GetTypeAnn());
return IGraphTransformer::TStatus::Ok;
input->SetTypeAnn(input->Head().GetTypeAnn());
return IGraphTransformer::TStatus::Ok;
} else {
if (!EnsureWideFlowBlockType(input->Head(), blockItemTypes, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

auto flowItemTypes = input->Head().GetTypeAnn()->Cast<TFlowExprType>()->GetItemType()->Cast<TMultiExprType>()->GetItems();
bool allScalars = AllOf(flowItemTypes, [](const TTypeAnnotationNode* item) { return item->GetKind() == ETypeAnnotationKind::Scalar; });
if (allScalars) {
output = input->HeadPtr();
return IGraphTransformer::TStatus::Repeat;
}

input->SetTypeAnn(input->Head().GetTypeAnn());
return IGraphTransformer::TStatus::Ok;
}
}

IGraphTransformer::TStatus BlockCoalesceWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
Expand Down
68 changes: 61 additions & 7 deletions ydb/library/yql/minikql/comp_nodes/mkql_blocks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,52 @@ using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedW
const size_t WideFieldsIndex_;
};

class TBlockExpandChunkedStreamWrapper : public TMutableComputationNode<TBlockExpandChunkedStreamWrapper> {
using TBaseComputation = TMutableComputationNode<TBlockExpandChunkedStreamWrapper>;
class TExpanderState : public TComputationValue<TExpanderState> {
using TBase = TComputationValue<TExpanderState>;
public:
TExpanderState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, NUdf::TUnboxedValue&& stream, size_t width)
: TBase(memInfo), HolderFactory_(ctx.HolderFactory), State_(ctx.HolderFactory.Create<TBlockState>(width)), Stream_(stream) {}

NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
auto& s = *static_cast<TBlockState*>(State_.AsBoxed().Get());
if (!s.Count) {
s.ClearValues();
auto result = Stream_.WideFetch(s.Values.data(), width);
if (NUdf::EFetchStatus::Ok != result) {
return result;
}
s.FillArrays();
}

const auto sliceSize = s.Slice();
for (size_t i = 0; i < width; ++i) {
output[i] = s.Get(sliceSize, HolderFactory_, i);
}
return NUdf::EFetchStatus::Ok;
}

private:
const THolderFactory& HolderFactory_;
NUdf::TUnboxedValue State_;
NUdf::TUnboxedValue Stream_;
};
public:
TBlockExpandChunkedStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, size_t width)
: TBaseComputation(mutables, EValueRepresentation::Boxed)
, Stream_(stream)
, Width_(width) {}

NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
return ctx.HolderFactory.Create<TExpanderState>(ctx, std::move(Stream_->GetValue(ctx)), Width_);
}
void RegisterDependencies() const override {}
private:
IComputationNode* const Stream_;
const size_t Width_;
};

} // namespace

IComputationNode* WrapToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
Expand Down Expand Up @@ -1184,13 +1230,21 @@ IComputationNode* WrapReplicateScalar(TCallable& callable, const TComputationNod

IComputationNode* WrapBlockExpandChunked(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());

const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
const auto wideComponents = GetWideComponents(flowType);

const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size());
if (callable.GetInput(0).GetStaticType()->IsStream()) {
const auto streamType = AS_TYPE(TStreamType, callable.GetInput(0).GetStaticType());
const auto wideComponents = GetWideComponents(streamType);
const auto computation = dynamic_cast<IComputationNode*>(LocateNode(ctx.NodeLocator, callable, 0));

MKQL_ENSURE(computation != nullptr, "Expected computation node");
return new TBlockExpandChunkedStreamWrapper(ctx.Mutables, computation, wideComponents.size());
} else {
const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
const auto wideComponents = GetWideComponents(flowType);

const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size());
}
}

}
Expand Down
30 changes: 26 additions & 4 deletions ydb/library/yql/minikql/mkql_program_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,24 @@ bool ReduceOptionalElements(const TType* type, const TArrayRef<const ui32>& test
return multiOptional;
}

std::vector<TType*> ValidateBlockStreamType(const TType* streamType) {
const auto wideComponents = GetWideComponents(AS_TYPE(TStreamType, streamType));
MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
std::vector<TType*> streamItems;
streamItems.reserve(wideComponents.size());
bool isScalar;
for (size_t i = 0; i < wideComponents.size(); ++i) {
auto blockType = AS_TYPE(TBlockType, wideComponents[i]);
isScalar = blockType->GetShape() == TBlockType::EShape::Scalar;
auto withoutBlock = blockType->GetItemType();
streamItems.push_back(withoutBlock);
}

MKQL_ENSURE(isScalar, "Last column should be scalar");
MKQL_ENSURE(AS_TYPE(TDataType, streamItems.back())->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected Uint64");
return streamItems;
}

std::vector<TType*> ValidateBlockFlowType(const TType* flowType) {
const auto wideComponents = GetWideComponents(AS_TYPE(TFlowType, flowType));
MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
Expand Down Expand Up @@ -1550,10 +1568,14 @@ TRuntimeNode TProgramBuilder::BlockCompress(TRuntimeNode flow, ui32 bitmapIndex)
return TRuntimeNode(callableBuilder.Build(), false);
}

TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode flow) {
ValidateBlockFlowType(flow.GetStaticType());
TCallableBuilder callableBuilder(Env, __func__, flow.GetStaticType());
callableBuilder.Add(flow);
TRuntimeNode TProgramBuilder::BlockExpandChunked(TRuntimeNode comp) {
if (comp.GetStaticType()->IsStream()) {
ValidateBlockStreamType(comp.GetStaticType());
} else {
ValidateBlockFlowType(comp.GetStaticType());
}
TCallableBuilder callableBuilder(Env, __func__, comp.GetStaticType());
callableBuilder.Add(comp);
return TRuntimeNode(callableBuilder.Build(), false);
}

Expand Down
21 changes: 12 additions & 9 deletions ydb/library/yql/providers/dq/opt/dqs_opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,18 @@ namespace NYql::NDqs {

YQL_CLOG(INFO, ProviderDq) << "DqsRewritePhyBlockReadOnDqIntegration";
return Build<TCoWideFromBlocks>(ctx, node->Pos())
.Input(ctx.Builder(node->Pos())
.Callable("BlockExpandChunked").Add(0, Build<TCoToFlow>(ctx, node->Pos())
.Input(Build<TDqReadBlockWideWrap>(ctx, node->Pos())
.Input(readWideWrap.Input())
.Flags(readWideWrap.Flags())
.Token(readWideWrap.Token())
.Done())
.Done().Ptr())
.Seal().Build())
.Input(
Build<TCoToFlow>(ctx, node->Pos())
.Input(ctx.Builder(node->Pos()).Callable("BlockExpandChunked")
.Add(0, Build<TDqReadBlockWideWrap>(ctx, node->Pos())
.Input(readWideWrap.Input())
.Flags(readWideWrap.Flags())
.Token(readWideWrap.Token())
.Done().Ptr())
.Seal().Build()
)
.Done()
)
.Done().Ptr();
}, ctx, optSettings);
});
Expand Down

0 comments on commit b2b34c0

Please sign in to comment.