Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark 2557] fix LOCAL_N_REGEX in createTaskScheduler and make local-n and local-n-failures consistent #1464

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1451,9 +1451,9 @@ object SparkContext extends Logging {
/** Creates a task scheduler based on a given master URL. Extracted for testing. */
private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = {
// Regular expression used for local[N] and local[*] master formats
val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
Expand Down Expand Up @@ -1483,8 +1483,12 @@ object SparkContext extends Logging {
scheduler

case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
def localCpuCount = Runtime.getRuntime.availableProcessors()
// local[*, M] means the number of cores on the computer with M failures
// local[N, M] means exactly N threads with M failures
val threadCount = if (threads == "*") localCpuCount else threads.toInt
val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
val backend = new LocalBackend(scheduler, threads.toInt)
val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
scheduler

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ class SparkContextSchedulerCreationSuite
}
}

test("local-*-n-failures") {
val sched = createTaskScheduler("local[* ,2]")
assert(sched.maxTaskFailures === 2)
sched.backend match {
case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors())
case _ => fail()
}
}

test("local-n-failures") {
val sched = createTaskScheduler("local[4, 2]")
assert(sched.maxTaskFailures === 2)
Expand All @@ -77,6 +86,20 @@ class SparkContextSchedulerCreationSuite
}
}

test("bad-local-n") {
val e = intercept[SparkException] {
createTaskScheduler("local[2*]")
}
assert(e.getMessage.contains("Could not parse Master URL"))
}

test("bad-local-n-failures") {
val e = intercept[SparkException] {
createTaskScheduler("local[2*,4]")
}
assert(e.getMessage.contains("Could not parse Master URL"))
}

test("local-default-parallelism") {
val defaultParallelism = System.getProperty("spark.default.parallelism")
System.setProperty("spark.default.parallelism", "16")
Expand Down