Skip to content

Commit

Permalink
Replace uint32 with uint64 for cluster id in vector index (#14576)
Browse files Browse the repository at this point in the history
  • Loading branch information
MBkkt authored Feb 19, 2025
1 parent 78556e0 commit 02fadfa
Show file tree
Hide file tree
Showing 18 changed files with 123 additions and 108 deletions.
7 changes: 7 additions & 0 deletions ydb/core/base/table_index.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <ydb/public/api/protos/ydb_value.pb.h>
#include <ydb/public/lib/scheme_types/scheme_type_id.h>
#include <ydb/core/protos/flat_scheme_op.pb.h>

#include <util/generic/hash_set.h>
Expand Down Expand Up @@ -35,5 +37,10 @@ std::span<const std::string_view> GetImplTables(NKikimrSchemeOp::EIndexType inde
bool IsImplTable(std::string_view tableName);
bool IsBuildImplTable(std::string_view tableName);

using TClusterId = ui64;

inline constexpr auto ClusterIdType = Ydb::Type::UINT64;
inline constexpr const char* ClusterIdTypeName = "Uint64";

}
}
4 changes: 2 additions & 2 deletions ydb/core/kqp/opt/logical/kqp_opt_log_indexes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,10 @@ TExprBase DoRewriteTopSortOverKMeansTree(

// TODO(mbkkt) How to inline construction of these constants to construction of readLevel0?
auto fromValues = ctx.Builder(pos)
.Callable("Uint32").Atom(0, "0", TNodeFlags::Default).Seal()
.Callable(NTableIndex::ClusterIdTypeName).Atom(0, "0", TNodeFlags::Default).Seal()
.Build();
auto toValues = ctx.Builder(pos)
.Callable("Uint32").Atom(0, "1", TNodeFlags::Default).Seal()
.Callable(NTableIndex::ClusterIdTypeName).Atom(0, "1", TNodeFlags::Default).Seal()
.Build();

auto levelLambda = [&] {
Expand Down
10 changes: 5 additions & 5 deletions ydb/core/protos/tx_datashard.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1548,10 +1548,10 @@ message TEvLocalKMeansRequest {
optional uint32 NeedsRounds = 14;

// id of parent cluster
optional uint32 ParentFrom = 15;
optional uint32 ParentTo = 21;
optional uint64 ParentFrom = 15;
optional uint64 ParentTo = 21;
// [Child ... Child + K * (ParentFrom - ParentTo + 1)) ids reserved for this kmeans clusters
optional uint32 Child = 16;
optional uint64 Child = 16;

optional string LevelName = 17;
optional string PostingName = 18;
Expand Down Expand Up @@ -1599,9 +1599,9 @@ message TEvReshuffleKMeansRequest {
optional TEvLocalKMeansRequest.EState Upload = 9;

// id of parent cluster
optional uint32 Parent = 10;
optional uint64 Parent = 10;
// [Child ... Child + ClustersSize) ids of this kmeans clusters
optional uint32 Child = 11;
optional uint64 Child = 11;
// centroids of clusters
repeated string Clusters = 12;

Expand Down
11 changes: 6 additions & 5 deletions ydb/core/tx/datashard/datashard_ut_local_kmeans.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ydb/core/base/table_index.h>
#include <ydb/core/testlib/test_client.h>
#include <ydb/core/tx/datashard/ut_common/datashard_ut_common.h>
#include <ydb/core/tx/schemeshard/schemeshard.h>
Expand Down Expand Up @@ -91,7 +92,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
}

static std::tuple<TString, TString> DoLocalKMeans(
Tests::TServer::TPtr server, TActorId sender, ui32 parent, ui64 seed, ui64 k,
Tests::TServer::TPtr server, TActorId sender, NTableIndex::TClusterId parent, ui64 seed, ui64 k,
NKikimrTxDataShard::TEvLocalKMeansRequest::EState upload, VectorIndexSettings::VectorType type,
VectorIndexSettings::Metric metric)
{
Expand Down Expand Up @@ -185,8 +186,8 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
{
options.AllowSystemColumnNames(true);
options.Columns({
{ParentColumn, "Uint32", true, true},
{IdColumn, "Uint32", true, true},
{ParentColumn, NTableIndex::ClusterIdTypeName, true, true},
{IdColumn, NTableIndex::ClusterIdTypeName, true, true},
{CentroidColumn, "String", false, true},
});
CreateShardedTable(server, sender, "/Root", "table-level", options);
Expand All @@ -196,7 +197,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
{
options.AllowSystemColumnNames(true);
options.Columns({
{ParentColumn, "Uint32", true, true},
{ParentColumn, NTableIndex::ClusterIdTypeName, true, true},
{"key", "Uint32", true, true},
{"data", "String", false, false},
});
Expand All @@ -208,7 +209,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardLocalKMeansScan) {
{
options.AllowSystemColumnNames(true);
options.Columns({
{ParentColumn, "Uint32", true, true},
{ParentColumn, NTableIndex::ClusterIdTypeName, true, true},
{"key", "Uint32", true, true},
{"embedding", "String", false, false},
{"data", "String", false, false},
Expand Down
7 changes: 4 additions & 3 deletions ydb/core/tx/datashard/datashard_ut_reshuffle_kmeans.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ydb/core/base/table_index.h>
#include <ydb/core/testlib/test_client.h>
#include <ydb/core/tx/datashard/ut_common/datashard_ut_common.h>
#include <ydb/core/tx/schemeshard/schemeshard.h>
Expand Down Expand Up @@ -84,7 +85,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardReshuffleKMeansScan) {
}
}

static TString DoReshuffleKMeans(Tests::TServer::TPtr server, TActorId sender, ui32 parent,
static TString DoReshuffleKMeans(Tests::TServer::TPtr server, TActorId sender, NTableIndex::TClusterId parent,
const std::vector<TString>& level,
NKikimrTxDataShard::TEvLocalKMeansRequest::EState upload,
VectorIndexSettings::VectorType type, VectorIndexSettings::Metric metric)
Expand Down Expand Up @@ -171,7 +172,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardReshuffleKMeansScan) {
{
options.AllowSystemColumnNames(true);
options.Columns({
{ParentColumn, "Uint32", true, true},
{ParentColumn, NTableIndex::ClusterIdTypeName, true, true},
{"key", "Uint32", true, true},
{"data", "String", false, false},
});
Expand All @@ -183,7 +184,7 @@ Y_UNIT_TEST_SUITE (TTxDataShardReshuffleKMeansScan) {
{
options.AllowSystemColumnNames(true);
options.Columns({
{ParentColumn, "Uint32", true, true},
{ParentColumn, NTableIndex::ClusterIdTypeName, true, true},
{"key", "Uint32", true, true},
{"embedding", "String", false, false},
{"data", "String", false, false},
Expand Down
12 changes: 6 additions & 6 deletions ydb/core/tx/datashard/kmeans_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace NKikimr::NDataShard::NKMeans {

TTableRange CreateRangeFrom(const TUserTable& table, ui32 parent, TCell& from, TCell& to) {
TTableRange CreateRangeFrom(const TUserTable& table, NTableIndex::TClusterId parent, TCell& from, TCell& to) {
if (parent == 0) {
return table.GetTableRange();
}
Expand All @@ -28,15 +28,15 @@ NTable::TLead CreateLeadFrom(const TTableRange& range) {
return lead;
}

void AddRowMain2Build(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row) {
void AddRowMain2Build(TBufferData& buffer, NTableIndex::TClusterId parent, TArrayRef<const TCell> key, const NTable::TRowState& row) {
std::array<TCell, 1> cells;
cells[0] = TCell::Make(parent);
auto pk = TSerializedCellVec::Serialize(cells);
TSerializedCellVec::UnsafeAppendCells(key, pk);
buffer.AddRow(TSerializedCellVec{key}, TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(*row));
}

void AddRowMain2Posting(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row,
void AddRowMain2Posting(TBufferData& buffer, NTableIndex::TClusterId parent, TArrayRef<const TCell> key, const NTable::TRowState& row,
ui32 dataPos)
{
std::array<TCell, 1> cells;
Expand All @@ -47,15 +47,15 @@ void AddRowMain2Posting(TBufferData& buffer, ui32 parent, TArrayRef<const TCell>
TSerializedCellVec::Serialize((*row).Slice(dataPos)));
}

void AddRowBuild2Build(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row) {
void AddRowBuild2Build(TBufferData& buffer, NTableIndex::TClusterId parent, TArrayRef<const TCell> key, const NTable::TRowState& row) {
std::array<TCell, 1> cells;
cells[0] = TCell::Make(parent);
auto pk = TSerializedCellVec::Serialize(cells);
TSerializedCellVec::UnsafeAppendCells(key.Slice(1), pk);
buffer.AddRow(TSerializedCellVec{key}, TSerializedCellVec{std::move(pk)}, TSerializedCellVec::Serialize(*row));
}

void AddRowBuild2Posting(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row,
void AddRowBuild2Posting(TBufferData& buffer, NTableIndex::TClusterId parent, TArrayRef<const TCell> key, const NTable::TRowState& row,
ui32 dataPos)
{
std::array<TCell, 1> cells;
Expand Down Expand Up @@ -96,7 +96,7 @@ MakeUploadTypes(const TUserTable& table, NKikimrTxDataShard::TEvLocalKMeansReque
uploadTypes->reserve(1 + 1 + std::min(table.KeyColumnTypes.size() + data.size(), types.size()));

Ydb::Type type;
type.set_type_id(Ydb::Type::UINT32);
type.set_type_id(NTableIndex::ClusterIdType);
uploadTypes->emplace_back(NTableIndex::NTableVectorKmeansTreeIndex::ParentColumn, type);

auto addType = [&](const auto& column) {
Expand Down
11 changes: 6 additions & 5 deletions ydb/core/tx/datashard/kmeans_helper.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <ydb/core/base/table_index.h>
#include <ydb/core/tx/datashard/buffer_data.h>
#include <ydb/core/tx/datashard/datashard_user_table.h>
#include <ydb/core/tx/datashard/range_ops.h>
Expand Down Expand Up @@ -48,7 +49,7 @@ Y_PURE_FUNCTION TTriWayDotProduct<TRes> CosineImpl(const ui8* lhs, const ui8* rh
return {static_cast<TRes>(ll), static_cast<TRes>(lr), static_cast<TRes>(rr)};
}

TTableRange CreateRangeFrom(const TUserTable& table, ui32 parent, TCell& from, TCell& to);
TTableRange CreateRangeFrom(const TUserTable& table, NTableIndex::TClusterId parent, TCell& from, TCell& to);

NTable::TLead CreateLeadFrom(const TTableRange& range);

Expand Down Expand Up @@ -200,14 +201,14 @@ ui32 FeedEmbedding(const TCalculation<TMetric>& calculation, std::span<const TSt
return calculation.FindClosest(clusters, embedding);
}

void AddRowMain2Build(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row);
void AddRowMain2Build(TBufferData& buffer, NTableIndex::TClusterId parent, TArrayRef<const TCell> key, const NTable::TRowState& row);

void AddRowMain2Posting(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row,
void AddRowMain2Posting(TBufferData& buffer, NTableIndex::TClusterId parent, TArrayRef<const TCell> key, const NTable::TRowState& row,
ui32 dataPos);

void AddRowBuild2Build(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row);
void AddRowBuild2Build(TBufferData& buffer, NTableIndex::TClusterId parent, TArrayRef<const TCell> key, const NTable::TRowState& row);

void AddRowBuild2Posting(TBufferData& buffer, ui32 parent, TArrayRef<const TCell> key, const NTable::TRowState& row,
void AddRowBuild2Posting(TBufferData& buffer, NTableIndex::TClusterId parent, TArrayRef<const TCell> key, const NTable::TRowState& row,
ui32 dataPos);

TTags MakeUploadTags(const TUserTable& table, const TProtoStringType& embedding,
Expand Down
10 changes: 5 additions & 5 deletions ydb/core/tx/datashard/local_kmeans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class TLocalKMeansScanBase: public TActor<TLocalKMeansScanBase>, public NTable::
protected:
using EState = NKikimrTxDataShard::TEvLocalKMeansRequest;

ui32 Parent = 0;
ui32 Child = 0;
NTableIndex::TClusterId Parent = 0;
NTableIndex::TClusterId Child = 0;

ui32 Round = 0;
ui32 MaxRounds = 0;
Expand Down Expand Up @@ -156,7 +156,7 @@ class TLocalKMeansScanBase: public TActor<TLocalKMeansScanBase>, public NTable::
return NKikimrServices::TActivity::LOCAL_KMEANS_SCAN_ACTOR;
}

TLocalKMeansScanBase(ui64 buildId, const TUserTable& table, TLead&& lead, ui32 parent, ui32 child,
TLocalKMeansScanBase(ui64 buildId, const TUserTable& table, TLead&& lead, NTableIndex::TClusterId parent, NTableIndex::TClusterId child,
const NKikimrTxDataShard::TEvLocalKMeansRequest& request,
std::shared_ptr<TResult> result)
: TActor{&TThis::StateWork}
Expand All @@ -180,7 +180,7 @@ class TLocalKMeansScanBase: public TActor<TLocalKMeansScanBase>, public NTable::
// upload types
if (Ydb::Type type; State <= EState::KMEANS) {
TargetTypes = std::make_shared<NTxProxy::TUploadTypes>(3);
type.set_type_id(Ydb::Type::UINT32);
type.set_type_id(NTableIndex::ClusterIdType);
(*TargetTypes)[0] = {NTableIndex::NTableVectorKmeansTreeIndex::ParentColumn, type};
(*TargetTypes)[1] = {NTableIndex::NTableVectorKmeansTreeIndex::IdColumn, type};
type.set_type_id(Ydb::Type::STRING);
Expand Down Expand Up @@ -382,7 +382,7 @@ class TLocalKMeansScan final: public TLocalKMeansScanBase, private TCalculation<
std::vector<TAggregatedCluster> AggregatedClusters;

public:
TLocalKMeansScan(ui64 buildId, const TUserTable& table, TLead&& lead, ui32 parent, ui32 child, NKikimrTxDataShard::TEvLocalKMeansRequest& request,
TLocalKMeansScan(ui64 buildId, const TUserTable& table, TLead&& lead, NTableIndex::TClusterId parent, NTableIndex::TClusterId child, NKikimrTxDataShard::TEvLocalKMeansRequest& request,
std::shared_ptr<TResult> result)
: TLocalKMeansScanBase{buildId, table, std::move(lead), parent, child, request, std::move(result)}
{
Expand Down
4 changes: 2 additions & 2 deletions ydb/core/tx/datashard/reshuffle_kmeans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class TReshuffleKMeansScanBase: public TActor<TReshuffleKMeansScanBase>, public
protected:
using EState = NKikimrTxDataShard::TEvLocalKMeansRequest;

ui32 Parent = 0;
ui32 Child = 0;
NTableIndex::TClusterId Parent = 0;
NTableIndex::TClusterId Child = 0;

ui32 K = 0;

Expand Down
10 changes: 5 additions & 5 deletions ydb/core/tx/schemeshard/schemeshard__init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4523,20 +4523,20 @@ struct TSchemeShard::TTxInit : public TTransactionBase<TSchemeShard> {

// read kmeans tree state
{
auto rowset = db.Table<Schema::KMeansTreeState>().Range().Select();
auto rowset = db.Table<Schema::KMeansTreeProgress>().Range().Select();
if (!rowset.IsReady()) {
return false;
}

while (!rowset.EndOfSet()) {
TIndexBuildId id = rowset.GetValue<Schema::KMeansTreeState::Id>();
TIndexBuildId id = rowset.GetValue<Schema::KMeansTreeProgress::Id>();
const auto* buildInfoPtr = Self->IndexBuilds.FindPtr(id);
Y_VERIFY_S(buildInfoPtr, "BuildIndex not found: id# " << id);
auto& buildInfo = *buildInfoPtr->Get();
buildInfo.KMeans.Set(
rowset.GetValue<Schema::KMeansTreeState::Level>(),
rowset.GetValue<Schema::KMeansTreeState::Parent>(),
rowset.GetValue<Schema::KMeansTreeState::State>()
rowset.GetValue<Schema::KMeansTreeProgress::Level>(),
rowset.GetValue<Schema::KMeansTreeProgress::Parent>(),
rowset.GetValue<Schema::KMeansTreeProgress::State>()
);
buildInfo.Sample.Rows.reserve(buildInfo.KMeans.K * 2);

Expand Down
2 changes: 1 addition & 1 deletion ydb/core/tx/schemeshard/schemeshard_build_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ void TSchemeShard::PersistBuildIndexForget(NIceDb::TNiceDb& db, const TIndexBuil
}

if (info.IsBuildVectorIndex()) {
db.Table<Schema::KMeansTreeState>().Key(info.Id).Delete();
db.Table<Schema::KMeansTreeProgress>().Key(info.Id).Delete();
PersistBuildIndexSampleForget(db, info);
}
}
Expand Down
26 changes: 13 additions & 13 deletions ydb/core/tx/schemeshard/schemeshard_build_index__progress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ static constexpr const char* Name(TIndexBuildInfo::EState state) noexcept {
}

// return count, parts, step
static std::tuple<ui32, ui32, ui32> ComputeKMeansBoundaries(const NSchemeShard::TTableInfo& tableInfo, const TIndexBuildInfo& buildInfo) {
static std::tuple<NTableIndex::TClusterId, NTableIndex::TClusterId, NTableIndex::TClusterId> ComputeKMeansBoundaries(const NSchemeShard::TTableInfo& tableInfo, const TIndexBuildInfo& buildInfo) {
const auto& kmeans = buildInfo.KMeans;
Y_ASSERT(kmeans.K != 0);
const auto count = TIndexBuildInfo::TKMeans::BinPow(kmeans.K, kmeans.Level);
ui32 step = 1;
NTableIndex::TClusterId step = 1;
auto parts = count;
auto shards = tableInfo.GetShard2PartitionIdx().size();
if (!buildInfo.KMeans.NeedsAnotherLevel() || count <= 1 || shards <= 1) {
Expand Down Expand Up @@ -97,8 +97,8 @@ class TUploadSampleK: public TActorBootstrapped<TUploadSampleK> {
TActorId Uploader;
ui32 RetryCount = 0;
ui32 RowsBytes = 0;
ui32 Parent = 0;
ui32 Child = 0;
NTableIndex::TClusterId Parent = 0;
NTableIndex::TClusterId Child = 0;

NDataShard::TUploadStatus UploadStatus;

Expand All @@ -108,8 +108,8 @@ class TUploadSampleK: public TActorBootstrapped<TUploadSampleK> {
const TActorId& responseActorId,
ui64 buildIndexId,
TIndexBuildInfo::TSample::TRows init,
ui32 parent,
ui32 child)
NTableIndex::TClusterId parent,
NTableIndex::TClusterId child)
: TargetTable(std::move(targetTable))
, ResponseActorId(responseActorId)
, BuildIndexId(buildIndexId)
Expand Down Expand Up @@ -159,7 +159,7 @@ class TUploadSampleK: public TActorBootstrapped<TUploadSampleK> {

Types = std::make_shared<NTxProxy::TUploadTypes>(3);
Ydb::Type type;
type.set_type_id(Ydb::Type::UINT32);
type.set_type_id(NTableIndex::ClusterIdType);
(*Types)[0] = {NTableIndex::NTableVectorKmeansTreeIndex::ParentColumn, type};
(*Types)[1] = {NTableIndex::NTableVectorKmeansTreeIndex::IdColumn, type};
type.set_type_id(Ydb::Type::STRING);
Expand Down Expand Up @@ -766,7 +766,7 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
InitMultiKMeans(buildInfo);
return false;
}
std::array<NScheme::TTypeInfo, 1> typeInfos{NScheme::NTypeIds::Uint32};
std::array<NScheme::TTypeInfo, 1> typeInfos{ClusterIdTypeId};
auto range = ParentRange(buildInfo.KMeans.Parent);
auto addRestricted = [&] (const auto& idx) {
const auto& status = buildInfo.Shards.at(idx);
Expand Down Expand Up @@ -858,10 +858,10 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil

void PersistKMeansState(TTransactionContext& txc, TIndexBuildInfo& buildInfo) {
NIceDb::TNiceDb db{txc.DB};
db.Table<Schema::KMeansTreeState>().Key(buildInfo.Id).Update(
NIceDb::TUpdate<Schema::KMeansTreeState::Level>(buildInfo.KMeans.Level),
NIceDb::TUpdate<Schema::KMeansTreeState::Parent>(buildInfo.KMeans.Parent),
NIceDb::TUpdate<Schema::KMeansTreeState::State>(buildInfo.KMeans.State)
db.Table<Schema::KMeansTreeProgress>().Key(buildInfo.Id).Update(
NIceDb::TUpdate<Schema::KMeansTreeProgress::Level>(buildInfo.KMeans.Level),
NIceDb::TUpdate<Schema::KMeansTreeProgress::State>(buildInfo.KMeans.State),
NIceDb::TUpdate<Schema::KMeansTreeProgress::Parent>(buildInfo.KMeans.Parent)
);
}

Expand Down Expand Up @@ -1184,7 +1184,7 @@ struct TSchemeShard::TIndexBuilder::TTxProgress: public TSchemeShard::TIndexBuil
return TSerializedTableRange(TSerializedCellVec::Serialize(cells), "", true, false);
}

static TSerializedTableRange ParentRange(ui32 parent) {
static TSerializedTableRange ParentRange(NTableIndex::TClusterId parent) {
if (parent == 0) {
return {}; // empty
}
Expand Down
Loading

0 comments on commit 02fadfa

Please sign in to comment.