Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20548][FLAKY-TEST] share one REPL instance among REPL test cases #17844

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object Main extends Logging {

if (!hasErrors) {
interp.process(settings) // Repl starts and goes in loop of R.E.P.L
Option(sparkContext).map(_.stop)
Option(sparkContext).foreach(_.stop)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,8 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
echo("Type :help for more information.")
}

/** Add repl commands that needs to be blocked. e.g. reset */
private val blockedCommands = Set[String]()

/** Standard commands */
lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] =
standardCommands.filter(cmd => !blockedCommands(cmd.name))

/** Available commands */
override def commands: List[LoopCommand] = sparkStandardCommands
override def commands: List[LoopCommand] = standardCommands

/**
* We override `loadFiles` because we need to initialize Spark *before* the REPL
Expand Down
272 changes: 2 additions & 270 deletions repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import java.io._
import java.net.URLClassLoader

import scala.collection.mutable.ArrayBuffer
import org.apache.commons.lang3.StringEscapeUtils

import org.apache.log4j.{Level, LogManager}

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.Utils

class ReplSuite extends SparkFunSuite {

Expand Down Expand Up @@ -148,71 +148,6 @@ class ReplSuite extends SparkFunSuite {
}
}

test("simple foreach with accumulator") {
val output = runInterpreter("local",
"""
|val accum = sc.longAccumulator
|sc.parallelize(1 to 10).foreach(x => accum.add(x))
|accum.value
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res1: Long = 55", output)
}

test("external vars") {
val output = runInterpreter("local",
"""
|var v = 7
|sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_)
|v = 10
|sc.parallelize(1 to 10).map(x => v).collect().reduceLeft(_+_)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}

test("external classes") {
val output = runInterpreter("local",
"""
|class C {
|def foo = 5
|}
|sc.parallelize(1 to 10).map(x => (new C).foo).collect().reduceLeft(_+_)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 50", output)
}

test("external functions") {
val output = runInterpreter("local",
"""
|def double(x: Int) = x + x
|sc.parallelize(1 to 10).map(x => double(x)).collect().reduceLeft(_+_)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 110", output)
}

test("external functions that access vars") {
val output = runInterpreter("local",
"""
|var v = 7
|def getV() = v
|sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_)
|v = 10
|sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}

test("broadcast vars") {
// Test that the value that a broadcast var had when it was created is used,
// even if that variable is then modified in the driver program
Expand All @@ -231,124 +166,6 @@ class ReplSuite extends SparkFunSuite {
assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
}

test("interacting with files") {
val tempDir = Utils.createTempDir()
val out = new FileWriter(tempDir + "/input")
out.write("Hello world!\n")
out.write("What's up?\n")
out.write("Goodbye\n")
out.close()
val output = runInterpreter("local",
"""
|var file = sc.textFile("%s").cache()
|file.count()
|file.count()
|file.count()
""".stripMargin.format(StringEscapeUtils.escapeJava(
tempDir.getAbsolutePath + File.separator + "input")))
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Long = 3", output)
assertContains("res1: Long = 3", output)
assertContains("res2: Long = 3", output)
Utils.deleteRecursively(tempDir)
}

test("local-cluster mode") {
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|var v = 7
|def getV() = v
|sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_)
|v = 10
|sc.parallelize(1 to 10).map(x => getV()).collect().reduceLeft(_+_)
|var array = new Array[Int](5)
|val broadcastArray = sc.broadcast(array)
|sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect()
|array(0) = 5
|sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
}

test("SPARK-1199 two instances of same class don't type check.") {
val output = runInterpreter("local",
"""
|case class Sum(exp: String, exp2: String)
|val a = Sum("A", "B")
|def b(a: Sum): String = a match { case Sum(_, _) => "Found Sum" }
|b(a)
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("SPARK-2452 compound statements.") {
val output = runInterpreter("local",
"""
|val x = 4 ; def f() = x
|f()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("SPARK-2576 importing implicits") {
// We need to use local-cluster to test this case.
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|import spark.implicits._
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
|
|// Test Dataset Serialization in the REPL
|Seq(TestCaseClass(1)).toDS().collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("Datasets and encoders") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.{Encoder, Encoders}
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
| def bufferEncoder: Encoder[Int] = Encoders.scalaInt
| def outputEncoder: Encoder[Int] = Encoders.scalaInt
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
|ds.select(simpleSum).collect
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("SPARK-2632 importing a method from non serializable class and not using it.") {
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|class TestClass() { def testMethod = 3 }
|val t = new TestClass
|import t.testMethod
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

if (System.getenv("MESOS_NATIVE_JAVA_LIBRARY") != null) {
test("running on Mesos") {
val output = runInterpreter("localquiet",
Expand All @@ -373,52 +190,6 @@ class ReplSuite extends SparkFunSuite {
}
}

test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
|case class Foo(i: Int)
|val ret = sc.parallelize((1 to 100).map(Foo), 10).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("ret: Array[Foo] = Array(Foo(1),", output)
}

test("collecting objects of class defined in repl - shuffling") {
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|case class Foo(i: Int)
|val list = List((1, Foo(1)), (1, Foo(2)))
|val ret = sc.parallelize(list).groupByKey().collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
}

test("replicating blocks of object with class defined in repl") {
val output = runInterpreter("local-cluster[2,1,1024]",
"""
|val timeout = 60000 // 60 seconds
|val start = System.currentTimeMillis
|while(sc.getExecutorStorageStatus.size != 3 &&
| (System.currentTimeMillis - start) < timeout) {
| Thread.sleep(10)
|}
|if (System.currentTimeMillis - start >= timeout) {
| throw new java.util.concurrent.TimeoutException("Executors were not up in 60 seconds")
|}
|import org.apache.spark.storage.StorageLevel._
|case class Foo(i: Int)
|val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2)
|ret.count()
|sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains(": Int = 20", output)
}

test("line wrapper only initialized once when used as encoder outer scope") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you don't move this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test only works in local mode...

val output = runInterpreter("local",
"""
Expand Down Expand Up @@ -446,43 +217,4 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("should clone and clean line object in ClosureCleaner") {
val output = runInterpreterInPasteMode("local-cluster[1,4,4096]",
"""
|import org.apache.spark.rdd.RDD
|
|val lines = sc.textFile("pom.xml")
|case class Data(s: String)
|val dataRDD = lines.map(line => Data(line.take(3)))
|dataRDD.cache.count
|val repartitioned = dataRDD.repartition(dataRDD.partitions.size)
|repartitioned.cache.count
|
|def getCacheSize(rdd: RDD[_]) = {
| sc.getRDDStorageInfo.filter(_.id == rdd.id).map(_.memSize).sum
|}
|val cacheSize1 = getCacheSize(dataRDD)
|val cacheSize2 = getCacheSize(repartitioned)
|
|// The cache size of dataRDD and the repartitioned one should be similar.
|val deviation = math.abs(cacheSize2 - cacheSize1).toDouble / cacheSize1
|assert(deviation < 0.2,
| s"deviation too large: $deviation, first size: $cacheSize1, second size: $cacheSize2")
""".stripMargin)
assertDoesNotContain("AssertionError", output)
assertDoesNotContain("Exception", output)
}

// TODO: [SPARK-20548] Fix and re-enable
ignore("newProductSeqEncoder with REPL defined class") {
val output = runInterpreterInPasteMode("local-cluster[1,4,4096]",
"""
|case class Click(id: Int)
|spark.implicits.newProductSeqEncoder[Click]
""".stripMargin)

assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}
}
Loading