Skip to content

Commit

Permalink
Merge acb38f1 into 4eceb69
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-Romanov authored Jan 25, 2024
2 parents 4eceb69 + acb38f1 commit 49e5063
Show file tree
Hide file tree
Showing 20 changed files with 219 additions and 131 deletions.
71 changes: 63 additions & 8 deletions ydb/library/yql/core/arrow_kernels/registry/ut/registry_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,57 @@ Y_UNIT_TEST_SUITE(TKernelRegistryTest) {
});
}

Y_UNIT_TEST(TestAddSubMulOps) {
for (const auto oper : {TKernelRequestBuilder::EBinaryOp::Add, TKernelRequestBuilder::EBinaryOp::Sub, TKernelRequestBuilder::EBinaryOp::Mul}) {
for (const auto slot : {EDataSlot::Int8, EDataSlot::Int16, EDataSlot::Int32, EDataSlot::Int64, EDataSlot::Uint8, EDataSlot::Uint16, EDataSlot::Uint32, EDataSlot::Uint64, EDataSlot::Float, EDataSlot::Double}) {
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockUint8Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Uint8));
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddBinaryOp(oper, blockUint8Type, blockType, blockType);
});
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockUint8Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Uint8));
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddBinaryOp(oper, blockType, blockUint8Type, blockType);
});
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddBinaryOp(oper, blockType, blockType, blockType);
});
}
}
}

Y_UNIT_TEST(TestDivModOps) {
for (const auto oper : {TKernelRequestBuilder::EBinaryOp::Div, TKernelRequestBuilder::EBinaryOp::Mod}) {
for (const auto slot : {EDataSlot::Int8, EDataSlot::Int16, EDataSlot::Int32, EDataSlot::Int64, EDataSlot::Uint8, EDataSlot::Uint16, EDataSlot::Uint32, EDataSlot::Uint64, EDataSlot::Float, EDataSlot::Double}) {
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockUint8Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Uint8));
const auto rawType = ctx.template MakeType<TDataExprType>(slot);
const auto blockType = ctx.template MakeType<TBlockExprType>(rawType);
const auto returnType = EDataSlot::Float != slot && EDataSlot::Double != slot ?
ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TOptionalExprType>(rawType)) : blockType;
return b.AddBinaryOp(oper, blockUint8Type, blockType, returnType);
});
TestOne([slot, oper](auto& b,auto& ctx) {
const auto blockUint8Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Uint8));
const auto rawType = ctx.template MakeType<TDataExprType>(slot);
const auto blockType = ctx.template MakeType<TBlockExprType>(rawType);
const auto returnType = EDataSlot::Float != slot && EDataSlot::Double != slot ?
ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TOptionalExprType>(rawType)) : blockType;
return b.AddBinaryOp(oper, blockType, blockUint8Type, returnType);
});
TestOne([slot, oper](auto& b,auto& ctx) {
const auto rawType = ctx.template MakeType<TDataExprType>(slot);
const auto blockType = ctx.template MakeType<TBlockExprType>(rawType);
const auto returnType = EDataSlot::Float != slot && EDataSlot::Double != slot ?
ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TOptionalExprType>(rawType)) : blockType;
return b.AddBinaryOp(oper, blockType, blockType, returnType);
});
}
}
}

Y_UNIT_TEST(TestSize) {
TestOne([](auto& b,auto& ctx) {
auto blockStrType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::String));
Expand All @@ -121,17 +172,21 @@ Y_UNIT_TEST_SUITE(TKernelRegistryTest) {
}

Y_UNIT_TEST(TestMinus) {
TestOne([](auto& b,auto& ctx) {
auto blockInt32Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Int32));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Minus, blockInt32Type, blockInt32Type);
});
for (const auto slot : {EDataSlot::Int8, EDataSlot::Int16, EDataSlot::Int32, EDataSlot::Int64, EDataSlot::Float, EDataSlot::Double}) {
TestOne([slot](auto& b,auto& ctx) {
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Minus, blockType, blockType);
});
}
}

Y_UNIT_TEST(TestAbs) {
TestOne([](auto& b,auto& ctx) {
auto blockInt32Type = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(EDataSlot::Int32));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Abs, blockInt32Type, blockInt32Type);
});
for (const auto slot : {EDataSlot::Int8, EDataSlot::Int16, EDataSlot::Int32, EDataSlot::Int64, EDataSlot::Float, EDataSlot::Double}) {
TestOne([slot](auto& b,auto& ctx) {
const auto blockType = ctx.template MakeType<TBlockExprType>(ctx.template MakeType<TDataExprType>(slot));
return b.AddUnaryOp(TKernelRequestBuilder::EUnaryOp::Abs, blockType, blockType);
});
}
}

