Skip to content

Commit

Permalink
merged ColumnarBatch changes from arrow-ColumnarBatch-support-SPARK-2…
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler committed Aug 4, 2017
1 parent e49758b commit 912143e
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,42 @@ public final class ColumnarBatch {
final Row row;

public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) {
return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode);
return allocate(schema, memMode, DEFAULT_BATCH_SIZE);
}

public static ColumnarBatch allocate(StructType type) {
return new ColumnarBatch(type, DEFAULT_BATCH_SIZE, DEFAULT_MEMORY_MODE);
return allocate(type, DEFAULT_MEMORY_MODE, DEFAULT_BATCH_SIZE);
}

public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) {
return new ColumnarBatch(schema, maxRows, memMode);
ColumnVector[] columns = allocateVectors(schema, maxRows, memMode);
return create(schema, columns, maxRows);
}

private static ColumnVector[] allocateVectors(StructType schema, int maxRows, MemoryMode memMode) {
ColumnVector[] columns = new ColumnVector[schema.size()];
for (int i = 0; i < schema.fields().length; ++i) {
StructField field = schema.fields()[i];
columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode);
}
return columns;
}

public static ColumnarBatch createReadOnly(
StructType schema,
ReadOnlyColumnVector[] columns,
int numRows) {
for (ReadOnlyColumnVector c: columns) {
assert(c.capacity >= numRows);
}
ColumnarBatch batch = create(schema, columns, numRows);
batch.setNumRows(numRows);
return batch;
}

private static ColumnarBatch create(StructType schema, ColumnVector[] columns, int capacity) {
assert(schema.length() == columns.length);
return new ColumnarBatch(schema, columns, capacity);
}

/**
Expand Down Expand Up @@ -505,18 +532,12 @@ public void filterNullsInColumn(int ordinal) {
nullFilteredColumns.add(ordinal);
}

private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) {
private ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) {
this.schema = schema;
this.capacity = maxRows;
this.columns = new ColumnVector[schema.size()];
this.columns = columns;
this.capacity = capacity;
this.nullFilteredColumns = new HashSet<>();
this.filteredRows = new boolean[maxRows];

for (int i = 0; i < schema.fields().length; ++i) {
StructField field = schema.fields()[i];
columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode);
}

this.filteredRows = new boolean[this.capacity];
this.row = new Row(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.file._
Expand All @@ -28,14 +30,15 @@ import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ReadOnlyColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
* Store Arrow data in a form that can be serialized by Spark and served to a Python process.
*/
private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable {
private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable {

/**
* Convert the ArrowPayload to an ArrowRecordBatch.
Expand Down Expand Up @@ -110,6 +113,67 @@ private[sql] object ArrowConverters {
}
}

private[sql] def fromPayloadIterator(
payloadIter: Iterator[ArrowPayload],
schema: StructType,
context: TaskContext): Iterator[InternalRow] = {

val allocator =
ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue)
var reader: ArrowFileReader = null

new Iterator[InternalRow] {

context.addTaskCompletionListener { _ =>
close()
}

private var _batch: ColumnarBatch = _
private var _rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty

override def hasNext: Boolean = _rowIter.hasNext || {
if (payloadIter.hasNext) {
_rowIter = nextBatch()
true
} else {
close()
false
}
}

override def next(): InternalRow = _rowIter.next()

def close(): Unit = {
closeReader()
allocator.close()
}

private def closeReader(): Unit = {
if (reader != null) {
reader.close()
reader = null
}
}

private def nextBatch(): Iterator[InternalRow] = {
closeReader()
val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable)
reader = new ArrowFileReader(in, allocator)
reader.loadNextBatch() // throws IOException
val root = reader.getVectorSchemaRoot

assert(schema.equals(ArrowUtils.fromArrowSchema(root.getSchema)),
s"$schema \n!=\n ${ArrowUtils.fromArrowSchema(root.getSchema)}")

val columns = root.getFieldVectors.asScala.map { vector =>
new ArrowColumnVector(vector).asInstanceOf[ReadOnlyColumnVector]
}.toArray

ColumnarBatch.createReadOnly(schema, columns, root.getRowCount).rowIterator().asScala
}
}
}

/**
* Convert a byte array to an ArrowRecordBatch.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@ import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale

import scala.collection.JavaConverters._

import com.google.common.io.Files
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.{NullableIntVector, VectorLoader, VectorSchemaRoot}
import org.apache.arrow.vector.file.json.JsonFileReader
import org.apache.arrow.vector.util.Validator
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ReadOnlyColumnVector}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -1629,6 +1632,40 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
}
}

test("roundtrip payloads") {
val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector.allocateNew()
val mutator = vector.getMutator()

(0 until 10).foreach { i =>
mutator.setSafe(i, i)
}
mutator.setNull(10)
mutator.setValueCount(11)

val schema = StructType(Seq(StructField("int", IntegerType)))

val columnarBatch = ColumnarBatch.createReadOnly(
schema, Array[ReadOnlyColumnVector](new ArrowColumnVector(vector)), 11)

val context = TaskContext.empty()

val payloadIter = ArrowConverters.toPayloadIterator(
columnarBatch.rowIterator().asScala, schema, 0, context)

val rowIter = ArrowConverters.fromPayloadIterator(payloadIter, schema, context)

rowIter.zipWithIndex.foreach { case (row, i) =>
if (i == 10) {
assert(row.isNullAt(0))
} else {
assert(row.getInt(0) == i)
}
}
}

/** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */
private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = {
// NOTE: coalesce to single partition because can only load 1 batch in validator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import org.apache.arrow.vector.NullableIntVector

import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -1248,4 +1251,55 @@ class ColumnarBatchSuite extends SparkFunSuite {
s"vectorized reader"))
}
}

test("create read-only batch") {
val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector1.allocateNew()
val mutator1 = vector1.getMutator()
val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true)
.createVector(allocator).asInstanceOf[NullableIntVector]
vector2.allocateNew()
val mutator2 = vector2.getMutator()

(0 until 10).foreach { i =>
mutator1.setSafe(i, i)
mutator2.setSafe(i + 1, i)
}
mutator1.setNull(10)
mutator1.setValueCount(11)
mutator2.setNull(0)
mutator2.setValueCount(11)

val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2))

val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType)))
val batch = ColumnarBatch.createReadOnly(
schema, columnVectors.toArray[ReadOnlyColumnVector], 11)

assert(batch.numCols() == 2)
assert(batch.numRows() == 11)

val rowIter = batch.rowIterator().asScala
rowIter.zipWithIndex.foreach { case (row, i) =>
if (i == 10) {
assert(row.isNullAt(0))
} else {
assert(row.getInt(0) == i)
}
if (i == 0) {
assert(row.isNullAt(1))
} else {
assert(row.getInt(1) == i - 1)
}
}

intercept[java.lang.AssertionError] {
batch.getRow(100)
}

columnVectors.foreach(_.close())
allocator.close()
}
}

0 comments on commit 912143e

Please sign in to comment.