diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/Utils.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/Utils.scala index d3e3308e3..de1beaea8 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/Utils.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/Utils.scala @@ -28,18 +28,24 @@ object Utils { def isOAPEnabled(): Boolean = { val sc = SparkSession.active.sparkContext - return sc.getConf.getBoolean("spark.oap.mllib.enabled", true) + val dynamic = sc.getConf.getBoolean("spark.dynamicAllocation.enabled", false) + val isoap = sc.getConf.getBoolean("spark.oap.mllib.enabled", true) + if (isoap && dynamic) { + throw new Exception( + s"OAP MLlib does not support dynamic allocation, " + + s"spark.dynamicAllocation.enabled should be set to false") + } + isoap } def getOneCCLIPPort(data: RDD[_]): String = { val executorIPAddress = Utils.sparkFirstExecutorIP(data.sparkContext) - val kvsIP = data.sparkContext.getConf.get("spark.oap.mllib.oneccl.kvs.ip", - executorIPAddress) + val kvsIP = data.sparkContext.getConf.get("spark.oap.mllib.oneccl.kvs.ip", executorIPAddress) // TODO: right now we use a configured port, will optimize auto port detection // val kvsPortDetected = Utils.checkExecutorAvailPort(data, kvsIP) val kvsPortDetected = 3000 - val kvsPort = data.sparkContext.getConf.getInt("spark.oap.mllib.oneccl.kvs.port", - kvsPortDetected) + val kvsPort = + data.sparkContext.getConf.getInt("spark.oap.mllib.oneccl.kvs.port", kvsPortDetected) kvsIP + "_" + kvsPort } @@ -96,20 +102,23 @@ object Utils { def checkExecutorAvailPort(data: RDD[_], localIP: String): Int = { if (localIP == "127.0.0.1" || localIP == "127.0.1.1") { - println(s"\nOneCCL: Error: doesn't support loopback IP ${localIP}, " + - s"please assign IP address to your host.\n") + println( + s"\nOneCCL: Error: doesn't support loopback IP ${localIP}, " + + s"please assign IP address to your host.\n") System.exit(-1) } val sc = data.sparkContext - val result = data.mapPartitions { p => - val port = OneCCL.getAvailPort(localIP) - if (port != -1) { - Iterator(port) - } else { - Iterator() + val result = data + .mapPartitions { p => + val port = OneCCL.getAvailPort(localIP) + if (port != -1) { + Iterator(port) + } else { + Iterator() + } } - }.collect() + .collect() result(0) } @@ -127,9 +136,11 @@ object Utils { // check workers' platform compatibility val executor_num = Utils.sparkExecutorNum(sc) val data = sc.parallelize(1 to executor_num, executor_num) - val result = data.mapPartitions { p => - Iterator(OneDAL.cCheckPlatformCompatibility()) - }.collect() + val result = data + .mapPartitions { p => + Iterator(OneDAL.cCheckPlatformCompatibility()) + } + .collect() result.forall(_ == true) } diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 74de47a85..e2c61ef35 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -1,3 +1,4 @@ +// scalastyle:off /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -14,10 +15,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// scalastyle:on package org.apache.spark.ml.classification import com.intel.oap.mllib.classification.NaiveBayesShim + import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamMap