-
Notifications
You must be signed in to change notification settings - Fork 28.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-11445][DOCS] Replaced example code in mllib-ensembles.md using…
… include_example I have made the required changes and tested. Kindly review the changes. Author: Rishabh Bhardwaj <rbnext29@gmail.com> Closes #9407 from rishabhbhardwaj/SPARK-11445. (cherry picked from commit 61a2848) Signed-off-by: Xiangrui Meng <meng@databricks.com>
- Loading branch information
1 parent
9fa9ad0
commit 98c614d
Showing
13 changed files
with
885 additions
and
514 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
92 changes: 92 additions & 0 deletions
92
.../main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib; | ||
|
||
// $example on$ | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import scala.Tuple2; | ||
|
||
import org.apache.spark.SparkConf; | ||
import org.apache.spark.api.java.JavaPairRDD; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.Function; | ||
import org.apache.spark.api.java.function.PairFunction; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.tree.GradientBoostedTrees; | ||
import org.apache.spark.mllib.tree.configuration.BoostingStrategy; | ||
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
// $example off$ | ||
|
||
public class JavaGradientBoostingClassificationExample { | ||
public static void main(String[] args) { | ||
// $example on$ | ||
SparkConf sparkConf = new SparkConf() | ||
.setAppName("JavaGradientBoostedTreesClassificationExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(sparkConf); | ||
|
||
// Load and parse the data file. | ||
String datapath = "data/mllib/sample_libsvm_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); | ||
// Split the data into training and test sets (30% held out for testing) | ||
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); | ||
JavaRDD<LabeledPoint> trainingData = splits[0]; | ||
JavaRDD<LabeledPoint> testData = splits[1]; | ||
|
||
// Train a GradientBoostedTrees model. | ||
// The defaultParams for Classification use LogLoss by default. | ||
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); | ||
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. | ||
boostingStrategy.getTreeStrategy().setNumClasses(2); | ||
boostingStrategy.getTreeStrategy().setMaxDepth(5); | ||
// Empty categoricalFeaturesInfo indicates all features are continuous. | ||
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | ||
boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); | ||
|
||
final GradientBoostedTreesModel model = | ||
GradientBoostedTrees.train(trainingData, boostingStrategy); | ||
|
||
// Evaluate model on test instances and compute test error | ||
JavaPairRDD<Double, Double> predictionAndLabel = | ||
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { | ||
@Override | ||
public Tuple2<Double, Double> call(LabeledPoint p) { | ||
return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); | ||
} | ||
}); | ||
Double testErr = | ||
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { | ||
@Override | ||
public Boolean call(Tuple2<Double, Double> pl) { | ||
return !pl._1().equals(pl._2()); | ||
} | ||
}).count() / testData.count(); | ||
System.out.println("Test Error: " + testErr); | ||
System.out.println("Learned classification GBT model:\n" + model.toDebugString()); | ||
|
||
// Save and load model | ||
model.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel"); | ||
GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), | ||
"target/tmp/myGradientBoostingClassificationModel"); | ||
// $example off$ | ||
} | ||
|
||
} |
96 changes: 96 additions & 0 deletions
96
.../src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib; | ||
|
||
// $example on$ | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import scala.Tuple2; | ||
|
||
import org.apache.spark.SparkConf; | ||
import org.apache.spark.api.java.function.Function2; | ||
import org.apache.spark.api.java.JavaPairRDD; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.Function; | ||
import org.apache.spark.api.java.function.PairFunction; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.tree.GradientBoostedTrees; | ||
import org.apache.spark.mllib.tree.configuration.BoostingStrategy; | ||
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
// $example off$ | ||
|
||
public class JavaGradientBoostingRegressionExample { | ||
public static void main(String[] args) { | ||
// $example on$ | ||
SparkConf sparkConf = new SparkConf() | ||
.setAppName("JavaGradientBoostedTreesRegressionExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(sparkConf); | ||
// Load and parse the data file. | ||
String datapath = "data/mllib/sample_libsvm_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); | ||
// Split the data into training and test sets (30% held out for testing) | ||
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); | ||
JavaRDD<LabeledPoint> trainingData = splits[0]; | ||
JavaRDD<LabeledPoint> testData = splits[1]; | ||
|
||
// Train a GradientBoostedTrees model. | ||
// The defaultParams for Regression use SquaredError by default. | ||
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); | ||
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. | ||
boostingStrategy.getTreeStrategy().setMaxDepth(5); | ||
// Empty categoricalFeaturesInfo indicates all features are continuous. | ||
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | ||
boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); | ||
|
||
final GradientBoostedTreesModel model = | ||
GradientBoostedTrees.train(trainingData, boostingStrategy); | ||
|
||
// Evaluate model on test instances and compute test error | ||
JavaPairRDD<Double, Double> predictionAndLabel = | ||
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { | ||
@Override | ||
public Tuple2<Double, Double> call(LabeledPoint p) { | ||
return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); | ||
} | ||
}); | ||
Double testMSE = | ||
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { | ||
@Override | ||
public Double call(Tuple2<Double, Double> pl) { | ||
Double diff = pl._1() - pl._2(); | ||
return diff * diff; | ||
} | ||
}).reduce(new Function2<Double, Double, Double>() { | ||
@Override | ||
public Double call(Double a, Double b) { | ||
return a + b; | ||
} | ||
}) / data.count(); | ||
System.out.println("Test Mean Squared Error: " + testMSE); | ||
System.out.println("Learned regression GBT model:\n" + model.toDebugString()); | ||
|
||
// Save and load model | ||
model.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel"); | ||
GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), | ||
"target/tmp/myGradientBoostingRegressionModel"); | ||
// $example off$ | ||
} | ||
} |
89 changes: 89 additions & 0 deletions
89
.../src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib; | ||
|
||
// $example on$ | ||
import java.util.HashMap; | ||
|
||
import scala.Tuple2; | ||
|
||
import org.apache.spark.SparkConf; | ||
import org.apache.spark.api.java.JavaPairRDD; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.Function; | ||
import org.apache.spark.api.java.function.PairFunction; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.tree.RandomForest; | ||
import org.apache.spark.mllib.tree.model.RandomForestModel; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
// $example off$ | ||
|
||
public class JavaRandomForestClassificationExample { | ||
public static void main(String[] args) { | ||
// $example on$ | ||
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(sparkConf); | ||
// Load and parse the data file. | ||
String datapath = "data/mllib/sample_libsvm_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); | ||
// Split the data into training and test sets (30% held out for testing) | ||
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); | ||
JavaRDD<LabeledPoint> trainingData = splits[0]; | ||
JavaRDD<LabeledPoint> testData = splits[1]; | ||
|
||
// Train a RandomForest model. | ||
// Empty categoricalFeaturesInfo indicates all features are continuous. | ||
Integer numClasses = 2; | ||
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | ||
Integer numTrees = 3; // Use more in practice. | ||
String featureSubsetStrategy = "auto"; // Let the algorithm choose. | ||
String impurity = "gini"; | ||
Integer maxDepth = 5; | ||
Integer maxBins = 32; | ||
Integer seed = 12345; | ||
|
||
final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, | ||
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, | ||
seed); | ||
|
||
// Evaluate model on test instances and compute test error | ||
JavaPairRDD<Double, Double> predictionAndLabel = | ||
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { | ||
@Override | ||
public Tuple2<Double, Double> call(LabeledPoint p) { | ||
return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); | ||
} | ||
}); | ||
Double testErr = | ||
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { | ||
@Override | ||
public Boolean call(Tuple2<Double, Double> pl) { | ||
return !pl._1().equals(pl._2()); | ||
} | ||
}).count() / testData.count(); | ||
System.out.println("Test Error: " + testErr); | ||
System.out.println("Learned classification forest model:\n" + model.toDebugString()); | ||
|
||
// Save and load model | ||
model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel"); | ||
RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), | ||
"target/tmp/myRandomForestClassificationModel"); | ||
// $example off$ | ||
} | ||
} |
95 changes: 95 additions & 0 deletions
95
...ples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib; | ||
|
||
// $example on$ | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import scala.Tuple2; | ||
|
||
import org.apache.spark.api.java.function.Function2; | ||
import org.apache.spark.api.java.JavaPairRDD; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.Function; | ||
import org.apache.spark.api.java.function.PairFunction; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.tree.RandomForest; | ||
import org.apache.spark.mllib.tree.model.RandomForestModel; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
import org.apache.spark.SparkConf; | ||
// $example off$ | ||
|
||
public class JavaRandomForestRegressionExample { | ||
public static void main(String[] args) { | ||
// $example on$ | ||
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(sparkConf); | ||
// Load and parse the data file. | ||
String datapath = "data/mllib/sample_libsvm_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); | ||
// Split the data into training and test sets (30% held out for testing) | ||
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); | ||
JavaRDD<LabeledPoint> trainingData = splits[0]; | ||
JavaRDD<LabeledPoint> testData = splits[1]; | ||
|
||
// Set parameters. | ||
// Empty categoricalFeaturesInfo indicates all features are continuous. | ||
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | ||
Integer numTrees = 3; // Use more in practice. | ||
String featureSubsetStrategy = "auto"; // Let the algorithm choose. | ||
String impurity = "variance"; | ||
Integer maxDepth = 4; | ||
Integer maxBins = 32; | ||
Integer seed = 12345; | ||
// Train a RandomForest model. | ||
final RandomForestModel model = RandomForest.trainRegressor(trainingData, | ||
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); | ||
|
||
// Evaluate model on test instances and compute test error | ||
JavaPairRDD<Double, Double> predictionAndLabel = | ||
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { | ||
@Override | ||
public Tuple2<Double, Double> call(LabeledPoint p) { | ||
return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); | ||
} | ||
}); | ||
Double testMSE = | ||
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { | ||
@Override | ||
public Double call(Tuple2<Double, Double> pl) { | ||
Double diff = pl._1() - pl._2(); | ||
return diff * diff; | ||
} | ||
}).reduce(new Function2<Double, Double, Double>() { | ||
@Override | ||
public Double call(Double a, Double b) { | ||
return a + b; | ||
} | ||
}) / testData.count(); | ||
System.out.println("Test Mean Squared Error: " + testMSE); | ||
System.out.println("Learned regression forest model:\n" + model.toDebugString()); | ||
|
||
// Save and load model | ||
model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel"); | ||
RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), | ||
"target/tmp/myRandomForestRegressionModel"); | ||
// $example off$ | ||
} | ||
} |
Oops, something went wrong.