diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc index bf6fa94dea7a4..34700080746e7 100644 --- a/cpp/src/arrow/ipc/adapter.cc +++ b/cpp/src/arrow/ipc/adapter.cc @@ -179,20 +179,13 @@ class RowBatchWriter { } // This must be called after invoking AssemblePayload - int64_t DataHeaderSize() { - // TODO(wesm): In case it is needed, compute the upper bound for the size - // of the buffer containing the flatbuffer data header. - return 0; - } - - // Total footprint of buffers. This must be called after invoking - // AssemblePayload - int64_t TotalBytes() { - int64_t total = 0; - for (const std::shared_ptr& buffer : buffers_) { - total += buffer->size(); - } - return total; + Status GetTotalSize(int64_t* size) { + // emulates the behavior of Write without actually writing + int64_t data_header_offset; + MockMemorySource source(0); + RETURN_NOT_OK(Write(&source, 0, &data_header_offset)); + *size = source.GetExtentBytesWritten(); + return Status::OK(); } private: @@ -211,6 +204,14 @@ Status WriteRowBatch(MemorySource* dst, const RowBatch* batch, int64_t position, RETURN_NOT_OK(serializer.AssemblePayload()); return serializer.Write(dst, position, header_offset); } + +Status GetRowBatchSize(const RowBatch* batch, int64_t* size) { + RowBatchWriter serializer(batch, kMaxIpcRecursionDepth); + RETURN_NOT_OK(serializer.AssemblePayload()); + RETURN_NOT_OK(serializer.GetTotalSize(size)); + return Status::OK(); +} + // ---------------------------------------------------------------------- // Row batch read path diff --git a/cpp/src/arrow/ipc/adapter.h b/cpp/src/arrow/ipc/adapter.h index 4c9a8a9d8ee39..0d2b77f5acefe 100644 --- a/cpp/src/arrow/ipc/adapter.h +++ b/cpp/src/arrow/ipc/adapter.h @@ -62,7 +62,7 @@ Status WriteRowBatch(MemorySource* dst, const RowBatch* batch, int64_t position, // Compute the precise number of bytes needed in a contiguous memory segment to // write the row batch. This involves generating the complete serialized // Flatbuffers metadata. -int64_t GetRowBatchSize(const RowBatch* batch); +Status GetRowBatchSize(const RowBatch* batch, int64_t* size); // ---------------------------------------------------------------------- // "Read" path; does not copy data if the MemorySource does not diff --git a/cpp/src/arrow/ipc/ipc-adapter-test.cc b/cpp/src/arrow/ipc/ipc-adapter-test.cc index c243cfba820cc..3b147343f772a 100644 --- a/cpp/src/arrow/ipc/ipc-adapter-test.cc +++ b/cpp/src/arrow/ipc/ipc-adapter-test.cc @@ -195,6 +195,34 @@ INSTANTIATE_TEST_CASE_P(RoundTripTests, TestWriteRowBatch, ::testing::Values(&MakeIntRowBatch, &MakeListRowBatch, &MakeNonNullRowBatch, &MakeZeroLengthRowBatch, &MakeDeeplyNestedList)); +void TestGetRowBatchSize(std::shared_ptr batch) { + MockMemorySource mock_source(1 << 16); + int64_t mock_header_location; + int64_t size; + ASSERT_OK(WriteRowBatch(&mock_source, batch.get(), 0, &mock_header_location)); + ASSERT_OK(GetRowBatchSize(batch.get(), &size)); + ASSERT_EQ(mock_source.GetExtentBytesWritten(), size); +} + +TEST_F(TestWriteRowBatch, IntegerGetRowBatchSize) { + std::shared_ptr batch; + + ASSERT_OK(MakeIntRowBatch(&batch)); + TestGetRowBatchSize(batch); + + ASSERT_OK(MakeListRowBatch(&batch)); + TestGetRowBatchSize(batch); + + ASSERT_OK(MakeZeroLengthRowBatch(&batch)); + TestGetRowBatchSize(batch); + + ASSERT_OK(MakeNonNullRowBatch(&batch)); + TestGetRowBatchSize(batch); + + ASSERT_OK(MakeDeeplyNestedList(&batch)); + TestGetRowBatchSize(batch); +} + class RecursionLimits : public ::testing::Test, public MemoryMapFixture { public: void SetUp() { pool_ = default_memory_pool(); } diff --git a/cpp/src/arrow/ipc/memory.cc b/cpp/src/arrow/ipc/memory.cc index 84cbc182cd26f..caff2c610b907 100644 --- a/cpp/src/arrow/ipc/memory.cc +++ b/cpp/src/arrow/ipc/memory.cc @@ -145,5 +145,30 @@ Status MemoryMappedSource::Write(int64_t position, const uint8_t* data, int64_t return Status::OK(); } +MockMemorySource::MockMemorySource(int64_t size) + : size_(size), extent_bytes_written_(0) {} + +Status MockMemorySource::Close() { + return Status::OK(); +} + +Status MockMemorySource::ReadAt( + int64_t position, int64_t nbytes, std::shared_ptr* out) { + return Status::OK(); +} + +Status MockMemorySource::Write(int64_t position, const uint8_t* data, int64_t nbytes) { + extent_bytes_written_ = std::max(extent_bytes_written_, position + nbytes); + return Status::OK(); +} + +int64_t MockMemorySource::Size() const { + return size_; +} + +int64_t MockMemorySource::GetExtentBytesWritten() const { + return extent_bytes_written_; +} + } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/memory.h b/cpp/src/arrow/ipc/memory.h index e529603dc6e2a..c6fd7a718991b 100644 --- a/cpp/src/arrow/ipc/memory.h +++ b/cpp/src/arrow/ipc/memory.h @@ -121,6 +121,28 @@ class MemoryMappedSource : public MemorySource { std::unique_ptr impl_; }; +// A MemorySource that tracks the size of allocations from a memory source +class MockMemorySource : public MemorySource { + public: + explicit MockMemorySource(int64_t size); + + Status Close() override; + + Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr* out) override; + + Status Write(int64_t position, const uint8_t* data, int64_t nbytes) override; + + int64_t Size() const override; + + // @return: the smallest number of bytes containing the modified region of the + // MockMemorySource + int64_t GetExtentBytesWritten() const; + + private: + int64_t size_; + int64_t extent_bytes_written_; +}; + } // namespace ipc } // namespace arrow