Skip to content

Commit

Permalink
rename LabelParser.apply to LabelParser.parse
Browse files Browse the repository at this point in the history
use extends for singleton
  • Loading branch information
mengxr committed Apr 7, 2014
1 parent 11c94e0 commit c2e571c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.util
/** Trait for label parsers. */
trait LabelParser extends Serializable {
/** Parses a string label into a double label. */
def apply(labelString: String): Double
def parse(labelString: String): Double
}

/**
Expand All @@ -32,24 +32,22 @@ class BinaryLabelParser extends LabelParser {
* Parses the input label into positive (1.0) if the value is greater than 0.5,
* or negative (0.0) otherwise.
*/
override def apply(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0
override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0
}

object BinaryLabelParser {
private lazy val instance = new BinaryLabelParser()
object BinaryLabelParser extends BinaryLabelParser {
/** Gets the default instance of BinaryLabelParser. */
def apply(): BinaryLabelParser = instance
def getInstance(): BinaryLabelParser = this
}

/**
* Label parser for multiclass labels, which converts the input label to double.
*/
class MulticlassLabelParser extends LabelParser {
override def apply(labelString: String): Double = labelString.toDouble
override def parse(labelString: String): Double = labelString.toDouble
}

object MulticlassLabelParser {
private lazy val instance = new MulticlassLabelParser()
object MulticlassLabelParser extends MulticlassLabelParser {
/** Gets the default instance of MulticlassLabelParser. */
def apply(): MulticlassLabelParser = instance
def getInstance(): MulticlassLabelParser = this
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ object MLUtils {
}.reduce(math.max)
}
parsed.map { items =>
val label = labelParser(items.head)
val label = labelParser.parse(items.head)
val (indices, values) = items.tail.map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1
Expand All @@ -96,7 +96,7 @@ object MLUtils {
* with number of features determined automatically and the default number of partitions.
*/
def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
loadLibSVMData(sc, path, BinaryLabelParser(), -1, sc.defaultMinSplits)
loadLibSVMData(sc, path, BinaryLabelParser, -1, sc.defaultMinSplits)

/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString

val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser(), 6).collect()
val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect()
val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()

for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
Expand All @@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
}

val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser()).collect()
val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect()
assert(multiclassPoints.length === 3)
assert(multiclassPoints(0).label === 1.0)
assert(multiclassPoints(1).label === -1.0)
Expand Down

0 comments on commit c2e571c

Please sign in to comment.