Skip to content

Commit

Permalink
Merge 45e8479 into 2cb82f8
Browse files Browse the repository at this point in the history
  • Loading branch information
kardymonds authored Feb 21, 2024
2 parents 2cb82f8 + 45e8479 commit c1aa960
Show file tree
Hide file tree
Showing 10 changed files with 886 additions and 100 deletions.
219 changes: 196 additions & 23 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp

Large diffs are not rendered by default.

46 changes: 44 additions & 2 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_list.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#pragma once

#include "mkql_match_recognize_save_load.h"

#include <ydb/library/yql/minikql/defs.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h>
#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
#include <ydb/library/yql/minikql/comp_nodes/mkql_saveload.h>
#include <ydb/library/yql/public/udf/udf_value.h>
#include <unordered_map>

Expand Down Expand Up @@ -131,15 +135,37 @@ class TSparseList {
}
}

void Save(TOutputSerializer& serializer) const {
serializer(Storage.size());
for (const auto& [key, item]: Storage) {
serializer(key, item.Value, item.LockCount);
}
}

void Load(TInputSerializer& serializer) {
auto size = serializer.Read<TStorage::size_type>();
Storage.reserve(size);
for (size_t i = 0; i < size; ++i) {
TStorage::key_type key;
NUdf::TUnboxedValue row;
decltype(TItem::LockCount) lockCount;
serializer(key, row, lockCount);
Storage.emplace(key, TItem{row, lockCount});
}
}

private:
//TODO consider to replace hash table with contiguous chunks
using TAllocator = TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>;
std::unordered_map<

using TStorage = std::unordered_map<
size_t,
TItem,
std::hash<size_t>,
std::equal_to<size_t>,
TAllocator> Storage;
TAllocator>;

TStorage Storage;
};
using TContainerPtr = TContainer::TPtr;

Expand Down Expand Up @@ -242,6 +268,14 @@ class TSparseList {
ToIndex = -1;
}

void Save(TOutputSerializer& serializer) const {
serializer(Container, FromIndex, ToIndex);
}

void Load(TInputSerializer& serializer) {
serializer(Container, FromIndex, ToIndex);
}

private:
TRange(TContainerPtr container, size_t index)
: Container(container)
Expand Down Expand Up @@ -297,6 +331,14 @@ class TSparseList {
return Size() == 0;
}

void Save(TOutputSerializer& serializer) const {
serializer(Container, ListSize);
}

void Load(TInputSerializer& serializer) {
serializer(Container, ListSize);
}

private:
TContainerPtr Container = MakeIntrusive<TContainer>();
size_t ListSize = 0; //impl: max index ever stored + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {

template<class R>
using TMatchedVar = std::vector<R, TMKQLAllocator<R>>;

