Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22682][SQL] HashExpression does not need to create global variables #19878

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,36 @@ abstract class HashExpression[E] extends Expression {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = "false"
val childrenHash = ctx.splitExpressions(children.map { child =>

val childrenHash = children.map { child =>
val childGen = child.genCode(ctx)
childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, ev.value, ctx)
}
})
}

val hashResultType = ctx.javaType(dataType)
val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern appears many times in the code base, we may need to create a ctx.splitExpressionsWithCurrentInput for it later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @kiszk is doing this

Copy link
Member

@gatorsmile gatorsmile Dec 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That one has been merged, but this one is still different.

childrenHash.mkString("\n")
} else {
ctx.splitExpressions(
expressions = childrenHash,
funcName = "computeHash",
arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> ev.value),
returnType = hashResultType,
makeSplitFunction = body =>
s"""
|$body
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
}

ctx.addMutableState(ctx.javaType(dataType), ev.value)
ev.copy(code = s"""
${ev.value} = $seed;
$childrenHash""")
ev.copy(code =
s"""
|$hashResultType ${ev.value} = $seed;
|$codes
""".stripMargin)
}

protected def nullSafeElementHash(
Expand Down Expand Up @@ -389,13 +408,21 @@ abstract class HashExpression[E] extends Expression {
input: String,
result: String,
fields: Array[StructField]): String = {
val hashes = fields.zipWithIndex.map { case (field, index) =>
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}
val hashResultType = ctx.javaType(dataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is done also in line 281. Can we do this only once? maybe with a lazy val?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx is only available inside doGenCode

ctx.splitExpressions(
expressions = hashes,
funcName = "getHash",
arguments = Seq("InternalRow" -> input))
expressions = fieldsHash,
funcName = "computeHashForStruct",
arguments = Seq("InternalRow" -> input, hashResultType -> result),
returnType = hashResultType,
makeSplitFunction = body =>
s"""
|$body
|return $result;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
}

@tailrec
Expand Down Expand Up @@ -610,25 +637,44 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = "false"

val childHash = ctx.freshName("childHash")
val childrenHash = ctx.splitExpressions(children.map { child =>
val childrenHash = children.map { child =>
val childGen = child.genCode(ctx)
val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, childHash, ctx)
}
s"""
|${childGen.code}
|$childHash = 0;
|$codeToComputeHash
|${ev.value} = (31 * ${ev.value}) + $childHash;
|$childHash = 0;
""".stripMargin
})
}

ctx.addMutableState(ctx.javaType(dataType), ev.value)
ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;")
ev.copy(code = s"""
${ev.value} = $seed;
$childrenHash""")
val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
childrenHash.mkString("\n")
} else {
ctx.splitExpressions(
expressions = childrenHash,
funcName = "computeHash",
arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> ev.value),
returnType = ctx.JAVA_INT,
makeSplitFunction = body =>
s"""
|${ctx.JAVA_INT} $childHash = 0;
|$body
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
}

ev.copy(code =
s"""
|${ctx.JAVA_INT} ${ev.value} = $seed;
|${ctx.JAVA_INT} $childHash = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: childHash is only needed to declare here when we don't split functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, splitExpressions could possibly not split expressions if only one block.

|$codes
""".stripMargin)
}

override def eval(input: InternalRow = null): Int = {
Expand Down Expand Up @@ -730,23 +776,29 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
input: String,
result: String,
fields: Array[StructField]): String = {
val localResult = ctx.freshName("localResult")
val childResult = ctx.freshName("childResult")
fields.zipWithIndex.map { case (field, index) =>
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
val computeFieldHash = nullSafeElementHash(
input, index.toString, field.nullable, field.dataType, childResult, ctx)
s"""
$childResult = 0;
${nullSafeElementHash(input, index.toString, field.nullable, field.dataType,
childResult, ctx)}
$localResult = (31 * $localResult) + $childResult;
"""
}.mkString(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We forgot to split the code for computing hive hash of struct, it's fixed now.

s"""
int $localResult = 0;
int $childResult = 0;
""",
"",
s"$result = (31 * $result) + $localResult;"
)
|$childResult = 0;
|$computeFieldHash
|$result = (31 * $result) + $childResult;
""".stripMargin
}

s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check ctx.INPUT_ROW == null || ctx.currentVars != null here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, the input here is a row that may be produced by row.getStruct instead of ctx.INPUT_ROW, so we don't need this check as the input won't be ctx.currentVars.

expressions = fieldsHash,
funcName = "computeHashForStruct",
arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result),
returnType = ctx.JAVA_INT,
makeSplitFunction = body =>
s"""
|${ctx.JAVA_INT} $childResult = 0;
|$body
|return $result;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
Expand Down Expand Up @@ -620,23 +621,30 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("SPARK-18207: Compute hash for a lot of expressions") {
def checkResult(schema: StructType, input: InternalRow): Unit = {
val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
BoundReference(i, f.dataType, true)
}
val murmur3HashExpr = Murmur3Hash(exprs, 42)
val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
val murmursHashEval = Murmur3Hash(exprs, 42).eval(input)
assert(murmur3HashPlan(input).getInt(0) == murmursHashEval)

val hiveHashExpr = HiveHash(exprs)
val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
val hiveHashEval = HiveHash(exprs).eval(input)
assert(hiveHashPlan(input).getInt(0) == hiveHashEval)
}

val N = 1000
val wideRow = new GenericInternalRow(
Seq.tabulate(N)(i => UTF8String.fromString(i.toString)).toArray[Any])
val schema = StructType((1 to N).map(i => StructField("", StringType)))

val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
BoundReference(i, f.dataType, true)
}
val murmur3HashExpr = Murmur3Hash(exprs, 42)
val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow)
assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval)
val schema = StructType((1 to N).map(i => StructField(i.toString, StringType)))
checkResult(schema, wideRow)

val hiveHashExpr = HiveHash(exprs)
val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
val hiveHashEval = HiveHash(exprs).eval(wideRow)
assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval)
val nestedRow = InternalRow(wideRow)
val nestedSchema = new StructType().add("nested", schema)
checkResult(nestedSchema, nestedRow)
}

test("SPARK-22284: Compute hash for nested structs") {
Expand Down