Skip to content

Commit

Permalink
Sparknlp 868 make spark driver cores override local in start functions (
Browse files Browse the repository at this point in the history
#13894)

* Allow users to change Driver's cores at start

* Add unit test for driver cores in start function
  • Loading branch information
maziyarpanahi authored Jul 18, 2023
1 parent 31e5bde commit 893b693
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
13 changes: 12 additions & 1 deletion python/sparknlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,24 @@ def start(gpu=False,

if params is None:
params = {}
else:
if not isinstance(params, dict):
raise TypeError('params must be a dictionary like {"spark.executor.memory": "8G"}')

if '_instantiatedSession' in dir(SparkSession) and SparkSession._instantiatedSession is not None:
print('Warning::Spark Session already created, some configs may not take.')

driver_cores = "*"
for key, value in params.items():
if key == "spark.driver.cores":
driver_cores = f"{value}"
else:
driver_cores = "*"

class SparkNLPConfig:

def __init__(self):
self.master, self.app_name = "local[*]", "Spark NLP"
self.master, self.app_name = "local[{}]".format(driver_cores), "Spark NLP"
self.serializer, self.serializer_max_buffer = "org.apache.spark.serializer.KryoSerializer", "2000M"
self.driver_max_result_size = "0"
# Spark NLP on CPU or GPU
Expand Down
8 changes: 7 additions & 1 deletion src/main/scala/com/johnsnowlabs/nlp/SparkNLP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,18 @@ object SparkNLP {
val builder = SparkSession
.builder()
.appName("Spark NLP")
.master("local[*]")
.config("spark.driver.memory", memory)
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.kryoserializer.buffer.max", "2000M")
.config("spark.driver.maxResultSize", "0")

// get the set cores by users since local[*] will override spark.driver.cores if set
if (params.contains("spark.driver.cores")) {
builder.master("local[" + params("spark.driver.cores") + "]")
} else {
builder.master("local[*]")
}

val sparkNlpJar =
if (apple_silicon) MavenSparkSilicon
else if (aarch64) MavenSparkAarch64
Expand Down
8 changes: 5 additions & 3 deletions src/test/scala/com/johnsnowlabs/nlp/SparkNLPTestSpec.scala
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
package com.johnsnowlabs.nlp

import com.johnsnowlabs.tags.SlowTest
import com.johnsnowlabs.tags.FastTest
import com.johnsnowlabs.util.ConfigHelper.{awsJavaSdkVersion, hadoopAwsVersion}
import org.scalatest.flatspec.AnyFlatSpec

class SparkNLPTestSpec extends AnyFlatSpec {

behavior of "SparkNLPTestSpec"

it should "start with extra parameters" taggedAs SlowTest ignore {
it should "start with extra parameters" taggedAs FastTest in {
val extraParams: Map[String, String] = Map(
"spark.jars.packages" -> ("org.apache.hadoop:hadoop-aws:" + hadoopAwsVersion + ",com.amazonaws:aws-java-sdk:" + awsJavaSdkVersion),
"spark.hadoop.fs.s3a.path.style.access" -> "true")
"spark.hadoop.fs.s3a.path.style.access" -> "true",
"spark.driver.cores" -> "2")

val spark = SparkNLP.start(params = extraParams)

assert(spark.conf.get("spark.hadoop.fs.s3a.path.style.access") == "true")
assert(spark.conf.get("spark.master") == "local[2]")

Seq(
"com.johnsnowlabs.nlp:spark-nlp",
Expand Down

0 comments on commit 893b693

Please sign in to comment.