Skip to content

Commit

Permalink
[SPARK-20548][FLAKY-TEST] share one REPL instance among REPL test cases
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

`ReplSuite.newProductSeqEncoder with REPL defined class` was flaky and throws OOM exception frequently. By analyzing the heap dump, we found the reason is that, in each test case of `ReplSuite`, we create a REPL instance, which creates a classloader and loads a lot of classes related to `SparkContext`. More details please see #17833 (comment).

In this PR, we create a new test suite, `SingletonReplSuite`, which shares one REPL instances among all the test cases. Then we move most of the tests from `ReplSuite` to `SingletonReplSuite`, to avoid creating a lot of REPL instances and reduce memory footprint.

## How was this patch tested?

test only change

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17844 from cloud-fan/flaky-test.
  • Loading branch information
cloud-fan committed May 9, 2017
1 parent 181261a commit f561a76
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 279 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object Main extends Logging {

if (!hasErrors) {
interp.process(settings) // Repl starts and goes in loop of R.E.P.L
Option(sparkContext).map(_.stop)
Option(sparkContext).foreach(_.stop)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,8 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
echo("Type :help for more information.")
}

/** Add repl commands that needs to be blocked. e.g. reset */
private val blockedCommands = Set[String]()

/** Standard commands */
lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] =
standardCommands.filter(cmd => !blockedCommands(cmd.name))

/** Available commands */
override def commands: List[LoopCommand] = sparkStandardCommands
override def commands: List[LoopCommand] = standardCommands

