Skip to content

Commit

Permalink
Return no-yield guaranties to Collect (#7219)
Browse files Browse the repository at this point in the history
  • Loading branch information
lll-phill-lll authored Jul 30, 2024
1 parent 367b417 commit 7120c84
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 60 deletions.
6 changes: 2 additions & 4 deletions ydb/library/yql/minikql/comp_nodes/mkql_collect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ using TBaseComputation = TMutableCodegeneratorRootNode<TCollectFlowWrapper>;
if (item.IsFinish()) {
return list.Release();
}

if (!item.IsYield()) {
list = ctx.HolderFactory.Append(list.Release(), item.Release());
}
MKQL_ENSURE(!item.IsYield(), "Unexpected flow status!");
list = ctx.HolderFactory.Append(list.Release(), item.Release());
}
}
#ifndef MKQL_DISABLE_CODEGEN
Expand Down
92 changes: 36 additions & 56 deletions ydb/library/yql/minikql/comp_nodes/ut/mkql_wide_combine_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,24 @@ TRuntimeNode WideLastCombiner(TProgramBuilder& pb, TRuntimeNode flow, const TPro
pb.WideLastCombiner(flow, extractor, init, update, finish);
}

void CheckIfStreamHasExpectedStringValues(const NUdf::TUnboxedValue& streamValue, std::unordered_set<TString>& expected) {
NUdf::TUnboxedValue item;
NUdf::EFetchStatus fetchStatus;
while (!expected.empty()) {
fetchStatus = streamValue.Fetch(item);
UNIT_ASSERT_UNEQUAL(fetchStatus, NUdf::EFetchStatus::Finish);
if (fetchStatus == NYql::NUdf::EFetchStatus::Yield) continue;

const auto actual = TString(item.AsStringRef());

auto it = expected.find(actual);
UNIT_ASSERT(it != expected.end());
expected.erase(it);
}
fetchStatus = streamValue.Fetch(item);
UNIT_ASSERT_EQUAL(fetchStatus, NUdf::EFetchStatus::Finish);
}

} // unnamed

#if !defined(MKQL_RUNTIME_VERSION) || MKQL_RUNTIME_VERSION >= 18u
Expand Down Expand Up @@ -1049,7 +1067,7 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {

const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9});

