Skip to content

Commit

Permalink
Merge pull request ydb-platform#9 from MBkkt/BitIndex
Browse files Browse the repository at this point in the history
Better
  • Loading branch information
azevaykin authored May 31, 2024
2 parents 59805d3 + b412c9e commit c80646b
Show file tree
Hide file tree
Showing 22 changed files with 416 additions and 179 deletions.
20 changes: 13 additions & 7 deletions ydb/library/yql/udfs/common/knn/knn-defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,34 @@

enum EFormat: ui8 {
FloatVector = 1, // 4-byte per element
Uint8Vector = 2, // 1-byte per element
BitVector = 10, // 1-bit per element
Int8Vector = 2, // 1-byte per element
Uint8Vector = 3, // 1-byte per element, better than Int8 for positive-only Float
BitVector = 4, // 1-bit per element
};

template<typename T>
template <typename T>
struct TTypeToFormat;

template<>
template <>
struct TTypeToFormat<float> {
static constexpr auto Format = EFormat::FloatVector;
};

template<>
template <>
struct TTypeToFormat<i8> {
static constexpr auto Format = EFormat::Int8Vector;
};

template <>
struct TTypeToFormat<ui8> {
static constexpr auto Format = EFormat::Uint8Vector;
};

template<>
template <>
struct TTypeToFormat<bool> {
static constexpr auto Format = EFormat::BitVector;
};

