Skip to content

Commit

Permalink
apacheGH-32276: [C++][FlightRPC] Align buffers from Flight
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed May 18, 2023
1 parent 9039ee2 commit cda6048
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 0 deletions.
75 changes: 75 additions & 0 deletions cpp/src/arrow/flight/test_definitions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "arrow/flight/test_util.h"
#include "arrow/table.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/align_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/config.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -241,6 +243,79 @@ void DataTest::TestDoGetLargeBatch() {
Ticket ticket{"ticket-large-batch-1"};
CheckDoGet(ticket, expected_batches);
}
int64_t BufferAlignment(Type::type ty, size_t buffer_index) {
if (buffer_index == 0) return 1;
switch (ty) {
case Type::INT16:
case Type::UINT16:
case Type::HALF_FLOAT:
return 2;
case Type::INT32:
case Type::UINT32:
case Type::FLOAT:
case Type::DATE32:
case Type::TIME32:
case Type::LIST: // Offsets may be cast to int32_t*, data is in child array
case Type::MAP: // This is a list array
case Type::DENSE_UNION: // Has an offsets buffer of int32_t*
case Type::INTERVAL_MONTHS: // Stored as int32_t*
return 4;
case Type::INT64:
case Type::UINT64:
case Type::DOUBLE:
case Type::LARGE_LIST: // Offsets may be cast to int64_t*
case Type::DATE64:
case Type::TIME64:
case Type::TIMESTAMP:
case Type::DURATION:
case Type::INTERVAL_DAY_TIME: // Stored as two contiguous 32-bit integers but may be
// cast to struct* containing both integers
return 8;
case Type::INTERVAL_MONTH_DAY_NANO: // Stored as two 32-bit integers and a 64-bit
// integer
return 16;
case Type::STRING:
case Type::BINARY: // Offsets may be cast to int32_t*, data is only uint8_t*
return (buffer_index == 1) ? 4 : 1;
case Type::LARGE_STRING:
case Type::LARGE_BINARY: // Offsets may be cast to int64_t*
return (buffer_index == 1) ? 8 : 1;
default:
// Everything else doesn't have buffers with non-trivial alignement requirements
return 1;
}
}
void AssertAligned(const Array& array) {
ARROW_SCOPED_TRACE(array);

size_t buffer_index = 0;
for (const auto& buf : array.data()->buffers) {
if (buf != nullptr) {
ASSERT_TRUE(
util::CheckAlignment(*buf, BufferAlignment(array.type()->id(), buffer_index)));
}
buffer_index++;
}
}
void DataTest::TestDoGetAlignment() {
// Regression test for GH-32276
ASSERT_OK_AND_ASSIGN(RecordBatchVector expected_batches, ExampleAlignmentBatches());
Ticket ticket{"ticket-alignment"};
CheckDoGet(ticket, expected_batches);

ASSERT_OK_AND_ASSIGN(auto stream, client_->DoGet(ticket));
for (size_t i = 0; i < expected_batches.size(); i++) {
ASSERT_OK_AND_ASSIGN(auto chunk, stream->Next());
ASSERT_NE(nullptr, chunk.data);
for (const auto& array : expected_batches[i]->columns()) {
EXPECT_NO_FATAL_FAILURE(AssertAligned(*array));
}

for (const auto& array : chunk.data->columns()) {
EXPECT_NO_FATAL_FAILURE(AssertAligned(*array));
}
}
}
// Ensure FlightDataStream/RecordBatchStream::Close errors are propagated
void DataTest::TestFlightDataStreamError() {
Ticket ticket{"ticket-stream-error"};
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/test_definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest {
void TestDoGetFloats();
void TestDoGetDicts();
void TestDoGetLargeBatch();
void TestDoGetAlignment();
void TestFlightDataStreamError();
void TestOverflowServerBatch();
void TestOverflowClientBatch();
Expand Down Expand Up @@ -108,6 +109,7 @@ class ARROW_FLIGHT_EXPORT DataTest : public FlightTest {
TEST_F(FIXTURE, TestDoGetFloats) { TestDoGetFloats(); } \
TEST_F(FIXTURE, TestDoGetDicts) { TestDoGetDicts(); } \
TEST_F(FIXTURE, TestDoGetLargeBatch) { TestDoGetLargeBatch(); } \
TEST_F(FIXTURE, TestDoGetAlignment) { TestDoGetAlignment(); } \
TEST_F(FIXTURE, TestFlightDataStreamError) { TestFlightDataStreamError(); } \
TEST_F(FIXTURE, TestOverflowServerBatch) { TestOverflowServerBatch(); } \
TEST_F(FIXTURE, TestOverflowClientBatch) { TestOverflowClientBatch(); } \
Expand Down
46 changes: 46 additions & 0 deletions cpp/src/arrow/flight/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <algorithm>
#include <cstdlib>
#include <fstream>
#include <limits>
#include <sstream>

// We need Windows fixes before including Boost
Expand All @@ -43,7 +44,9 @@
#include "arrow/ipc/test_common.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"
#include "arrow/type_fwd.h"
#include "arrow/util/logging.h"

#include "arrow/flight/api.h"
Expand Down Expand Up @@ -192,6 +195,10 @@ Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader
RETURN_NOT_OK(ExampleLargeBatches(&batches));
ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
return Status::OK();
} else if (ticket.ticket == "ticket-alignment") {
ARROW_ASSIGN_OR_RAISE(RecordBatchVector batches, ExampleAlignmentBatches());
ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches));
return Status::OK();
} else {
return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket);
}
Expand Down Expand Up @@ -679,6 +686,45 @@ Status ExampleLargeBatches(RecordBatchVector* out) {
return Status::OK();
}