const auto pgmReturn = pb.Collect(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
const auto pgmReturn = pb.FromFlow(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
[&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }),
[&](TRuntimeNode::TList items) -> TRuntimeNode::TList { return {items.front()}; },
[&](TRuntimeNode::TList keys, TRuntimeNode::TList items) -> TRuntimeNode::TList {
Expand All @@ -1076,26 +1094,16 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {
if (SPILLING) {
graph->GetContext().SpillerFactory = std::make_shared<TMockSpillerFactory>();
}
const auto iterator = graph->GetValue().GetListIterator();

const auto streamVal = graph->GetValue();
std::unordered_set<TString> expected {
"key one",
"very long value 2 / key two",
"very long key one",
"very long value 8 / very long value 7 / very long value 6"
};

NUdf::TUnboxedValue item;
while (!expected.empty()) {
UNIT_ASSERT(iterator.Next(item));
const auto actual = TString(item.AsStringRef());

auto it = expected.find(actual);
UNIT_ASSERT(it != expected.end());
expected.erase(it);
}
UNIT_ASSERT(!iterator.Next(item));
UNIT_ASSERT(!iterator.Next(item));
CheckIfStreamHasExpectedStringValues(streamVal, expected);
}

Y_UNIT_TEST_LLVM_SPILLING(TestLongStringsPasstroughtRefCounting) {
Expand Down Expand Up @@ -1140,7 +1148,7 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {

const auto list = pb.NewList(tupleType, {data1, data2, data3, data4, data5, data6, data7, data8, data9});

const auto pgmReturn = pb.Collect(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
const auto pgmReturn = pb.FromFlow(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
[&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U)}; }),
[&](TRuntimeNode::TList items) -> TRuntimeNode::TList { return {items.front()}; },
[&](TRuntimeNode::TList keys, TRuntimeNode::TList items) -> TRuntimeNode::TList {
Expand All @@ -1166,26 +1174,16 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {
if (SPILLING) {
graph->GetContext().SpillerFactory = std::make_shared<TMockSpillerFactory>();
}
const auto iterator = graph->GetValue().GetListIterator();

const auto streamVal = graph->GetValue();
std::unordered_set<TString> expected {
"very long value 1 / key one / very long value 1 / key one",
"very long value 3 / key two / very long value 2 / key two",
"very long value 4 / very long key one / very long value 4 / very long key one",
"very long value 9 / very long key two / very long value 5 / very long key two"
};

NUdf::TUnboxedValue item;
while (!expected.empty()) {
UNIT_ASSERT(iterator.Next(item));
const auto actual = TString(item.AsStringRef());

auto it = expected.find(actual);
UNIT_ASSERT(it != expected.end());
expected.erase(it);
}
UNIT_ASSERT(!iterator.Next(item));
UNIT_ASSERT(!iterator.Next(item));
CheckIfStreamHasExpectedStringValues(streamVal, expected);
}

Y_UNIT_TEST_LLVM_SPILLING(TestDoNotCalculateUnusedInput) {
Expand Down Expand Up @@ -1230,7 +1228,7 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {

const auto landmine = pb.NewDataLiteral<NUdf::EDataSlot::String>("ACHTUNG MINEN!");

const auto pgmReturn = pb.Collect(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
const auto pgmReturn = pb.FromFlow(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
[&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Unwrap(pb.Nth(item, 1U), landmine, __FILE__, __LINE__, 0), pb.Nth(item, 2U)}; }),
[&](TRuntimeNode::TList items) -> TRuntimeNode::TList { return {items.front()}; },
[&](TRuntimeNode::TList keys, TRuntimeNode::TList items) -> TRuntimeNode::TList {
Expand All @@ -1257,23 +1255,14 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {
if (SPILLING) {
graph->GetContext().SpillerFactory = std::make_shared<TMockSpillerFactory>();
}

const auto streamVal = graph->GetValue();
std::unordered_set<TString> expected {
"key one / value 2 / value 1 / value 5 / value 4",
"key two / value 4 / value 3 / value 3 / value 2"
};

const auto iterator = graph->GetValue().GetListIterator();
NUdf::TUnboxedValue item;
while (!expected.empty()) {
UNIT_ASSERT(iterator.Next(item));
const auto actual = TString(item.AsStringRef());

auto it = expected.find(actual);
UNIT_ASSERT(it != expected.end());
expected.erase(it);
}
UNIT_ASSERT(!iterator.Next(item));
UNIT_ASSERT(!iterator.Next(item));
CheckIfStreamHasExpectedStringValues(streamVal, expected);
}

Y_UNIT_TEST_LLVM_SPILLING(TestDoNotCalculateUnusedOutput) {
Expand Down Expand Up @@ -1315,7 +1304,7 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {

const auto landmine = pb.NewDataLiteral<NUdf::EDataSlot::String>("ACHTUNG MINEN!");

const auto pgmReturn = pb.Collect(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
const auto pgmReturn = pb.FromFlow(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
[&](TRuntimeNode item) -> TRuntimeNode::TList { return {pb.Nth(item, 0U), pb.Nth(item, 1U), pb.Nth(item, 2U)}; }),
[&](TRuntimeNode::TList items) -> TRuntimeNode::TList { return {items.front()}; },
[&](TRuntimeNode::TList, TRuntimeNode::TList items) -> TRuntimeNode::TList {
Expand All @@ -1334,23 +1323,14 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {
if (SPILLING) {
graph->GetContext().SpillerFactory = std::make_shared<TMockSpillerFactory>();
}

const auto streamVal = graph->GetValue();
std::unordered_set<TString> expected {
"key one: value 1, value 4, value 5, value 1, value 2",
"key two: value 2, value 3, value 3, value 4"
};

const auto iterator = graph->GetValue().GetListIterator();
NUdf::TUnboxedValue item;
while (!expected.empty()) {
UNIT_ASSERT(iterator.Next(item));
const auto actual = TString(item.AsStringRef());

auto it = expected.find(actual);
UNIT_ASSERT(it != expected.end());
expected.erase(it);
}
UNIT_ASSERT(!iterator.Next(item));
UNIT_ASSERT(!iterator.Next(item));
CheckIfStreamHasExpectedStringValues(streamVal, expected);
}

Y_UNIT_TEST_LLVM_SPILLING(TestThinAllLambdas) {
Expand All @@ -1366,7 +1346,7 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {

const auto list = pb.NewList(tupleType, {data, data, data, data});

const auto pgmReturn = pb.Collect(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
const auto pgmReturn = pb.FromFlow(pb.NarrowMap(WideLastCombiner<SPILLING>(pb, pb.ExpandMap(pb.ToFlow(list),
[](TRuntimeNode) -> TRuntimeNode::TList { return {}; }),
[](TRuntimeNode::TList items) { return items; },
[](TRuntimeNode::TList, TRuntimeNode::TList items) { return items; },
Expand All @@ -1376,10 +1356,10 @@ Y_UNIT_TEST_SUITE(TMiniKQLWideLastCombinerTest) {
));

const auto graph = setup.BuildGraph(pgmReturn);
const auto iterator = graph->GetValue().GetListIterator();
const auto streamVal = graph->GetValue();
NUdf::TUnboxedValue item;
UNIT_ASSERT(!iterator.Next(item));
UNIT_ASSERT(!iterator.Next(item));
const auto fetchStatus = streamVal.Fetch(item);
UNIT_ASSERT_EQUAL(fetchStatus, NUdf::EFetchStatus::Finish);
}
}

Expand Down

0 comments on commit 7120c84

Please sign in to comment.