-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-11445][DOCS]Replaced example code in mllib-ensembles.md using include_example #9407
Changes from all commits
d152cb5
a53a20d
870cbb3
a21b0ed
079b1de
24e74e1
a71e99b
29a8067
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib; | ||
|
||
// $example on$ | ||
import java.util.HashMap; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to add some blank lines in imports. See code style guide. |
||
import java.util.Map; | ||
|
||
import scala.Tuple2; | ||
|
||
import org.apache.spark.SparkConf; | ||
import org.apache.spark.api.java.JavaPairRDD; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.Function; | ||
import org.apache.spark.api.java.function.PairFunction; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.tree.GradientBoostedTrees; | ||
import org.apache.spark.mllib.tree.configuration.BoostingStrategy; | ||
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
// $example off$ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line below here |
||
|
||
public class JavaGradientBoostingClassificationExample { | ||
public static void main(String[] args) { | ||
// $example on$ | ||
SparkConf sparkConf = new SparkConf() | ||
.setAppName("JavaGradientBoostedTreesClassificationExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(sparkConf); | ||
|
||
// Load and parse the data file. | ||
String datapath = "data/mllib/sample_libsvm_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); | ||
// Split the data into training and test sets (30% held out for testing) | ||
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); | ||
JavaRDD<LabeledPoint> trainingData = splits[0]; | ||
JavaRDD<LabeledPoint> testData = splits[1]; | ||
|
||
// Train a GradientBoostedTrees model. | ||
// The defaultParams for Classification use LogLoss by default. | ||
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); | ||
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. | ||
boostingStrategy.getTreeStrategy().setNumClasses(2); | ||
boostingStrategy.getTreeStrategy().setMaxDepth(5); | ||
// Empty categoricalFeaturesInfo indicates all features are continuous. | ||
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | ||
boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); | ||
|
||
final GradientBoostedTreesModel model = | ||
GradientBoostedTrees.train(trainingData, boostingStrategy); | ||
|
||
// Evaluate model on test instances and compute test error | ||
JavaPairRDD<Double, Double> predictionAndLabel = | ||
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { | ||
@Override | ||
public Tuple2<Double, Double> call(LabeledPoint p) { | ||
return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); | ||
} | ||
}); | ||
Double testErr = | ||
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { | ||
@Override | ||
public Boolean call(Tuple2<Double, Double> pl) { | ||
return !pl._1().equals(pl._2()); | ||
} | ||
}).count() / testData.count(); | ||
System.out.println("Test Error: " + testErr); | ||
System.out.println("Learned classification GBT model:\n" + model.toDebugString()); | ||
|
||
// Save and load model | ||
model.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel"); | ||
GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), | ||
"target/tmp/myGradientBoostingClassificationModel"); | ||
// $example off$ | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line here |
||
|
||
// $example on$ | ||
import java.util.HashMap; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see comment in the previous code file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see comment in the previous code file. |
||
import java.util.Map; | ||
|
||
import scala.Tuple2; | ||
|
||
import org.apache.spark.SparkConf; | ||
import org.apache.spark.api.java.function.Function2; | ||
import org.apache.spark.api.java.JavaPairRDD; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.Function; | ||
import org.apache.spark.api.java.function.PairFunction; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.tree.GradientBoostedTrees; | ||
import org.apache.spark.mllib.tree.configuration.BoostingStrategy; | ||
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
// $example off$ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line here |
||
|
||
public class JavaGradientBoostingRegressionExample { | ||
public static void main(String[] args) { | ||
// $example on$ | ||
SparkConf sparkConf = new SparkConf() | ||
.setAppName("JavaGradientBoostedTreesRegressionExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(sparkConf); | ||
// Load and parse the data file. | ||
String datapath = "data/mllib/sample_libsvm_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); | ||
// Split the data into training and test sets (30% held out for testing) | ||
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); | ||
JavaRDD<LabeledPoint> trainingData = splits[0]; | ||
JavaRDD<LabeledPoint> testData = splits[1]; | ||
|
||
// Train a GradientBoostedTrees model. | ||
// The defaultParams for Regression use SquaredError by default. | ||
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); | ||
boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. | ||
boostingStrategy.getTreeStrategy().setMaxDepth(5); | ||
// Empty categoricalFeaturesInfo indicates all features are continuous. | ||
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | ||
boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); | ||
|
||
final GradientBoostedTreesModel model = | ||
GradientBoostedTrees.train(trainingData, boostingStrategy); | ||
|
||
// Evaluate model on test instances and compute test error | ||
JavaPairRDD<Double, Double> predictionAndLabel = | ||
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { | ||
@Override | ||
public Tuple2<Double, Double> call(LabeledPoint p) { | ||
return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); | ||
} | ||
}); | ||
Double testMSE = | ||
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { | ||
@Override | ||
public Double call(Tuple2<Double, Double> pl) { | ||
Double diff = pl._1() - pl._2(); | ||
return diff * diff; | ||
} | ||
}).reduce(new Function2<Double, Double, Double>() { | ||
@Override | ||
public Double call(Double a, Double b) { | ||
return a + b; | ||
} | ||
}) / data.count(); | ||
System.out.println("Test Mean Squared Error: " + testMSE); | ||
System.out.println("Learned regression GBT model:\n" + model.toDebugString()); | ||
|
||
// Save and load model | ||
model.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel"); | ||
GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), | ||
"target/tmp/myGradientBoostingRegressionModel"); | ||
// $example off$ | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib; | ||
|
||
// $example on$ | ||
import java.util.HashMap; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank lines according to spark scala style guide There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank lines according to spark scala style guide |
||
|
||
import scala.Tuple2; | ||
|
||
import org.apache.spark.SparkConf; | ||
import org.apache.spark.api.java.JavaPairRDD; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.Function; | ||
import org.apache.spark.api.java.function.PairFunction; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.tree.RandomForest; | ||
import org.apache.spark.mllib.tree.model.RandomForestModel; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
// $example off$ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blank line |
||
|
||
public class JavaRandomForestClassificationExample { | ||
public static void main(String[] args) { | ||
// $example on$ | ||
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(sparkConf); | ||
// Load and parse the data file. | ||
String datapath = "data/mllib/sample_libsvm_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); | ||
// Split the data into training and test sets (30% held out for testing) | ||
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); | ||
JavaRDD<LabeledPoint> trainingData = splits[0]; | ||
JavaRDD<LabeledPoint> testData = splits[1]; | ||
|
||
// Train a RandomForest model. | ||
// Empty categoricalFeaturesInfo indicates all features are continuous. | ||
Integer numClasses = 2; | ||
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | ||
Integer numTrees = 3; // Use more in practice. | ||
String featureSubsetStrategy = "auto"; // Let the algorithm choose. | ||
String impurity = "gini"; | ||
Integer maxDepth = 5; | ||
Integer maxBins = 32; | ||
Integer seed = 12345; | ||
|
||
final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, | ||
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, | ||
seed); | ||
|
||
// Evaluate model on test instances and compute test error | ||
JavaPairRDD<Double, Double> predictionAndLabel = | ||
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { | ||
@Override | ||
public Tuple2<Double, Double> call(LabeledPoint p) { | ||
return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); | ||
} | ||
}); | ||
Double testErr = | ||
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { | ||
@Override | ||
public Boolean call(Tuple2<Double, Double> pl) { | ||
return !pl._1().equals(pl._2()); | ||
} | ||
}).count() / testData.count(); | ||
System.out.println("Test Error: " + testErr); | ||
System.out.println("Learned classification forest model:\n" + model.toDebugString()); | ||
|
||
// Save and load model | ||
model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel"); | ||
RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), | ||
"target/tmp/myRandomForestClassificationModel"); | ||
// $example off$ | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib; | ||
|
||
// $example on$ | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import scala.Tuple2; | ||
|
||
import org.apache.spark.api.java.function.Function2; | ||
import org.apache.spark.api.java.JavaPairRDD; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.Function; | ||
import org.apache.spark.api.java.function.PairFunction; | ||
import org.apache.spark.mllib.regression.LabeledPoint; | ||
import org.apache.spark.mllib.tree.RandomForest; | ||
import org.apache.spark.mllib.tree.model.RandomForestModel; | ||
import org.apache.spark.mllib.util.MLUtils; | ||
import org.apache.spark.SparkConf; | ||
// $example off$ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same issue with previous code files. |
||
|
||
public class JavaRandomForestRegressionExample { | ||
public static void main(String[] args) { | ||
// $example on$ | ||
SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample"); | ||
JavaSparkContext jsc = new JavaSparkContext(sparkConf); | ||
// Load and parse the data file. | ||
String datapath = "data/mllib/sample_libsvm_data.txt"; | ||
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); | ||
// Split the data into training and test sets (30% held out for testing) | ||
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3}); | ||
JavaRDD<LabeledPoint> trainingData = splits[0]; | ||
JavaRDD<LabeledPoint> testData = splits[1]; | ||
|
||
// Set parameters. | ||
// Empty categoricalFeaturesInfo indicates all features are continuous. | ||
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); | ||
Integer numTrees = 3; // Use more in practice. | ||
String featureSubsetStrategy = "auto"; // Let the algorithm choose. | ||
String impurity = "variance"; | ||
Integer maxDepth = 4; | ||
Integer maxBins = 32; | ||
Integer seed = 12345; | ||
// Train a RandomForest model. | ||
final RandomForestModel model = RandomForest.trainRegressor(trainingData, | ||
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); | ||
|
||
// Evaluate model on test instances and compute test error | ||
JavaPairRDD<Double, Double> predictionAndLabel = | ||
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { | ||
@Override | ||
public Tuple2<Double, Double> call(LabeledPoint p) { | ||
return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); | ||
} | ||
}); | ||
Double testMSE = | ||
predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { | ||
@Override | ||
public Double call(Tuple2<Double, Double> pl) { | ||
Double diff = pl._1() - pl._2(); | ||
return diff * diff; | ||
} | ||
}).reduce(new Function2<Double, Double, Double>() { | ||
@Override | ||
public Double call(Double a, Double b) { | ||
return a + b; | ||
} | ||
}) / testData.count(); | ||
System.out.println("Test Mean Squared Error: " + testMSE); | ||
System.out.println("Learned regression forest model:\n" + model.toDebugString()); | ||
|
||
// Save and load model | ||
model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel"); | ||
RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), | ||
"target/tmp/myRandomForestRegressionModel"); | ||
// $example off$ | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a blank line below here.