Skip to content

Commit

Permalink
use namedtuple
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Apr 22, 2015
1 parent 5532e78 commit da8c404
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions python/pyspark/mllib/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy
from numpy import array
from collections import namedtuple

from pyspark import SparkContext
from pyspark.rdd import ignore_unicode_prefix
Expand All @@ -38,15 +39,15 @@ class FPGrowthModel(JavaModelWrapper):
>>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
>>> rdd = sc.parallelize(data, 2)
>>> model = FPGrowth.train(rdd, 0.6, 2)
>>> model.freqItemsets().collect()
[(array([u'a'], ...), 4), (array([u'c'], ...), 3), (array([u'c', u'a'], ...), 3)]
>>> sorted(model.freqItemsets().collect())
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
"""

def freqItemsets(self):
"""
Get the frequent itemsets of this model
Returns the frequent itemsets of this model.
"""
return self.call("getFreqItemsets").map(lambda x: (numpy.array(x[0]), x[1]))
return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))


class FPGrowth(object):
Expand All @@ -70,6 +71,11 @@ def train(cls, data, minSupport=0.3, numPartitions=-1):
model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
return FPGrowthModel(model)

class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])):
"""
Represents an (items, freq) tuple.
"""


def _test():
import doctest
Expand Down

0 comments on commit da8c404

Please sign in to comment.