arrow::Result<arrow::RecordBatchVector> ExampleAlignmentBatches() {
const double null_probability = 0.3;
auto schema = ::arrow::schema({
field("int8", int8()),
field("int16", int16()),
field("int32", int32()),
field("int64", int64()),
});

RecordBatchVector batches;
for (int i = 0; i < 5; ++i) {
int64_t length = i + 1;
arrow::random::RandomArrayGenerator rand(2 * i + 3);

std::shared_ptr<Array> int8s, int16s, int32s, int64s;
int8s = rand.Numeric<Int8Type>(length, std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max(), null_probability);
int16s =
rand.Numeric<Int16Type>(length, std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max(), null_probability);
int32s =
rand.Numeric<Int32Type>(length, std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max(), null_probability);
int64s =
rand.Numeric<Int64Type>(length, std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max(), null_probability);

std::shared_ptr<RecordBatch> batch = RecordBatch::Make(schema, length,
{
std::move(int8s),
std::move(int16s),
std::move(int32s),
std::move(int64s),
});
batches.push_back(batch);
}
return batches;
}

arrow::Result<std::shared_ptr<RecordBatch>> VeryLargeBatch() {
// In CI, some platforms don't let us allocate one very large
// buffer, so allocate a smaller buffer and repeat it a few times
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/flight/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ Status ExampleNestedBatches(RecordBatchVector* out);
ARROW_FLIGHT_EXPORT
Status ExampleLargeBatches(RecordBatchVector* out);

// Batches of data that previously Flight sent unaligned
ARROW_FLIGHT_EXPORT
arrow::Result<arrow::RecordBatchVector> ExampleAlignmentBatches();

ARROW_FLIGHT_EXPORT
arrow::Result<std::shared_ptr<RecordBatch>> VeryLargeBatch();

Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/flight/transport/grpc/serialization_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "arrow/flight/transport/grpc/util_internal.h"
#include "arrow/ipc/message.h"
#include "arrow/ipc/writer.h"
#include "arrow/util/align_util.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/logging.h"

Expand Down Expand Up @@ -380,6 +381,17 @@ ::grpc::Status FlightDataDeserialize(ByteBuffer* buffer,
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
"Unable to read FlightData body");
}
if (!util::CheckAlignment(*out->body, 8)) {
// XXX: due to where we sit, we can't use a custom allocator
// XXX: any error here will likely crash or hang gRPC!
auto buf = std::move(out->body);
auto status =
buf->CopySlice(/*start=*/0, /*nbytes=*/buf->size()).Value(&out->body);
if (!status.ok()) {
return {::grpc::StatusCode::INTERNAL, status.ToString()};
}
}
DCHECK(util::CheckAlignment(*out->body, 8)) << "FlightData body is unaligned";
} break;
default:
DCHECK(false) << "cannot happen";
Expand Down

0 comments on commit cda6048

Please sign in to comment.