Skip to content

Commit

Permalink
[ML-187] Support spark 3.1.3 and 3.2.0 and support CDH (#197)
Browse files Browse the repository at this point in the history
* support spark 3.1.3

Signed-off-by: minmingzhu <minming.zhu@intel.com>

* support spark 3.2.1

Signed-off-by: minmingzhu <minming.zhu@intel.com>

* support CDH spark

Signed-off-by: minmingzhu <minming.zhu@intel.com>

* update

Signed-off-by: minmingzhu <minming.zhu@intel.com>
  • Loading branch information
minmingzhu authored Apr 7, 2022
1 parent 0f2f4bc commit a249d28
Show file tree
Hide file tree
Showing 19 changed files with 61 additions and 38 deletions.
15 changes: 13 additions & 2 deletions mllib-dal/src/main/scala/com/intel/oap/mllib/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package com.intel.oap.mllib
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}

import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
import java.net.InetAddress

object Utils {
Expand Down Expand Up @@ -155,4 +154,16 @@ object Utils {
// Return executor number (exclude driver)
executorInfos.length - 1
}
def getSparkVersion(): String = {
// For example: CHD spark version is 3.1.1.3.1.7290.5-2.
// The string before the third dot is the spark version.
val array = SPARK_VERSION.split("\\.")
val sparkVersion = if (array.size > 3) {
val version = array.take(3).mkString(".")
version
} else {
SPARK_VERSION
}
sparkVersion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.intel.oap.mllib.classification

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.NaiveBayesModel
import org.apache.spark.ml.classification.spark320.{NaiveBayes => NaiveBayesSpark320}
import org.apache.spark.ml.classification.spark321.{NaiveBayes => NaiveBayesSpark321}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.{SPARK_VERSION, SparkException}
Expand All @@ -31,8 +33,9 @@ trait NaiveBayesShim extends Logging {
object NaiveBayesShim extends Logging {
def create(uid: String): NaiveBayesShim = {
logInfo(s"Loading NaiveBayes for Spark $SPARK_VERSION")
val shim = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new NaiveBayesSpark320(uid)

val shim = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new NaiveBayesSpark321(uid)
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
shim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.intel.oap.mllib.clustering

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.clustering.spark320.{KMeans => KMeansSpark320}
import org.apache.spark.ml.clustering.spark321.{KMeans => KMeansSpark321}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.{SPARK_VERSION, SparkException}
Expand All @@ -31,8 +33,8 @@ trait KMeansShim extends Logging {
object KMeansShim extends Logging {
def create(uid: String): KMeansShim = {
logInfo(s"Loading KMeans for Spark $SPARK_VERSION")
val kmeans = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new KMeansSpark320(uid)
val kmeans = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new KMeansSpark321(uid)
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
kmeans
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.intel.oap.mllib.feature

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.PCAModel
import org.apache.spark.ml.feature.spark320.{PCA => PCASpark320}
import org.apache.spark.ml.feature.spark321.{PCA => PCASpark321}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.{SPARK_VERSION, SparkException}
Expand All @@ -31,8 +33,8 @@ trait PCAShim extends Logging {
object PCAShim extends Logging {
def create(uid: String): PCAShim = {
logInfo(s"Loading PCA for Spark $SPARK_VERSION")
val pca = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new PCASpark320(uid)
val pca = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new PCASpark321(uid)
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
pca
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package com.intel.oap.mllib.recommendation

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.recommendation.ALS.Rating
import org.apache.spark.ml.recommendation.spark312.{ALS => ALSSpark312}
import org.apache.spark.ml.recommendation.spark320.{ALS => ALSSpark320}
import org.apache.spark.ml.recommendation.spark313.{ALS => ALSSpark313}
import org.apache.spark.ml.recommendation.spark321.{ALS => ALSSpark321}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{SPARK_VERSION, SparkException}
Expand All @@ -46,9 +48,9 @@ trait ALSShim extends Serializable with Logging {
object ALSShim extends Logging {
def create(): ALSShim = {
logInfo(s"Loading ALS for Spark $SPARK_VERSION")
val als = SPARK_VERSION match {
case "3.1.1" | "3.1.2" => new ALSSpark312()
case "3.2.0" => new ALSSpark320()
val als = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" => new ALSSpark313()
case "3.2.0" | "3.2.1" => new ALSSpark321()
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
als
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package com.intel.oap.mllib.regression

import com.intel.oap.mllib.Utils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.LinearRegressionModel
import org.apache.spark.ml.regression.spark312.{LinearRegression => LinearRegressionSpark312}
import org.apache.spark.ml.regression.spark320.{LinearRegression => LinearRegressionSpark320}
import org.apache.spark.ml.regression.spark313.{LinearRegression => LinearRegressionSpark313}
import org.apache.spark.ml.regression.spark321.{LinearRegression => LinearRegressionSpark321}
import org.apache.spark.sql.Dataset
import org.apache.spark.{SPARK_VERSION, SparkException}

Expand All @@ -32,9 +34,9 @@ trait LinearRegressionShim extends Serializable with Logging {
object LinearRegressionShim extends Logging {
def create(uid: String): LinearRegressionShim = {
logInfo(s"Loading ALS for Spark $SPARK_VERSION")
val linearRegression = SPARK_VERSION match {
case "3.1.1" | "3.1.2" => new LinearRegressionSpark312(uid)
case "3.2.0" => new LinearRegressionSpark320(uid)
val linearRegression = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" => new LinearRegressionSpark313(uid)
case "3.2.0" | "3.2.1" => new LinearRegressionSpark321(uid)
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
linearRegression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.intel.oap.mllib.stat

import com.intel.oap.mllib.Utils
import org.apache.spark.{SPARK_VERSION, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.recommendation.ALS.Rating
Expand All @@ -24,8 +25,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.storage.StorageLevel

import scala.reflect.ClassTag

import org.apache.spark.ml.stat.spark320.{Correlation => CorrelationSpark320 }
import org.apache.spark.ml.stat.spark321.{Correlation => CorrelationSpark321}

trait CorrelationShim extends Serializable with Logging {
def corr(dataset: Dataset[_], column: String, method: String): DataFrame
Expand All @@ -34,8 +34,8 @@ trait CorrelationShim extends Serializable with Logging {
object CorrelationShim extends Logging {
def create(): CorrelationShim = {
logInfo(s"Loading Correlation for Spark $SPARK_VERSION")
val als = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new CorrelationSpark320()
val als = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new CorrelationSpark321()
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
als
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

package com.intel.oap.mllib.stat

import com.intel.oap.mllib.Utils

import org.apache.spark.{SPARK_VERSION, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}

import org.apache.spark.mllib.stat.spark320.{Statistics => SummarizerSpark320 }
import org.apache.spark.mllib.stat.spark321.{Statistics => SummarizerSpark321}

trait SummarizerShim extends Serializable with Logging {
def colStats(X: RDD[Vector]): MultivariateStatisticalSummary
Expand All @@ -33,8 +34,8 @@ trait SummarizerShim extends Serializable with Logging {
object SummarizerShim extends Logging {
def create(): SummarizerShim = {
logInfo(s"Loading Summarizer for Spark $SPARK_VERSION")
val summarizer = SPARK_VERSION match {
case "3.1.1" | "3.1.2" | "3.2.0" => new SummarizerSpark320()
val summarizer = Utils.getSparkVersion() match {
case "3.1.1" | "3.1.2" | "3.1.3" | "3.2.0" | "3.2.1" => new SummarizerSpark321()
case _ => throw new SparkException(s"Unsupported Spark version $SPARK_VERSION")
}
summarizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.classification.spark320
package org.apache.spark.ml.classification.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.classification.{NaiveBayesDALImpl, NaiveBayesShim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.clustering.spark320
package org.apache.spark.ml.clustering.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.clustering.{KMeansDALImpl, KMeansShim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.feature.spark320
package org.apache.spark.ml.feature.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.feature.{PCADALImpl, PCAShim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.recommendation.spark312
package org.apache.spark.ml.recommendation.spark313

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.intel.oap.mllib.{Utils => DALUtils}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.recommendation.spark320
package org.apache.spark.ml.recommendation.spark321

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.intel.oap.mllib.{Utils => DALUtils}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.regression.spark312
package org.apache.spark.ml.regression.spark313

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.regression.spark320
package org.apache.spark.ml.regression.spark321

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
// scalastyle:on

package org.apache.spark.ml.stat.spark320
package org.apache.spark.ml.stat.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.stat.{CorrelationDALImpl, CorrelationShim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.stat.spark320
package org.apache.spark.mllib.stat.spark321

import com.intel.oap.mllib.Utils
import com.intel.oap.mllib.stat.{SummarizerDALImpl, SummarizerShim}
Expand Down

0 comments on commit a249d28

Please sign in to comment.