Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8305] [SPARK-8190] [SQL] improve codegen #6755

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,27 @@ public int fieldIndex(String name) {
throw new UnsupportedOperationException();
}

/**
* A generic version of Row.equals(Row), which is used for tests.
*/
@Override
public boolean equals(Object other) {
if (other instanceof Row) {
Row row = (Row) other;
int n = size();
if (n != row.size()) {
return false;
}
for (int i = 0; i < n; i ++) {
if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
return false;
}
}
return true;
}
return false;
}

@Override
public Row copy() {
final int n = size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case ByteType =>
buildCast[Byte](_, _ != 0)
case DecimalType() =>
buildCast[Decimal](_, _ != 0)
buildCast[Decimal](_, _ != Decimal(0))
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
Expand Down Expand Up @@ -454,7 +454,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c.isZero()")
defineCodeGen(ctx, ev, c => s"!$c.isZero()")
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,23 @@ class CodeGenContext {
}

/**
* Returns a function to generate equal expression in Java
* Generate code for equal expression in Java
*/
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
case BinaryType => { case (eval1, eval2) =>
s"java.util.Arrays.equals($eval1, $eval2)" }
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
{ case (eval1, eval2) => s"$eval1 == $eval2" }
case other =>
{ case (eval1, eval2) => s"$eval1.equals($eval2)" }
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case other => s"$c1.equals($c2)"
}

/**
* Generate code for compare expression in Java
*/
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
// Use signum() to keep any small difference bwteen float/double
case FloatType | DoubleType => s"(int)java.lang.Math.signum($c1 - $c2)"
case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 - $c2)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1.2 - 1.1).asInstanceOf[Int] => 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first case will handle float and double, using java.lang.Math.signum

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add inline comment explaining why we have special case for float/double?

case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case other => s"$c1.compare($c2)"
}

/**
Expand All @@ -182,6 +190,16 @@ class CodeGenContext {
* Returns true if the data type has a special accessor and setter in [[Row]].
*/
def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)

/**
* List of data types who's Java type is primitive type
*/
val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType)

