Skip to content

Commit

Permalink
Code reorg and cleanup for SparkR linear SVM.
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed May 19, 2017
1 parent 5d2750a commit 1ed3ba0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
40 changes: 16 additions & 24 deletions R/pkg/R/mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"
#' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj"))

#' linear SVM Model
#' Linear SVM Model
#'
#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package
#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package.
#' Currently only supports binary classification model with linear kernal.
#' Users can print, make predictions on the produced model and save the model to the input path.
#'
#' @param data SparkDataFrame for training.
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param regParam The regularization parameter.
#' @param regParam The regularization parameter. Only supports L2 regularization currently.
#' @param maxIter Maximum iteration number.
#' @param tol Convergence tolerance of iterations.
#' @param standardization Whether to standardize the training features before fitting the model. The coefficients
Expand Down Expand Up @@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
new("LinearSVCModel", jobj = jobj)
})

# Predicted values based on an LinearSVCModel model
# Predicted values based on a linear SVM model.

#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns the predicted values based on an LinearSVCModel.
#' @return \code{predict} returns the predicted values based on a linear SVM model.
#' @rdname spark.svmLinear
#' @aliases predict,LinearSVCModel,SparkDataFrame-method
#' @export
Expand All @@ -124,39 +125,30 @@ setMethod("predict", signature(object = "LinearSVCModel"),
predict_internal(object, newData)
})

# Get the summary of an LinearSVCModel
# Get the summary of a linear SVM model.

#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}.
#' @param object a linear SVM model fitted by \code{spark.svmLinear}.
#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' The list includes \code{coefficients} (coefficients of the fitted model),
#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes),
#' \code{numFeatures} (number of features).
#' \code{numClasses} (number of classes), \code{numFeatures} (number of features).
#' @rdname spark.svmLinear
#' @aliases summary,LinearSVCModel-method
#' @export
#' @note summary(LinearSVCModel) since 2.2.0
setMethod("summary", signature(object = "LinearSVCModel"),
function(object) {
jobj <- object@jobj
features <- callJMethod(jobj, "features")
labels <- callJMethod(jobj, "labels")
coefficients <- callJMethod(jobj, "coefficients")
nCol <- length(coefficients) / length(features)
coefficients <- matrix(unlist(coefficients), ncol = nCol)
intercept <- callJMethod(jobj, "intercept")
features <- callJMethod(jobj, "rFeatures")
coefficients <- callJMethod(jobj, "rCoefficients")
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
numClasses <- callJMethod(jobj, "numClasses")
numFeatures <- callJMethod(jobj, "numFeatures")
if (nCol == 1) {
colnames(coefficients) <- c("Estimate")
} else {
colnames(coefficients) <- unlist(labels)
}
rownames(coefficients) <- unlist(features)
list(coefficients = coefficients, intercept = intercept,
numClasses = numClasses, numFeatures = numFeatures)
list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures)
})

# Save fitted LinearSVCModel to the input path
# Save fitted linear SVM model to the input path.

#' @param path The directory where the model is saved.
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
Expand Down
12 changes: 10 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ private[r] class LinearSVCWrapper private (
private val svcModel: LinearSVCModel =
pipeline.stages(1).asInstanceOf[LinearSVCModel]

lazy val coefficients: Array[Double] = svcModel.coefficients.toArray
lazy val rFeatures: Array[String] = if (svcModel.getFitIntercept) {
Array("(Intercept)") ++ features
} else {
features
}

lazy val intercept: Double = svcModel.intercept
lazy val rCoefficients: Array[Double] = if (svcModel.getFitIntercept) {
Array(svcModel.intercept) ++ svcModel.coefficients.toArray
} else {
svcModel.coefficients.toArray
}

lazy val numClasses: Int = svcModel.numClasses

Expand Down

0 comments on commit 1ed3ba0

Please sign in to comment.