From 4dd1c8a2393b91dc1841c3b01dad7163371dd434 Mon Sep 17 00:00:00 2001 From: zhangjiajin Date: Wed, 15 Jul 2015 10:57:41 +0800 Subject: [PATCH] initialize file before rebase. --- .../apache/spark/mllib/fpm/PrefixSpan.scala | 75 +++---------------- 1 file changed, 10 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 33e381e6d4d66..9d8c60ef0fc45 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -43,8 +43,6 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { - private val minPatternsBeforeShuffle: Int = 20 - /** * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`}. @@ -88,69 +86,16 @@ class PrefixSpan private ( getFreqItemAndCounts(minCount, sequences).collect() val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( lengthOnePatternsAndCounts.map(_._1), sequences) - - var patternsCount = lengthOnePatternsAndCounts.length - var allPatternAndCounts = sequences.sparkContext.parallelize( - lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) - var currentProjectedDatabase = prefixAndProjectedDatabase - while (patternsCount <= minPatternsBeforeShuffle && - currentProjectedDatabase.count() != 0) { - val (nextPatternAndCounts, nextProjectedDatabase) = - getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase) - patternsCount = nextPatternAndCounts.count().toInt - currentProjectedDatabase = nextProjectedDatabase - allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts - } - if (patternsCount > 0) { - val groupedProjectedDatabase = currentProjectedDatabase - .map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase) - allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts - } - allPatternAndCounts - } - - /** - * Get the pattern and counts, and projected database - * @param minCount minimum count - * @param prefixAndProjectedDatabase prefix and projected database, - * @return pattern and counts, and projected database - * (Array[pattern, count], RDD[prefix, projected database ]) - */ - private def getPatternCountsAndProjectedDatabase( - minCount: Long, - prefixAndProjectedDatabase: RDD[(Array[Int], Array[Int])]): - (RDD[(Array[Int], Long)], RDD[(Array[Int], Array[Int])]) = { - val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x => - x._2.distinct.map(y => ((x._1.toSeq, y), 1L)) - }.reduceByKey(_ + _) - .filter(_._2 >= minCount) - val patternAndCounts = prefixAndFreqentItemAndCounts - .map(x => (x._1._1.toArray ++ Array(x._1._2), x._2)) - val prefixlength = prefixAndProjectedDatabase.take(1)(0)._1.length - if (prefixlength + 1 >= maxPatternLength) { - (patternAndCounts, prefixAndProjectedDatabase.filter(x => false)) - } else { - val frequentItemsMap = prefixAndFreqentItemAndCounts - .keys.map(x => (x._1, x._2)) - .groupByKey() - .mapValues(_.toSet) - .collect - .toMap - val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase - .filter(x => frequentItemsMap.contains(x._1)) - .flatMap { x => - val frequentItemSet = frequentItemsMap(x._1) - val filteredSequence = x._2.filter(frequentItemSet.contains(_)) - val subProjectedDabase = frequentItemSet.map{ y => - (y, LocalPrefixSpan.getSuffix(y, filteredSequence)) - }.filter(_._2.nonEmpty) - subProjectedDabase.map(y => (x._1 ++ Array(y._1), y._2)) - } - (patternAndCounts, nextPrefixAndProjectedDatabase) - } + val groupedProjectedDatabase = prefixAndProjectedDatabase + .map(x => (x._1.toSeq, x._2)) + .groupByKey() + .map(x => (x._1.toArray, x._2.toArray)) + val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) + val lengthOnePatternsAndCountsRdd = + sequences.sparkContext.parallelize( + lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) + val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns + allPatterns } /**