Skip to content

Commit

Permalink
Add assert to testcase on cluster sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
sboeschhuawei committed Jan 28, 2015
1 parent 24f438e commit 88aacc8
Showing 1 changed file with 8 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.clustering

import org.apache.log4j.Logger
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.graphx._
import org.apache.spark.mllib.clustering.PICLinalg.DMatrix
Expand All @@ -27,6 +28,8 @@ import scala.util.Random

class PIClusteringSuite extends FunSuite with LocalSparkContext {

val logger = Logger.getLogger(getClass.getName)

import org.apache.spark.mllib.clustering.PIClusteringSuite._

val PIC = PIClustering
Expand All @@ -38,6 +41,7 @@ class PIClusteringSuite extends FunSuite with LocalSparkContext {
concentricCirclesTest()
}


def concentricCirclesTest() = {
val sigma = 1.0
val nIterations = 10
Expand All @@ -63,33 +67,13 @@ class PIClusteringSuite extends FunSuite with LocalSparkContext {
val (ccenters, estCollected) = PIC.run(sc, vertices, nClusters, nIterations)
println(s"Cluster centers: ${ccenters.mkString(",")} " +
s"\nEstimates: ${estCollected.mkString("[", ",", "]")}")
assert(ccenters.size == circleSpecs.length,"Did not get correct number of centers")
val clustGroupsList = estCollected.groupBy{ case ((vid, eigenV), clustNum) =>
clustNum
}.mapValues{
_.map{ case ((vid, eigenV), clustNum) =>
(vid, clustNum)
}}.toList.sortBy(_._1)


val ccentersOrdered = ccenters.sortBy(-1.0 * _._2(0))

// val joinedGroups = ccentersOrdered.(clustGroupsList.toMap)
//
// val clustValids = clustGroupsList.map{ case (clustNum, vidEigensList) =>
// (clustNum, vidEigensList.size, vidEigensList.map{ (_._1 / 1000).toLong }}
// assert(clustGroups.map{_._2.size} == circleSpecs.map{ p => p.nPoints },
// "Incorrect match on clusterGroupsSize")
// val matchedCentersAndPoints = ccentersOrdered.map{ case (groupId, loc) => groupId}.zip(clustGroups)
// assert(matchedCentersAndPoints.map{_._2.size} == circleSpecs.map{ p => p.nPoints },
// "Incorrect match on clusterGroupsSize
//
// assert(estCollected == circleSpecs.length,"Did not get correct number of centers")
assert(ccenters.size == circleSpecs.length, "Did not get correct number of centers")

}
}

def join[T <: Comparable[T]](a: Map[T,_], b: Map[T,_]) = {
(a.toSeq++b.toSeq).groupBy(_._1).mapValues(_.map(_._2).toList)
def join[T <: Comparable[T]](a: Map[T, _], b: Map[T, _]) = {
(a.toSeq ++ b.toSeq).groupBy(_._1).mapValues(_.map(_._2).toList)
}

ignore("irisData") {
Expand Down

0 comments on commit 88aacc8

Please sign in to comment.