diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4628cef6e255c..0384c39b5d5ab 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,6 +26,22 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti HasRegParam): """ Logistic regression. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \ + Row(label=1.0, features=Vectors.dense(1.0)), \ + Row(label=0.0, features=Vectors.sparse(1, [], []))])) + >>> lr = LogisticRegression() \ + .setMaxIter(5) \ + .setRegParam(0.01) + >>> model = lr.fit(dataset) + >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))])) + >>> print model.transform(test0).first().prediction + 0.0 + >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))])) + >>> print model.transform(test1).first().prediction + 1.0 """ def __init__(self): @@ -52,3 +68,21 @@ def __init__(self, java_model): @property def _java_class(self): return "org.apache.spark.ml.classification.LogisticRegressionModel" + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlCtx = SQLContext(sc) + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index dee574774e83e..45191c8fcb3dc 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -73,7 +73,7 @@ def _java_class(self): if __name__ == "__main__": import doctest from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import SQLContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: