Skip to content

Commit

Permalink
[SPARK-19372][SQL] Fix throwing a Java exception at df.fliter() due t…
Browse files Browse the repository at this point in the history
…o 64KB bytecode size limit

When an expression for `df.filter()` has many nodes (e.g. 400), the size of Java bytecode for the generated Java code is more than 64KB. It produces an Java exception. As a result, the execution fails.
This PR continues to execute by calling `Expression.eval()` disabling code generation if an exception has been caught.

Add a test suite into `DataFrameSuite`

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes apache#17087 from kiszk/SPARK-19372.
  • Loading branch information
kiszk authored and poplav committed Aug 17, 2017
1 parent 267aca5 commit afde7df
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,38 @@ object ExternalCatalogUtils {
}
escapePathName(col) + "=" + partitionString
}

def prunePartitionsByFilter(
catalogTable: CatalogTable,
inputPartitions: Seq[CatalogTablePartition],
predicates: Seq[Expression],
defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
if (predicates.isEmpty) {
inputPartitions
} else {
val partitionSchema = catalogTable.partitionSchema
val partitionColumnNames = catalogTable.partitionColumnNames.toSet

val nonPartitionPruningPredicates = predicates.filterNot {
_.references.map(_.name).toSet.subsetOf(partitionColumnNames)
}
if (nonPartitionPruningPredicates.nonEmpty) {
throw new AnalysisException("Expected only partition pruning predicates: " +
nonPartitionPruningPredicates)
}

val boundPredicate =
InterpretedPredicate.create(predicates.reduce(And).transform {
case att: AttributeReference =>
val index = partitionSchema.indexWhere(_.name == att.name)
BoundReference(index, partitionSchema(index).dataType, nullable = true)
})

inputPartitions.filter { p =>
boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId))
}
}
}
}

object CatalogUtils {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler}
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
import org.apache.commons.lang3.exception.ExceptionUtils
import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler}
import org.codehaus.janino.util.ClassFile
import scala.language.existentials

Expand Down Expand Up @@ -901,8 +904,14 @@ object CodeGenerator extends Logging {
/**
* Compile the Java source code into a Java class, using Janino.
*/
def compile(code: CodeAndComment): GeneratedClass = {
def compile(code: CodeAndComment): GeneratedClass = try {
cache.get(code)
} catch {
// Cache.get() may wrap the original exception. See the following URL
// http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/
// Cache.html#get(K,%20java.util.concurrent.Callable)
case e @ (_: UncheckedExecutionException | _: ExecutionError) =>
throw e.getCause
}

/**
Expand Down Expand Up @@ -950,10 +959,14 @@ object CodeGenerator extends Logging {
evaluator.cook("generated.java", code.body)
recordCompilationStats(evaluator)
} catch {
case e: Exception =>
case e: JaninoRuntimeException =>
val msg = s"failed to compile: $e\n$formatted"
logError(msg, e)
throw new Exception(msg, e)
throw new JaninoRuntimeException(msg, e)
case e: CompileException =>
val msg = s"failed to compile: $e\n$formatted"
logError(msg, e)
throw new CompileException(msg, e.getLocation)
}
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._


object InterpretedPredicate {
def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) =
def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate =
create(BindReferences.bindReference(expression, inputSchema))

def create(expression: Expression): (InternalRow => Boolean) = {
(r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
}
def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression)
}

case class InterpretedPredicate(expression: Expression) extends BasePredicate {
override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean]
}

/**
* An [[Expression]] that returns a boolean value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext

import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.JaninoRuntimeException

import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
Expand Down Expand Up @@ -353,9 +356,27 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination)
}

private def genInterpretedPredicate(
expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = {
val str = expression.toString
val logMessage = if (str.length > 256) {
str.substring(0, 256 - 3) + "..."
} else {
str
}
logWarning(s"Codegen disabled for this expression:\n $logMessage")
InterpretedPredicate.create(expression, inputSchema)
}

protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
GeneratePredicate.generate(expression, inputSchema)
try {
GeneratePredicate.generate(expression, inputSchema)
} catch {
case e @ (_: JaninoRuntimeException | _: CompileException)
if sqlContext == null || sqlContext.conf.wholeStageFallback =>
genInterpretedPredicate(expression, inputSchema)
}
}

protected def newOrdering(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ abstract class PartitioningAwareFileIndex(
})

val selected = partitions.filter {
case PartitionPath(values, _) => boundPredicate(values)
case PartitionPath(values, _) => boundPredicate.eval(values)
}
logInfo {
val total = partitions.length
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1765,4 +1765,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
.filter($"x1".isNotNull || !$"y".isin("a!"))
.count
}

test("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") {
val N = 400
val rows = Seq(Row.fromSeq(Seq.fill(N)("string")))
val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType)))
val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema)

val filter = (0 until N)
.foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string"))
df.filter(filter).count
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
// `Cast`ed values are always of internal types (e.g. UTF8String instead of String)
Cast(Literal(value), dataType).eval()
})
}.filter(predicate).map(projection)
}.filter(predicate.eval).map(projection)

// Appends partition values
val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
Expand Down

0 comments on commit afde7df

Please sign in to comment.