diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4727ff1885ad7..72fe06545910c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -62,9 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { - ev.isNull = ctx.currentVars(ordinal).isNull - ev.value = ctx.currentVars(ordinal).value - "" + val oev = ctx.currentVars(ordinal) + ev.isNull = oev.isNull + ev.value = oev.value + oev.code } else if (nullable) { s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 63e19564dd861..c4265a753933f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. * * @param code The sequence of statements required to evaluate the expression. + * It should be empty string, if `isNull` and `value` are already existed, or no code + * needed to evaluate them (literals). * @param isNull A term that holds a boolean value representing whether the expression evaluated * to null. * @param value A term for a (possibly primitive) value of the result of the evaluation. Not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 36e656b8b6abf..0a132beca9a6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -151,9 +151,6 @@ private[sql] case class PhysicalRDD( val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") val numOutputRows = metricTerm(ctx, "numOutputRows") - ctx.INPUT_ROW = row - ctx.currentVars = null - val columns = exprs.map(_.gen(ctx)) // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this // by looking at the first value of the RDD and then calling the function which will process @@ -161,7 +158,9 @@ private[sql] case class PhysicalRDD( // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know // here which path to use. Fix this. - + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns1 = exprs.map(_.gen(ctx)) val scanBatches = ctx.freshName("processBatches") ctx.addNewFunction(scanBatches, s""" @@ -170,12 +169,11 @@ private[sql] case class PhysicalRDD( | int numRows = $batch.numRows(); | if ($idx == 0) $numOutputRows.add(numRows); | - | while ($idx < numRows) { + | while (!shouldStop() && $idx < numRows) { | InternalRow $row = $batch.getRow($idx++); - | ${columns.map(_.code).mkString("\n").trim} - | ${consume(ctx, columns).trim} - | if (shouldStop()) return; + | ${consume(ctx, columns1).trim} | } + | if (shouldStop()) return; | | if (!$input.hasNext()) { | $batch = null; @@ -186,30 +184,37 @@ private[sql] case class PhysicalRDD( | } | }""".stripMargin) + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns2 = exprs.map(_.gen(ctx)) + val inputRow = if (isUnsafeRow) row else null val scanRows = ctx.freshName("processRows") ctx.addNewFunction(scanRows, s""" | private void $scanRows(InternalRow $row) throws java.io.IOException { - | while (true) { + | boolean firstRow = true; + | while (!shouldStop() && (firstRow || $input.hasNext())) { + | if (firstRow) { + | firstRow = false; + | } else { + | $row = (InternalRow) $input.next(); + | } | $numOutputRows.add(1); - | ${columns.map(_.code).mkString("\n").trim} - | ${consume(ctx, columns).trim} - | if (shouldStop()) return; - | if (!$input.hasNext()) break; - | $row = (InternalRow)$input.next(); + | ${consume(ctx, columns2, inputRow).trim} | } | }""".stripMargin) + val value = ctx.freshName("value") s""" | if ($batch != null) { | $scanBatches(); | } else if ($input.hasNext()) { - | Object value = $input.next(); - | if (value instanceof $columnarBatchClz) { - | $batch = ($columnarBatchClz)value; + | Object $value = $input.next(); + | if ($value instanceof $columnarBatchClz) { + | $batch = ($columnarBatchClz)$value; | $scanBatches(); | } else { - | $scanRows((InternalRow)value); + | $scanRows((InternalRow) $value); | } | } """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 12998a38f59e7..524285bc87123 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -185,8 +185,10 @@ case class Expand( val numOutput = metricTerm(ctx, "numOutputRows") val i = ctx.freshName("i") + // these column have to declared before the loop. + val evaluate = evaluateVariables(outputColumns) s""" - |${outputColumns.map(_.code).mkString("\n").trim} + |$evaluate |for (int $i = 0; $i < ${projections.length}; $i ++) { | switch ($i) { | ${cases.mkString("\n").trim} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 6d231bf74a0e9..45578d50bfc0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -81,11 +81,14 @@ trait CodegenSupport extends SparkPlan { this.parent = parent ctx.freshNamePrefix = variablePrefix waitForSubqueries() - doProduce(ctx) + s""" + |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */ + |${doProduce(ctx)} + """.stripMargin } /** - * Generate the Java source code to process, should be overrided by subclass to support codegen. + * Generate the Java source code to process, should be overridden by subclass to support codegen. * * doProduce() usually generate the framework, for example, aggregation could generate this: * @@ -94,11 +97,11 @@ trait CodegenSupport extends SparkPlan { * # call child.produce() * initialized = true; * } - * while (hashmap.hasNext()) { + * while (!shouldStop() && hashmap.hasNext()) { * row = hashmap.next(); * # build the aggregation results - * # create varialbles for results - * # call consume(), wich will call parent.doConsume() + * # create variables for results + * # call consume(), which will call parent.doConsume() * } */ protected def doProduce(ctx: CodegenContext): String @@ -114,27 +117,71 @@ trait CodegenSupport extends SparkPlan { } /** - * Consume the columns generated from it's child, call doConsume() or emit the rows. + * Returns source code to evaluate all the variables, and clear the code of them, to prevent + * them to be evaluated twice. + */ + protected def evaluateVariables(variables: Seq[ExprCode]): String = { + val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") + variables.foreach(_.code = "") + evaluate + } + + /** + * Returns source code to evaluate the variables for required attributes, and clear the code + * of evaluated variables, to prevent them to be evaluated twice.. */ + protected def evaluateRequiredVariables( + attributes: Seq[Attribute], + variables: Seq[ExprCode], + required: AttributeSet): String = { + var evaluateVars = "" + variables.zipWithIndex.foreach { case (ev, i) => + if (ev.code != "" && required.contains(attributes(i))) { + evaluateVars += ev.code.trim + "\n" + ev.code = "" + } + } + evaluateVars + } + + /** + * The subset of inputSet those should be evaluated before this plan. + * + * We will use this to insert some code to access those columns that are actually used by current + * plan before calling doConsume(). + */ + def usedInputs: AttributeSet = references + + /** + * Consume the columns generated from its child, call doConsume() or emit the rows. + * + * An operator could generate variables for the output, or a row, either one could be null. + * + * If the row is not null, we create variables to access the columns that are actually used by + * current plan before calling doConsume(). + */ def consumeChild( ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode], row: String = null): String = { ctx.freshNamePrefix = variablePrefix - if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - val evals = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + val inputVars = + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + } else { + input } - s""" - | ${evals.map(_.code).mkString("\n")} - | ${doConsume(ctx, evals)} - """.stripMargin - } else { - doConsume(ctx, input) - } + s""" + | + |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */ + |${evaluateRequiredVariables(child.output, inputVars, usedInputs)} + |${doConsume(ctx, inputVars)} + """.stripMargin } /** @@ -145,9 +192,8 @@ trait CodegenSupport extends SparkPlan { * For example, Filter will generate the code like this: * * # code to evaluate the predicate expression, result is isNull1 and value2 - * if (isNull1 || value2) { - * # call consume(), which will call parent.doConsume() - * } + * if (isNull1 || !value2) continue; + * # call consume(), which will call parent.doConsume() */ protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { throw new UnsupportedOperationException @@ -190,13 +236,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) s""" - | while ($input.hasNext()) { + | while (!shouldStop() && $input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); - | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} - | if (shouldStop()) { - | return; - | } | } """.stripMargin } @@ -332,10 +374,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup val colExprs = output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } + val evaluateInputs = evaluateVariables(input) // generate the code to create a UnsafeRow ctx.currentVars = input val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" + |$evaluateInputs |${code.code.trim} |append(${code.value}.copy()); """.stripMargin.trim diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a46722963a6e1..f07add83d5849 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -116,6 +116,8 @@ case class TungstenAggregate( // all the mode of aggregate expressions private val modes = aggregateExpressions.map(_.mode).distinct + override def usedInputs: AttributeSet = inputSet + override def supportCodegen: Boolean = { // ImperativeAggregate is not supported right now !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) @@ -164,23 +166,24 @@ case class TungstenAggregate( """.stripMargin ExprCode(ev.code + initVars, isNull, value) } + val initBufVar = evaluateVariables(bufVars) // generate variables for output - val bufferAttrs = functions.flatMap(_.aggBufferAttributes) val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { // evaluate aggregate results ctx.currentVars = bufVars val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, bufferAttrs).gen(ctx) + BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) } + val evaluateAggResults = evaluateVariables(aggResults) // evaluate result expressions ctx.currentVars = aggResults val resultVars = resultExpressions.map { e => BindReferences.bindReference(e, aggregateAttributes).gen(ctx) } (resultVars, s""" - | ${aggResults.map(_.code).mkString("\n")} - | ${resultVars.map(_.code).mkString("\n")} + |$evaluateAggResults + |${evaluateVariables(resultVars)} """.stripMargin) } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // output the aggregate buffer directly @@ -188,7 +191,7 @@ case class TungstenAggregate( } else { // no aggregate function, the result should be literals val resultVars = resultExpressions.map(_.gen(ctx)) - (resultVars, resultVars.map(_.code).mkString("\n")) + (resultVars, evaluateVariables(resultVars)) } val doAgg = ctx.freshName("doAggregateWithoutKey") @@ -196,7 +199,7 @@ case class TungstenAggregate( s""" | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer - | ${bufVars.map(_.code).mkString("\n")} + | $initBufVar | | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } @@ -204,7 +207,7 @@ case class TungstenAggregate( val numOutput = metricTerm(ctx, "numOutputRows") s""" - | if (!$initAgg) { + | while (!$initAgg) { | $initAgg = true; | $doAgg(); | @@ -241,7 +244,7 @@ case class TungstenAggregate( } s""" | // do aggregate - | ${aggVals.map(_.code).mkString("\n").trim} + | ${evaluateVariables(aggVals)} | // update aggregation buffer | ${updates.mkString("\n").trim} """.stripMargin @@ -252,8 +255,7 @@ case class TungstenAggregate( private val declFunctions = aggregateExpressions.map(_.aggregateFunction) .filter(_.isInstanceOf[DeclarativeAggregate]) .map(_.asInstanceOf[DeclarativeAggregate]) - private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) - private val bufferSchema = StructType.fromAttributes(bufferAttributes) + private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) // The name for HashMap private var hashMapTerm: String = _ @@ -318,7 +320,7 @@ case class TungstenAggregate( val mergeExpr = declFunctions.flatMap(_.mergeExpressions) val mergeProjection = newMutableProjection( mergeExpr, - bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), subexpressionEliminationEnabled)() val joinedRow = new JoinedRow() @@ -380,15 +382,18 @@ case class TungstenAggregate( val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).gen(ctx) } + val evaluateKeyVars = evaluateVariables(keyVars) ctx.INPUT_ROW = bufferTerm - val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => + val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).gen(ctx) } + val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, bufferAttributes).gen(ctx) + BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) } + val evaluateAggResults = evaluateVariables(aggResults) // generate the final result ctx.currentVars = keyVars ++ aggResults val inputAttrs = groupingAttributes ++ aggregateAttributes @@ -396,11 +401,9 @@ case class TungstenAggregate( BindReferences.bindReference(e, inputAttrs).gen(ctx) } s""" - ${keyVars.map(_.code).mkString("\n")} - ${bufferVars.map(_.code).mkString("\n")} - ${aggResults.map(_.code).mkString("\n")} - ${resultVars.map(_.code).mkString("\n")} - + $evaluateKeyVars + $evaluateBufferVars + $evaluateAggResults ${consume(ctx, resultVars)} """ @@ -422,10 +425,7 @@ case class TungstenAggregate( val eval = resultExpressions.map{ e => BindReferences.bindReference(e, groupingAttributes).gen(ctx) } - s""" - ${eval.map(_.code).mkString("\n")} - ${consume(ctx, eval)} - """ + consume(ctx, eval) } } @@ -508,8 +508,8 @@ case class TungstenAggregate( ctx.currentVars = input val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx) - val inputAttr = bufferAttributes ++ child.output - ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input + val inputAttr = aggregateBufferAttributes ++ child.output + ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input ctx.INPUT_ROW = buffer // TODO: support subexpression elimination val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) @@ -557,7 +557,7 @@ case class TungstenAggregate( $incCounter // evaluate aggregate function - ${evals.map(_.code).mkString("\n").trim} + ${evaluateVariables(evals)} // update aggregate buffer ${updates.mkString("\n").trim} """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b2f443c0e9ae6..4a9e736f7abdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -39,15 +39,26 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def usedInputs: AttributeSet = { + // only the attributes those are used at least twice should be evaluated before this plan, + // otherwise we could defer the evaluation until output attribute is actually used. + val usedExprIds = projectList.flatMap(_.collect { + case a: Attribute => a.exprId + }) + val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet + references.filter(a => usedMoreThanOnce.contains(a.exprId)) + } + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input - val output = exprs.map(_.gen(ctx)) + val resultVars = exprs.map(_.gen(ctx)) + // Evaluation of non-deterministic expressions can't be deferred. + val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" - | ${output.map(_.code).mkString("\n")} - | - | ${consume(ctx, output)} + |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} + |${consume(ctx, resultVars)} """.stripMargin } @@ -89,11 +100,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit s"" } s""" - | ${eval.code} - | if ($nullCheck ${eval.value}) { - | $numOutput.add(1); - | ${consume(ctx, ctx.currentVars)} - | } + |${eval.code} + |if (!($nullCheck ${eval.value})) continue; + |$numOutput.add(1); + |${consume(ctx, ctx.currentVars)} """.stripMargin } @@ -228,15 +238,13 @@ case class Range( | } | } | - | while (!$overflow && $checkEnd) { + | while (!$overflow && $checkEnd && !shouldStop()) { | long $value = $number; | $number += ${step}L; | if ($number < $value ^ ${step}L < 0) { | $overflow = true; | } | ${consume(ctx, Seq(ev))} - | - | if (shouldStop()) return; | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 6699dbafe7e74..c52662a61e7f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -190,40 +190,38 @@ case class BroadcastHashJoin( val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) - val resultVars = buildSide match { - case BuildLeft => buildVars ++ input - case BuildRight => input ++ buildVars - } val numOutput = metricTerm(ctx, "numOutputRows") - val outputCode = if (condition.isDefined) { + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) // filter the output via condition - ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + ctx.currentVars = input ++ buildVars + val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) s""" + |$eval |${ev.code} - |if (!${ev.isNull} && ${ev.value}) { - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - |} + |if (${ev.isNull} || !${ev.value}) continue; """.stripMargin } else { - s""" - |$numOutput.add(1); - |${consume(ctx, resultVars)} - """.stripMargin + "" } + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" |// generate join key for stream side |${keyEv.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |if ($matched != null) { - | ${buildVars.map(_.code).mkString("\n")} - | $outputCode - |} + |if ($matched == null) continue; + |$checkCondition + |$numOutput.add(1); + |${consume(ctx, resultVars)} """.stripMargin } else { @@ -236,13 +234,13 @@ case class BroadcastHashJoin( |${keyEv.code} |// find matches from HashRelation |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); - |if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); - | ${buildVars.map(_.code).mkString("\n")} - | $outputCode - | } + |if ($matches == null) continue; + |int $size = $matches.size(); + |for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | $checkCondition + | $numOutput.add(1); + | ${consume(ctx, resultVars)} |} """.stripMargin } @@ -257,21 +255,21 @@ case class BroadcastHashJoin( val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) - val resultVars = buildSide match { - case BuildLeft => buildVars ++ input - case BuildRight => input ++ buildVars - } val numOutput = metricTerm(ctx, "numOutputRows") // filter the output via condition val conditionPassed = ctx.freshName("conditionPassed") val checkCondition = if (condition.isDefined) { - ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + ctx.currentVars = input ++ buildVars + val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) s""" |boolean $conditionPassed = true; + |${eval.trim} + |${ev.code} |if ($matched != null) { - | ${ev.code} | $conditionPassed = !${ev.isNull} && ${ev.value}; |} """.stripMargin @@ -279,17 +277,21 @@ case class BroadcastHashJoin( s"final boolean $conditionPassed = true;" } + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" |// generate join key for stream side |${keyEv.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |${buildVars.map(_.code).mkString("\n")} |${checkCondition.trim} |if (!$conditionPassed) { - | // reset to null - | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")} + | $matched = null; + | // reset the variables those are already evaluated. + | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} |} |$numOutput.add(1); |${consume(ctx, resultVars)} @@ -311,13 +313,11 @@ case class BroadcastHashJoin( |// the last iteration of this loop is to emit an empty row if there is no matched rows. |for (int $i = 0; $i <= $size; $i++) { | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; - | ${buildVars.map(_.code).mkString("\n")} | ${checkCondition.trim} - | if ($conditionPassed && ($i < $size || !$found)) { - | $found = true; - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } + | if (!$conditionPassed || ($i == $size && $found)) continue; + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} |} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 7ec4027188f14..cffd6f6032f2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -306,11 +306,11 @@ case class SortMergeJoin( val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => condRefs.contains(a) } - val beforeCond = used.map(_._2.code).mkString("\n") - val afterCond = notUsed.map(_._2.code).mkString("\n") + val beforeCond = evaluateVariables(used.map(_._2)) + val afterCond = evaluateVariables(notUsed.map(_._2)) (beforeCond, afterCond) } else { - (variables.map(_.code).mkString("\n"), "") + (evaluateVariables(variables), "") } } @@ -326,41 +326,48 @@ case class SortMergeJoin( val leftVars = createLeftVars(ctx, leftRow) val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) - val resultVars = leftVars ++ rightVars - - // Check condition - ctx.currentVars = resultVars - val cond = if (condition.isDefined) { - BindReferences.bindReference(condition.get, output).gen(ctx) - } else { - ExprCode("", "false", "true") - } - // Split the code of creating variables based on whether it's used by condition or not. - val loaded = ctx.freshName("loaded") - val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) - val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) - val size = ctx.freshName("size") val i = ctx.freshName("i") val numOutput = metricTerm(ctx, "numOutputRows") + val (beforeLoop, condCheck) = if (condition.isDefined) { + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + // Generate code for condition + ctx.currentVars = leftVars ++ rightVars + val cond = BindReferences.bindReference(condition.get, output).gen(ctx) + // evaluate the columns those used by condition before loop + val before = s""" + |boolean $loaded = false; + |$leftBefore + """.stripMargin + + val checking = s""" + |$rightBefore + |${cond.code} + |if (${cond.isNull} || !${cond.value}) continue; + |if (!$loaded) { + | $loaded = true; + | $leftAfter + |} + |$rightAfter + """.stripMargin + (before, checking) + } else { + (evaluateVariables(leftVars), "") + } + s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { | int $size = $matches.size(); - | boolean $loaded = false; - | $leftBefore + | ${beforeLoop.trim} | for (int $i = 0; $i < $size; $i ++) { | InternalRow $rightRow = (InternalRow) $matches.get($i); - | $rightBefore - | ${cond.code} - | if (${cond.isNull} || !${cond.value}) continue; - | if (!$loaded) { - | $loaded = true; - | $leftAfter - | } - | $rightAfter + | ${condCheck.trim} | $numOutput.add(1); - | ${consume(ctx, resultVars)} + | ${consume(ctx, leftVars ++ rightVars)} | } | if (shouldStop()) return; |}