Skip to content

Commit

Permalink
Addressing reviewers comments @mengxr
Browse files Browse the repository at this point in the history
  • Loading branch information
avulanov committed Feb 2, 2015
1 parent a6ad82a commit 755d358
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,42 @@

package org.apache.spark.mllib.feature

import scala.collection.mutable.ArrayBuilder

import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD

import scala.collection.mutable.ArrayBuilder

/**
* :: Experimental ::
* Chi Squared selector model.
*
* @param indices list of indices to select (filter). Must be ordered asc
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
*/
@Experimental
class ChiSqSelectorModel private[mllib] (indices: Array[Int]) extends VectorTransformer {
class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer {

require(isSorted(selectedFeatures), "Array has to be sorted asc")

protected def isSorted(array: Array[Int]): Boolean = {
var i = 1
while (i < array.length) {
if (array(i) < array(i-1)) return false
i += 1
}
true
}

/**
* Applies transformation on a vector.
*
* @param vector vector to be transformed.
* @return transformed vector.
*/
override def transform(vector: Vector): Vector = {
compress(vector, indices)
compress(vector, selectedFeatures)
}

/**
Expand All @@ -56,23 +68,27 @@ class ChiSqSelectorModel private[mllib] (indices: Array[Int]) extends VectorTra
val newSize = filterIndices.length
val newValues = new ArrayBuilder.ofDouble
val newIndices = new ArrayBuilder.ofInt
var i: Int = 0
var j: Int = 0
while(i < indices.length && j < filterIndices.length) {
if(indices(i) == filterIndices(j)) {
var i = 0
var j = 0
var indicesIdx = 0
var filterIndicesIdx = 0
while (i < indices.length && j < filterIndices.length) {
indicesIdx = indices(i)
filterIndicesIdx = filterIndices(j)
if (indicesIdx == filterIndicesIdx) {
newIndices += j
newValues += values(i)
j += 1
i += 1
} else {
if(indices(i) > filterIndices(j)) {
if (indicesIdx > filterIndicesIdx) {
j += 1
} else {
i += 1
}
}
}
/** Sparse representation might be ineffective if (newSize ~= newValues.size) */
// TODO: Sparse representation might be ineffective if (newSize ~= newValues.size)
Vectors.sparse(newSize, newIndices.result(), newValues.result())
case DenseVector(values) =>
val values = features.toArray
Expand All @@ -96,13 +112,15 @@ class ChiSqSelector (val numTopFeatures: Int) {
/**
* Returns a ChiSquared feature selector.
*
* @param data data used to compute the Chi Squared statistic.
* @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features.
* Real-valued features will be treated as categorical for each distinct value.
* Apply feature discretizer before using this function.
*/
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val indices = Statistics.chiSqTest(data)
.zipWithIndex.sortBy { case(res, _) => -res.statistic }
.zipWithIndex.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
.map{ case(_, indices) => indices }
.map { case (_, indices) => indices }
.sorted
new ChiSqSelectorModel(indices)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,18 @@ class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext {

test("ChiSqSelector transform test (sparse & dense vector)") {
val labeledDiscreteData = sc.parallelize(
Seq(new LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
new LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
new LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
new LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
), 2)
Seq(LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2)
val preFilteredData =
Set(new LabeledPoint(0.0, Vectors.dense(Array(0.0))),
new LabeledPoint(1.0, Vectors.dense(Array(6.0))),
new LabeledPoint(1.0, Vectors.dense(Array(8.0))),
new LabeledPoint(2.0, Vectors.dense(Array(5.0)))
)
Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
LabeledPoint(1.0, Vectors.dense(Array(6.0))),
LabeledPoint(1.0, Vectors.dense(Array(8.0))),
LabeledPoint(2.0, Vectors.dense(Array(5.0))))
val model = new ChiSqSelector(1).fit(labeledDiscreteData)
val filteredData = labeledDiscreteData.map { lp =>
new LabeledPoint(lp.label, model.transform(lp.features))
LabeledPoint(lp.label, model.transform(lp.features))
}.collect().toSet
assert(filteredData == preFilteredData)
}
Expand Down

0 comments on commit 755d358

Please sign in to comment.