Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18016][SQL][FOLLOW-UP] Code Generation: Constant Pool Limit - reduce entries for mutable state #20036

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class CodegenContext {

/**
* Returns the reference of next available slot in current compacted array. The size of each
* compacted array is controlled by the config `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
* compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
* Once reaching the threshold, new compacted array is created.
*/
def getNextSlot(): String = {
Expand Down Expand Up @@ -299,7 +299,7 @@ class CodegenContext {
def initMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
val initCodes = mutableStateInitCode.distinct
val initCodes = mutableStateInitCode.distinct.map(_ + "\n")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good catch!


// The generated initialization code may exceed 64kb function size limit in JVM if there are too
// many mutable states, so split it into multiple functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
// inline mutable state since not many Like operations in a task
val pattern = ctx.addMutableState(patternClass, "patternLike",
v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true)
v => s"""$v = $patternClass.compile("$regexStr");""")

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
Expand All @@ -143,9 +142,9 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
val rightStr = ctx.freshName("rightStr")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $rightStr = ${eval2}.toString();
${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr));
${ev.value} = $pattern.matcher(${eval1}.toString()).matches();
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile($escapeFunc($rightStr));
${ev.value} = $pattern.matcher($eval1.toString()).matches();
"""
})
}
Expand Down Expand Up @@ -194,9 +193,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString())
// inline mutable state since not many RLike operations in a task
val pattern = ctx.addMutableState(patternClass, "patternRLike",
v => s"""$v = ${patternClass}.compile("$regexStr");""", forceInline = true)
v => s"""$v = $patternClass.compile("$regexStr");""")

// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
Expand All @@ -219,9 +217,9 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
val pattern = ctx.freshName("pattern")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $rightStr = ${eval2}.toString();
${patternClass} $pattern = ${patternClass}.compile($rightStr);
${ev.value} = $pattern.matcher(${eval1}.toString()).find(0);
String $rightStr = $eval2.toString();
$patternClass $pattern = $patternClass.compile($rightStr);
${ev.value} = $pattern.matcher($eval1.toString()).find(0);
"""
})
}
Expand Down Expand Up @@ -338,25 +336,25 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio

nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
s"""
if (!$regexp.equals(${termLastRegex})) {
if (!$regexp.equals($termLastRegex)) {
// regex value changed
${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
$termLastRegex = $regexp.clone();
$termPattern = $classNamePattern.compile($termLastRegex.toString());
}
if (!$rep.equals(${termLastReplacementInUTF8})) {
if (!$rep.equals($termLastReplacementInUTF8)) {
// replacement string changed
${termLastReplacementInUTF8} = $rep.clone();
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
$termLastReplacementInUTF8 = $rep.clone();
$termLastReplacement = $termLastReplacementInUTF8.toString();
}
$classNameStringBuffer ${termResult} = new $classNameStringBuffer();
java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
$classNameStringBuffer $termResult = new $classNameStringBuffer();
java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString());

while (${matcher}.find()) {
${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
while ($matcher.find()) {
$matcher.appendReplacement($termResult, $termLastReplacement);
}
${matcher}.appendTail(${termResult});
${ev.value} = UTF8String.fromString(${termResult}.toString());
${termResult} = null;
$matcher.appendTail($termResult);
${ev.value} = UTF8String.fromString($termResult.toString());
$termResult = null;
$setEvNotNull
"""
})
Expand Down Expand Up @@ -425,19 +423,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio

nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
if (!$regexp.equals(${termLastRegex})) {
if (!$regexp.equals($termLastRegex)) {
// regex value changed
${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
$termLastRegex = $regexp.clone();
$termPattern = $classNamePattern.compile($termLastRegex.toString());
}
java.util.regex.Matcher ${matcher} =
${termPattern}.matcher($subject.toString());
if (${matcher}.find()) {
java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
if (${matchResult}.group($idx) == null) {
java.util.regex.Matcher $matcher =
$termPattern.matcher($subject.toString());
if ($matcher.find()) {
java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
if ($matchResult.group($idx) == null) {
${ev.value} = UTF8String.EMPTY_UTF8;
} else {
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
${ev.value} = UTF8String.fromString($matchResult.group($idx));
}
$setEvNotNull
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ case class SortExec(
// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
val thisPlan = ctx.addReferenceObj("plan", this)
// inline mutable state since not many Sort operations in a task
// Inline mutable state since not many Sort operations in a task
sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter",
v => s"$v = $thisPlan.createSorter();", forceInline = true)
val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp

override def doProduce(ctx: CodegenContext): String = {
// Right now, InputAdapter is only used when there is one input RDD.
// inline mutable state since an inputAdaptor in a task
// Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
forceInline = true)
val row = ctx.freshName("row")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,31 +587,35 @@ case class HashAggregateExec(
fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
ctx.addInnerClass(generatedMap)

// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "vectorizedHastHashMap",
v => s"$v = new $fastHashMapClassName();")
ctx.addMutableState(s"java.util.Iterator<InternalRow>", "vectorizedFastHashMapIter")
v => s"$v = new $fastHashMapClassName();", forceInline = true)
ctx.addMutableState(s"java.util.Iterator<InternalRow>", "vectorizedFastHashMapIter",
forceInline = true)
} else {
val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions,
fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
ctx.addInnerClass(generatedMap)

// Inline mutable state since not many aggregation operations in a task
fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, "fastHashMap",
v => s"$v = new $fastHashMapClassName(" +
s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());")
s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());",
forceInline = true)
ctx.addMutableState(
"org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
"fastHashMapIter")
"fastHashMapIter", forceInline = true)
}
}

// Create a name for the iterator from the regular hash map.
// inline mutable state since not many aggregation operations in a task
// Inline mutable state since not many aggregation operations in a task
val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName,
"mapIter", forceInline = true)
// create hashMap
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
hashMapTerm = ctx.addMutableState(hashMapClassName, "hashMap",
v => s"$v = $thisPlan.createHashMap();")
v => s"$v = $thisPlan.createHashMap();", forceInline = true)
sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, "sorter",
forceInline = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ case class SampleExec(
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")

// inline mutable state since not many Sample operations in a task
// Inline mutable state since not many Sample operations in a task
val sampler = ctx.addMutableState(s"$samplerClass<UnsafeRow>", "sampleReplace",
v => {
val initSamplerFuncName = ctx.addNewFunction(initSampler,
Expand Down Expand Up @@ -371,7 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName

// inline mutable state since not many Range operations in a task
// Inline mutable state since not many Range operations in a task
val taskContext = ctx.addMutableState("TaskContext", "taskContext",
v => s"$v = TaskContext.get();", forceInline = true)
val inputMetrics = ctx.addMutableState("InputMetrics", "inputMetrics",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ case class BroadcastHashJoinExec(
// At the end of the task, we update the avg hash probe.
val avgHashProbe = metricTerm(ctx, "avgHashProbe")

// inline mutable state since not many join operations in a task
// Inline mutable state since not many join operations in a task
val relationTerm = ctx.addMutableState(clsName, "relation",
v => s"""
| $v = (($clsName) $broadcast.value()).asReadOnlyCopy();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ case class SortMergeJoinExec(
*/
private def genScanner(ctx: CodegenContext): (String, String) = {
// Create class member for next row from both sides.
// inline mutable state since not many join operations in a task
// Inline mutable state since not many join operations in a task
val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true)
val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true)

Expand All @@ -440,8 +440,9 @@ case class SortMergeJoinExec(
val spillThreshold = getSpillThreshold
val inMemoryThreshold = getInMemoryThreshold

// Inline mutable state since not many join operations in a task
val matches = ctx.addMutableState(clsName, "matches",
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);")
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)

Expand Down Expand Up @@ -576,7 +577,7 @@ case class SortMergeJoinExec(
override def needCopyResult: Boolean = true

override def doProduce(ctx: CodegenContext): String = {
// inline mutable state since not many join operations in a task
// Inline mutable state since not many join operations in a task
val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
v => s"$v = inputs[0];", forceInline = true)
val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
Expand Down