diff --git a/docs/ml-features.md b/docs/ml-features.md
index 06f1ac196b39d..970033503820f 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -18,30 +18,38 @@ This section covers algorithms for working with features, roughly divided into t
# Feature Extractors
-## Hashing Term-Frequency (HashingTF)
+## TF-IDF (HashingTF and IDF)
-`HashingTF` is a `Transformer` which takes sets of terms (e.g., `String` terms can be sets of words) and converts those sets into fixed-length feature vectors.
-The algorithm combines [Term Frequency (TF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) counts with the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality reduction. Please refer to the [MLlib user guide on TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on Term-Frequency.
+[Term Frequency-Inverse Document Frequency (TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a common text pre-processing step. In Spark ML, TF-IDF is separate into two parts: TF (+hashing) and IDF.
-HashingTF is implemented in
-[HashingTF](api/scala/index.html#org.apache.spark.ml.feature.HashingTF).
-In the following code segment, we start with a set of sentences. We split each sentence into words using `Tokenizer`. For each sentence (bag of words), we hash it into a feature vector. This feature vector could then be passed to a learning algorithm.
+**TF**: `HashingTF` is a `Transformer` which takes sets of terms and converts those sets into fixed-length feature vectors. In text processing, a "set of terms" might be a bag of words.
+The algorithm combines Term Frequency (TF) counts with the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality reduction.
+
+**IDF**: `IDF` is an `Estimator` which fits on a dataset and produces an `IDFModel`. The `IDFModel` takes feature vectors (generally created from `HashingTF`) and scales each column. Intuitively, it down-weights columns which appear frequently in a corpus.
+
+Please refer to the [MLlib user guide on TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on Term Frequency and Inverse Document Frequency.
+For API details, refer to the [HashingTF API docs](api/scala/index.html#org.apache.spark.ml.feature.HashingTF) and the [IDF API docs](api/scala/index.html#org.apache.spark.ml.feature.IDF).
+
+In the following code segment, we start with a set of sentences. We split each sentence into words using `Tokenizer`. For each sentence (bag of words), we use `HashingTF` to hash the sentence into a feature vector. We use `IDF` to rescale the feature vectors; this generally improves performance when using text as features. Our feature vectors could then be passed to a learning algorithm.
{% highlight scala %}
-import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
-val sentenceDataFrame = sqlContext.createDataFrame(Seq(
+val sentenceData = sqlContext.createDataFrame(Seq(
(0, "Hi I heard about Spark"),
(0, "I wish Java could use case classes"),
(1, "Logistic regression models are neat")
)).toDF("label", "sentence")
val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words")
-val wordsDataFrame = tokenizer.transform(sentenceDataFrame)
-val hashingTF = new HashingTF().setInputCol("words").setOutputCol("features").setNumFeatures(20)
-val featurized = hashingTF.transform(wordsDataFrame)
-featurized.select("features", "label").take(3).foreach(println)
+val wordsData = tokenizer.transform(sentenceData)
+val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(20)
+val featurizedData = hashingTF.transform(wordsData)
+val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
+val idfModel = idf.fit(featurizedData)
+val rescaledData = idfModel.transform(featurizedData)
+rescaledData.select("features", "label").take(3).foreach(println)
{% endhighlight %}
@@ -51,6 +59,7 @@ import com.google.common.collect.Lists;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.HashingTF;
+import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
@@ -70,16 +79,19 @@ StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
-DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema);
+DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema);
Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
-DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame);
+DataFrame wordsData = tokenizer.transform(sentenceData);
int numFeatures = 20;
HashingTF hashingTF = new HashingTF()
.setInputCol("words")
- .setOutputCol("features")
+ .setOutputCol("rawFeatures")
.setNumFeatures(numFeatures);
-DataFrame featurized = hashingTF.transform(wordsDataFrame);
-for (Row r : featurized.select("features", "label").take(3)) {
+DataFrame featurizedData = hashingTF.transform(wordsData);
+IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
+IDFModel idfModel = idf.fit(featurizedData);
+DataFrame rescaledData = idfModel.transform(featurizedData);
+for (Row r : rescaledData.select("features", "label").take(3)) {
Vector features = r.getAs(0);
Double label = r.getDouble(1);
System.out.println(features);
@@ -89,18 +101,21 @@ for (Row r : featurized.select("features", "label").take(3)) {
{% highlight python %}
-from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.ml.feature import HashingTF, IDF, Tokenizer
-sentenceDataFrame = sqlContext.createDataFrame([
+sentenceData = sqlContext.createDataFrame([
(0, "Hi I heard about Spark"),
(0, "I wish Java could use case classes"),
(1, "Logistic regression models are neat")
], ["label", "sentence"])
tokenizer = Tokenizer(inputCol="sentence", outputCol="words")
-wordsDataFrame = tokenizer.transform(sentenceDataFrame)
-hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
-featurized = hashingTF.transform(wordsDataFrame)
-for features_label in featurized.select("features", "label").take(3):
+wordsData = tokenizer.transform(sentenceData)
+hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=20)
+featurizedData = hashingTF.transform(wordsData)
+idf = IDF(inputCol="rawFeatures", outputCol="features")
+idfModel = idf.fit(featurizedData)
+rescaledData = idfModel.transform(featurizedData)
+for features_label in rescaledData.select("features", "label").take(3):
print features_label
{% endhighlight %}
@@ -618,5 +633,6 @@ indexedData = indexerModel.transform(data)
+
# Feature Selectors
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 23463ab5fe848..da2218056307e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -63,17 +63,22 @@ public void hashingTF() {
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame sentenceDataFrame = jsql.createDataFrame(jrdd, schema);
- Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
- DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame);
+ DataFrame sentenceData = jsql.createDataFrame(jrdd, schema);
+ Tokenizer tokenizer = new Tokenizer()
+ .setInputCol("sentence")
+ .setOutputCol("words");
+ DataFrame wordsData = tokenizer.transform(sentenceData);
int numFeatures = 20;
HashingTF hashingTF = new HashingTF()
.setInputCol("words")
- .setOutputCol("features")
+ .setOutputCol("rawFeatures")
.setNumFeatures(numFeatures);
- DataFrame featurized = hashingTF.transform(wordsDataFrame);
- for (Row r : featurized.select("features", "words", "label").take(3)) {
+ DataFrame featurizedData = hashingTF.transform(wordsData);
+ IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
+ IDFModel idfModel = idf.fit(featurizedData);
+ DataFrame rescaledData = idfModel.transform(featurizedData);
+ for (Row r : rescaledData.select("features", "label").take(3)) {
Vector features = r.getAs(0);
Assert.assertEquals(features.size(), numFeatures);
}