Skip to content

Commit

Permalink
Forgot some lightGBMlibConstants
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jan 31, 2020
1 parent 37da33b commit 13059a7
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ import com.microsoft.ml.spark.lightgbm.LightGBMUtils.getBoosterPtrFromModelStrin
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.sql.{SaveMode, SparkSession}

class BoosterHandler(model: String) extends Serializable {
protected class BoosterHandler(model: String) extends Serializable {

LightGBMUtils.initializeNativeLibrary()

@transient
lazy val boosterPtr: SWIGTYPE_p_void = {
getBoosterPtrFromModelString(model)
}

lazy val rawScoreConstant = lightgbmlibConstants.C_API_PREDICT_RAW_SCORE
lazy val normalScoreConstant = lightgbmlibConstants.C_API_PREDICT_RAW_SCORE
lazy val leafIndexPredictConstant = lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX

lazy val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32
lazy val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
Expand Down Expand Up @@ -113,7 +115,7 @@ class LightGBMBooster(val model: String) extends Serializable {
}

protected def predictScoreForMat(row: Array[Double], kind: Int, classification: Boolean): Array[Double] = {
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
val data64bitType = boosterHandler.data64bitType

val numCols = row.length
val isRowMajor = 1
Expand All @@ -134,15 +136,15 @@ class LightGBMBooster(val model: String) extends Serializable {
val numCols = sparseVector.size

val datasetParams = "max_bin=255 is_pre_partition=True"
val dataInt32bitType = lightgbmlibConstants.C_API_DTYPE_INT32
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
val dataInt32bitType = boosterHandler.dataInt32bitType
val data64bitType = boosterHandler.data64bitType

LightGBMUtils.validate(
lightgbmlib.LGBM_BoosterPredictForCSRSingle(
sparseVector.indices, sparseVector.values,
sparseVector.numNonzeros,
boosterHandler.boosterPtr, dataInt32bitType, data64bitType, 2, numCols,
lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX, -1, datasetParams,
boosterHandler.leafIndexPredictConstant, -1, datasetParams,
leafIndexDataLengthLongPtr, leafIndexDataOutPtr), "Booster Predict Leaf")

predLeafToArray(leafIndexDataOutPtr)
Expand All @@ -160,7 +162,7 @@ class LightGBMBooster(val model: String) extends Serializable {
lightgbmlib.LGBM_BoosterPredictForMatSingle(
row, boosterHandler.boosterPtr, data64bitType,
numCols,
isRowMajor, lightgbmlibConstants.C_API_PREDICT_LEAF_INDEX,
isRowMajor, boosterHandler.leafIndexPredictConstant,
-1, datasetParams, leafIndexDataLengthLongPtr, leafIndexDataOutPtr),
"Booster Predict Leaf")

Expand Down

0 comments on commit 13059a7

Please sign in to comment.