Skip to content

Commit

Permalink
[SPARK-23090][SQL] polish ColumnVector
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Several improvements:
* provide a default implementation for the batch get methods
* rename `getChildColumn` to `getChild`, which is more concise
* remove `getStruct(int, int)`, it's only used to simplify the codegen, which is an internal thing, we should not add a public API for this purpose.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #20277 from cloud-fan/column-vector.
  • Loading branch information
cloud-fan committed Jan 22, 2018
1 parent 896e45a commit 5d680ca
Show file tree
Hide file tree
Showing 17 changed files with 164 additions and 296 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -688,17 +688,13 @@ class CodegenContext {
/**
* Returns the specialized code to access a value from a column vector for a given `DataType`.
*/
def getValue(vector: String, rowId: String, dataType: DataType): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
s"$vector.get${primitiveTypeName(jt)}($rowId)"
case t: DecimalType =>
s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})"
case StringType =>
s"$vector.getUTF8String($rowId)"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
if (dataType.isInstanceOf[StructType]) {
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
// `ordinal` parameter.
s"$vector.getStruct($rowId)"
} else {
getValue(vector, dataType, rowId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,57 +110,21 @@ public boolean getBoolean(int rowId) {
return longData.vector[getRowIndex(rowId)] == 1;
}

@Override
public boolean[] getBooleans(int rowId, int count) {
boolean[] res = new boolean[count];
for (int i = 0; i < count; i++) {
res[i] = getBoolean(rowId + i);
}
return res;
}

@Override
public byte getByte(int rowId) {
return (byte) longData.vector[getRowIndex(rowId)];
}

@Override
public byte[] getBytes(int rowId, int count) {
byte[] res = new byte[count];
for (int i = 0; i < count; i++) {
res[i] = getByte(rowId + i);
}
return res;
}

@Override
public short getShort(int rowId) {
return (short) longData.vector[getRowIndex(rowId)];
}

@Override
public short[] getShorts(int rowId, int count) {
short[] res = new short[count];
for (int i = 0; i < count; i++) {
res[i] = getShort(rowId + i);
}
return res;
}

@Override
public int getInt(int rowId) {
return (int) longData.vector[getRowIndex(rowId)];
}

@Override
public int[] getInts(int rowId, int count) {
int[] res = new int[count];
for (int i = 0; i < count; i++) {
res[i] = getInt(rowId + i);
}
return res;
}

@Override
public long getLong(int rowId) {
int index = getRowIndex(rowId);
Expand All @@ -171,43 +135,16 @@ public long getLong(int rowId) {
}
}

@Override
public long[] getLongs(int rowId, int count) {
long[] res = new long[count];
for (int i = 0; i < count; i++) {
res[i] = getLong(rowId + i);
}
return res;
}

@Override
public float getFloat(int rowId) {
return (float) doubleData.vector[getRowIndex(rowId)];
}

@Override
public float[] getFloats(int rowId, int count) {
float[] res = new float[count];
for (int i = 0; i < count; i++) {
res[i] = getFloat(rowId + i);
}
return res;
}

@Override
public double getDouble(int rowId) {
return doubleData.vector[getRowIndex(rowId)];
}

@Override
public double[] getDoubles(int rowId, int count) {
double[] res = new double[count];
for (int i = 0; i < count; i++) {
res[i] = getDouble(rowId + i);
}
return res;
}

@Override
public int getArrayLength(int rowId) {
throw new UnsupportedOperationException();
Expand Down Expand Up @@ -245,7 +182,7 @@ public org.apache.spark.sql.vectorized.ColumnVector arrayData() {
}

@Override
public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) {
public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,9 @@ private void putRepeatingValues(
toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]);
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector data = (BytesColumnVector)fromColumn;
WritableColumnVector arrayData = toColumn.getChildColumn(0);
int size = data.vector[0].length;
arrayData.reserve(size);
arrayData.putBytes(0, size, data.vector[0], 0);
toColumn.arrayData().reserve(size);
toColumn.arrayData().putBytes(0, size, data.vector[0], 0);
for (int index = 0; index < batchSize; index++) {
toColumn.putArray(index, 0, size);
}
Expand Down Expand Up @@ -352,7 +351,7 @@ private void putNonNullValues(
toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0);
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector data = ((BytesColumnVector)fromColumn);
WritableColumnVector arrayData = toColumn.getChildColumn(0);
WritableColumnVector arrayData = toColumn.arrayData();
int totalNumBytes = IntStream.of(data.length).sum();
arrayData.reserve(totalNumBytes);
for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) {
Expand All @@ -363,8 +362,7 @@ private void putNonNullValues(
DecimalType decimalType = (DecimalType)type;
DecimalColumnVector data = ((DecimalColumnVector)fromColumn);
if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) {
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(batchSize * 16);
toColumn.arrayData().reserve(batchSize * 16);
}
for (int index = 0; index < batchSize; index++) {
putDecimalWritable(
Expand Down Expand Up @@ -459,7 +457,7 @@ private void putValues(
}
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector vector = (BytesColumnVector)fromColumn;
WritableColumnVector arrayData = toColumn.getChildColumn(0);
WritableColumnVector arrayData = toColumn.arrayData();
int totalNumBytes = IntStream.of(vector.length).sum();
arrayData.reserve(totalNumBytes);
for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) {
Expand All @@ -474,8 +472,7 @@ private void putValues(
DecimalType decimalType = (DecimalType)type;
HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector;
if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) {
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(batchSize * 16);
toColumn.arrayData().reserve(batchSize * 16);
}
for (int index = 0; index < batchSize; index++) {
if (fromColumn.isNull[index]) {
Expand Down Expand Up @@ -521,8 +518,7 @@ private static void putDecimalWritable(
toColumn.putLong(index, value.toUnscaledLong());
} else {
byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray();
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.putBytes(index * 16, bytes.length, bytes, 0);
toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0);
toColumn.putArray(index, index * 16, bytes.length);
}
}
Expand All @@ -547,9 +543,8 @@ private static void putDecimalWritables(
toColumn.putLongs(0, size, value.toUnscaledLong());
} else {
byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray();
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(bytes.length);
arrayData.putBytes(0, bytes.length, bytes, 0);
toColumn.arrayData().reserve(bytes.length);
toColumn.arrayData().putBytes(0, bytes.length, bytes, 0);
for (int index = 0; index < size; index++) {
toColumn.putArray(index, 0, bytes.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field
}
} else if (t instanceof CalendarIntervalType) {
CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t);
col.getChildColumn(0).putInts(0, capacity, c.months);
col.getChildColumn(1).putLongs(0, capacity, c.microseconds);
col.getChild(0).putInts(0, capacity, c.months);
col.getChild(1).putLongs(0, capacity, c.microseconds);
} else if (t instanceof DateType) {
col.putInts(0, capacity, row.getInt(fieldIdx));
} else if (t instanceof TimestampType) {
Expand Down Expand Up @@ -149,8 +149,8 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o)
} else if (t instanceof CalendarIntervalType) {
CalendarInterval c = (CalendarInterval)o;
dst.appendStruct(false);
dst.getChildColumn(0).appendInt(c.months);
dst.getChildColumn(1).appendLong(c.microseconds);
dst.getChild(0).appendInt(c.months);
dst.getChild(1).appendLong(c.microseconds);
} else if (t instanceof DateType) {
dst.appendInt(DateTimeUtils.fromJavaDate((Date)o));
} else {
Expand Down Expand Up @@ -179,7 +179,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i
dst.appendStruct(false);
Row c = src.getStruct(fieldIdx);
for (int i = 0; i < st.fields().length; i++) {
appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i);
appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i);
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ public byte[] getBinary(int ordinal) {
@Override
public CalendarInterval getInterval(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
final int months = columns[ordinal].getChild(0).getInt(rowId);
final long microseconds = columns[ordinal].getChild(1).getLong(rowId);
return new CalendarInterval(months, microseconds);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,17 +599,13 @@ public final int appendStruct(boolean isNull) {
return elementsAppended;
}

/**
* Returns the data for the underlying array.
*/
// `WritableColumnVector` puts the data of array in the first child column vector, and puts the
// array offsets and lengths in the current column vector.
@Override
public WritableColumnVector arrayData() { return childColumns[0]; }

/**
* Returns the ordinal's child data column.
*/
@Override
public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; }
public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; }

/**
* Returns the elements appended.
Expand Down
Loading

0 comments on commit 5d680ca

Please sign in to comment.