template<class R>
void Extend(TMatchedVar<R>& var, const R& r) {
if (var.empty()) {
Expand Down Expand Up @@ -110,8 +111,7 @@ class TMatchedVarsValue : public TComputationValue<TMatchedVarsValue<R>> {
: TComputationValue<TMatchedVarsValue>(memInfo)
, HolderFactory(holderFactory)
, Vars(vars)
{
}
{}

NUdf::TUnboxedValue GetElement(ui32 index) const override {
return HolderFactory.Create<TRangeList>(HolderFactory, Vars[index]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TRowForMeasureValue: public TComputationValue<TRowForMeasureValue>
, VarNames(varNames)
, MatchNumber(matchNumber)
{}

NUdf::TUnboxedValue GetElement(ui32 index) const override {
switch(ColumnOrder[index].first) {
case EMeasureInputDataSpecialColumns::Classifier: {
Expand Down
137 changes: 132 additions & 5 deletions ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "mkql_match_recognize_matched_vars.h"
#include "mkql_match_recognize_save_load.h"
#include "../computation/mkql_computation_node_holders.h"
#include "../computation/mkql_computation_node_impl.h"
#include <ydb/library/yql/core/sql_types/match_recognize.h>
Expand All @@ -12,20 +13,38 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {
using namespace NYql::NMatchRecognize;

struct TVoidTransition {
friend bool operator==(const TVoidTransition&, const TVoidTransition&) {
return true;
}
};
using TEpsilonTransition = size_t; //to
using TEpsilonTransitions = std::vector<TEpsilonTransition, TMKQLAllocator<TEpsilonTransition>>;
using TMatchedVarTransition = std::pair<std::pair<ui32, bool>, size_t>; //{{varIndex, saveState}, to}
using TQuantityEnterTransition = size_t; //to
using TQuantityExitTransition = std::pair<std::pair<ui64, ui64>, std::pair<size_t, size_t>>; //{{min, max}, {foFindMore, toMatched}}
using TNfaTransition = std::variant<

template <typename... Ts>
struct TVariantHelper {
using TVariant = std::variant<Ts...>;
using TTuple = std::tuple<Ts...>;

static std::variant<Ts...> getVariantByIndex(size_t i) {
MKQL_ENSURE(i < sizeof...(Ts), "Wrong variant index");
static std::variant<Ts...> table[] = { Ts{ }... };
return table[i];
}
};

using TNfaTransitionHelper = TVariantHelper<
TVoidTransition,
TMatchedVarTransition,
TEpsilonTransitions,
TQuantityEnterTransition,
TQuantityExitTransition
>;

using TNfaTransition = TNfaTransitionHelper::TVariant;

struct TNfaTransitionDestinationVisitor {
std::function<size_t(size_t)> callback;

Expand Down Expand Up @@ -61,11 +80,42 @@ struct TNfaTransitionDestinationVisitor {
};

struct TNfaTransitionGraph {
std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>> Transitions;
using TTransitions = std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>>;

TTransitions Transitions;
size_t Input;
size_t Output;

using TPtr = std::shared_ptr<TNfaTransitionGraph>;

template<class>
inline constexpr static bool always_false_v = false;

void Save(TOutputSerializer& serializer) const {
serializer(Transitions.size());
for (ui64 i = 0; i < Transitions.size(); ++i) {
serializer.Write(Transitions[i].index());
std::visit(serializer, Transitions[i]);
}
serializer(Input, Output);
}

void Load(TInputSerializer& serializer) {
ui64 transitionSize = serializer.Read<TTransitions::size_type>();
Transitions.resize(transitionSize);
for (ui64 i = 0; i < transitionSize; ++i) {
size_t index = serializer.Read<std::size_t>();
Transitions[i] = TNfaTransitionHelper::getVariantByIndex(index);
std::visit(serializer, Transitions[i]);
}
serializer(Input, Output);
}

bool operator==(const TNfaTransitionGraph& other) {
return Transitions == other.Transitions
&& Input == other.Input
&& Output == other.Output;
}
};

class TNfaTransitionGraphOptimizer {
Expand All @@ -78,6 +128,7 @@ class TNfaTransitionGraphOptimizer {
EliminateSingleEpsilons();
CollectGarbage();
}

private:
void EliminateEpsilonChains() {
for (size_t node = 0; node != Graph->Transitions.size(); node++) {
Expand Down Expand Up @@ -250,14 +301,69 @@ class TNfaTransitionGraphBuilder {
class TNfa {
using TRange = TSparseList::TRange;
using TMatchedVars = TMatchedVars<TRange>;


struct TState {

TState() {}

TState(size_t index, const TMatchedVars& vars, std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>&& quantifiers)
: Index(index)
, Vars(vars)
, Quantifiers(quantifiers) {}
const size_t Index;
size_t Index;
TMatchedVars Vars;
std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>> Quantifiers; //get rid of this

using TQuantifiersStdStack = std::stack<
ui64,
std::deque<ui64, TMKQLAllocator<ui64>>>; //get rid of this

struct TQuantifiersStack: public TQuantifiersStdStack {
template<typename...TArgs>
TQuantifiersStack(TArgs... args) : TQuantifiersStdStack(args...) {}

auto begin() const { return c.begin(); }
auto end() const { return c.end(); }
auto clear() { return c.clear(); }
};

TQuantifiersStack Quantifiers;

void Save(TOutputSerializer& serializer) const {
serializer.Write(Index);
serializer.Write(Vars.size());
for (const auto& vector : Vars) {
serializer.Write(vector.size());
for (const auto& range : vector) {
range.Save(serializer);
}
}
serializer.Write(Quantifiers.size());
for (ui64 qnt : Quantifiers) {
serializer.Write(qnt);
}
}

void Load(TInputSerializer& serializer) {
serializer.Read(Index);

auto varsSize = serializer.Read<TMatchedVars::size_type>();
Vars.clear();
Vars.resize(varsSize);
for (auto& subvec: Vars) {
ui64 vectorSize = serializer.Read<ui64>();
subvec.resize(vectorSize);
for (auto& item : subvec) {
item.Load(serializer);
}
}
Quantifiers.clear();
auto quantifiersSize = serializer.Read<ui64>();
for (size_t i = 0; i < quantifiersSize; ++i) {
ui64 qnt = serializer.Read<ui64>();
Quantifiers.push(qnt);
}
}

friend inline bool operator<(const TState& lhs, const TState& rhs) {
return std::tie(lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.Index, rhs.Quantifiers, rhs.Vars);
Expand All @@ -267,13 +373,14 @@ class TNfa {
}
};
public:

TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines)
: TransitionGraph(transitionGraph)
, MatchedRangesArg(matchedRangesArg)
, Defines(defines) {
}

void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
ActiveStates.emplace(TransitionGraph->Input, TMatchedVars(Defines.size()), std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>{});
MakeEpsilonTransitions();
std::set<TState, std::less<TState>, TMKQLAllocator<TState>> newStates;
Expand Down Expand Up @@ -329,6 +436,25 @@ class TNfa {
return ActiveStates.size();
}

void Save(TOutputSerializer& serializer) const {
// TransitionGraph is not saved/loaded, passed in constructor.
serializer.Write(ActiveStates.size());
for (const auto& state : ActiveStates) {
state.Save(serializer);
}
serializer.Write(EpsilonTransitionsLastRow);
}

void Load(TInputSerializer& serializer) {
auto stateSize = serializer.Read<ui64>();
for (size_t i = 0; i < stateSize; ++i) {
TState state;
state.Load(serializer);
ActiveStates.emplace(state);
}
serializer.Read(EpsilonTransitionsLastRow);
}

private:
//TODO (zverevgeny): Consider to change to std::vector for the sake of perf
using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
Expand Down Expand Up @@ -376,6 +502,7 @@ class TNfa {
TStateSet& NewStates;
TStateSet& DeletedStates;
};

bool MakeEpsilonTransitionsImpl() {
TStateSet newStates;
TStateSet deletedStates;
Expand Down
Loading

0 comments on commit c1aa960

Please sign in to comment.