Y_UNIT_TEST(TestCoalesece) {
Expand Down
8 changes: 4 additions & 4 deletions ydb/library/yql/minikql/arrow/mkql_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ bool FindArrowFunction(TStringBuf name, const TArrayRef<TType*>& inputTypes, TTy
}

bool match = false;
switch (kernel->Family.NullMode) {
case TKernelFamily::ENullMode::Default:
switch (kernel->NullMode) {
case TKernel::ENullMode::Default:
match = returnIsOptional == hasOptionals;
break;
case TKernelFamily::ENullMode::AlwaysNull:
case TKernel::ENullMode::AlwaysNull:
match = returnIsOptional;
break;
case TKernelFamily::ENullMode::AlwaysNotNull:
case TKernel::ENullMode::AlwaysNotNull:
match = !returnIsOptional;
break;
}
Expand Down
59 changes: 0 additions & 59 deletions ydb/library/yql/minikql/invoke_builtins/mkql_builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,65 +19,6 @@ namespace NMiniKQL {

namespace {

class TForeignKernel : public TKernel {
public:
TForeignKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType,
const std::shared_ptr<arrow::compute::Function>& function)
: TKernel(family, argTypes, returnType)
, Function(function)
, ArrowKernel(ResolveKernel(Function, argTypes))
{}

const arrow::compute::ScalarKernel& GetArrowKernel() const final {
return ArrowKernel;
}

private:
static const arrow::compute::ScalarKernel& ResolveKernel(const std::shared_ptr<arrow::compute::Function>& function,
const std::vector<NUdf::TDataTypeId>& argTypes) {
std::vector<arrow::ValueDescr> args;
for (const auto& t : argTypes) {
args.emplace_back();
auto slot = NUdf::FindDataSlot(t);
MKQL_ENSURE(slot, "Unexpected data type");
MKQL_ENSURE(ConvertArrowType(*slot, args.back().type), "Can't get arrow type");
}

const auto kernel = ARROW_RESULT(function->DispatchExact(args));
return *static_cast<const arrow::compute::ScalarKernel*>(kernel);
}

private:
const std::shared_ptr<arrow::compute::Function> Function;
const arrow::compute::ScalarKernel& ArrowKernel;
};

template <typename TInput1, typename TOutput>
void RegisterUnary(const arrow::compute::FunctionRegistry& registry, std::string_view name, TKernelFamilyMap& kernelFamilyMap) {
auto func = ARROW_RESULT(registry.GetFunction(std::string(name)));

std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id });
NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id;

auto family = std::make_unique<TKernelFamilyBase>();
family->Adopt(argTypes, returnType, std::make_unique<TForeignKernel>(*family, argTypes, returnType, func));

Y_ENSURE(kernelFamilyMap.emplace(TString(name), std::move(family)).second);
}

template <typename TInput1, typename TInput2, typename TOutput>
void RegisterBinary(const arrow::compute::FunctionRegistry& registry, std::string_view name, TKernelFamilyMap& kernelFamilyMap) {
auto func = ARROW_RESULT(registry.GetFunction(std::string(name)));

std::vector<NUdf::TDataTypeId> argTypes({ NUdf::TDataType<TInput1>::Id, NUdf::TDataType<TInput2>::Id });
NUdf::TDataTypeId returnType = NUdf::TDataType<TOutput>::Id;

auto family = std::make_unique<TKernelFamilyBase>();
family->Adopt(argTypes, returnType, std::make_unique<TForeignKernel>(*family, argTypes, returnType, func));

Y_ENSURE(kernelFamilyMap.emplace(TString(name), std::move(family)).second);
}

void RegisterDefaultOperations(IBuiltinFunctionRegistry& registry, TKernelFamilyMap& kernelFamilyMap) {
RegisterAdd(registry);
RegisterAdd(kernelFamilyMap);
Expand Down
2 changes: 2 additions & 0 deletions ydb/library/yql/minikql/invoke_builtins/mkql_builtins_abs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ inline T Abs(T v) {

template<typename TInput, typename TOutput>
struct TAbs : public TSimpleArithmeticUnary<TInput, TOutput, TAbs<TInput, TOutput>> {
static constexpr auto NullMode = TKernel::ENullMode::Default;

static TOutput Do(TInput val)
{
return Abs<TInput>(val);
Expand Down
4 changes: 2 additions & 2 deletions ydb/library/yql/minikql/invoke_builtins/mkql_builtins_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace {

template<typename TLeft, typename TRight, typename TOutput>
struct TAdd : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TAdd<TLeft, TRight, TOutput>> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;

static TOutput Do(TOutput left, TOutput right)
{
Expand Down Expand Up @@ -193,7 +193,7 @@ void RegisterAdd(IBuiltinFunctionRegistry& registry) {
}

void RegisterAdd(TKernelFamilyMap& kernelFamilyMap) {
kernelFamilyMap["Add"] = std::make_unique<TBinaryNumericKernelFamily<TAdd>>();
kernelFamilyMap["Add"] = std::make_unique<TBinaryNumericKernelFamily<TAdd, TAdd>>();
}

void RegisterAggrAdd(IBuiltinFunctionRegistry& registry) {
Expand Down
8 changes: 5 additions & 3 deletions ydb/library/yql/minikql/invoke_builtins/mkql_builtins_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ template<typename TLeft, typename TRight, typename TOutput>
struct TDiv : public TSimpleArithmeticBinary<TLeft, TRight, TOutput, TDiv<TLeft, TRight, TOutput>> {
static_assert(std::is_floating_point<TOutput>::value, "expected floating point");

static constexpr auto NullMode = TKernel::ENullMode::Default;

static TOutput Do(TOutput left, TOutput right)
{
return left / right;
Expand All @@ -29,7 +31,7 @@ template <typename TLeft, typename TRight, typename TOutput>
struct TIntegralDiv {
static_assert(std::is_integral<TOutput>::value, "integral type expected");

static constexpr bool DefaultNulls = false;
static constexpr auto NullMode = TKernel::ENullMode::AlwaysNull;

static NUdf::TUnboxedValuePod Execute(const NUdf::TUnboxedValuePod& left, const NUdf::TUnboxedValuePod& right)
{
Expand Down Expand Up @@ -60,7 +62,7 @@ struct TIntegralDiv {
const auto result = PHINode::Create(type, 2, "result", done);
result->addIncoming(zero, block);

if (std::is_signed<TOutput>() && sizeof(TOutput) <= sizeof(TLeft)) {
if constexpr (std::is_signed<TOutput>() && sizeof(TOutput) <= sizeof(TLeft)) {
const auto min = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, lv, ConstantInt::get(lv->getType(), Min<TOutput>()), "min", block);
const auto one = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rv, ConstantInt::get(rv->getType(), -1), "one", block);
const auto two = BinaryOperator::CreateAnd(min, one, "two", block);
Expand Down Expand Up @@ -167,7 +169,7 @@ void RegisterDiv(IBuiltinFunctionRegistry& registry) {
}

void RegisterDiv(TKernelFamilyMap& kernelFamilyMap) {
kernelFamilyMap["Div"] = std::make_unique<TBinaryNumericKernelFamily<TIntegralDiv>>(TKernelFamily::ENullMode::AlwaysNull);
kernelFamilyMap["Div"] = std::make_unique<TBinaryNumericKernelFamily<TIntegralDiv, TDiv>>();
}

} // namespace NMiniKQL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct TEqualsOp;

template<typename TLeft, typename TRight>
struct TEqualsOp<TLeft, TRight, bool> : public TEquals<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down Expand Up @@ -190,7 +190,7 @@ struct TDiffDateEqualsOp;

template<typename TLeft, typename TRight>
struct TDiffDateEqualsOp<TLeft, TRight, NUdf::TDataType<bool>> : public TDiffDateEquals<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template <typename TLeft, typename TRight, bool Aggr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct TGreaterOp;

template<typename TLeft, typename TRight>
struct TGreaterOp<TLeft, TRight, bool> : public TGreater<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down Expand Up @@ -183,7 +183,7 @@ struct TDiffDateGreaterOp;

template<typename TLeft, typename TRight>
struct TDiffDateGreaterOp<TLeft, TRight, NUdf::TDataType<bool>> : public TDiffDateGreater<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct TGreaterOrEqualOp;

template<typename TLeft, typename TRight>
struct TGreaterOrEqualOp<TLeft, TRight, bool> : public TGreaterOrEqual<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down Expand Up @@ -183,7 +183,7 @@ struct TDiffDateGreaterOrEqualOp;

template<typename TLeft, typename TRight>
struct TDiffDateGreaterOrEqualOp<TLeft, TRight, NUdf::TDataType<bool>> : public TDiffDateGreaterOrEqual<TLeft, TRight, false> {
static constexpr bool DefaultNulls = true;
static constexpr auto NullMode = TKernel::ENullMode::Default;
};

template<typename TLeft, typename TRight, bool Aggr>
Expand Down
Loading

0 comments on commit 49e5063

Please sign in to comment.