diff --git a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala index e930cb1bec..7f626bcf63 100644 --- a/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala +++ b/src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBooster.scala @@ -8,16 +8,18 @@ import com.microsoft.ml.spark.lightgbm.LightGBMUtils.getBoosterPtrFromModelStrin import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.sql.{SaveMode, SparkSession} -class BoosterHandler(model: String) extends Serializable { +protected class BoosterHandler(model: String) extends Serializable { LightGBMUtils.initializeNativeLibrary() + @transient lazy val boosterPtr: SWIGTYPE_p_void = { getBoosterPtrFromModelString(model) } lazy val rawScoreConstant = lightgbmlibConstants.C_API_PREDICT_RAW_SCORE lazy val normalScoreConstant = lightgbmlibConstants.C_API_PREDICT_RAW_SCORE + lazy val leafIndexPredictConstant = lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX lazy val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32 lazy val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64 @@ -113,7 +115,7 @@ class LightGBMBooster(val model: String) extends Serializable { } protected def predictScoreForMat(row: Array[Double], kind: Int, classification: Boolean): Array[Double] = { - val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64 + val data64bitType = boosterHandler.data64bitType val numCols = row.length val isRowMajor = 1 @@ -134,15 +136,15 @@ class LightGBMBooster(val model: String) extends Serializable { val numCols = sparseVector.size val datasetParams = "max_bin=255 is_pre_partition=True" - val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32 - val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64 + val dataInt32bitType = boosterHandler.dataInt32bitType + val data64bitType = boosterHandler.data64bitType LightGBMUtils.validate( lightgbmlib.LGBM_BoosterPredictForCSRSingle( sparseVector.indices, sparseVector.values, sparseVector.numNonzeros, boosterHandler.boosterPtr, dataInt32bitType, data64bitType, 2, numCols, - lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX, -1, datasetParams, + boosterHandler.leafIndexPredictConstant, -1, datasetParams, leafIndexDataLengthLongPtr, leafIndexDataOutPtr), "Booster Predict Leaf") predLeafToArray(leafIndexDataOutPtr) @@ -160,7 +162,7 @@ class LightGBMBooster(val model: String) extends Serializable { lightgbmlib.LGBM_BoosterPredictForMatSingle( row, boosterHandler.boosterPtr, data64bitType, numCols, - isRowMajor, lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX, + isRowMajor, boosterHandler.leafIndexPredictConstant, -1, datasetParams, leafIndexDataLengthLongPtr, leafIndexDataOutPtr), "Booster Predict Leaf")