/**
* Returns true if the Java type is primitive type
*/
def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
}
"""


logDebug(s"code for ${expressions.mkString(",")}:\n$code")

val c = compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Private
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{BinaryType, NumericType}

/**
* Inherits some default implementation for Java from `Ordering[Row]`
Expand Down Expand Up @@ -55,39 +54,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
val evalA = order.child.gen(ctx)
val evalB = order.child.gen(ctx)
val asc = order.direction == Ascending
val compare = order.child.dataType match {
case BinaryType =>
s"""
{
byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
int j = 0;
while (j < x.length && j < y.length) {
if (x[j] != y[j]) return x[j] - y[j];
j = j + 1;
}
int d = x.length - y.length;
if (d != 0) {
return d;
}
}"""
case _: NumericType =>
s"""
if (${evalA.primitive} != ${evalB.primitive}) {
if (${evalA.primitive} > ${evalB.primitive}) {
return ${if (asc) "1" else "-1"};
} else {
return ${if (asc) "-1" else "1"};
}
}"""
case _ =>
s"""
int comp = ${evalA.primitive}.compare(${evalB.primitive});
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}"""
}

s"""
i = $a;
${evalA.code}
Expand All @@ -100,7 +66,10 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
} else if (${evalB.isNull}) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
$compare
int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)};
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}
}
"""
}.mkString("\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}.mkString("\n ")

val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
val cases = expressions.zipWithIndex.map {
case (e, i) if e.dataType == dataType
|| dataType == IntegerType && e.dataType == DateType
|| dataType == LongType && e.dataType == TimestampType =>
s"case $i: return c$i;"
case _ => ""
val cases = expressions.zipWithIndex.flatMap {
case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
List(s"case $i: return c$i;")
case _ => Nil
}.mkString("\n ")
if (cases.count(_ != '\n') > 0) {
if (cases.length > 0) {
s"""
@Override
public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
Expand All @@ -89,29 +87,30 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
switch (i) {
$cases
}
return ${ctx.defaultValue(dataType)};
throw new IllegalArgumentException("Invalid index: " + i
+ " in ${ctx.accessorForType(dataType)}");
}"""
} else {
""
}
}.mkString("\n")

val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
val cases = expressions.zipWithIndex.map {
case (e, i) if e.dataType == dataType
|| dataType == IntegerType && e.dataType == DateType
|| dataType == LongType && e.dataType == TimestampType =>
s"case $i: { c$i = value; return; }"
case _ => ""
}.mkString("\n")
if (cases.count(_ != '\n') > 0) {
val cases = expressions.zipWithIndex.flatMap {
case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) =>
List(s"case $i: { c$i = value; return; }")
case _ => Nil
}.mkString("\n ")
if (cases.length > 0) {
s"""
@Override
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
nullBits[i] = false;
switch (i) {
$cases
}
throw new IllegalArgumentException("Invalid index: " + i +
" in ${ctx.mutatorForType(dataType)}");
}"""
} else {
""
Expand Down Expand Up @@ -139,9 +138,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

val columnChecks = expressions.zipWithIndex.map { case (e, i) =>
s"""
if (isNullAt($i) != row.isNullAt($i) || !isNullAt($i) && !get($i).equals(row.get($i))) {
return false;
}
if (nullBits[$i] != row.nullBits[$i] ||
(!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) {
return false;
}
"""
}.mkString("\n")

Expand Down Expand Up @@ -174,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}

public int size() { return ${expressions.length};}
private boolean[] nullBits = new boolean[${expressions.length}];
protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }

Expand Down Expand Up @@ -207,9 +207,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

@Override
public boolean equals(Object other) {
if (other instanceof Row) {
Row row = (Row) other;
if (row.length() != size()) return false;
if (other instanceof SpecificRow) {
SpecificRow row = (SpecificRow) other;
$columnChecks
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
${cond.code}
if (${keyEval.isNull} && ${cond.isNull} ||
!${keyEval.isNull} && !${cond.isNull}
&& ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
&& ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
$got = true;
${res.code}
${ev.isNull} = ${res.isNull};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
// change the isNull and primitive to consts, to inline them
if (value == null) {
ev.isNull = "true"
ev.primitive = ctx.defaultValue(dataType)
""
s"final ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};"
} else {
dataType match {
case BooleanType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,11 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
left.dataType match {
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
(c1, c3) => s"$c1 $symbol $c3"
})
case DateType | TimestampType => defineCodeGen (ctx, ev, {
(c1, c3) => s"$c1 $symbol $c3"
})
case other => defineCodeGen (ctx, ev, {
(c1, c2) => s"$c1.compare($c2) $symbol 0"
})
if (ctx.isPrimitiveType(left.dataType)) {
// faster version
defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
}
}

Expand All @@ -280,8 +275,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
if (left.dataType != BinaryType) l == r
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a blank line here

defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
}
}

Expand All @@ -307,7 +303,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive)
val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive)
ev.isNull = "false"
eval1.code + eval2.code + s"""
boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,12 @@ object TypeUtils {

def getOrdering(t: DataType): Ordering[Any] =
t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove one extra line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems super slow ... after this PR we should create a ByteArrayUtils in unsafe.types

for (i <- 0 until x.length; if i < y.length) {
val res = x(i).compareTo(y(i))
if (res != 0) return res
}
x.length - y.length
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.util.TypeUtils


/**
Expand All @@ -43,11 +44,7 @@ class BinaryType private() extends AtomicType {

private[sql] val ordering = new Ordering[InternalType] {
def compare(x: Array[Byte], y: Array[Byte]): Int = {
for (i <- 0 until x.length; if i < y.length) {
val res = x(i).compareTo(y(i))
if (res != 0) return res
}
x.length - y.length
TypeUtils.compareBinary(x, y)
}
}

Expand Down
Loading