Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up filters construction #7934

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 47 additions & 47 deletions ydb/core/formats/arrow/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TConstFunction : public IStepFunction<TAssign> {
using TBase = IStepFunction<TAssign>;
public:
using TBase::TBase;
arrow::Result<arrow::Datum> Call(const TAssign& assign, const TDatumBatch& batch) const override {
arrow::Result<arrow::Datum> Call(const TAssign& assign, const TDatumBatch& batch) const override {
Y_UNUSED(batch);
return assign.GetConstant();
}
Expand Down Expand Up @@ -531,7 +531,7 @@ class TFilterVisitor : public arrow::ArrayVisitor {


arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& column) {
if (Schema->GetFieldIndex(name) != -1) {
if (HasColumn(name)) {
return arrow::Status::Invalid("Trying to add duplicate column '" + name + "'");
}

Expand All @@ -543,20 +543,27 @@ arrow::Status TDatumBatch::AddColumn(const std::string& name, arrow::Datum&& col
return arrow::Status::Invalid("Wrong column length.");
}

Schema = *Schema->AddField(Schema->num_fields(), field);
NewColumnIds.emplace(name, NewColumnsPtr.size());
NewColumnsPtr.emplace_back(field);

Datums.emplace_back(column);
return arrow::Status::OK();
}

arrow::Result<arrow::Datum> TDatumBatch::GetColumnByName(const std::string& name) const {
auto i = Schema->GetFieldIndex(name);
auto it = NewColumnIds.find(name);
if (it != NewColumnIds.end()) {
AFL_VERIFY(SchemaBase->num_fields() + it->second < Datums.size());
return Datums[SchemaBase->num_fields() + it->second];
}
auto i = SchemaBase->GetFieldIndex(name);
if (i < 0) {
return arrow::Status::Invalid("Not found column '" + name + "' or duplicate");
}
return Datums[i];
}

std::shared_ptr<arrow::Table> TDatumBatch::ToTable() const {
std::shared_ptr<arrow::Table> TDatumBatch::ToTable() {
std::vector<std::shared_ptr<arrow::ChunkedArray>> columns;
columns.reserve(Datums.size());
for (auto col : Datums) {
Expand All @@ -576,10 +583,10 @@ std::shared_ptr<arrow::Table> TDatumBatch::ToTable() const {
AFL_VERIFY(false);
}
}
return arrow::Table::Make(Schema, columns, Rows);
return arrow::Table::Make(GetSchema(), columns, Rows);
}

std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() const {
std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() {
std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(Datums.size());
for (auto col : Datums) {
Expand All @@ -594,7 +601,7 @@ std::shared_ptr<arrow::RecordBatch> TDatumBatch::ToRecordBatch() const {
AFL_VERIFY(false);
}
}
return arrow::RecordBatch::Make(Schema, Rows, columns);
return arrow::RecordBatch::Make(GetSchema(), Rows, columns);
}

std::shared_ptr<TDatumBatch> TDatumBatch::FromRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch) {
Expand All @@ -603,12 +610,7 @@ std::shared_ptr<TDatumBatch> TDatumBatch::FromRecordBatch(const std::shared_ptr<
for (int64_t i = 0; i < batch->num_columns(); ++i) {
datums.push_back(arrow::Datum(batch->column(i)));
}
return std::make_shared<TProgramStep::TDatumBatch>(
TProgramStep::TDatumBatch{
.Schema = std::make_shared<arrow::Schema>(*batch->schema()),
.Datums = std::move(datums),
.Rows = batch->num_rows()
});
return std::make_shared<TDatumBatch>(std::make_shared<arrow::Schema>(*batch->schema()), std::move(datums), batch->num_rows());
}

std::shared_ptr<TDatumBatch> TDatumBatch::FromTable(const std::shared_ptr<arrow::Table>& batch) {
Expand All @@ -617,12 +619,15 @@ std::shared_ptr<TDatumBatch> TDatumBatch::FromTable(const std::shared_ptr<arrow:
for (int64_t i = 0; i < batch->num_columns(); ++i) {
datums.push_back(arrow::Datum(batch->column(i)));
}
return std::make_shared<TProgramStep::TDatumBatch>(
TProgramStep::TDatumBatch{
.Schema = std::make_shared<arrow::Schema>(*batch->schema()),
.Datums = std::move(datums),
.Rows = batch->num_rows()
});
return std::make_shared<TDatumBatch>(std::make_shared<arrow::Schema>(*batch->schema()), std::move(datums), batch->num_rows());
}

TDatumBatch::TDatumBatch(const std::shared_ptr<arrow::Schema>& schema, std::vector<arrow::Datum>&& datums, const i64 rows)
: SchemaBase(schema)
, Rows(rows)
, Datums(std::move(datums)) {
AFL_VERIFY(SchemaBase);
AFL_VERIFY(Datums.size() == (ui32)SchemaBase->num_fields());
}

TAssign TAssign::MakeTimestamp(const TColumnInfo& column, ui64 value) {
Expand Down Expand Up @@ -680,7 +685,7 @@ arrow::Status TProgramStep::ApplyAssignes(TDatumBatch& batch, arrow::compute::Ex
}
batch.Datums.reserve(batch.Datums.size() + Assignes.size());
for (auto& assign : Assignes) {
if (batch.GetColumnByName(assign.GetName()).ok()) {
if (batch.HasColumn(assign.GetName())) {
return arrow::Status::Invalid("Assign to existing column '" + assign.GetName() + "'.");
}

Expand All @@ -703,8 +708,9 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
}

ui32 numResultColumns = GroupBy.size() + GroupByKeys.size();
TDatumBatch res;
res.Datums.reserve(numResultColumns);
std::vector<arrow::Datum> datums;
datums.reserve(numResultColumns);
std::optional<ui32> resultRecordsCount;

arrow::FieldVector fields;
fields.reserve(numResultColumns);
Expand All @@ -715,13 +721,13 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
if (!funcResult.ok()) {
return funcResult.status();
}
res.Datums.push_back(*funcResult);
fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), res.Datums.back().type()));
datums.push_back(*funcResult);
fields.emplace_back(std::make_shared<arrow::Field>(assign.GetName(), datums.back().type()));
}
res.Rows = 1;
resultRecordsCount = 1;
} else {
CH::GroupByOptions funcOpts;
funcOpts.schema = batch.Schema;
funcOpts.schema = batch.GetSchema();
funcOpts.assigns.reserve(numResultColumns);
funcOpts.has_nullable_key = false;

Expand Down Expand Up @@ -759,19 +765,18 @@ arrow::Status TProgramStep::ApplyAggregates(TDatumBatch& batch, arrow::compute::
return arrow::Status::Invalid("No expected column in GROUP BY result.");
}
fields.emplace_back(std::make_shared<arrow::Field>(assign.result_column, column->type()));
res.Datums.push_back(column);
datums.push_back(column);
}

res.Rows = gbBatch->num_rows();
resultRecordsCount = gbBatch->num_rows();
}

res.Schema = std::make_shared<arrow::Schema>(std::move(fields));
batch = std::move(res);
AFL_VERIFY(resultRecordsCount);
batch = TDatumBatch(std::make_shared<arrow::Schema>(std::move(fields)), std::move(datums), *resultRecordsCount);
return arrow::Status::OK();
}

arrow::Status TProgramStep::MakeCombinedFilter(TDatumBatch& batch, NArrow::TColumnFilter& result) const {
TFilterVisitor filterVisitor(batch.Rows);
TFilterVisitor filterVisitor(batch.GetRecordsCount());
for (auto& colName : Filters) {
auto column = batch.GetColumnByName(colName.GetColumnName());
if (!column.ok()) {
Expand Down Expand Up @@ -821,13 +826,13 @@ arrow::Status TProgramStep::ApplyFilters(TDatumBatch& batch) const {
}
}
std::vector<arrow::Datum*> filterDatums;
for (int64_t i = 0; i < batch.Schema->num_fields(); ++i) {
if (batch.Datums[i].is_arraylike() && (allColumns || neededColumns.contains(batch.Schema->field(i)->name()))) {
for (int64_t i = 0; i < batch.GetSchema()->num_fields(); ++i) {
if (batch.Datums[i].is_arraylike() && (allColumns || neededColumns.contains(batch.GetSchema()->field(i)->name()))) {
filterDatums.emplace_back(&batch.Datums[i]);
}
}
bits.Apply(batch.Rows, filterDatums);
batch.Rows = bits.GetFilteredCount().value_or(batch.Rows);
bits.Apply(batch.GetRecordsCount(), filterDatums);
batch.SetRecordsCount(bits.GetFilteredCount().value_or(batch.GetRecordsCount()));
return arrow::Status::OK();
}

Expand All @@ -838,15 +843,14 @@ arrow::Status TProgramStep::ApplyProjection(TDatumBatch& batch) const {
std::vector<std::shared_ptr<arrow::Field>> newFields;
std::vector<arrow::Datum> newDatums;
for (size_t i = 0; i < Projection.size(); ++i) {
int schemaFieldIndex = batch.Schema->GetFieldIndex(Projection[i].GetColumnName());
int schemaFieldIndex = batch.GetSchema()->GetFieldIndex(Projection[i].GetColumnName());
if (schemaFieldIndex == -1) {
return arrow::Status::Invalid("Could not find column " + Projection[i].GetColumnName() + " in record batch schema.");
}
newFields.push_back(batch.Schema->field(schemaFieldIndex));
newFields.push_back(batch.GetSchema()->field(schemaFieldIndex));
newDatums.push_back(batch.Datums[schemaFieldIndex]);
}
batch.Schema = std::make_shared<arrow::Schema>(std::move(newFields));
batch.Datums = std::move(newDatums);
batch = TDatumBatch(std::make_shared<arrow::Schema>(std::move(newFields)), std::move(newDatums), batch.GetRecordsCount());
return arrow::Status::OK();
}

Expand Down Expand Up @@ -919,14 +923,10 @@ std::set<std::string> TProgramStep::GetColumnsInUsage(const bool originalOnly/*
}

arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(const std::shared_ptr<NArrow::TGeneralContainer>& t) const {
return BuildFilter(t->BuildTableVerified(GetColumnsInUsage(true)));
}

arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(const std::shared_ptr<arrow::Table>& t) const {
if (Filters.empty()) {
return nullptr;
}
std::vector<std::shared_ptr<arrow::RecordBatch>> batches = NArrow::SliceToRecordBatches(t);
std::vector<std::shared_ptr<arrow::RecordBatch>> batches = NArrow::SliceToRecordBatches(t->BuildTableVerified(GetColumnsInUsage(true)));
NArrow::TColumnFilter fullLocal = NArrow::TColumnFilter::BuildAllowFilter();
for (auto&& rb : batches) {
auto datumBatch = TDatumBatch::FromRecordBatch(rb);
Expand All @@ -938,7 +938,7 @@ arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> TProgramStep::BuildFilter(
}
NArrow::TColumnFilter local = NArrow::TColumnFilter::BuildAllowFilter();
NArrow::TStatusValidator::Validate(MakeCombinedFilter(*datumBatch, local));
AFL_VERIFY(local.Size() == datumBatch->Rows)("local", local.Size())("datum", datumBatch->Rows);
AFL_VERIFY(local.Size() == datumBatch->GetRecordsCount())("local", local.Size())("datum", datumBatch->GetRecordsCount());
fullLocal.Append(local);
}
AFL_VERIFY(fullLocal.Size() == t->num_rows())("filter", fullLocal.Size())("t", t->num_rows());
Expand Down
43 changes: 37 additions & 6 deletions ydb/core/formats/arrow/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,47 @@ const char * GetHouseFunctionName(EAggregate op);
inline const char * GetHouseGroupByName() { return "ch.group_by"; }
EOperation ValidateOperation(EOperation op, ui32 argsSize);

struct TDatumBatch {
std::shared_ptr<arrow::Schema> Schema;
std::vector<arrow::Datum> Datums;
class TDatumBatch {
private:
std::shared_ptr<arrow::Schema> SchemaBase;
THashMap<std::string, ui32> NewColumnIds;
std::vector<std::shared_ptr<arrow::Field>> NewColumnsPtr;
int64_t Rows = 0;

public:
std::vector<arrow::Datum> Datums;

ui64 GetRecordsCount() const {
return Rows;
}

void SetRecordsCount(const ui64 value) {
Rows = value;
}

TDatumBatch(const std::shared_ptr<arrow::Schema>& schema, std::vector<arrow::Datum>&& datums, const i64 rows);

const std::shared_ptr<arrow::Schema>& GetSchema() {
if (NewColumnIds.size()) {
std::vector<std::shared_ptr<arrow::Field>> fields = SchemaBase->fields();
fields.insert(fields.end(), NewColumnsPtr.begin(), NewColumnsPtr.end());
SchemaBase = std::make_shared<arrow::Schema>(fields);
NewColumnIds.clear();
NewColumnsPtr.clear();
}
return SchemaBase;
}

arrow::Status AddColumn(const std::string& name, arrow::Datum&& column);
arrow::Result<arrow::Datum> GetColumnByName(const std::string& name) const;
std::shared_ptr<arrow::Table> ToTable() const;
std::shared_ptr<arrow::RecordBatch> ToRecordBatch() const;
bool HasColumn(const std::string& name) const {
if (NewColumnIds.contains(name)) {
return true;
}
return SchemaBase->GetFieldIndex(name) > -1;
}
std::shared_ptr<arrow::Table> ToTable();
std::shared_ptr<arrow::RecordBatch> ToRecordBatch();
static std::shared_ptr<TDatumBatch> FromRecordBatch(const std::shared_ptr<arrow::RecordBatch>& batch);
static std::shared_ptr<TDatumBatch> FromTable(const std::shared_ptr<arrow::Table>& batch);
};
Expand Down Expand Up @@ -405,7 +437,6 @@ class TProgramStep {
return Filters.size() && (!GroupBy.size() && !GroupByKeys.size());
}

[[nodiscard]] arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> BuildFilter(const std::shared_ptr<arrow::Table>& t) const;
[[nodiscard]] arrow::Result<std::shared_ptr<NArrow::TColumnFilter>> BuildFilter(const std::shared_ptr<NArrow::TGeneralContainer>& t) const;
};

Expand Down
Loading