/**
* We override `loadFiles` because we need to initialize Spark *before* the REPL
Expand Down
272 changes: 2 additions & 270 deletions repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import java.io._
import java.net.URLClassLoader

import scala.collection.mutable.ArrayBuffer
import org.apache.commons.lang3.StringEscapeUtils

import org.apache.log4j.{Level, LogManager}

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.Utils

class ReplSuite extends SparkFunSuite {

Expand Down Expand Up @@ -148,71 +148,6 @@ class ReplSuite extends SparkFunSuite {
}
}

test("simple foreach with accumulator") {
val output = runInterpreter("local",
"""
|val accum = sc.longAccumulator
|sc.parallelize(1 to 10).foreach(x => accum.add(x))
|accum.value
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res1: Long = 55", output)
}

test("external vars") {
val output = runInterpreter("local",
"""
|var v = 7
|sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_)
|v = 10
|sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}

test("external classes") {
val output = runInterpreter("local",
"""
|class C {
|def foo = 5
|}
|sc.parallelize(1 to 10).map(x => (new C).foo).collect().reduceLeft(_+_)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 50", output)
}

test("external functions") {
val output = runInterpreter("local",
"""
|def double(x: Int) = x + x
|sc.parallelize(1 to 10).map(x => double(x)).collect().reduceLeft(_+_)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 110", output)
}

test("external functions that access vars") {
val output = runInterpreter("local",
"""
|var v = 7
|def getV() = v
|sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_)
|v = 10
|sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}

test("broadcast vars") {
// Test that the value that a broadcast var had when it was created is used,
// even if that variable is then modified in the driver program
Expand All @@ -231,124 +166,6 @@ class ReplSuite extends SparkFunSuite {
assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
}

test("interacting with files") {
val tempDir = Utils.createTempDir()
val out = new FileWriter(tempDir + "/input")
out.write("Hello world!\n")
out.write("What's up?\n")
out.write("Goodbye\n")
out.close()
val output = runInterpreter("local",
"""
|var file = sc.textFile("%s").cache()
|file.count()
|file.count()
|file.count()
""".stripMargin.format(StringEscapeUtils.escapeJava(
tempDir.getAbsolutePath + File.separator + "input")))
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Long = 3", output)
assertContains("res1: Long = 3", output)
assertContains("res2: Long = 3", output)
Utils.deleteRecursively(tempDir)
}

test("local-cluster mode") {
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|var v = 7
|def getV() = v
|sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_)
|v = 10
|sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_)
|var array = new Array[Int](5)
|val broadcastArray = sc.broadcast(array)
|sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect()
|array(0) = 5
|sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
}

test("SPARK-1199 two instances of same class don't type check.") {
val output = runInterpreter("local",
"""
|case class Sum(exp: String, exp2: String)
|val a = Sum("A", "B")
|def b(a: Sum): String = a match { case Sum(_, _) => "Found Sum" }
|b(a)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("SPARK-2452 compound statements.") {
val output = runInterpreter("local",
"""
|val x = 4 ; def f() = x
|f()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("SPARK-2576 importing implicits") {
// We need to use local-cluster to test this case.
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|import spark.implicits._
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
|
|// Test Dataset Serialization in the REPL
|Seq(TestCaseClass(1)).toDS().collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("Datasets and encoders") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.{Encoder, Encoders}
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
| def bufferEncoder: Encoder[Int] = Encoders.scalaInt
| def outputEncoder: Encoder[Int] = Encoders.scalaInt
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
|ds.select(simpleSum).collect
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("SPARK-2632 importing a method from non serializable class and not using it.") {
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|class TestClass() { def testMethod = 3 }
|val t = new TestClass
|import t.testMethod
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

if (System.getenv("MESOS_NATIVE_JAVA_LIBRARY") != null) {
test("running on Mesos") {
val output = runInterpreter("localquiet",
Expand All @@ -373,52 +190,6 @@ class ReplSuite extends SparkFunSuite {
}
}

test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
|case class Foo(i: Int)
|val ret = sc.parallelize((1 to 100).map(Foo), 10).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("ret: Array[Foo] = Array(Foo(1),", output)
}

test("collecting objects of class defined in repl - shuffling") {
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|case class Foo(i: Int)
|val list = List((1, Foo(1)), (1, Foo(2)))
|val ret = sc.parallelize(list).groupByKey().collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
}

test("replicating blocks of object with class defined in repl") {
val output = runInterpreter("local-cluster[2,1,1024]",
"""
|val timeout = 60000 // 60 seconds
|val start = System.currentTimeMillis
|while(sc.getExecutorStorageStatus.size != 3 &&
| (System.currentTimeMillis - start) < timeout) {
| Thread.sleep(10)
|}
|if (System.currentTimeMillis - start >= timeout) {
| throw new java.util.concurrent.TimeoutException("Executors were not up in 60 seconds")
|}
|import org.apache.spark.storage.StorageLevel._
|case class Foo(i: Int)
|val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2)
|ret.count()
|sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains(": Int = 20", output)
}

test("line wrapper only initialized once when used as encoder outer scope") {
val output = runInterpreter("local",
"""
Expand Down Expand Up @@ -446,43 +217,4 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("should clone and clean line object in ClosureCleaner") {
val output = runInterpreterInPasteMode("local-cluster[1,4,4096]",
"""
|import org.apache.spark.rdd.RDD
|
|val lines = sc.textFile("pom.xml")
|case class Data(s: String)
|val dataRDD = lines.map(line => Data(line.take(3)))
|dataRDD.cache.count
|val repartitioned = dataRDD.repartition(dataRDD.partitions.size)
|repartitioned.cache.count
|
|def getCacheSize(rdd: RDD[_]) = {
| sc.getRDDStorageInfo.filter(_.id == rdd.id).map(_.memSize).sum
|}
|val cacheSize1 = getCacheSize(dataRDD)
|val cacheSize2 = getCacheSize(repartitioned)
|
|// The cache size of dataRDD and the repartitioned one should be similar.
|val deviation = math.abs(cacheSize2 - cacheSize1).toDouble / cacheSize1
|assert(deviation < 0.2,
| s"deviation too large: $deviation, first size: $cacheSize1, second size: $cacheSize2")
""".stripMargin)
assertDoesNotContain("AssertionError", output)
assertDoesNotContain("Exception", output)
}

// TODO: [SPARK-20548] Fix and re-enable
ignore("newProductSeqEncoder with REPL defined class") {
val output = runInterpreterInPasteMode("local-cluster[1,4,4096]",
"""
|case class Click(id: Int)
|spark.implicits.newProductSeqEncoder[Click]
""".stripMargin)

assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}
}
Loading

0 comments on commit f561a76

Please sign in to comment.