template<typename T>
template <typename T>
inline constexpr auto Format = TTypeToFormat<T>::Format;
inline constexpr auto HeaderLen = sizeof(ui8);
24 changes: 22 additions & 2 deletions ydb/library/yql/udfs/common/knn/knn-distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ inline TDistanceResult VectorFuncImpl(const auto* v1, const auto* v2, auto len1,

template <typename T, typename Func>
inline auto VectorFunc(const TStringRef& str1, const TStringRef& str2, Func&& func) {
const TArrayRef<const T> v1 = TKnnSerializerFacade::GetArray<T>(str1);
const TArrayRef<const T> v2 = TKnnSerializerFacade::GetArray<T>(str2);
const TArrayRef<const T> v1 = TKnnVectorSerializer<T>::GetArray(str1);
const TArrayRef<const T> v2 = TKnnVectorSerializer<T>::GetArray(str2);
return VectorFuncImpl(v1.data(), v2.data(), v1.size(), v2.size(), std::forward<Func>(func));
}

Expand All @@ -97,6 +97,10 @@ inline TDistanceResult KnnManhattanDistance(const TStringRef& str1, const TStrin
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
return ::L1Distance(v1, v2, len);
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
return ::L1Distance(v1, v2, len);
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
return ::L1Distance(v1, v2, len);
Expand Down Expand Up @@ -125,6 +129,10 @@ inline TDistanceResult KnnEuclideanDistance(const TStringRef& str1, const TStrin
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
return ::L2Distance(v1, v2, len);
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
return ::L2Distance(v1, v2, len);
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
return ::L2Distance(v1, v2, len);
Expand Down Expand Up @@ -153,6 +161,10 @@ inline TDistanceResult KnnDotProduct(const TStringRef& str1, const TStringRef& s
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
return ::DotProduct(v1, v2, len);
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
return ::DotProduct(v1, v2, len);
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
return ::DotProduct(v1, v2, len);
Expand Down Expand Up @@ -188,6 +200,14 @@ inline TDistanceResult KnnCosineSimilarity(const TStringRef& str1, const TString
const auto res = ::TriWayDotProduct(v1, v2, len);
return compute(res.LL, res.LR, res.RR);
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [&](const i8* v1, const i8* v2, size_t len) {
// TODO We can optimize it if we will iterate over both vector at the same time, look to the float implementation
const i64 ll = ::DotProduct(v1, v1, len);
const i64 lr = ::DotProduct(v1, v2, len);
const i64 rr = ::DotProduct(v2, v2, len);
return compute(ll, lr, rr);
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [&](const ui8* v1, const ui8* v2, size_t len) {
// TODO We can optimize it if we will iterate over both vector at the same time, look to the float implementation
Expand Down
6 changes: 3 additions & 3 deletions ydb/library/yql/udfs/common/knn/knn-enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
using namespace NYql;
using namespace NYql::NUdf;

template <typename TCallback>
template <typename T, typename TCallback>
void EnumerateVector(const TUnboxedValuePod vector, TCallback&& callback) {
const auto* elements = vector.GetElements();
if (elements) {
for (auto& value : TArrayRef{elements, vector.GetListLength()}) {
callback(value.Get<float>());
callback(value.Get<T>());
}
} else {
TUnboxedValue value;
const auto it = vector.GetListIterator();
while (it.Next(value)) {
callback(value.Get<float>());
callback(value.Get<T>());
}
}
}
45 changes: 15 additions & 30 deletions ydb/library/yql/udfs/common/knn/knn-serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@
using namespace NYql;
using namespace NYql::NUdf;

template <typename T>
template <typename TTo, typename TFrom = TTo>
class TKnnVectorSerializer {
public:
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
auto serialize = [&](IOutputStream& outStream) {
EnumerateVector(x, [&](float from) {
T to = static_cast<T>(from);
outStream.Write(&to, sizeof(T));
EnumerateVector<TFrom>(x, [&](TFrom from) {
TTo to = static_cast<TTo>(from);
outStream.Write(&to, sizeof(TTo));
});
const auto format = Format<T>;
const auto format = Format<TTo>;
outStream.Write(&format, HeaderLen);
};

if (x.HasFastListLength()) {
auto str = valueBuilder->NewStringNotFilled(x.GetListLength() * sizeof(T) + HeaderLen);
auto str = valueBuilder->NewStringNotFilled(x.GetListLength() * sizeof(TTo) + HeaderLen);
auto strRef = str.AsStringRef();
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());

Expand All @@ -50,22 +50,22 @@ class TKnnVectorSerializer {
auto res = valueBuilder->NewArray(vector.size(), items);

for (auto element : vector) {
*items++ = TUnboxedValuePod{static_cast<float>(element)};
*items++ = TUnboxedValuePod{static_cast<TFrom>(element)};
}

return res.Release();
}

static TArrayRef<const T> GetArray(const TStringRef& str) {
static TArrayRef<const TTo> GetArray(const TStringRef& str) {
const char* buf = str.Data();
const size_t len = str.Size() - HeaderLen;

if (Y_UNLIKELY(len % sizeof(T) != 0))
if (Y_UNLIKELY(len % sizeof(TTo) != 0))
return {};

const ui32 count = len / sizeof(T);
const ui32 count = len / sizeof(TTo);

return {reinterpret_cast<const T*>(buf), count};
return {reinterpret_cast<const TTo*>(buf), count};
}
};

Expand All @@ -81,7 +81,7 @@ class TKnnBitVectorSerializer {
ui64 accumulator = 0;
ui8 filledBits = 0;

EnumerateVector(x, [&](float element) {
EnumerateVector<float>(x, [&](float element) {
if (element > 0)
accumulator |= 1;

Expand Down Expand Up @@ -161,25 +161,10 @@ class TKnnSerializerFacade {
switch (format) {
case EFormat::FloatVector:
return TKnnVectorSerializer<float>::Deserialize(valueBuilder, str);
case EFormat::Int8Vector:
return TKnnVectorSerializer<i8, float>::Deserialize(valueBuilder, str);
case EFormat::Uint8Vector:
return TKnnVectorSerializer<ui8>::Deserialize(valueBuilder, str);
case EFormat::BitVector:
default:
return {};
}
}

template <typename T>
static const TArrayRef<const T> GetArray(const TStringRef& str) {
if (Y_UNLIKELY(str.Size() == 0))
return {};

const ui8 format = str.Data()[str.Size() - HeaderLen];
switch (format) {
case EFormat::FloatVector:
return TKnnVectorSerializer<T>::GetArray(str);
case EFormat::Uint8Vector:
return TKnnVectorSerializer<T>::GetArray(str);
return TKnnVectorSerializer<ui8, float>::Deserialize(valueBuilder, str);
case EFormat::BitVector:
default:
return {};
Expand Down
31 changes: 21 additions & 10 deletions ydb/library/yql/udfs/common/knn/knn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@ static constexpr const char TagStoredVector[] = "StoredVector";

static constexpr const char TagFloatVector[] = "FloatVector";
using TFloatVector = TTagged<const char*, TagFloatVector>;
static constexpr const char TagByteVector[] = "ByteVector";
using TByteVector = TTagged<const char*, TagByteVector>;
static constexpr const char TagInt8Vector[] = "Int8Vector";
using TInt8Vector = TTagged<const char*, TagInt8Vector>;
static constexpr const char TagUint8Vector[] = "Uint8Vector";
using TUint8Vector = TTagged<const char*, TagUint8Vector>;
static constexpr const char TagBitVector[] = "BitVector";
using TBitVector = TTagged<const char*, TagBitVector>;

SIMPLE_STRICT_UDF(TToBinaryStringFloat, TFloatVector(TAutoMap<TListType<float>>)) {
return TKnnVectorSerializer<float>::Serialize(valueBuilder, args[0]);
}

SIMPLE_STRICT_UDF(TToBinaryStringByte, TByteVector(TAutoMap<TListType<float>>)) {
SIMPLE_STRICT_UDF(TToBinaryStringInt8, TInt8Vector(TAutoMap<TListType<i8>>)) {
return TKnnVectorSerializer<i8>::Serialize(valueBuilder, args[0]);
}

SIMPLE_STRICT_UDF(TToBinaryStringUint8, TUint8Vector(TAutoMap<TListType<ui8>>)) {
return TKnnVectorSerializer<ui8>::Serialize(valueBuilder, args[0]);
}

Expand Down Expand Up @@ -121,14 +127,18 @@ class TFloatFromBinaryString: public TMultiSignatureBase<TFloatFromBinaryString>

auto argType = argsTuple.GetElementType(0);
auto argTag = GetArg(*typeInfoHelper, argType, builder);
if (!ValidTag(argTag, {TagStoredVector, TagFloatVector, TagByteVector})) {
builder.SetError("Expected argument is string from ToBinaryString[Float|Byte]");
if (!ValidTag(argTag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector})) {
builder.SetError("Expected argument is string from ToBinaryString[Float|Int8|Uint8]");
return true;
}

builder.UserType(userType);
builder.Args(1)->Add(argType).Flags(ICallablePayload::TArgumentFlags::AutoMap);
builder.Returns<TOptional<TListType<float>>>().IsStrict();
if (ValidTag(argTag, {TagFloatVector, TagInt8Vector, TagUint8Vector}) && argType == argsTuple.GetElementType(0)) {
builder.Returns<TListType<float>>().IsStrict();
} else {
builder.Returns<TOptional<TListType<float>>>().IsStrict();
}

if (!typesOnly) {
builder.Implementation(new TFloatFromBinaryString(builder));
Expand Down Expand Up @@ -166,9 +176,9 @@ class TDistanceBase: public TMultiSignatureBase<Derived> {
auto arg1Type = argsTuple.GetElementType(1);
auto arg1Tag = Base::GetArg(*typeInfoHelper, arg1Type, builder);

if (!Base::ValidTag(arg0Tag, {TagStoredVector, TagFloatVector, TagByteVector, TagBitVector}) ||
!Base::ValidTag(arg1Tag, {TagStoredVector, TagFloatVector, TagByteVector, TagBitVector})) {
builder.SetError("Expected arguments are strings from ToBinaryString[Float|Byte|Bit]");
if (!Base::ValidTag(arg0Tag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector}) ||
!Base::ValidTag(arg1Tag, {TagStoredVector, TagFloatVector, TagInt8Vector, TagUint8Vector, TagBitVector})) {
builder.SetError("Expected arguments are strings from ToBinaryString[Float|Int8|Uint8|Bit]");
return true;
}

Expand Down Expand Up @@ -282,7 +292,8 @@ class TEuclideanDistance: public TDistanceBase<TEuclideanDistance> {

SIMPLE_MODULE(TKnnModule,
TToBinaryStringFloat,
TToBinaryStringByte,
TToBinaryStringInt8,
TToBinaryStringUint8,
TToBinaryStringBit,
TFloatFromBinaryString,
TInnerProductSimilarity,
Expand Down
5 changes: 5 additions & 0 deletions ydb/library/yql/udfs/common/knn/test/canondata/result.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@
"uri": "file://test.test_InnerProductSimilarity_/results.txt"
}
],
"test.test[Int8Serialization]": [
{
"uri": "file://test.test_Int8Serialization_/results.txt"
}
],
"test.test[LazyListSerialization]": [
{
"uri": "file://test.test_LazyListSerialization_/results.txt"
Expand Down
Loading

0 comments on commit c80646b

Please sign in to comment.