From 5731af5be65ccac831445f351baf040a0d007687 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 31 Mar 2014 15:23:46 -0700 Subject: [PATCH 01/10] [SQL] Rewrite join implementation to allow streaming of one relation. Before we were materializing everything in memory. This also uses the projection interface so will be easier to plug in code gen (its ported from that branch). @rxin @liancheng Author: Michael Armbrust Closes #250 from marmbrus/hashJoin and squashes the following commits: 1ad873e [Michael Armbrust] Change hasNext logic back to the correct version. 8e6f2a2 [Michael Armbrust] Review comments. 1e9fb63 [Michael Armbrust] style bc0cb84 [Michael Armbrust] Rewrite join implementation to allow streaming of one relation. --- .../spark/sql/catalyst/expressions/Row.scala | 10 ++ .../sql/catalyst/expressions/predicates.scala | 6 + .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../apache/spark/sql/execution/joins.scala | 127 +++++++++++++----- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 6 files changed, 116 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 31d42b9ee71a0..6f939e6c41f6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable { s"[${this.mkString(",")}]" def copy(): Row + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + var i = 0 + while (i < length) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 722ff517d250e..02fedd16b8d4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.types.{BooleanType, StringType} +object InterpretedPredicate { + def apply(expression: Expression): (Row => Boolean) = { + (r: Row) => expression.apply(r).asInstanceOf[Boolean] + } +} + trait Predicate extends Expression { self: Product => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cf3c06acce5b0..f950ea08ec57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -117,7 +117,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val strategies: Seq[Strategy] = TopK :: PartialAggregation :: - SparkEquiInnerJoin :: + HashJoin :: ParquetOperations :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 86f9d3e0fa954..e35ac0b6ca95a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => - object SparkEquiInnerJoin extends Strategy { + object HashJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) => logger.debug(s"Considering join: ${predicates ++ condition}") @@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val leftKeys = joinKeys.map(_._1) val rightKeys = joinKeys.map(_._2) - val joinOp = execution.SparkEquiInnerJoin( - leftKeys, rightKeys, planLater(left), planLater(right)) + val joinOp = execution.HashJoin( + leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) // Make sure other conditions are met if present. if (otherPredicates.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index f0d21143ba5d1..c89dae9358bf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -17,21 +17,22 @@ package org.apache.spark.sql.execution -import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, BitSet} -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} -import org.apache.spark.rdd.PartitionLocalRDDFunctions._ +sealed abstract class BuildSide +case object BuildLeft extends BuildSide +case object BuildRight extends BuildSide -case class SparkEquiInnerJoin( +case class HashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + buildSide: BuildSide, left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -40,33 +41,93 @@ case class SparkEquiInnerJoin( override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + def output = left.output ++ right.output - def execute() = attachTree(this, "execute") { - val leftWithKeys = left.execute().mapPartitions { iter => - val generateLeftKeys = new Projection(leftKeys, left.output) - iter.map(row => (generateLeftKeys(row), row.copy())) - } + @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val streamSideKeyGenerator = + () => new MutableProjection(streamedKeys, streamedPlan.output) - val rightWithKeys = right.execute().mapPartitions { iter => - val generateRightKeys = new Projection(rightKeys, right.output) - iter.map(row => (generateRightKeys(row), row.copy())) - } + def execute() = { - // Do the join. - val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys)) - // Drop join keys and merge input tuples. - joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) } - } + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + // TODO: Use Spark's HashMap implementation. + val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if(!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 - /** - * Filters any rows where the any of the join keys is null, ensuring three-valued - * logic for the equi-join conditions. - */ - protected def filterNulls(rdd: RDD[(Row, Row)]) = - rdd.filter { - case (key: Seq[_], _) => !key.exists(_ == null) + // Mutable per row objects. + private[this] val joinRow = new JoinedRow + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } } + } } case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin( def right = broadcast @transient lazy val boundCondition = - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true)) + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) def execute() = { val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new mutable.ArrayBuffer[Row] - val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size) + val matchedRows = new ArrayBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => @@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin( while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) { + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { matchedRows += buildRow(streamedRow ++ broadcastedRow) matched = true includedBroadcastTuples += i diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fc5057b73fe24..197b557cba5f4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { DataSinks, Scripts, PartialAggregation, - SparkEquiInnerJoin, + HashJoin, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin From 33b3c2a8c6c71b89744834017a183ea855e1697c Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 31 Mar 2014 16:25:43 -0700 Subject: [PATCH 02/10] SPARK-1365 [HOTFIX] Fix RateLimitedOutputStream test This test needs to be fixed. It currently depends on Thread.sleep() having exact-timing semantics, which is not a valid assumption. Author: Patrick Wendell Closes #277 from pwendell/rate-limited-stream and squashes the following commits: 6c0ff81 [Patrick Wendell] SPARK-1365: Fix RateLimitedOutputStream test --- .../spark/streaming/util/RateLimitedOutputStreamSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 7d18a0fcf7ba8..9ebf7b484f421 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -36,8 +36,9 @@ class RateLimitedOutputStreamSuite extends FunSuite { val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } - // We accept anywhere from 4.0 to 4.99999 seconds since the value is rounded down. - assert(SECONDS.convert(elapsedNs, NANOSECONDS) === 4) + val seconds = SECONDS.convert(elapsedNs, NANOSECONDS) + assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.") + assert(seconds <= 30, s"Took more than 30 seconds ($seconds) to write data.") assert(underlying.toString("UTF-8") === data) } } From 564f1c137caf07bd1f073ec6c93551dcad935ee5 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 1 Apr 2014 08:26:31 +0530 Subject: [PATCH 03/10] SPARK-1376. In the yarn-cluster submitter, rename "args" option to "arg" Author: Sandy Ryza Closes #279 from sryza/sandy-spark-1376 and squashes the following commits: d8aebfa [Sandy Ryza] SPARK-1376. In the yarn-cluster submitter, rename "args" option to "arg" --- docs/running-on-yarn.md | 7 ++++--- .../org/apache/spark/deploy/yarn/ClientArguments.scala | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d8657c4bc7096..982514391ac00 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -61,7 +61,7 @@ The command to launch the Spark application on the cluster is as follows: SPARK_JAR= ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar \ --class \ - --args \ + --arg \ --num-executors \ --driver-memory \ --executor-memory \ @@ -72,7 +72,7 @@ The command to launch the Spark application on the cluster is as follows: --files \ --archives -For example: +To pass multiple arguments the "arg" option can be specified multiple times. For example: # Build the Spark assembly JAR and the Spark examples JAR $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly @@ -85,7 +85,8 @@ For example: ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ --class org.apache.spark.examples.SparkPi \ - --args yarn-cluster \ + --arg yarn-cluster \ + --arg 5 \ --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index c565f2dde24fc..3e4c739e34fe9 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -63,7 +63,10 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { userClass = value args = tail - case ("--args") :: value :: tail => + case ("--args" | "--arg") :: value :: tail => + if (args(0) == "--args") { + println("--args is deprecated. Use --arg instead.") + } userArgsBuffer += value args = tail @@ -146,8 +149,8 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { "Options:\n" + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + " --class CLASS_NAME Name of your application's main class (required)\n" + - " --args ARGS Arguments to be passed to your application's main class.\n" + - " Mutliple invocations are possible, each will be passed in order.\n" + + " --arg ARGS Argument to be passed to your application's main class.\n" + + " Multiple invocations are possible, each will be passed in order.\n" + " --num-executors NUM Number of executors to start (Default: 2)\n" + " --executor-cores NUM Number of cores for the executors (Default: 1).\n" + " --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)\n" + From 94fe7fd4fa9749cb13e540e4f9caf28de47eaf32 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 31 Mar 2014 21:42:36 -0700 Subject: [PATCH 04/10] [SPARK-1377] Upgrade Jetty to 8.1.14v20131031 Previous version was 7.6.8v20121106. The only difference between Jetty 7 and Jetty 8 is that the former uses Servlet API 2.5, while the latter uses Servlet API 3.0. Author: Andrew Or Closes #280 from andrewor14/jetty-upgrade and squashes the following commits: dd57104 [Andrew Or] Merge github.com:apache/spark into jetty-upgrade e75fa85 [Andrew Or] Upgrade Jetty to 8.1.14v20131031 --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 3 ++- pom.xml | 8 ++++---- project/SparkBuild.scala | 12 ++++++------ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 6e1736f6fbc23..e1a1f209c9282 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -18,13 +18,14 @@ package org.apache.spark.ui import java.net.{InetSocketAddress, URL} +import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.annotation.tailrec import scala.util.{Failure, Success, Try} import scala.xml.Node -import org.eclipse.jetty.server.{DispatcherType, Server} +import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.servlet._ import org.eclipse.jetty.util.thread.QueuedThreadPool diff --git a/pom.xml b/pom.xml index 72acf2b402703..09a449d81453f 100644 --- a/pom.xml +++ b/pom.xml @@ -192,22 +192,22 @@ org.eclipse.jetty jetty-util - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-security - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-plus - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-server - 7.6.8.v20121106 + 8.1.14.v20131031 com.google.guava diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2549bc9710f1f..7457ff456ade4 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -248,13 +248,13 @@ object SparkBuild extends Build { */ libraryDependencies ++= Seq( - "io.netty" % "netty-all" % "4.0.17.Final", - "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-util" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-plus" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-security" % "7.6.8.v20121106", + "io.netty" % "netty-all" % "4.0.17.Final", + "org.eclipse.jetty" % "jetty-server" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-util" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-plus" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-security" % "8.1.14.v20131031", /** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */ - "org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"), + "org.eclipse.jetty.orbit" % "javax.servlet" % "3.0.0.v201112011016" artifacts Artifact("javax.servlet", "jar", "jar"), "org.scalatest" %% "scalatest" % "1.9.1" % "test", "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", "com.novocode" % "junit-interface" % "0.10" % "test", From ada310a9d3d5419e101b24d9b41398f609da1ad3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 31 Mar 2014 23:01:14 -0700 Subject: [PATCH 05/10] [Hot Fix #42] Persisted RDD disappears on storage page if re-used If a previously persisted RDD is re-used, its information disappears from the Storage page. This is because the tasks associated with re-using the RDD do not report the RDD's blocks as updated (which is correct). On stage submit, however, we overwrite any existing information regarding that RDD with a fresh one, whether or not the information for the RDD already exists. Author: Andrew Or Closes #281 from andrewor14/ui-storage-fix and squashes the following commits: 408585a [Andrew Or] Fix storage UI bug --- .../main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index 4d8b01dbe6e1b..a7b24ff695214 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -84,7 +84,7 @@ private[ui] class BlockManagerListener(storageStatusListener: StorageStatusListe override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { val rddInfo = stageSubmitted.stageInfo.rddInfo - _rddInfoMap(rddInfo.id) = rddInfo + _rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo) } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { From f5c418da044ef7f3d7185cc5bb1bef79d7f4e25c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 1 Apr 2014 14:45:44 -0700 Subject: [PATCH 06/10] [SQL] SPARK-1372 Support for caching and uncaching tables in a SQLContext. This doesn't yet support different databases in Hive (though you can probably workaround this by calling `USE `). However, given the time constraints for 1.0 I think its probably worth including this now and extending the functionality in the next release. Author: Michael Armbrust Closes #282 from marmbrus/cacheTables and squashes the following commits: 83785db [Michael Armbrust] Support for caching and uncaching tables in a SQLContext. --- .../spark/sql/catalyst/analysis/Catalog.scala | 11 +++- .../org/apache/spark/sql/SQLContext.scala | 32 +++++++++- .../apache/spark/sql/CachedTableSuite.scala | 61 +++++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 +++ ...d table-0-ce3797dc14a603cba2a5e58c8612de5b | 1 + ...d table-0-ce3797dc14a603cba2a5e58c8612de5b | 1 + .../spark/sql/hive/CachedTableSuite.scala | 58 ++++++++++++++++++ 7 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala create mode 100644 sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b create mode 100644 sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index e09182dd8d5df..6b58b9322c4bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -31,6 +31,7 @@ trait Catalog { alias: Option[String] = None): LogicalPlan def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit + def unregisterTable(databaseName: Option[String], tableName: String): Unit } class SimpleCatalog extends Catalog { @@ -40,7 +41,7 @@ class SimpleCatalog extends Catalog { tables += ((tableName, plan)) } - def dropTable(tableName: String) = tables -= tableName + def unregisterTable(databaseName: Option[String], tableName: String) = { tables -= tableName } def lookupRelation( databaseName: Option[String], @@ -87,6 +88,10 @@ trait OverrideCatalog extends Catalog { plan: LogicalPlan): Unit = { overrides.put((databaseName, tableName), plan) } + + override def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + overrides.remove((databaseName, tableName)) + } } /** @@ -104,4 +109,8 @@ object EmptyCatalog extends Catalog { def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } + + def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + throw new UnsupportedOperationException + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f950ea08ec57a..69bbbdc8943fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution._ /** @@ -111,6 +112,35 @@ class SQLContext(@transient val sparkContext: SparkContext) result } + /** Returns the specified table as a SchemaRDD */ + def table(tableName: String): SchemaRDD = + new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = { + val currentTable = catalog.lookupRelation(None, tableName) + val asInMemoryRelation = + InMemoryColumnarTableScan(currentTable.output, executePlan(currentTable).executedPlan) + + catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation)) + } + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = { + EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match { + // This is kind of a hack to make sure that if this was just an RDD registered as a table, + // we reregister the RDD as a table. + case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd)) => + inMem.cachedColumnBuffers.unpersist() + catalog.unregisterTable(None, tableName) + catalog.registerTable(None, tableName, SparkLogicalPlan(e)) + case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) => + inMem.cachedColumnBuffers.unpersist() + catalog.unregisterTable(None, tableName) + case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") + } + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext = self.sparkContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala new file mode 100644 index 0000000000000..e5902c3cae381 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan + +class CachedTableSuite extends QueryTest { + TestData // Load test tables. + + test("read from cached table and uncache") { + TestSQLContext.cacheTable("testData") + + checkAnswer( + TestSQLContext.table("testData"), + testData.collect().toSeq + ) + + TestSQLContext.table("testData").queryExecution.analyzed match { + case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case noCache => fail(s"No cache node found in plan $noCache") + } + + TestSQLContext.uncacheTable("testData") + + checkAnswer( + TestSQLContext.table("testData"), + testData.collect().toSeq + ) + + TestSQLContext.table("testData").queryExecution.analyzed match { + case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + fail(s"Table still cached after uncache: $cachePlan") + case noCache => // Table uncached successfully + } + } + + test("correct error on uncache of non-cached table") { + intercept[IllegalArgumentException] { + TestSQLContext.uncacheTable("testData") + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4f8353666a12b..29834a11f41dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -141,6 +141,13 @@ class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { */ override def registerTable( databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ??? + + /** + * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. + * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. + */ + override def unregisterTable( + databaseName: Option[String], tableName: String): Unit = ??? } object HiveMetastoreTypes extends RegexParsers { diff --git a/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala new file mode 100644 index 0000000000000..68d45e53cdf26 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.execution.HiveComparisonTest + +class CachedTableSuite extends HiveComparisonTest { + TestHive.loadTestTable("src") + + test("cache table") { + TestHive.cacheTable("src") + } + + createQueryTest("read from cached table", + "SELECT * FROM src LIMIT 1") + + test("check that table is cached and uncache") { + TestHive.table("src").queryExecution.analyzed match { + case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case noCache => fail(s"No cache node found in plan $noCache") + } + TestHive.uncacheTable("src") + } + + createQueryTest("read from uncached table", + "SELECT * FROM src LIMIT 1") + + test("make sure table is uncached") { + TestHive.table("src").queryExecution.analyzed match { + case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + fail(s"Table still cached after uncache: $cachePlan") + case noCache => // Table uncached successfully + } + } + + test("correct error on uncache of non-cached table") { + intercept[IllegalArgumentException] { + TestHive.uncacheTable("src") + } + } +} From 764353d2c5162352781c273dd3d4af6a309190c7 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Tue, 1 Apr 2014 18:35:50 -0700 Subject: [PATCH 07/10] [SPARK-1342] Scala 2.10.4 Just a Scala version increment Author: Mark Hamstra Closes #259 from markhamstra/scala-2.10.4 and squashes the following commits: fbec547 [Mark Hamstra] [SPARK-1342] Bumped Scala version to 2.10.4 --- core/pom.xml | 2 +- dev/audit-release/README.md | 2 +- dev/audit-release/audit_release.py | 2 +- docker/spark-test/base/Dockerfile | 2 +- docs/_config.yml | 2 +- pom.xml | 4 ++-- project/SparkBuild.scala | 2 +- project/plugins.sbt | 2 +- project/project/SparkPluginBuild.scala | 2 +- sql/README.md | 2 +- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index eb6cc4d3105e9..e4c32eff0cd77 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -150,7 +150,7 @@ json4s-jackson_${scala.binary.version} 3.2.6 diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md index 2437a98672177..38becda0eae92 100644 --- a/dev/audit-release/README.md +++ b/dev/audit-release/README.md @@ -4,7 +4,7 @@ run them locally by setting appropriate environment variables. ``` $ cd sbt_app_core -$ SCALA_VERSION=2.10.3 \ +$ SCALA_VERSION=2.10.4 \ SPARK_VERSION=1.0.0-SNAPSHOT \ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ sbt run diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 52c367d9b030d..fa2f02dfecc75 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -35,7 +35,7 @@ RELEASE_KEY = "9E4FE3AF" RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/" RELEASE_VERSION = "1.0.0" -SCALA_VERSION = "2.10.3" +SCALA_VERSION = "2.10.4" SCALA_BINARY_VERSION = "2.10" ## diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index e543db6143e4d..5956d59130fbf 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update # install a few other useful packages plus Open Jdk 7 RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server -ENV SCALA_VERSION 2.10.3 +ENV SCALA_VERSION 2.10.4 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/docs/_config.yml b/docs/_config.yml index aa5a5adbc1743..d585b8c5ea763 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -6,7 +6,7 @@ markdown: kramdown SPARK_VERSION: 1.0.0-SNAPSHOT SPARK_VERSION_SHORT: 1.0.0 SCALA_BINARY_VERSION: "2.10" -SCALA_VERSION: "2.10.3" +SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.13.0 SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/pom.xml b/pom.xml index 09a449d81453f..7d58060cba606 100644 --- a/pom.xml +++ b/pom.xml @@ -110,7 +110,7 @@ 1.6 - 2.10.3 + 2.10.4 2.10 0.13.0 org.spark-project.akka @@ -380,7 +380,7 @@ lift-json_${scala.binary.version} 2.5.1 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7457ff456ade4..c5c697e8e2427 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -152,7 +152,7 @@ object SparkBuild extends Build { def sharedSettings = Defaults.defaultSettings ++ MimaBuild.mimaSettings(file(sparkHome)) ++ Seq( organization := "org.apache.spark", version := SPARK_VERSION, - scalaVersion := "2.10.3", + scalaVersion := "2.10.4", scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation", "-target:" + SCALAC_JVM_VERSION), javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION), diff --git a/project/plugins.sbt b/project/plugins.sbt index 5aa8a1ec2409b..d787237ddc540 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,4 +1,4 @@ -scalaVersion := "2.10.3" +scalaVersion := "2.10.4" resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 5a307044ba123..0142256e90fb7 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -32,7 +32,7 @@ object SparkPluginDef extends Build { name := "spark-style", organization := "org.apache.spark", version := sparkVersion, - scalaVersion := "2.10.3", + scalaVersion := "2.10.4", scalacOptions := Seq("-unchecked", "-deprecation"), libraryDependencies ++= Dependencies.scalaStyle ) diff --git a/sql/README.md b/sql/README.md index 4192fecb92fb0..14d5555f0c713 100644 --- a/sql/README.md +++ b/sql/README.md @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.TestHive._ -Welcome to Scala version 2.10.3 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). +Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). Type in expressions to have them evaluated. Type :help for more information. From afb5ea62786e3ca055e247176def3e7ecf0d2c9d Mon Sep 17 00:00:00 2001 From: Diana Carroll Date: Tue, 1 Apr 2014 19:29:26 -0700 Subject: [PATCH 08/10] [Spark-1134] only call ipython if no arguments are given; remove IPYTHONOPTS from call see comments on Pull Request https://github.com/apache/spark/pull/38 (i couldn't figure out how to modify an existing pull request, so I'm hoping I can withdraw that one and replace it with this one.) Author: Diana Carroll Closes #227 from dianacarroll/spark-1134 and squashes the following commits: ffe47f2 [Diana Carroll] [spark-1134] remove ipythonopts from ipython command b673bf7 [Diana Carroll] Merge branch 'master' of github.com:apache/spark 0309cf9 [Diana Carroll] SPARK-1134 bug with ipython prevents non-interactive use with spark; only call ipython if no command line arguments were supplied --- bin/pyspark | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 67e1f61eeb1e5..7932a247b54d0 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -55,8 +55,9 @@ if [ -n "$IPYTHON_OPTS" ]; then IPYTHON=1 fi -if [[ "$IPYTHON" = "1" ]] ; then - exec ipython $IPYTHON_OPTS +# Only use ipython if no command line arguments were provided [SPARK-1134] +if [[ "$IPYTHON" = "1" && $# = 0 ]] ; then + exec ipython else exec "$PYSPARK_PYTHON" "$@" fi From 45df9127365f8942794273b8ada004bf6ea3ef10 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 1 Apr 2014 19:31:50 -0700 Subject: [PATCH 09/10] Revert "[Spark-1134] only call ipython if no arguments are given; remove IPYTHONOPTS from call" This reverts commit afb5ea62786e3ca055e247176def3e7ecf0d2c9d. --- bin/pyspark | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 7932a247b54d0..67e1f61eeb1e5 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -55,9 +55,8 @@ if [ -n "$IPYTHON_OPTS" ]; then IPYTHON=1 fi -# Only use ipython if no command line arguments were provided [SPARK-1134] -if [[ "$IPYTHON" = "1" && $# = 0 ]] ; then - exec ipython +if [[ "$IPYTHON" = "1" ]] ; then + exec ipython $IPYTHON_OPTS else exec "$PYSPARK_PYTHON" "$@" fi From 8b3045ceab591a3f3ca18823c7e2c5faca38a06e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 1 Apr 2014 21:40:49 -0700 Subject: [PATCH 10/10] MLI-1 Decision Trees Joint work with @hirakendu, @etrain, @atalwalkar and @harsha2010. Key features: + Supports binary classification and regression + Supports gini, entropy and variance for information gain calculation + Supports both continuous and categorical features The algorithm has gone through several development iterations over the last few months leading to a highly optimized implementation. Optimizations include: 1. Level-wise training to reduce passes over the entire dataset. 2. Bin-wise split calculation to reduce computation overhead. 3. Aggregation over partitions before combining to reduce communication overhead. Author: Manish Amde Author: manishamde Author: Xiangrui Meng Closes #79 from manishamde/tree and squashes the following commits: 1e8c704 [Manish Amde] remove numBins field in the Strategy class 7d54b4f [manishamde] Merge pull request #4 from mengxr/dtree f536ae9 [Xiangrui Meng] another pass on code style e1dd86f [Manish Amde] implementing code style suggestions 62dc723 [Manish Amde] updating javadoc and converting helper methods to package private to allow unit testing 201702f [Manish Amde] making some more methods private f963ef5 [Manish Amde] making methods private c487e6a [manishamde] Merge pull request #1 from mengxr/dtree 24500c5 [Xiangrui Meng] minor style updates 4576b64 [Manish Amde] documentation and for to while loop conversion ff363a7 [Manish Amde] binary search for bins and while loop for categorical feature bins 632818f [Manish Amde] removing threshold for classification predict method 2116360 [Manish Amde] removing dummy bin calculation for categorical variables 6068356 [Manish Amde] ensuring num bins is always greater than max number of categories 62c2562 [Manish Amde] fixing comment indentation ad1fc21 [Manish Amde] incorporated mengxr's code style suggestions d1ef4f6 [Manish Amde] more documentation 794ff4d [Manish Amde] minor improvements to docs and style eb8fcbe [Manish Amde] minor code style updates cd2c2b4 [Manish Amde] fixing code style based on feedback 63e786b [Manish Amde] added multiple train methods for java compatability d3023b3 [Manish Amde] adding more docs for nested methods 84f85d6 [Manish Amde] code documentation 9372779 [Manish Amde] code style: max line lenght <= 100 dd0c0d7 [Manish Amde] minor: some docs 0dd7659 [manishamde] basic doc 5841c28 [Manish Amde] unit tests for categorical features f067d68 [Manish Amde] minor cleanup c0e522b [Manish Amde] updated predict and split threshold logic b09dc98 [Manish Amde] minor refactoring 6b7de78 [Manish Amde] minor refactoring and tests d504eb1 [Manish Amde] more tests for categorical features dbb7ac1 [Manish Amde] categorical feature support 6df35b9 [Manish Amde] regression predict logic 53108ed [Manish Amde] fixing index for highest bin e23c2e5 [Manish Amde] added regression support c8f6d60 [Manish Amde] adding enum for feature type b0e3e76 [Manish Amde] adding enum for feature type 154aa77 [Manish Amde] enums for configurations 733d6dd [Manish Amde] fixed tests 02c595c [Manish Amde] added command line parsing 98ec8d5 [Manish Amde] tree building and prediction logic b0eb866 [Manish Amde] added logic to handle leaf nodes 80e8c66 [Manish Amde] working version of multi-level split calculation 4798aae [Manish Amde] added gain stats class dad0afc [Manish Amde] decison stump functionality working 03f534c [Manish Amde] some more tests 0012a77 [Manish Amde] basic stump working 8bca1e2 [Manish Amde] additional code for creating intermediate RDD 92cedce [Manish Amde] basic building blocks for intermediate RDD calculation. untested. cd53eae [Manish Amde] skeletal framework --- .../spark/mllib/tree/DecisionTree.scala | 1150 +++++++++++++++++ .../org/apache/spark/mllib/tree/README.md | 17 + .../spark/mllib/tree/configuration/Algo.scala | 26 + .../tree/configuration/FeatureType.scala | 26 + .../tree/configuration/QuantileStrategy.scala | 26 + .../mllib/tree/configuration/Strategy.scala | 43 + .../spark/mllib/tree/impurity/Entropy.scala | 47 + .../spark/mllib/tree/impurity/Gini.scala | 46 + .../spark/mllib/tree/impurity/Impurity.scala | 42 + .../spark/mllib/tree/impurity/Variance.scala | 37 + .../apache/spark/mllib/tree/model/Bin.scala | 33 + .../mllib/tree/model/DecisionTreeModel.scala | 49 + .../spark/mllib/tree/model/Filter.scala | 28 + .../tree/model/InformationGainStats.scala | 39 + .../apache/spark/mllib/tree/model/Node.scala | 90 ++ .../apache/spark/mllib/tree/model/Split.scala | 64 + .../spark/mllib/tree/DecisionTreeSuite.scala | 425 ++++++ 17 files changed, 2188 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/README.md create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala new file mode 100644 index 0000000000000..33205b919db8f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -0,0 +1,1150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.util.control.Breaks._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * A class that implements a decision tree algorithm for classification and regression. It + * supports both continuous and categorical features. + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + */ +class DecisionTree private(val strategy: Strategy) extends Serializable with Logging { + + /** + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * @return a DecisionTreeModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + + // Cache input RDD for speedup during multiple passes. + input.cache() + logDebug("algo = " + strategy.algo) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + logDebug("numSplits = " + bins(0).length) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + // the max number of nodes possible given the depth of the tree + val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + // Initialize an array to hold filters applied to points for each node. + val filters = new Array[List[Filter]](maxNumNodes) + // The filter at the top node is an empty list. + filters(0) = List() + // Initialize an array to hold parent impurity calculations for each node. + val parentImpurities = new Array[Double](maxNumNodes) + // dummy value for top node (updated during first split calculation) + val nodes = new Array[Node](maxNumNodes) + + + /* + * The main idea here is to perform level-wise training of the decision tree nodes thus + * reducing the passes over the data from l to log2(l) where l is the total number of nodes. + * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., + * the sample is only used for the split calculation at the node if the sampled would have + * still survived the filters of the parent nodes. + */ + + // TODO: Convert for loop to while loop + breakable { + for (level <- 0 until maxDepth) { + + logDebug("#####################################") + logDebug("level = " + level) + logDebug("#####################################") + + // Find best split for all nodes at a level. + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, + level, filters, splits, bins) + + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + // Extract info for nodes at the current level. + extractNodeInfo(nodeSplitStats, level, index, nodes) + // Extract info for nodes at the next lower level. + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, + filters) + logDebug("final best split = " + nodeSplitStats._1) + } + require(scala.math.pow(2, level) == splitsStatsForLevel.length) + // Check whether all the nodes at the current level at leaves. + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) + logDebug("all leaf = " + allLeaf) + if (allLeaf) break // no more tree construction + } + } + + // Initialize the top or root node of the tree. + val topNode = nodes(0) + // Build the full tree using the node info calculated in the level-wise best split calculations. + topNode.build(nodes) + + new DecisionTreeModel(topNode, strategy.algo) + } + + /** + * Extract the decision tree node information for the given tree level and node index + */ + private def extractNodeInfo( + nodeSplitStats: (Split, InformationGainStats), + level: Int, + index: Int, + nodes: Array[Node]): Unit = { + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node + } + + /** + * Extract the decision tree node information for the children of the node + */ + private def extractInfoForLowerLevels( + level: Int, + index: Int, + maxDepth: Int, + nodeSplitStats: (Split, InformationGainStats), + parentImpurities: Array[Double], + filters: Array[List[Filter]]): Unit = { + // 0 corresponds to the left child node and 1 corresponds to the right child node. + // TODO: Convert to while loop + for (i <- 0 to 1) { + // Calculate the index of the node from the node level and the index at the current level. + val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + if (level < maxDepth - 1) { + val impurity = if (i == 0) { + nodeSplitStats._2.leftImpurity + } else { + nodeSplitStats._2.rightImpurity + } + logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + // noting the parent impurities + parentImpurities(nodeIndex) = impurity + // noting the parents filters for the child nodes + val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) + filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) + for (filter <- filters(nodeIndex)) { + logDebug("Filter = " + filter) + } + } + } + } +} + +object DecisionTree extends Serializable with Logging { + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. The parameters for the algorithm are specified using the strategy parameter. + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + * @return a DecisionTreeModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int): DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The decision tree method supports binary classification and + * regression. For the binary classification, the label for each instance should either be 0 or + * 1 to denote the two classes. The method also supports categorical features inputs where the + * number of categories can specified using the categoricalFeaturesInfo option. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data for DecisionTree + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + private val InvalidBinIndex = -1 + + /** + * Returns an array of optimal splits for all nodes at a given level + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @return array of splits with best splits for all nodes at a given level. + */ + protected[tree] def findBestSplits( + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + + /* + * The high-level description for the best split optimizations are noted here. + * + * *Level-wise training* + * We perform bin calculations for all nodes at the given level to avoid making multiple + * passes over the data. Thus, for a slightly increased computation and storage cost we save + * several iterations over the data especially at higher levels of the decision tree. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * is ordered (read ordering for categorical variables in the findSplitsBins method), + * we exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // common calculations for multiple nested methods + val numNodes = scala.math.pow(2, level).toInt + logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. + val numFeatures = input.first().features.length + logDebug("numFeatures = " + numFeatures) + val numBins = bins(0).length + logDebug("numBins = " + numBins) + + /** Find the filters used before reaching the current code. */ + def findParentFilters(nodeIndex: Int): List[Filter] = { + if (level == 0) { + List[Filter]() + } else { + val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + filters(nodeFilterIndex) + } + } + + /** + * Find whether the sample is valid input for the current node, i.e., whether it passes through + * all the filters for the current node. + */ + def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + // leaf + if ((level > 0) & (parentFilters.length == 0)) { + return false + } + + // Apply each filter and check sample validity. Return false when invalid condition found. + for (filter <- parentFilters) { + val features = labeledPoint.features + val featureIndex = filter.split.feature + val threshold = filter.split.threshold + val comparison = filter.comparison + val categories = filter.split.categories + val isFeatureContinuous = filter.split.featureType == Continuous + val feature = features(featureIndex) + if (isFeatureContinuous) { + comparison match { + case -1 => if (feature > threshold) return false + case 1 => if (feature <= threshold) return false + } + } else { + val containsFeature = categories.contains(feature) + comparison match { + case -1 => if (!containsFeature) return false + case 1 => if (containsFeature) return false + } + + } + } + + // Return true when the sample is valid for all filters. + true + } + + /** + * Find bin for one feature. + */ + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) & (highThreshold >= feature)){ + return mid + } + else if (lowThreshold >= feature) { + right = mid - 1 + } + else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature. + */ + def sequentialBinSearchForCategoricalFeature(): Int = { + val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val bin = bins(featureIndex)(binIndex) + val category = bin.category + val features = labeledPoint.features + if (category == features(featureIndex)) { + return binIndex + } + binIndex += 1 + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1){ + throw new UnknownError("no bin was found for continuous variable.") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = sequentialBinSearchForCategoricalFeature() + if (binIndex == -1){ + throw new UnknownError("no bin was found for categorical variable.") + } + binIndex + } + } + + /** + * Finds bins for all nodes (and all features) at a given level. + * For l nodes, k features the storage is as follows: + * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, + * where b_ij is an integer between 0 and numBins - 1. + * Invalid sample is denoted by noting bin for feature 1 as -1. + */ + def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + // Calculate bin index and label per feature per node. + val arr = new Array[Double](1 + (numFeatures * numNodes)) + arr(0) = labeledPoint.label + var nodeIndex = 0 + while (nodeIndex < numNodes) { + val parentFilters = findParentFilters(nodeIndex) + // Find out whether the sample qualifies for the particular node. + val sampleValid = isSampleValid(parentFilters, labeledPoint) + val shift = 1 + numFeatures * nodeIndex + if (!sampleValid) { + // Mark one bin as -1 is sufficient. + arr(shift) = InvalidBinIndex + } else { + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + featureIndex += 1 + } + } + nodeIndex += 1 + } + arr + } + + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures * numNodes for classification + */ + def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = 2 * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 + label match { + case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 + case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + } + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. + * + * @param agg Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for regression + */ + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update count, sum, and sum^2 for one bin. + val aggShift = 3 * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition. + */ + def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { + strategy.algo match { + case Classification => classificationBinSeqOp(arr, agg) + case Regression => regressionBinSeqOp(arr, agg) + } + agg + } + + // Calculate bin aggregate length for classification or regression. + val binAggregateLength = strategy.algo match { + case Classification => 2 * numBins * numFeatures * numNodes + case Regression => 3 * numBins * numFeatures * numNodes + } + logDebug("binAggregateLength = " + binAggregateLength) + + /** + * Combines the aggregates from partitions. + * @param agg1 Array containing aggregates from one or more partitions + * @param agg2 Array containing aggregates from one or more partitions + * @return Combined aggregate from agg1 and agg2 + */ + def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { + var index = 0 + val combinedAggregate = new Array[Double](binAggregateLength) + while (index < binAggregateLength) { + combinedAggregate(index) = agg1(index) + agg2(index) + index += 1 + } + combinedAggregate + } + + // Find feature bins for all nodes at a level. + val binMappedRDD = input.map(x => findBinsForLevel(x)) + + // Calculate bin aggregates. + val binAggregates = { + binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + } + logDebug("binAggregates.length = " + binAggregates.length) + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + def calculateGainForSplit( + leftNodeAgg: Array[Array[Double]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Double]], + topImpurity: Double): InformationGainStats = { + strategy.algo match { + case Classification => + val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) + val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) + val leftCount = left0Count + left1Count + + val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) + val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) + val rightCount = right0Count + right1Count + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) + } + + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (left1Count + right1Count) / (leftCount + rightCount) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + case Regression => + val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) + val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) + val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) + + val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) + val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) + val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + strategy.impurity.calculate(count, sum, sumSquares) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity ,topImpurity, + Double.MinValue, leftSum / leftCount) + } + + val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + } + } + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], + * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) + */ + def extractLeftRightNodeAggregates( + binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + strategy.algo match { + case Classification => + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 2 * featureIndex * numBins + + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(2 * (numBins - 2)) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) + = binData(shift + (2 * (numBins - 1)) + 1) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = + binData(shift + (2 *(numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + case Regression => + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(3 * (numBins - 2)) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3) + leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) + leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = + binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } + } + + /** + * Calculates information gain for all nodes splits. + */ + def calculateGainsForAllNodeSplits( + leftNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Double]], + nodeImpurity: Double): Array[Array[InformationGainStats]] = { + val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) + + for (featureIndex <- 0 until numFeatures) { + for (splitIndex <- 0 until numBins - 1) { + gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, + splitIndex, rightNodeAgg, nodeImpurity) + } + } + gains + } + + /** + * Find the best split for a node. + * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param nodeImpurity impurity of the top node + * @return tuple of split and information gain + */ + def binsToBestSplit( + binData: Array[Double], + nodeImpurity: Double): (Split, InformationGainStats) = { + + logDebug("node impurity = " + nodeImpurity) + + // Extract left right node aggregates. + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + + // Calculate gains for all splits. + val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) + + val (bestFeatureIndex,bestSplitIndex, gainStats) = { + // Initialize with infeasible values. + var bestFeatureIndex = Int.MinValue + var bestSplitIndex = Int.MinValue + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) + // Iterate over features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Iterate over all splits. + var splitIndex = 0 + while (splitIndex < numBins - 1) { + val gainStats = gains(featureIndex)(splitIndex) + if (gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + } + splitIndex += 1 + } + featureIndex += 1 + } + (bestFeatureIndex, bestSplitIndex, bestGainStats) + } + + logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) + + (splits(bestFeatureIndex)(bestSplitIndex), gainStats) + } + + /** + * Get bin data for one node. + */ + def getBinDataForNode(node: Int): Array[Double] = { + strategy.algo match { + case Classification => + val shift = 2 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) + binsForNode + case Regression => + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) + binsForNode + } + } + + // Calculate best splits for all nodes at a given level + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + // Iterating over all nodes at this level + var node = 0 + while (node < numNodes) { + val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + val binsForNode: Array[Double] = getBinDataForNode(node) + logDebug("nodeImpurityIndex = " + nodeImpurityIndex) + val parentNodeImpurity = parentImpurities(nodeImpurityIndex) + logDebug("node impurity = " + parentNodeImpurity) + bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) + node += 1 + } + + bestSplits + } + + /** + * Returns split and bins for decision tree calculation. + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree + * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache + * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + */ + protected[tree] def findSplitsBins( + input: RDD[LabeledPoint], + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() + + // Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.length + + val maxBins = strategy.maxBins + val numBins = if (maxBins <= count) maxBins else count.toInt + logDebug("numBins = " + numBins) + + /* + * TODO: Add a require statement ensuring #bins is always greater than the categories. + * It's a limitation of the current implementation but a reasonable trade-off since features + * with large number of categories get favored over continuous features. + */ + if (strategy.categoricalFeaturesInfo.size > 0) { + val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + require(numBins >= maxCategoriesForFeatures) + } + + // Calculate the number of sample for approximate quantile calculation. + val requiredSamples = numBins*numBins + val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 + logDebug("fraction of data used for calculating quantiles = " + fraction) + + // sampled input for RDD calculation + val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val numSamples = sampledInput.length + + val stride: Double = numSamples.toDouble / numBins + logDebug("stride = " + stride) + + strategy.quantileCalculationStrategy match { + case Sort => + val splits = Array.ofDim[Split](numFeatures, numBins - 1) + val bins = Array.ofDim[Bin](numFeatures, numBins) + + // Find all splits. + + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures){ + // Check whether the feature is continuous. + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + val stride: Double = numSamples.toDouble / numBins + logDebug("stride = " + stride) + for (index <- 0 until numBins - 1) { + val sampleIndex = (index + 1) * stride.toInt + val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + splits(featureIndex)(index) = split + } + } else { + val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + require(maxFeatureValue < numBins, "number of categories should be less than number " + + "of bins") + + // For categorical variables, each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + val centroidForCategories = + sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until maxFeatureValue) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, + categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } + } + featureIndex += 1 + } + + // Find all bins. + featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { // Bins for categorical variables are already assigned. + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) + for (index <- 1 until numBins - 1){ + val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Continuous, Double.MinValue) + bins(featureIndex)(index) = bin + } + bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) + } + featureIndex += 1 + } + (splits,bins) + case MinMax => + throw new UnsupportedOperationException("minmax not supported yet.") + case ApproxHist => + throw new UnsupportedOperationException("approximate histogram not supported yet.") + } + } + + val usage = """ + Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] + """ + + def main(args: Array[String]) { + + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + + val sc = new SparkContext(args(0), "DecisionTree") + + val argList = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]): OptionMap = { + list match { + case Nil => map + case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string) + , tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), + tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => logError("Unknown option " + option) + sys.exit(1) + } + } + val options = nextOption(Map(), argList) + logDebug(options.toString()) + + // Load training data. + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) + + // Identify the type of algorithm. + val algoStr = options.get('algo).get.toString + val algo = algoStr match { + case "Classification" => Classification + case "Regression" => Regression + } + + // Identify the type of impurity. + val impurityStr = options.getOrElse('impurity, + if (algo == Classification) "Gini" else "Variance").toString + val impurity = impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + + val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt + val maxBins = options.getOrElse('maxBins, "100").toString.toInt + + val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val model = DecisionTree.train(trainData, strategy) + + // Load test data. + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + + // Measure algorithm accuracy + if (algo == Classification) { + val accuracy = accuracyScore(model, testData) + logDebug("accuracy = " + accuracy) + } + + if (algo == Regression) { + val mse = meanSquaredError(model, testData) + logDebug("mean square error = " + mse) + } + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ..., + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + LabeledPoint(label, features) + } + } + + // TODO: Port this method to a generic metrics package. + /** + * Calculates the classifier accuracy. + */ + private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], + threshold: Double = 0.5): Double = { + def predictedValue(features: Array[Double]) = { + if (model.predict(features) < threshold) 0.0 else 1.0 + } + val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + val count = data.count() + logDebug("correct prediction count = " + correctCount) + logDebug("data count = " + count) + correctCount.toDouble / count + } + + // TODO: Port this method to a generic metrics package + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md new file mode 100644 index 0000000000000..0fd71aa9735bc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md @@ -0,0 +1,17 @@ +This package contains the default implementation of the decision tree algorithm. + +The decision tree algorithm supports: ++ Binary classification ++ Regression ++ Information loss calculation with entropy and gini for classification and variance for regression ++ Both continuous and categorical features + +# Tree improvements ++ Node model pruning ++ Printing to dot files + +# Future Ensemble Extensions + ++ Random forests ++ Boosting ++ Extremely randomized trees diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala new file mode 100644 index 0000000000000..2dd1f0f27b8f5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum to select the algorithm for the decision tree + */ +object Algo extends Enumeration { + type Algo = Value + val Classification, Regression = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala new file mode 100644 index 0000000000000..09ee0586c58fa --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum to describe whether a feature is "continuous" or "categorical" + */ +object FeatureType extends Enumeration { + type FeatureType = Value + val Continuous, Categorical = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala new file mode 100644 index 0000000000000..2457a480c2a14 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum for selecting the quantile calculation strategy + */ +object QuantileStrategy extends Enumeration { + type QuantileStrategy = Value + val Sort, MinMax, ApproxHist = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala new file mode 100644 index 0000000000000..df565f3eb8859 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ + +/** + * Stores all the configuration options for tree construction + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ +class Strategy ( + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala new file mode 100644 index 0000000000000..b93995fcf9441 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during + * binary classification. + */ +object Entropy extends Impurity { + + def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + /** + * entropy calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return entropy value + */ + def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Entropy.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala new file mode 100644 index 0000000000000..c0407554a91b3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating the + * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] + * during binary classification. + */ +object Gini extends Impurity { + + /** + * Gini coefficient calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return Gini coefficient value + */ + override def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0 * f0 - f1 * f1 + } + } + + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Gini.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala new file mode 100644 index 0000000000000..a4069063af2ad --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Trait for calculating information gain. + */ +trait Impurity extends Serializable { + + /** + * information calculation for binary classification + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return information value + */ + def calculate(c0 : Double, c1 : Double): Double + + /** + * information calculation for regression + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value + */ + def calculate(count: Double, sum: Double, sumSquares: Double): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala new file mode 100644 index 0000000000000..b74577dcec167 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating variance during regression + */ +object Variance extends Impurity { + override def calculate(c0: Double, c1: Double): Double = + throw new UnsupportedOperationException("Variance.calculate") + + /** + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + val squaredLoss = sumSquares - (sum * sum) / count + squaredLoss / count + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala new file mode 100644 index 0000000000000..a57faa13745f7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +/** + * Used for "binning" the features bins for faster best split calculation. For a continuous + * feature, a bin is determined by a low and a high "split". For a categorical feature, + * the a bin is determined using a single label value (category). + * @param lowSplit signifying the lower threshold for the continuous feature to be + * accepted in the bin + * @param highSplit signifying the upper threshold for the continuous feature to be + * accepted in the bin + * @param featureType type of feature -- categorical or continuous + * @param category categorical label value accepted in the bin + */ +case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala new file mode 100644 index 0000000000000..a8bbf21daec01 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD + +/** + * Model to store the decision tree parameters + * @param topNode root node + * @param algo algorithm type -- classification or regression + */ +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features: Array[Double]): Double = { + topNode.predictIfLeaf(features) + } + + /** + * Predict values for the given data set using the model trained. + * + * @param features RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Array[Double]]): RDD[Double] = { + features.map(x => predict(x)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala new file mode 100644 index 0000000000000..ebc9595eafef3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +/** + * Filter specifying a split and type of comparison to be applied on features + * @param split split specifying the feature index, type and threshold + * @param comparison integer specifying <,=,> + */ +case class Filter(split: Split, comparison: Int) { + // Comparison -1,0,1 signifies <.=,> + override def toString = " split = " + split + "comparison = " + comparison +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala new file mode 100644 index 0000000000000..99bf79cf12e45 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +/** + * Information gain statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param leftImpurity left node impurity + * @param rightImpurity right node impurity + * @param predict predicted value + */ +class InformationGainStats( + val gain: Double, + val impurity: Double, + val leftImpurity: Double, + val rightImpurity: Double, + val predict: Double) extends Serializable { + + override def toString = { + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala new file mode 100644 index 0000000000000..ea4693c5c2f4e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.Logging +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +/** + * Node in a decision tree + * @param id integer node id + * @param predict predicted value at the node + * @param isLeaf whether the leaf is a node + * @param split split to calculate left and right nodes + * @param leftNode left child + * @param rightNode right child + * @param stats information gain stats + */ +class Node ( + val id: Int, + val predict: Double, + val isLeaf: Boolean, + val split: Option[Split], + var leftNode: Option[Node], + var rightNode: Option[Node], + val stats: Option[InformationGainStats]) extends Serializable with Logging { + + override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + + "split = " + split + ", stats = " + stats + + /** + * build the left node and right nodes if not leaf + * @param nodes array of nodes + */ + def build(nodes: Array[Node]): Unit = { + + logDebug("building node " + id + " at level " + + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("id = " + id + ", split = " + split) + logDebug("stats = " + stats) + logDebug("predict = " + predict) + if (!isLeaf) { + val leftNodeIndex = id*2 + 1 + val rightNodeIndex = id*2 + 2 + leftNode = Some(nodes(leftNodeIndex)) + rightNode = Some(nodes(rightNodeIndex)) + leftNode.get.build(nodes) + rightNode.get.build(nodes) + } + } + + /** + * predict value if node is not leaf + * @param feature feature value + * @return predicted value + */ + def predictIfLeaf(feature: Array[Double]) : Double = { + if (isLeaf) { + predict + } else{ + if (split.get.featureType == Continuous) { + if (feature(split.get.feature) <= split.get.threshold) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } else { + if (split.get.categories.contains(feature(split.get.feature))) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala new file mode 100644 index 0000000000000..4e64a81dda74e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType + +/** + * Split applied to a feature + * @param feature feature index + * @param threshold threshold for continuous feature + * @param featureType type of feature -- categorical or continuous + * @param categories accepted values for categorical variables + */ +case class Split( + feature: Int, + threshold: Double, + featureType: FeatureType, + categories: List[Double]){ + + override def toString = + "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + + ", categories = " + categories +} + +/** + * Split with minimum threshold for continuous features. Helps with the smallest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyLowSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MinValue, featureType, List()) + +/** + * Split with maximum threshold for continuous features. Helps with the highest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyHighSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) + +/** + * Split with no acceptable feature values for categorical features. Helps with the first bin + * creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyCategoricalSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala new file mode 100644 index 0000000000000..4349c7000a0ae --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.Filter +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { + + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + test("split and bin calculation") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + } + + test("split and bin calculation for categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2) === null) + + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2) === null) + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2) === null) + + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2) === null) + } + + test("split and bin calculations for categorical variables with no sample for one category") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(2.0)) + + assert(splits(0)(3) === null) + + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 3) + assert(splits(1)(2).categories.contains(1.0)) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(2.0)) + + assert(splits(1)(3) === null) + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2).category === 2.0) + assert(bins(0)(2).lowSplit.categories.length === 2) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).lowSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.length === 3) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.contains(2.0)) + + assert(bins(0)(3) === null) + + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2).category === 2.0) + assert(bins(1)(2).lowSplit.categories.length === 2) + assert(bins(1)(2).lowSplit.categories.contains(0.0)) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length === 3) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(2.0)) + + assert(bins(1)(3) === null) + } + + test("classification stump with all categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + + val split = bestSplits(0)._1 + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } + + test("regression stump with all categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + + val split = bestSplits(0)._1 + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } + + test("stump with fixed label 0 for Gini") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + } + + test("stump with fixed label 1 for Gini") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) + } + + test("stump with fixed label 0 for Entropy") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 0) + } + + test("stump with fixed label 1 for Entropy") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) + } +} + +object DecisionTreeSuite { + + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i)) + arr(i) = lp + } + arr + } + + def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i)) + arr(i) = lp + } + arr + } + + def generateCategoricalDataPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + if (i < 600){ + arr(i) = new LabeledPoint(1.0,Array(0.0,1.0)) + } else { + arr(i) = new LabeledPoint(0.0,Array(1.0,0.0)) + } + } + arr + } +}