}
val stageDataOption = listener.stageIdToData.get(s.stageId)
// Too many nested map/flatMaps with options are just annoying to read. Do this imperatively.
if (stageDataOption.isDefined && stageDataOption.get.description.isDefined) {
val desc = stageDataOption.get.description
-
{desc}
{nameLink} {killLink}
+
{desc}
{killLink} {nameLink} {details}
} else {
{killLink} {nameLink} {details}
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 3448aaaf5724c..bb6079154aafe 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -257,7 +257,8 @@ private[spark] object JsonProtocol {
val reason = Utils.getFormattedClassName(taskEndReason)
val json = taskEndReason match {
case fetchFailed: FetchFailed =>
- val blockManagerAddress = blockManagerIdToJson(fetchFailed.bmAddress)
+ val blockManagerAddress = Option(fetchFailed.bmAddress).
+ map(blockManagerIdToJson).getOrElse(JNothing)
("Block Manager Address" -> blockManagerAddress) ~
("Shuffle ID" -> fetchFailed.shuffleId) ~
("Map ID" -> fetchFailed.mapId) ~
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 71ab2a3e3bef4..be8f6529f7a1c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -106,6 +106,7 @@ class ExternalAppendOnlyMap[K, V, C](
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()
+ private val threadId = Thread.currentThread().getId
/**
* Insert the given key and value into the map.
@@ -128,7 +129,6 @@ class ExternalAppendOnlyMap[K, V, C](
// Atomically check whether there is sufficient memory in the global pool for
// this map to grow and, if possible, allocate the required amount
shuffleMemoryMap.synchronized {
- val threadId = Thread.currentThread().getId
val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
val availableMemory = maxMemoryThreshold -
(shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
@@ -153,8 +153,8 @@ class ExternalAppendOnlyMap[K, V, C](
*/
private def spill(mapSize: Long) {
spillCount += 1
- logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
- .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+ logWarning("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
+ .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
var objectsWritten = 0
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 21c0f3c596a11..a301cbd48a0c3 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -120,6 +120,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
"--archives", "archive1.txt,archive2.txt",
"--num-executors", "6",
"--name", "beauty",
+ "--conf", "spark.shuffle.spill=false",
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
@@ -139,6 +140,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
mainClass should be ("org.apache.spark.deploy.yarn.Client")
classpath should have length (0)
sysProps("spark.app.name") should be ("beauty")
+ sysProps("spark.shuffle.spill") should be ("false")
sysProps("SPARK_SUBMIT") should be ("true")
}
@@ -156,6 +158,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
"--archives", "archive1.txt,archive2.txt",
"--num-executors", "6",
"--name", "trill",
+ "--conf", "spark.shuffle.spill=false",
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
@@ -176,6 +179,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt")
sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar")
sysProps("SPARK_SUBMIT") should be ("true")
+ sysProps("spark.shuffle.spill") should be ("false")
}
test("handles standalone cluster mode") {
@@ -186,6 +190,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
"--supervise",
"--driver-memory", "4g",
"--driver-cores", "5",
+ "--conf", "spark.shuffle.spill=false",
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
@@ -195,11 +200,13 @@ class SparkSubmitSuite extends FunSuite with Matchers {
childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2")
mainClass should be ("org.apache.spark.deploy.Client")
classpath should have size (0)
- sysProps should have size (4)
+ sysProps should have size (5)
+ sysProps.keys should contain ("SPARK_SUBMIT")
sysProps.keys should contain ("spark.master")
sysProps.keys should contain ("spark.app.name")
sysProps.keys should contain ("spark.jars")
- sysProps.keys should contain ("SPARK_SUBMIT")
+ sysProps.keys should contain ("spark.shuffle.spill")
+ sysProps("spark.shuffle.spill") should be ("false")
}
test("handles standalone client mode") {
@@ -210,6 +217,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
"--total-executor-cores", "5",
"--class", "org.SomeClass",
"--driver-memory", "4g",
+ "--conf", "spark.shuffle.spill=false",
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
@@ -220,6 +228,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
classpath(0) should endWith ("thejar.jar")
sysProps("spark.executor.memory") should be ("5g")
sysProps("spark.cores.max") should be ("5")
+ sysProps("spark.shuffle.spill") should be ("false")
}
test("handles mesos client mode") {
@@ -230,6 +239,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
"--total-executor-cores", "5",
"--class", "org.SomeClass",
"--driver-memory", "4g",
+ "--conf", "spark.shuffle.spill=false",
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
@@ -240,6 +250,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
classpath(0) should endWith ("thejar.jar")
sysProps("spark.executor.memory") should be ("5g")
sysProps("spark.cores.max") should be ("5")
+ sysProps("spark.shuffle.spill") should be ("false")
}
test("launch simple application with spark-submit") {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 5dd8de319a654..a0483886f8db3 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -43,7 +43,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
test("seed distribution") {
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
val sampler = new MockSampler
- val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
+ val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L)
assert(sample.distinct().count == 2, "Seeds must be different.")
}
@@ -52,7 +52,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
// We want to make sure there are no concurrency issues.
val rdd = sc.parallelize(0 until 111, 10)
for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
- val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
+ val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler, true)
sampled.zip(sampled).count()
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2924de112934c..6654ec2d7c656 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -523,6 +523,15 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedTopK === nums.sorted(ord).take(5))
}
+ test("sample preserves partitioner") {
+ val partitioner = new HashPartitioner(2)
+ val rdd = sc.parallelize(Seq((0, 1), (2, 3))).partitionBy(partitioner)
+ for (withReplacement <- Seq(true, false)) {
+ val sampled = rdd.sample(withReplacement, 1.0)
+ assert(sampled.partitioner === rdd.partitioner)
+ }
+ }
+
test("takeSample") {
val n = 1000000
val data = sc.parallelize(1 to n, 2)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 86b443b18f2a6..c52368b5514db 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -475,6 +475,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
// Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY
assert(manager.myLocalityLevels.sameElements(
Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY)))
+ FakeRackUtil.cleanUp()
}
test("test RACK_LOCAL tasks") {
@@ -505,6 +506,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
// Offer host2
// Task 1 can be scheduled with RACK_LOCAL
assert(manager.resourceOffer("execB", "host2", RACK_LOCAL).get.index === 1)
+ FakeRackUtil.cleanUp()
}
test("do not emit warning when serialized task is small") {
diff --git a/docs/configuration.md b/docs/configuration.md
index a70007c165442..cb0c65e2d2200 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -42,13 +42,15 @@ val sc = new SparkContext(new SparkConf())
Then, you can supply configuration values at runtime:
{% highlight bash %}
-./bin/spark-submit --name "My fancy app" --master local[4] myApp.jar
+./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false
+ --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar
{% endhighlight %}
The Spark shell and [`spark-submit`](cluster-overview.html#launching-applications-with-spark-submit)
tool support two ways to load configurations dynamically. The first are command line options,
-such as `--master`, as shown above. Running `./bin/spark-submit --help` will show the entire list
-of options.
+such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf`
+flag, but uses special flags for properties that play a part in launching the Spark application.
+Running `./bin/spark-submit --help` will show the entire list of these options.
`bin/spark-submit` will also read configuration options from `conf/spark-defaults.conf`, in which
each line consists of a key and a value separated by whitespace. For example:
@@ -388,6 +390,17 @@ Apart from these, the following properties are also available, and may be useful
case.
+
+
spark.kryo.registrationRequired
+
false
+
+ Whether to require registration with Kryo. If set to 'true', Kryo will throw an exception
+ if an unregistered class is serialized. If set to false (the default), Kryo will write
+ unregistered class names along with each object. Writing class names can cause
+ significant performance overhead, so enabling this option can enforce strictly that a
+ user has not omitted classes from registration.
+
+
spark.kryoserializer.buffer.mb
2
@@ -497,9 +510,9 @@ Apart from these, the following properties are also available, and may be useful
spark.hadoop.validateOutputSpecs
true
-
If set to true, validates the output specification (e.g. checking if the output directory already exists)
- used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing
- output directories. We recommend that users do not disable this except if trying to achieve compatibility with
+
If set to true, validates the output specification (e.g. checking if the output directory already exists)
+ used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing
+ output directories. We recommend that users do not disable this except if trying to achieve compatibility with
previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand.
@@ -861,7 +874,7 @@ Apart from these, the following properties are also available, and may be useful
#### Cluster Managers
-Each cluster manager in Spark has additional configuration options. Configurations
+Each cluster manager in Spark has additional configuration options. Configurations
can be found on the pages for each mode:
* [YARN](running-on-yarn.html#configuration)
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index e05883072bfa8..45b70b1a5457a 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -33,6 +33,7 @@ dependencies, and can support different cluster managers and deploy modes that S
--class
--master \
--deploy-mode \
+ --conf = \
... # other options
\
[application-arguments]
@@ -43,6 +44,7 @@ Some of the commonly used options are:
* `--class`: The entry point for your application (e.g. `org.apache.spark.examples.SparkPi`)
* `--master`: The [master URL](#master-urls) for the cluster (e.g. `spark://23.195.26.187:7077`)
* `--deploy-mode`: Whether to deploy your driver on the worker nodes (`cluster`) or locally as an external client (`client`) (default: `client`)*
+* `--conf`: Arbitrary Spark configuration property in key=value format. For values that contain spaces wrap "key=value" in quotes (as shown).
* `application-jar`: Path to a bundled jar including your application and all dependencies. The URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes.
* `application-arguments`: Arguments passed to the main method of your main class, if any
diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala
index 5ea2e5549d7df..4eacc47da5699 100644
--- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala
+++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala
@@ -63,7 +63,8 @@ class TwitterReceiver(
storageLevel: StorageLevel
) extends Receiver[Status](storageLevel) with Logging {
- private var twitterStream: TwitterStream = _
+ @volatile private var twitterStream: TwitterStream = _
+ @volatile private var stopped = false
def onStart() {
try {
@@ -78,7 +79,9 @@ class TwitterReceiver(
def onScrubGeo(l: Long, l1: Long) {}
def onStallWarning(stallWarning: StallWarning) {}
def onException(e: Exception) {
- restart("Error receiving tweets", e)
+ if (!stopped) {
+ restart("Error receiving tweets", e)
+ }
}
})
@@ -91,12 +94,14 @@ class TwitterReceiver(
}
setTwitterStream(newTwitterStream)
logInfo("Twitter receiver started")
+ stopped = false
} catch {
case e: Exception => restart("Error starting Twitter stream", e)
}
}
def onStop() {
+ stopped = true
setTwitterStream(null)
logInfo("Twitter receiver stopped")
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index 3507f358bfb40..fa4b891754c40 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -344,7 +344,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
*
* {{{
* val rawGraph: Graph[_, _] = Graph.textFile("webgraph")
- * val outDeg: RDD[(VertexId, Int)] = rawGraph.outDegrees()
+ * val outDeg: RDD[(VertexId, Int)] = rawGraph.outDegrees
* val graph = rawGraph.outerJoinVertices(outDeg) {
* (vid, data, optDeg) => optDeg.getOrElse(0)
* }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
index f97f329c0e832..1948c978c30bf 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -35,9 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) {
kryo.register(classOf[Edge[Object]])
- kryo.register(classOf[MessageToPartition[Object]])
- kryo.register(classOf[VertexBroadcastMsg[Object]])
- kryo.register(classOf[RoutingTableMessage])
kryo.register(classOf[(VertexId, Object)])
kryo.register(classOf[EdgePartition[Object, Object]])
kryo.register(classOf[BitSet])
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index edd5b79da1522..02afaa987d40d 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -198,10 +198,10 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
*
* {{{
* val rawGraph: Graph[Int, Int] = GraphLoader.edgeListFile(sc, "webgraph")
- * .mapVertices(v => 0)
- * val outDeg: RDD[(Int, Int)] = rawGraph.outDegrees
- * val graph = rawGraph.leftJoinVertices[Int,Int](outDeg,
- * (v, deg) => deg )
+ * .mapVertices((_, _) => 0)
+ * val outDeg = rawGraph.outDegrees
+ * val graph = rawGraph.joinVertices[Int](outDeg)
+ * ((_, _, outDeg) => outDeg)
* }}}
*
*/
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index ccdaa82eb9162..33f35cfb69a26 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -26,7 +26,6 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl._
-import org.apache.spark.graphx.impl.MsgRDDFunctions._
import org.apache.spark.graphx.util.BytecodeUtils
@@ -83,15 +82,13 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
val vdTag = classTag[VD]
val newEdges = edges.withPartitionsRDD(edges.map { e =>
val part: PartitionID = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions)
-
- // Should we be using 3-tuple or an optimized class
- new MessageToPartition(part, (e.srcId, e.dstId, e.attr))
+ (part, (e.srcId, e.dstId, e.attr))
}
.partitionBy(new HashPartitioner(numPartitions))
.mapPartitionsWithIndex( { (pid, iter) =>
val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag)
iter.foreach { message =>
- val data = message.data
+ val data = message._2
builder.add(data._1, data._2, data._3)
}
val edgePartition = builder.toEdgePartition
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
index d85afa45b1264..5318b8da6412a 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
@@ -25,82 +25,6 @@ import org.apache.spark.graphx.{PartitionID, VertexId}
import org.apache.spark.rdd.{ShuffledRDD, RDD}
-private[graphx]
-class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T](
- @transient var partition: PartitionID,
- var vid: VertexId,
- var data: T)
- extends Product2[PartitionID, (VertexId, T)] with Serializable {
-
- override def _1 = partition
-
- override def _2 = (vid, data)
-
- override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]]
-}
-
-
-/**
- * A message used to send a specific value to a partition.
- * @param partition index of the target partition.
- * @param data value to send
- */
-private[graphx]
-class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T](
- @transient var partition: PartitionID,
- var data: T)
- extends Product2[PartitionID, T] with Serializable {
-
- override def _1 = partition
-
- override def _2 = data
-
- override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]]
-}
-
-
-private[graphx]
-class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) {
- def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
- val rdd = new ShuffledRDD[PartitionID, (VertexId, T), (VertexId, T), VertexBroadcastMsg[T]](
- self, partitioner)
-
- // Set a custom serializer if the data is of int or double type.
- if (classTag[T] == ClassTag.Int) {
- rdd.setSerializer(new IntVertexBroadcastMsgSerializer)
- } else if (classTag[T] == ClassTag.Long) {
- rdd.setSerializer(new LongVertexBroadcastMsgSerializer)
- } else if (classTag[T] == ClassTag.Double) {
- rdd.setSerializer(new DoubleVertexBroadcastMsgSerializer)
- }
- rdd
- }
-}
-
-
-private[graphx]
-class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) {
-
- /**
- * Return a copy of the RDD partitioned using the specified partitioner.
- */
- def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = {
- new ShuffledRDD[PartitionID, T, T, MessageToPartition[T]](self, partitioner)
- }
-
-}
-
-private[graphx]
-object MsgRDDFunctions {
- implicit def rdd2PartitionRDDFunctions[T: ClassTag](rdd: RDD[MessageToPartition[T]]) = {
- new MsgRDDFunctions(rdd)
- }
-
- implicit def rdd2vertexMessageRDDFunctions[T: ClassTag](rdd: RDD[VertexBroadcastMsg[T]]) = {
- new VertexBroadcastMsgRDDFunctions(rdd)
- }
-}
-
private[graphx]
class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index 502b112d31c2e..a565d3b28bf52 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
-/**
- * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that
- * the edge partition references `vid` in the specified `position` (src, dst, or both).
-*/
-private[graphx]
-class RoutingTableMessage(
- var vid: VertexId,
- var pid: PartitionID,
- var position: Byte)
- extends Product2[VertexId, (PartitionID, Byte)] with Serializable {
- override def _1 = vid
- override def _2 = (pid, position)
- override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage]
-}
+import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
private[graphx]
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
- new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage](
+ new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
self, partitioner).setSerializer(new RoutingTableMessageSerializer)
}
}
@@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions {
private[graphx]
object RoutingTablePartition {
+ /**
+ * A message from an edge partition to a vertex specifying the position in which the edge
+ * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower
+ * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.
+ */
+ type RoutingTableMessage = (VertexId, Int)
+
+ private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = {
+ val positionUpper2 = position << 30
+ val pidLower30 = pid & 0x3FFFFFFF
+ (vid, positionUpper2 | pidLower30)
+ }
+
+ private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1
+ private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF
+ private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte
+
val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty)
/** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */
@@ -77,7 +81,9 @@ object RoutingTablePartition {
map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
}
map.iterator.map { vidAndPosition =>
- new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2)
+ val vid = vidAndPosition._1
+ val position = vidAndPosition._2
+ toMessage(vid, pid, position)
}
}
@@ -88,9 +94,12 @@ object RoutingTablePartition {
val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
for (msg <- iter) {
- pid2vid(msg.pid) += msg.vid
- srcFlags(msg.pid) += (msg.position & 0x1) != 0
- dstFlags(msg.pid) += (msg.position & 0x2) != 0
+ val vid = vidFromMessage(msg)
+ val pid = pidFromMessage(msg)
+ val position = positionFromMessage(msg)
+ pid2vid(pid) += vid
+ srcFlags(pid) += (position & 0x1) != 0
+ dstFlags(pid) += (position & 0x2) != 0
}
new RoutingTablePartition(pid2vid.zipWithIndex.map {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
index 033237f597216..3909efcdfc993 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
@@ -24,9 +24,11 @@ import java.nio.ByteBuffer
import scala.reflect.ClassTag
-import org.apache.spark.graphx._
import org.apache.spark.serializer._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
+
private[graphx]
class RoutingTableMessageSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
@@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
new ShuffleSerializationStream(s) {
def writeObject[T: ClassTag](t: T): SerializationStream = {
val msg = t.asInstanceOf[RoutingTableMessage]
- writeVarLong(msg.vid, optimizePositive = false)
- writeUnsignedVarInt(msg.pid)
- // TODO: Write only the bottom two bits of msg.position
- s.write(msg.position)
+ writeVarLong(msg._1, optimizePositive = false)
+ writeInt(msg._2)
this
}
}
@@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
new ShuffleDeserializationStream(s) {
override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
- val b = readUnsignedVarInt()
- val c = s.read()
- if (c == -1) throw new EOFException
- new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T]
+ val b = readInt()
+ (a, b).asInstanceOf[T]
}
}
}
@@ -76,78 +74,6 @@ class VertexIdMsgSerializer extends Serializer with Serializable {
}
}
-/** A special shuffle serializer for VertexBroadcastMessage[Int]. */
-private[graphx]
-class IntVertexBroadcastMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[VertexBroadcastMsg[Int]]
- writeVarLong(msg.vid, optimizePositive = false)
- writeInt(msg.data)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readInt()
- new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for VertexBroadcastMessage[Long]. */
-private[graphx]
-class LongVertexBroadcastMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
- writeVarLong(msg.vid, optimizePositive = false)
- writeLong(msg.data)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readLong()
- new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for VertexBroadcastMessage[Double]. */
-private[graphx]
-class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[VertexBroadcastMsg[Double]]
- writeVarLong(msg.vid, optimizePositive = false)
- writeDouble(msg.data)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readDouble()
- new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
- }
- }
- }
-}
-
/** A special shuffle serializer for AggregationMessage[Int]. */
private[graphx]
class IntAggMsgSerializer extends Serializer with Serializable {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
index ff17edeaf8f16..6aab28ff05355 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
@@ -30,7 +30,7 @@ package object graphx {
*/
type VertexId = Long
- /** Integer identifer of a graph partition. */
+ /** Integer identifer of a graph partition. Must be less than 2^30. */
// TODO: Consider using Char.
type PartitionID = Int
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
index 91caa6b605a1e..864cb1fdf0022 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
@@ -26,75 +26,11 @@ import org.scalatest.FunSuite
import org.apache.spark._
import org.apache.spark.graphx.impl._
-import org.apache.spark.graphx.impl.MsgRDDFunctions._
import org.apache.spark.serializer.SerializationStream
class SerializerSuite extends FunSuite with LocalSparkContext {
- test("IntVertexBroadcastMsgSerializer") {
- val outMsg = new VertexBroadcastMsg[Int](3, 4, 5)
- val bout = new ByteArrayOutputStream
- val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject()
- val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject()
- assert(outMsg.vid === inMsg1.vid)
- assert(outMsg.vid === inMsg2.vid)
- assert(outMsg.data === inMsg1.data)
- assert(outMsg.data === inMsg2.data)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("LongVertexBroadcastMsgSerializer") {
- val outMsg = new VertexBroadcastMsg[Long](3, 4, 5)
- val bout = new ByteArrayOutputStream
- val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject()
- val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject()
- assert(outMsg.vid === inMsg1.vid)
- assert(outMsg.vid === inMsg2.vid)
- assert(outMsg.data === inMsg1.data)
- assert(outMsg.data === inMsg2.data)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("DoubleVertexBroadcastMsgSerializer") {
- val outMsg = new VertexBroadcastMsg[Double](3, 4, 5.0)
- val bout = new ByteArrayOutputStream
- val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject()
- val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject()
- assert(outMsg.vid === inMsg1.vid)
- assert(outMsg.vid === inMsg2.vid)
- assert(outMsg.data === inMsg1.data)
- assert(outMsg.data === inMsg2.data)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
test("IntAggMsgSerializer") {
val outMsg = (4: VertexId, 5)
val bout = new ByteArrayOutputStream
@@ -152,15 +88,6 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
}
}
- test("TestShuffleVertexBroadcastMsg") {
- withSpark { sc =>
- val bmsgs = sc.parallelize(0 until 100, 10).map { pid =>
- new VertexBroadcastMsg[Int](pid, pid, pid)
- }
- bmsgs.partitionBy(new HashPartitioner(3)).collect()
- }
- }
-
test("variable long encoding") {
def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
val bout = new ByteArrayOutputStream
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index 079743742d86d..1af40de2c7fcf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -103,11 +103,11 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = false)
- val agg = counts.values.mapPartitions({ iter =>
+ val agg = counts.values.mapPartitions { iter =>
val agg = new BinaryLabelCounter()
iter.foreach(agg += _)
Iterator(agg)
- }, preservesPartitioning = true).collect()
+ }.collect()
val partitionwiseCumulativeCounts =
agg.scanLeft(new BinaryLabelCounter())(
(agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index f4c403bc7861c..8c2b044ea73f2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -377,9 +377,9 @@ class RowMatrix(
s"Only support dense matrix at this time but found ${B.getClass.getName}.")
val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray)
- val AB = rows.mapPartitions({ iter =>
+ val AB = rows.mapPartitions { iter =>
val Bi = Bb.value
- iter.map(row => {
+ iter.map { row =>
val v = BDV.zeros[Double](k)
var i = 0
while (i < k) {
@@ -387,8 +387,8 @@ class RowMatrix(
i += 1
}
Vectors.fromBreeze(v)
- })
- }, preservesPartitioning = true)
+ }
+ }
new RowMatrix(AB, nRows, B.numCols)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 15e8855db6ca7..5356790cb5339 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -430,7 +430,7 @@ class ALS private (
val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner)
val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner)
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
- }, true)
+ }, preservesPartitioning = true)
val inLinks = links.mapValues(_._1)
val outLinks = links.mapValues(_._2)
inLinks.persist(StorageLevel.MEMORY_AND_DISK)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index aaf92a1a8869a..30de24ad89f98 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -264,8 +264,8 @@ object MLUtils {
(1 to numFolds).map { fold =>
val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
complement = false)
- val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
- val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
+ val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed)
+ val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed)
(training, validation)
}.toArray
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index 9d16182f9d8c4..94db1dc183230 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -20,8 +20,26 @@ package org.apache.spark.mllib.evaluation
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals
class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
+
+ // TODO: move utility functions to TestingUtils.
+
+ def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = {
+ actual.zip(expected).forall { case (x1, x2) =>
+ x1.almostEquals(x2)
+ }
+ }
+
+ def elementsAlmostEqual(
+ actual: Seq[(Double, Double)],
+ expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = {
+ actual.zip(expected).forall { case ((x1, y1), (x2, y2)) =>
+ x1.almostEquals(x2) && y1.almostEquals(y2)
+ }
+ }
+
test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2)
@@ -41,14 +59,14 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
- assert(metrics.thresholds().collect().toSeq === threshold)
- assert(metrics.roc().collect().toSeq === rocCurve)
- assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve))
- assert(metrics.pr().collect().toSeq === prCurve)
- assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve))
- assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1))
- assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2))
- assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision))
- assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall))
+ assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold))
+ assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve))
+ assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve)))
+ assert(elementsAlmostEqual(metrics.pr().collect(), prCurve))
+ assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve)))
+ assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1)))
+ assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2)))
+ assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision)))
+ assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall)))
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 5e5ddd227aab6..e9220db6b1f9a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -32,108 +32,83 @@ import com.typesafe.tools.mima.core._
*/
object MimaExcludes {
- def excludes(version: String) = version match {
- case v if v.startsWith("1.1") =>
- Seq(
- MimaBuild.excludeSparkPackage("deploy"),
- MimaBuild.excludeSparkPackage("graphx")
- ) ++
- closures.map(method => ProblemFilters.exclude[MissingMethodProblem](method)) ++
- Seq(
- // Adding new method to JavaRDLike trait - we should probably mark this as a developer API.
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"),
- // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values
- // for countApproxDistinct* functions, which does not work in Java. We later removed
- // them, and use the following to tell Mima to not care about them.
- ProblemFilters.exclude[IncompatibleResultTypeProblem](
- "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem](
- "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.storage.MemoryStore.Entry"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$debugChildren$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$firstDebugString$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$shuffleDebugString$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$debugString$1"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$"
- + "createZero$1")
- ) ++
- Seq(
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this")
- ) ++
- Seq( // Ignore some private methods in ALS.
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
- ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments.
- "org.apache.spark.mllib.recommendation.ALS.this"),
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures")
- ) ++
- MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++
- MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
- MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
- MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
- MimaBuild.excludeSparkClass("storage.Values") ++
- MimaBuild.excludeSparkClass("storage.Entry") ++
- MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++
- Seq(
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.mllib.tree.impurity.Gini.calculate"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.mllib.tree.impurity.Variance.calculate")
- )
- case v if v.startsWith("1.0") =>
- Seq(
- MimaBuild.excludeSparkPackage("api.java"),
- MimaBuild.excludeSparkPackage("mllib"),
- MimaBuild.excludeSparkPackage("streaming")
- ) ++
- MimaBuild.excludeSparkClass("rdd.ClassTags") ++
- MimaBuild.excludeSparkClass("util.XORShiftRandom") ++
- MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++
- MimaBuild.excludeSparkClass("graphx.VertexRDD") ++
- MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++
- MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++
- MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++
- MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++
- MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++
- MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++
- MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++
- MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++
- MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
- case _ => Seq()
- }
-
- private val closures = Seq(
- "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$mergeMaps$1",
- "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$countPartition$1",
- "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$distributePartition$1",
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeValue$1",
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$writeToFile$1",
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$reducePartition$1",
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$writeShard$1",
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeCombiners$1",
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$process$1",
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$createCombiner$1",
- "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeMaps$1"
- )
+ def excludes(version: String) =
+ version match {
+ case v if v.startsWith("1.1") =>
+ Seq(
+ MimaBuild.excludeSparkPackage("deploy"),
+ MimaBuild.excludeSparkPackage("graphx")
+ ) ++
+ Seq(
+ // Adding new method to JavaRDLike trait - we should probably mark this as a developer API.
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"),
+ // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values
+ // for countApproxDistinct* functions, which does not work in Java. We later removed
+ // them, and use the following to tell Mima to not care about them.
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.storage.MemoryStore.Entry")
+ ) ++
+ Seq(
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this")
+ ) ++
+ Seq( // Ignore some private methods in ALS.
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
+ ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments.
+ "org.apache.spark.mllib.recommendation.ALS.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures")
+ ) ++
+ MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++
+ MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
+ MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
+ MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
+ MimaBuild.excludeSparkClass("storage.Values") ++
+ MimaBuild.excludeSparkClass("storage.Entry") ++
+ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++
+ Seq(
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.mllib.tree.impurity.Gini.calculate"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.mllib.tree.impurity.Variance.calculate")
+ )
+ case v if v.startsWith("1.0") =>
+ Seq(
+ MimaBuild.excludeSparkPackage("api.java"),
+ MimaBuild.excludeSparkPackage("mllib"),
+ MimaBuild.excludeSparkPackage("streaming")
+ ) ++
+ MimaBuild.excludeSparkClass("rdd.ClassTags") ++
+ MimaBuild.excludeSparkClass("util.XORShiftRandom") ++
+ MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++
+ MimaBuild.excludeSparkClass("graphx.VertexRDD") ++
+ MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++
+ MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++
+ MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++
+ MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++
+ MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++
+ MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++
+ MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++
+ MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++
+ MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
+ case _ => Seq()
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c7188469bfb86..02bdb64f308a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
* when all relations are already filled in and the analyser needs only to resolve attribute
@@ -54,6 +53,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
StarExpansion ::
ResolveFunctions ::
GlobalAggregates ::
+ UnresolvedHavingClauseAttributes ::
typeCoercionRules :_*),
Batch("Check Analysis", Once,
CheckResolution),
@@ -151,6 +151,31 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}
+ /**
+ * This rule finds expressions in HAVING clause filters that depend on
+ * unresolved attributes. It pushes these expressions down to the underlying
+ * aggregates and then projects them away above the filter.
+ */
+ object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
+ if !filter.resolved && aggregate.resolved && containsAggregate(havingCondition) => {
+ val evaluatedCondition = Alias(havingCondition, "havingCondition")()
+ val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
+
+ Project(aggregate.output,
+ Filter(evaluatedCondition.toAttribute,
+ aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+ }
+
+ }
+
+ protected def containsAggregate(condition: Expression): Boolean =
+ condition
+ .collect { case ae: AggregateExpression => ae }
+ .nonEmpty
+ }
+
/**
* When a SELECT clause has only a single expression and that expression is a
* [[catalyst.expressions.Generator Generator]] we convert the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 76ddeba9cb312..9887856b9c1c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -231,10 +231,20 @@ trait HiveTypeCoercion {
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
*/
object BooleanComparisons extends Rule[LogicalPlan] {
+ val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, BigDecimal(1)).map(Literal(_))
+ val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, BigDecimal(0)).map(Literal(_))
+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- // No need to change EqualTo operators as that actually makes sense for boolean types.
+
+ // Hive treats (true = 1) as true and (false = 0) as true.
+ case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
+ case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
+ case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
+ case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)
+
+ // No need to change other EqualTo operators as that actually makes sense for boolean types.
case e: EqualTo => e
// Otherwise turn them to Byte types so that there exists and ordering.
case p: BinaryComparison
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 1b503b957d146..15c98efbcabcf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -79,8 +79,24 @@ package object dsl {
def === (other: Expression) = EqualTo(expr, other)
def !== (other: Expression) = Not(EqualTo(expr, other))
+ def in(list: Expression*) = In(expr, list)
+
def like(other: Expression) = Like(expr, other)
def rlike(other: Expression) = RLike(expr, other)
+ def contains(other: Expression) = Contains(expr, other)
+ def startsWith(other: Expression) = StartsWith(expr, other)
+ def endsWith(other: Expression) = EndsWith(expr, other)
+ def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
+ Substring(expr, pos, len)
+ def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) =
+ Substring(expr, pos, len)
+
+ def isNull = IsNull(expr)
+ def isNotNull = IsNotNull(expr)
+
+ def getItem(ordinal: Expression) = GetItem(expr, ordinal)
+ def getField(fieldName: String) = GetField(expr, fieldName)
+
def cast(to: DataType) = Cast(expr, to)
def asc = SortOrder(expr, Ascending)
@@ -112,6 +128,7 @@ package object dsl {
def sumDistinct(e: Expression) = SumDistinct(e)
def count(e: Expression) = Count(e)
def countDistinct(e: Expression*) = CountDistinct(e)
+ def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd)
def avg(e: Expression) = Average(e)
def first(e: Expression) = First(e)
def min(e: Expression) = Min(e)
@@ -163,6 +180,18 @@ package object dsl {
/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = true)()
+
+ /** Creates a new AttributeReference of type array */
+ def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)()
+
+ /** Creates a new AttributeReference of type map */
+ def map(keyType: DataType, valueType: DataType): AttributeReference =
+ map(MapType(keyType, valueType))
+ def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)()
+
+ /** Creates a new AttributeReference of type struct */
+ def struct(fields: StructField*): AttributeReference = struct(StructType(fields))
+ def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)()
}
implicit class DslAttribute(a: AttributeReference) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index db1ae29d400c6..c3f5c26fdbe59 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -301,17 +301,17 @@ class ExpressionEvaluationSuite extends FunSuite {
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)
- checkEvaluation(IsNull(c1), false, row)
- checkEvaluation(IsNotNull(c1), true, row)
+ checkEvaluation(c1.isNull, false, row)
+ checkEvaluation(c1.isNotNull, true, row)
- checkEvaluation(IsNull(c2), true, row)
- checkEvaluation(IsNotNull(c2), false, row)
+ checkEvaluation(c2.isNull, true, row)
+ checkEvaluation(c2.isNotNull, false, row)
- checkEvaluation(IsNull(Literal(1, ShortType)), false)
- checkEvaluation(IsNotNull(Literal(1, ShortType)), true)
+ checkEvaluation(Literal(1, ShortType).isNull, false)
+ checkEvaluation(Literal(1, ShortType).isNotNull, true)
- checkEvaluation(IsNull(Literal(null, ShortType)), true)
- checkEvaluation(IsNotNull(Literal(null, ShortType)), false)
+ checkEvaluation(Literal(null, ShortType).isNull, true)
+ checkEvaluation(Literal(null, ShortType).isNotNull, false)
checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
@@ -326,11 +326,11 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(If(Literal(false, BooleanType),
Literal("a", StringType), Literal("b", StringType)), "b", row)
- checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
- checkEvaluation(In(Literal("^Ba*n", StringType),
- Literal("^Ba*n", StringType) :: Nil), true, row)
- checkEvaluation(In(Literal("^Ba*n", StringType),
- Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
+ checkEvaluation(c1 in (c1, c2), true, row)
+ checkEvaluation(
+ Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType)), true, row)
+ checkEvaluation(
+ Literal("^Ba*n", StringType) in (Literal("^Ba*n", StringType), c2), true, row)
}
test("case when") {
@@ -420,6 +420,10 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
+
+ checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
+ checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
+ checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
}
test("arithmetic") {
@@ -472,20 +476,20 @@ class ExpressionEvaluationSuite extends FunSuite {
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
- checkEvaluation(Contains(c1, "b"), true, row)
- checkEvaluation(Contains(c1, "x"), false, row)
- checkEvaluation(Contains(c2, "b"), null, row)
- checkEvaluation(Contains(c1, Literal(null, StringType)), null, row)
+ checkEvaluation(c1 contains "b", true, row)
+ checkEvaluation(c1 contains "x", false, row)
+ checkEvaluation(c2 contains "b", null, row)
+ checkEvaluation(c1 contains Literal(null, StringType), null, row)
- checkEvaluation(StartsWith(c1, "a"), true, row)
- checkEvaluation(StartsWith(c1, "b"), false, row)
- checkEvaluation(StartsWith(c2, "a"), null, row)
- checkEvaluation(StartsWith(c1, Literal(null, StringType)), null, row)
+ checkEvaluation(c1 startsWith "a", true, row)
+ checkEvaluation(c1 startsWith "b", false, row)
+ checkEvaluation(c2 startsWith "a", null, row)
+ checkEvaluation(c1 startsWith Literal(null, StringType), null, row)
- checkEvaluation(EndsWith(c1, "c"), true, row)
- checkEvaluation(EndsWith(c1, "b"), false, row)
- checkEvaluation(EndsWith(c2, "b"), null, row)
- checkEvaluation(EndsWith(c1, Literal(null, StringType)), null, row)
+ checkEvaluation(c1 endsWith "c", true, row)
+ checkEvaluation(c1 endsWith "b", false, row)
+ checkEvaluation(c2 endsWith "b", null, row)
+ checkEvaluation(c1 endsWith Literal(null, StringType), null, row)
}
test("Substring") {
@@ -542,5 +546,10 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false)
assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true)
assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true)
+
+ checkEvaluation(s.substr(0, 2), "ex", row)
+ checkEvaluation(s.substr(0), "example", row)
+ checkEvaluation(s.substring(0, 2), "ex", row)
+ checkEvaluation(s.substring(0), "example", row)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 34b355e906695..34654447a5f4b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -24,10 +24,10 @@ import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}
-import com.twitter.chill.AllScalaRegistrar
+import com.twitter.chill.{AllScalaRegistrar, ResourcePool}
import org.apache.spark.{SparkEnv, SparkConf}
-import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.serializer.{SerializerInstance, KryoSerializer}
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils
@@ -48,22 +48,41 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
}
}
-private[sql] object SparkSqlSerializer {
- // TODO (lian) Using KryoSerializer here is workaround, needs further investigation
- // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization
- // related error.
- @transient lazy val ser: KryoSerializer = {
+private[execution] class KryoResourcePool(size: Int)
+ extends ResourcePool[SerializerInstance](size) {
+
+ val ser: KryoSerializer = {
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ // TODO (lian) Using KryoSerializer here is workaround, needs further investigation
+ // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization
+ // related error.
new KryoSerializer(sparkConf)
}
- def serialize[T: ClassTag](o: T): Array[Byte] = {
- ser.newInstance().serialize(o).array()
- }
+ def newInstance() = ser.newInstance()
+}
- def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
- ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
+private[sql] object SparkSqlSerializer {
+ @transient lazy val resourcePool = new KryoResourcePool(30)
+
+ private[this] def acquireRelease[O](fn: SerializerInstance => O): O = {
+ val kryo = resourcePool.borrow
+ try {
+ fn(kryo)
+ } finally {
+ resourcePool.release(kryo)
+ }
}
+
+ def serialize[T: ClassTag](o: T): Array[Byte] =
+ acquireRelease { k =>
+ k.serialize(o).array()
+ }
+
+ def deserialize[T: ClassTag](bytes: Array[Byte]): T =
+ acquireRelease { k =>
+ k.deserialize[T](ByteBuffer.wrap(bytes))
+ }
}
private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index df80dfb98b93c..b48c70ee73a27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.json
-import scala.collection.JavaConversions._
+import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal
import com.fasterxml.jackson.databind.ObjectMapper
@@ -210,12 +210,12 @@ private[sql] object JsonRDD extends Logging {
case (k, dataType) => (s"$key.$k", dataType)
} ++ Set((key, StructType(Nil)))
}
- case (key: String, array: List[_]) => {
+ case (key: String, array: Seq[_]) => {
// The value associated with the key is an array.
typeOfArray(array) match {
case ArrayType(StructType(Nil)) => {
// The elements of this arrays are structs.
- array.asInstanceOf[List[Map[String, Any]]].flatMap {
+ array.asInstanceOf[Seq[Map[String, Any]]].flatMap {
element => allKeysWithValueTypes(element)
}.map {
case (k, dataType) => (s"$key.$k", dataType)
@@ -229,7 +229,7 @@ private[sql] object JsonRDD extends Logging {
}
/**
- * Converts a Java Map/List to a Scala Map/List.
+ * Converts a Java Map/List to a Scala Map/Seq.
* We do not use Jackson's scala module at here because
* DefaultScalaModule in jackson-module-scala will make
* the parsing very slow.
@@ -239,9 +239,9 @@ private[sql] object JsonRDD extends Logging {
// .map(identity) is used as a workaround of non-serializable Map
// generated by .mapValues.
// This issue is documented at https://issues.scala-lang.org/browse/SI-7005
- map.toMap.mapValues(scalafy).map(identity)
+ JMapWrapper(map).mapValues(scalafy).map(identity)
case list: java.util.List[_] =>
- list.toList.map(scalafy)
+ JListWrapper(list).map(scalafy)
case atom => atom
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index c8ea01c4e1b6a..1a6a6c17473a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test._
/* Implicits */
@@ -41,15 +40,15 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
- testData2.groupBy('a)('a, Sum('b)),
+ testData2.groupBy('a)('a, sum('b)),
Seq((1,3),(2,3),(3,3))
)
checkAnswer(
- testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
+ testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
9
)
checkAnswer(
- testData2.aggregate(Sum('b)),
+ testData2.aggregate(sum('b)),
9
)
}
@@ -104,19 +103,19 @@ class DslQuerySuite extends QueryTest {
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
checkAnswer(
- arrayData.orderBy(GetItem('data, 0).asc),
+ arrayData.orderBy('data.getItem(0).asc),
arrayData.collect().sortBy(_.data(0)).toSeq)
checkAnswer(
- arrayData.orderBy(GetItem('data, 0).desc),
+ arrayData.orderBy('data.getItem(0).desc),
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
checkAnswer(
- mapData.orderBy(GetItem('data, 1).asc),
+ mapData.orderBy('data.getItem(1).asc),
mapData.collect().sortBy(_.data(1)).toSeq)
checkAnswer(
- mapData.orderBy(GetItem('data, 1).desc),
+ mapData.orderBy('data.getItem(1).desc),
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
similarity index 99%
rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
rename to sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index fd44325925cdd..8b451973a47a1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -291,6 +291,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"correlationoptimizer1",
"correlationoptimizer10",
"correlationoptimizer11",
+ "correlationoptimizer13",
"correlationoptimizer14",
"correlationoptimizer15",
"correlationoptimizer2",
@@ -299,6 +300,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"correlationoptimizer6",
"correlationoptimizer7",
"correlationoptimizer8",
+ "correlationoptimizer9",
"count",
"cp_mj_rc",
"create_insert_outputformat",
@@ -389,6 +391,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"groupby_sort_8",
"groupby_sort_9",
"groupby_sort_test_1",
+ "having",
+ "having1",
"implicit_cast1",
"innerjoin",
"inoutdriver",
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index f30ae28b81e06..1699ffe06ce15 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -102,6 +102,36 @@
test
+
+
+
+ hive
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-scala-test-sources
+ generate-test-sources
+
+ add-test-source
+
+
+
+ src/test/scala
+ compatibility/src/test/scala
+
+
+
+
+
+
+
+
+
+
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
new file mode 100644
index 0000000000000..ad7dc0ecdb1bf
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.serde2.objectinspector.primitive._
+import org.apache.hadoop.hive.serde2.{io => hiveIo}
+import org.apache.hadoop.{io => hadoopIo}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types
+import org.apache.spark.sql.catalyst.types._
+
+/* Implicit conversions */
+import scala.collection.JavaConversions._
+
+private[hive] trait HiveInspectors {
+
+ def javaClassToDataType(clz: Class[_]): DataType = clz match {
+ // writable
+ case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
+ case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
+ case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
+ case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
+ case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
+ case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
+ case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
+ case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
+ case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
+ case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
+ case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
+ case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
+
+ // java class
+ case c: Class[_] if c == classOf[java.lang.String] => StringType
+ case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
+ case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
+ case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
+ case c: Class[_] if c == classOf[java.lang.Short] => ShortType
+ case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
+ case c: Class[_] if c == classOf[java.lang.Long] => LongType
+ case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
+ case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
+ case c: Class[_] if c == classOf[java.lang.Float] => FloatType
+ case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
+
+ // primitive type
+ case c: Class[_] if c == java.lang.Short.TYPE => ShortType
+ case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
+ case c: Class[_] if c == java.lang.Long.TYPE => LongType
+ case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
+ case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
+ case c: Class[_] if c == java.lang.Float.TYPE => FloatType
+ case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
+
+ case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
+ }
+
+ /** Converts hive types to native catalyst types. */
+ def unwrap(a: Any): Any = a match {
+ case null => null
+ case i: hadoopIo.IntWritable => i.get
+ case t: hadoopIo.Text => t.toString
+ case l: hadoopIo.LongWritable => l.get
+ case d: hadoopIo.DoubleWritable => d.get
+ case d: hiveIo.DoubleWritable => d.get
+ case s: hiveIo.ShortWritable => s.get
+ case b: hadoopIo.BooleanWritable => b.get
+ case b: hiveIo.ByteWritable => b.get
+ case b: hadoopIo.FloatWritable => b.get
+ case b: hadoopIo.BytesWritable => {
+ val bytes = new Array[Byte](b.getLength)
+ System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
+ bytes
+ }
+ case t: hiveIo.TimestampWritable => t.getTimestamp
+ case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
+ case list: java.util.List[_] => list.map(unwrap)
+ case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
+ case array: Array[_] => array.map(unwrap).toSeq
+ case p: java.lang.Short => p
+ case p: java.lang.Long => p
+ case p: java.lang.Float => p
+ case p: java.lang.Integer => p
+ case p: java.lang.Double => p
+ case p: java.lang.Byte => p
+ case p: java.lang.Boolean => p
+ case str: String => str
+ case p: java.math.BigDecimal => p
+ case p: Array[Byte] => p
+ case p: java.sql.Timestamp => p
+ }
+
+ def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
+ case hvoi: HiveVarcharObjectInspector =>
+ if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
+ case hdoi: HiveDecimalObjectInspector =>
+ if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
+ case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
+ case li: ListObjectInspector =>
+ Option(li.getList(data))
+ .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
+ .orNull
+ case mi: MapObjectInspector =>
+ Option(mi.getMap(data)).map(
+ _.map {
+ case (k,v) =>
+ (unwrapData(k, mi.getMapKeyObjectInspector),
+ unwrapData(v, mi.getMapValueObjectInspector))
+ }.toMap).orNull
+ case si: StructObjectInspector =>
+ val allRefs = si.getAllStructFieldRefs
+ new GenericRow(
+ allRefs.map(r =>
+ unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
+ }
+
+ /** Converts native catalyst types to the types expected by Hive */
+ def wrap(a: Any): AnyRef = a match {
+ case s: String => new hadoopIo.Text(s) // TODO why should be Text?
+ case i: Int => i: java.lang.Integer
+ case b: Boolean => b: java.lang.Boolean
+ case f: Float => f: java.lang.Float
+ case d: Double => d: java.lang.Double
+ case l: Long => l: java.lang.Long
+ case l: Short => l: java.lang.Short
+ case l: Byte => l: java.lang.Byte
+ case b: BigDecimal => b.bigDecimal
+ case b: Array[Byte] => b
+ case t: java.sql.Timestamp => t
+ case s: Seq[_] => seqAsJavaList(s.map(wrap))
+ case m: Map[_,_] =>
+ mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
+ case null => null
+ }
+
+ def toInspector(dataType: DataType): ObjectInspector = dataType match {
+ case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
+ case MapType(keyType, valueType) =>
+ ObjectInspectorFactory.getStandardMapObjectInspector(
+ toInspector(keyType), toInspector(valueType))
+ case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
+ case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
+ case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
+ case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
+ case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
+ case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
+ case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
+ case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
+ case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
+ case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
+ case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
+ case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
+ case StructType(fields) =>
+ ObjectInspectorFactory.getStandardStructObjectInspector(
+ fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
+ }
+
+ def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
+ case s: StructObjectInspector =>
+ StructType(s.getAllStructFieldRefs.map(f => {
+ types.StructField(
+ f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
+ }))
+ case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
+ case m: MapObjectInspector =>
+ MapType(
+ inspectorToDataType(m.getMapKeyObjectInspector),
+ inspectorToDataType(m.getMapValueObjectInspector))
+ case _: WritableStringObjectInspector => StringType
+ case _: JavaStringObjectInspector => StringType
+ case _: WritableIntObjectInspector => IntegerType
+ case _: JavaIntObjectInspector => IntegerType
+ case _: WritableDoubleObjectInspector => DoubleType
+ case _: JavaDoubleObjectInspector => DoubleType
+ case _: WritableBooleanObjectInspector => BooleanType
+ case _: JavaBooleanObjectInspector => BooleanType
+ case _: WritableLongObjectInspector => LongType
+ case _: JavaLongObjectInspector => LongType
+ case _: WritableShortObjectInspector => ShortType
+ case _: JavaShortObjectInspector => ShortType
+ case _: WritableByteObjectInspector => ByteType
+ case _: JavaByteObjectInspector => ByteType
+ case _: WritableFloatObjectInspector => FloatType
+ case _: JavaFloatObjectInspector => FloatType
+ case _: WritableBinaryObjectInspector => BinaryType
+ case _: JavaBinaryObjectInspector => BinaryType
+ case _: WritableHiveDecimalObjectInspector => DecimalType
+ case _: JavaHiveDecimalObjectInspector => DecimalType
+ case _: WritableTimestampObjectInspector => TimestampType
+ case _: JavaTimestampObjectInspector => TimestampType
+ }
+
+ implicit class typeInfoConversions(dt: DataType) {
+ import org.apache.hadoop.hive.serde2.typeinfo._
+ import TypeInfoFactory._
+
+ def toTypeInfo: TypeInfo = dt match {
+ case BinaryType => binaryTypeInfo
+ case BooleanType => booleanTypeInfo
+ case ByteType => byteTypeInfo
+ case DoubleType => doubleTypeInfo
+ case FloatType => floatTypeInfo
+ case IntegerType => intTypeInfo
+ case LongType => longTypeInfo
+ case ShortType => shortTypeInfo
+ case StringType => stringTypeInfo
+ case DecimalType => decimalTypeInfo
+ case TimestampType => timestampTypeInfo
+ case NullType => voidTypeInfo
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 300e249f5b2e1..c4ca9f362a04d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -42,8 +42,6 @@ private[hive] case class ShellCommand(cmd: String) extends Command
private[hive] case class SourceCommand(filePath: String) extends Command
-private[hive] case class AddJar(jarPath: String) extends Command
-
private[hive] case class AddFile(filePath: String) extends Command
/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
@@ -229,7 +227,7 @@ private[hive] object HiveQl {
} else if (sql.trim.toLowerCase.startsWith("uncache table")) {
CacheCommand(sql.trim.drop(14).trim, false)
} else if (sql.trim.toLowerCase.startsWith("add jar")) {
- AddJar(sql.trim.drop(8))
+ NativeCommand(sql)
} else if (sql.trim.toLowerCase.startsWith("add file")) {
AddFile(sql.trim.drop(9))
} else if (sql.trim.toLowerCase.startsWith("dfs")) {
@@ -932,6 +930,7 @@ private[hive] object HiveQl {
/* Comparisons */
case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
+ case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index fc33c5b460d70..057eb60a02612 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -24,22 +24,19 @@ import org.apache.hadoop.hive.ql.exec.UDF
import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
-import org.apache.hadoop.hive.serde2.objectinspector._
-import org.apache.hadoop.hive.serde2.objectinspector.primitive._
-import org.apache.hadoop.hive.serde2.{io => hiveIo}
-import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.sql.Logging
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.types
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.util.Utils.getContextOrSparkClassLoader
/* Implicit conversions */
import scala.collection.JavaConversions._
-private[hive] object HiveFunctionRegistry
- extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors {
+private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors {
+
+ def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
@@ -47,111 +44,37 @@ private[hive] object HiveFunctionRegistry
val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse(
sys.error(s"Couldn't find function $name"))
+ val functionClassName = functionInfo.getFunctionClass.getName()
+
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- val function = createFunction[UDF](name)
+ val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF]
val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
HiveSimpleUdf(
- name,
+ functionClassName,
children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) }
)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdf(name, children)
+ HiveGenericUdf(functionClassName, children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdaf(name, children)
+ HiveGenericUdaf(functionClassName, children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdtf(name, Nil, children)
+ HiveGenericUdtf(functionClassName, Nil, children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
}
-
- def javaClassToDataType(clz: Class[_]): DataType = clz match {
- // writable
- case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
- case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
- case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
- case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
- case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
- case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
- case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
- case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
- case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
- case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
- case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
- case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
-
- // java class
- case c: Class[_] if c == classOf[java.lang.String] => StringType
- case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
- case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
- case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
- case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
- case c: Class[_] if c == classOf[java.lang.Short] => ShortType
- case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
- case c: Class[_] if c == classOf[java.lang.Long] => LongType
- case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
- case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
- case c: Class[_] if c == classOf[java.lang.Float] => FloatType
- case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
-
- // primitive type
- case c: Class[_] if c == java.lang.Short.TYPE => ShortType
- case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
- case c: Class[_] if c == java.lang.Long.TYPE => LongType
- case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
- case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
- case c: Class[_] if c == java.lang.Float.TYPE => FloatType
- case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
-
- case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
- }
}
private[hive] trait HiveFunctionFactory {
- def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
- def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass
- def createFunction[UDFType](name: String) =
- getFunctionClass(name).newInstance.asInstanceOf[UDFType]
-
- /** Converts hive types to native catalyst types. */
- def unwrap(a: Any): Any = a match {
- case null => null
- case i: hadoopIo.IntWritable => i.get
- case t: hadoopIo.Text => t.toString
- case l: hadoopIo.LongWritable => l.get
- case d: hadoopIo.DoubleWritable => d.get
- case d: hiveIo.DoubleWritable => d.get
- case s: hiveIo.ShortWritable => s.get
- case b: hadoopIo.BooleanWritable => b.get
- case b: hiveIo.ByteWritable => b.get
- case b: hadoopIo.FloatWritable => b.get
- case b: hadoopIo.BytesWritable => {
- val bytes = new Array[Byte](b.getLength)
- System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
- bytes
- }
- case t: hiveIo.TimestampWritable => t.getTimestamp
- case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
- case list: java.util.List[_] => list.map(unwrap)
- case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
- case array: Array[_] => array.map(unwrap).toSeq
- case p: java.lang.Short => p
- case p: java.lang.Long => p
- case p: java.lang.Float => p
- case p: java.lang.Integer => p
- case p: java.lang.Double => p
- case p: java.lang.Byte => p
- case p: java.lang.Boolean => p
- case str: String => str
- case p: java.math.BigDecimal => p
- case p: Array[Byte] => p
- case p: java.sql.Timestamp => p
- }
+ val functionClassName: String
+
+ def createFunction[UDFType]() =
+ getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
}
private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory {
@@ -160,19 +83,17 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu
type UDFType
type EvaluatedType = Any
- val name: String
-
def nullable = true
def references = children.flatMap(_.references).toSet
- // FunctionInfo is not serializable so we must look it up here again.
- lazy val functionInfo = getFunctionInfo(name)
- lazy val function = createFunction[UDFType](name)
+ lazy val function = createFunction[UDFType]()
- override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})"
+ override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
}
-private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf {
+private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
+ extends HiveUdf {
+
import org.apache.spark.sql.hive.HiveFunctionRegistry._
type UDFType = UDF
@@ -226,7 +147,7 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression])
}
}
-private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
+private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
extends HiveUdf with HiveInspectors {
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
@@ -277,131 +198,8 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
}
}
-private[hive] trait HiveInspectors {
-
- def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
- case hvoi: HiveVarcharObjectInspector =>
- if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
- case hdoi: HiveDecimalObjectInspector =>
- if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
- case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
- case li: ListObjectInspector =>
- Option(li.getList(data))
- .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
- .orNull
- case mi: MapObjectInspector =>
- Option(mi.getMap(data)).map(
- _.map {
- case (k,v) =>
- (unwrapData(k, mi.getMapKeyObjectInspector),
- unwrapData(v, mi.getMapValueObjectInspector))
- }.toMap).orNull
- case si: StructObjectInspector =>
- val allRefs = si.getAllStructFieldRefs
- new GenericRow(
- allRefs.map(r =>
- unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
- }
-
- /** Converts native catalyst types to the types expected by Hive */
- def wrap(a: Any): AnyRef = a match {
- case s: String => new hadoopIo.Text(s) // TODO why should be Text?
- case i: Int => i: java.lang.Integer
- case b: Boolean => b: java.lang.Boolean
- case f: Float => f: java.lang.Float
- case d: Double => d: java.lang.Double
- case l: Long => l: java.lang.Long
- case l: Short => l: java.lang.Short
- case l: Byte => l: java.lang.Byte
- case b: BigDecimal => b.bigDecimal
- case b: Array[Byte] => b
- case t: java.sql.Timestamp => t
- case s: Seq[_] => seqAsJavaList(s.map(wrap))
- case m: Map[_,_] =>
- mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
- case null => null
- }
-
- def toInspector(dataType: DataType): ObjectInspector = dataType match {
- case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
- case MapType(keyType, valueType) =>
- ObjectInspectorFactory.getStandardMapObjectInspector(
- toInspector(keyType), toInspector(valueType))
- case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
- case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
- case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
- case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
- case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
- case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
- case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
- case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
- case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
- case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
- case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
- case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
- case StructType(fields) =>
- ObjectInspectorFactory.getStandardStructObjectInspector(
- fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
- }
-
- def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
- case s: StructObjectInspector =>
- StructType(s.getAllStructFieldRefs.map(f => {
- types.StructField(
- f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
- }))
- case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
- case m: MapObjectInspector =>
- MapType(
- inspectorToDataType(m.getMapKeyObjectInspector),
- inspectorToDataType(m.getMapValueObjectInspector))
- case _: WritableStringObjectInspector => StringType
- case _: JavaStringObjectInspector => StringType
- case _: WritableIntObjectInspector => IntegerType
- case _: JavaIntObjectInspector => IntegerType
- case _: WritableDoubleObjectInspector => DoubleType
- case _: JavaDoubleObjectInspector => DoubleType
- case _: WritableBooleanObjectInspector => BooleanType
- case _: JavaBooleanObjectInspector => BooleanType
- case _: WritableLongObjectInspector => LongType
- case _: JavaLongObjectInspector => LongType
- case _: WritableShortObjectInspector => ShortType
- case _: JavaShortObjectInspector => ShortType
- case _: WritableByteObjectInspector => ByteType
- case _: JavaByteObjectInspector => ByteType
- case _: WritableFloatObjectInspector => FloatType
- case _: JavaFloatObjectInspector => FloatType
- case _: WritableBinaryObjectInspector => BinaryType
- case _: JavaBinaryObjectInspector => BinaryType
- case _: WritableHiveDecimalObjectInspector => DecimalType
- case _: JavaHiveDecimalObjectInspector => DecimalType
- case _: WritableTimestampObjectInspector => TimestampType
- case _: JavaTimestampObjectInspector => TimestampType
- }
-
- implicit class typeInfoConversions(dt: DataType) {
- import org.apache.hadoop.hive.serde2.typeinfo._
- import TypeInfoFactory._
-
- def toTypeInfo: TypeInfo = dt match {
- case BinaryType => binaryTypeInfo
- case BooleanType => booleanTypeInfo
- case ByteType => byteTypeInfo
- case DoubleType => doubleTypeInfo
- case FloatType => floatTypeInfo
- case IntegerType => intTypeInfo
- case LongType => longTypeInfo
- case ShortType => shortTypeInfo
- case StringType => stringTypeInfo
- case DecimalType => decimalTypeInfo
- case TimestampType => timestampTypeInfo
- case NullType => voidTypeInfo
- }
- }
-}
-
private[hive] case class HiveGenericUdaf(
- name: String,
+ functionClassName: String,
children: Seq[Expression]) extends AggregateExpression
with HiveInspectors
with HiveFunctionFactory {
@@ -409,7 +207,7 @@ private[hive] case class HiveGenericUdaf(
type UDFType = AbstractGenericUDAFResolver
@transient
- protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
+ protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
@transient
protected lazy val objectInspector = {
@@ -426,9 +224,9 @@ private[hive] case class HiveGenericUdaf(
def references: Set[Attribute] = children.map(_.references).flatten.toSet
- override def toString = s"$nodeName#$name(${children.mkString(",")})"
+ override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
- def newInstance() = new HiveUdafFunction(name, children, this)
+ def newInstance() = new HiveUdafFunction(functionClassName, children, this)
}
/**
@@ -443,7 +241,7 @@ private[hive] case class HiveGenericUdaf(
* user defined aggregations, which have clean semantics even in a partitioned execution.
*/
private[hive] case class HiveGenericUdtf(
- name: String,
+ functionClassName: String,
aliasNames: Seq[String],
children: Seq[Expression])
extends Generator with HiveInspectors with HiveFunctionFactory {
@@ -451,7 +249,7 @@ private[hive] case class HiveGenericUdtf(
override def references = children.flatMap(_.references).toSet
@transient
- protected lazy val function: GenericUDTF = createFunction(name)
+ protected lazy val function: GenericUDTF = createFunction()
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
@@ -506,11 +304,11 @@ private[hive] case class HiveGenericUdtf(
}
}
- override def toString = s"$nodeName#$name(${children.mkString(",")})"
+ override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
}
private[hive] case class HiveUdafFunction(
- functionName: String,
+ functionClassName: String,
exprs: Seq[Expression],
base: AggregateExpression)
extends AggregateFunction
@@ -519,7 +317,7 @@ private[hive] case class HiveUdafFunction(
def this() = this(null, null, null)
- private val resolver = createFunction[AbstractGenericUDAFResolver](functionName)
+ private val resolver = createFunction[AbstractGenericUDAFResolver]()
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
diff --git a/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0 b/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0
new file mode 100644
index 0000000000000..4d1ebdcde2c71
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/boolean = number-0-6b6975fa1892cc48edd87dc0df48a7c0
@@ -0,0 +1 @@
+true true true true true true false false false false false false false false false false false false true true true true true true false false false false false false false false false false false false
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-0-efd135a811fa94760736a761d220b82 b/sql/hive/src/test/resources/golden/correlationoptimizer13-0-efd135a811fa94760736a761d220b82
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-1-32a82500cc28465fac6f64dde0c431c6 b/sql/hive/src/test/resources/golden/correlationoptimizer13-1-32a82500cc28465fac6f64dde0c431c6
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer13-2-777edd9d575f3480ca6cebe4be57b1f6
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer13-3-bb61d9292434f37bd386e5bff683764d b/sql/hive/src/test/resources/golden/correlationoptimizer13-3-bb61d9292434f37bd386e5bff683764d
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-0-efd135a811fa94760736a761d220b82 b/sql/hive/src/test/resources/golden/correlationoptimizer9-0-efd135a811fa94760736a761d220b82
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-1-b1e2ade89ae898650f0be4f796d8947b
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-10-1190d82f88f7fb1f91968f6e2e03772a b/sql/hive/src/test/resources/golden/correlationoptimizer9-10-1190d82f88f7fb1f91968f6e2e03772a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168 b/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168
new file mode 100644
index 0000000000000..17c838bb62b3b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-11-bc2ae88b17ac2bdbd288e07194a40168
@@ -0,0 +1,9 @@
+103 val_103 103 val_103 4 4
+104 val_104 104 val_104 4 4
+105 val_105 105 val_105 1 1
+111 val_111 111 val_111 1 1
+113 val_113 113 val_113 4 4
+114 val_114 114 val_114 1 1
+116 val_116 116 val_116 1 1
+118 val_118 118 val_118 4 4
+119 val_119 119 val_119 9 9
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-12-777edd9d575f3480ca6cebe4be57b1f6
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-13-1190d82f88f7fb1f91968f6e2e03772a b/sql/hive/src/test/resources/golden/correlationoptimizer9-13-1190d82f88f7fb1f91968f6e2e03772a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168 b/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168
new file mode 100644
index 0000000000000..17c838bb62b3b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-14-bc2ae88b17ac2bdbd288e07194a40168
@@ -0,0 +1,9 @@
+103 val_103 103 val_103 4 4
+104 val_104 104 val_104 4 4
+105 val_105 105 val_105 1 1
+111 val_111 111 val_111 1 1
+113 val_113 113 val_113 4 4
+114 val_114 114 val_114 1 1
+116 val_116 116 val_116 1 1
+118 val_118 118 val_118 4 4
+119 val_119 119 val_119 9 9
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-2-32a82500cc28465fac6f64dde0c431c6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-2-32a82500cc28465fac6f64dde0c431c6
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-3-b9d963d24994c47c3776dda6f7d3881f
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-4-ec131bcf578dba99f20b16a7dc6b9b b/sql/hive/src/test/resources/golden/correlationoptimizer9-4-ec131bcf578dba99f20b16a7dc6b9b
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450 b/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450
new file mode 100644
index 0000000000000..248a14f1f4a9f
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-5-b4e378104bb5ab8d8ba5f905aa1ff450
@@ -0,0 +1,9 @@
+103 103 4 4
+104 104 4 4
+105 105 1 1
+111 111 1 1
+113 113 4 4
+114 114 1 1
+116 116 1 1
+118 118 4 4
+119 119 9 9
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-6-777edd9d575f3480ca6cebe4be57b1f6
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-7-f952899d70bd718cbdbc44a5290938c9 b/sql/hive/src/test/resources/golden/correlationoptimizer9-7-f952899d70bd718cbdbc44a5290938c9
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450 b/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450
new file mode 100644
index 0000000000000..248a14f1f4a9f
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-8-b4e378104bb5ab8d8ba5f905aa1ff450
@@ -0,0 +1,9 @@
+103 103 4 4
+104 104 4 4
+105 105 1 1
+111 111 1 1
+113 113 4 4
+114 114 1 1
+116 116 1 1
+118 118 4 4
+119 119 9 9
diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/correlationoptimizer9-9-b9d963d24994c47c3776dda6f7d3881f
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/having-0-57f3f26c0203c29c2a91a7cca557ce55 b/sql/hive/src/test/resources/golden/having-0-57f3f26c0203c29c2a91a7cca557ce55
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873 b/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873
new file mode 100644
index 0000000000000..704f1e62f14c5
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/having-1-ef81808faeab6d212c3cf32abfc0d873
@@ -0,0 +1,10 @@
+4
+4
+5
+4
+5
+5
+4
+4
+5
+4
diff --git a/sql/hive/src/test/resources/golden/having-2-a2b4f52cb92f730ddb912b063636d6c1 b/sql/hive/src/test/resources/golden/having-2-a2b4f52cb92f730ddb912b063636d6c1
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e b/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e
new file mode 100644
index 0000000000000..b56757a60f780
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/having-3-3fa6387b6a4ece110ac340c7b893964e
@@ -0,0 +1,308 @@
+0 val_0
+2 val_2
+4 val_4
+5 val_5
+8 val_8
+9 val_9
+10 val_10
+11 val_11
+12 val_12
+15 val_15
+17 val_17
+18 val_18
+19 val_19
+20 val_20
+24 val_24
+26 val_26
+27 val_27
+28 val_28
+30 val_30
+33 val_33
+34 val_34
+35 val_35
+37 val_37
+41 val_41
+42 val_42
+43 val_43
+44 val_44
+47 val_47
+51 val_51
+53 val_53
+54 val_54
+57 val_57
+58 val_58
+64 val_64
+65 val_65
+66 val_66
+67 val_67
+69 val_69
+70 val_70
+72 val_72
+74 val_74
+76 val_76
+77 val_77
+78 val_78
+80 val_80
+82 val_82
+83 val_83
+84 val_84
+85 val_85
+86 val_86
+87 val_87
+90 val_90
+92 val_92
+95 val_95
+96 val_96
+97 val_97
+98 val_98
+100 val_100
+103 val_103
+104 val_104
+105 val_105
+111 val_111
+113 val_113
+114 val_114
+116 val_116
+118 val_118
+119 val_119
+120 val_120
+125 val_125
+126 val_126
+128 val_128
+129 val_129
+131 val_131
+133 val_133
+134 val_134
+136 val_136
+137 val_137
+138 val_138
+143 val_143
+145 val_145
+146 val_146
+149 val_149
+150 val_150
+152 val_152
+153 val_153
+155 val_155
+156 val_156
+157 val_157
+158 val_158
+160 val_160
+162 val_162
+163 val_163
+164 val_164
+165 val_165
+166 val_166
+167 val_167
+168 val_168
+169 val_169
+170 val_170
+172 val_172
+174 val_174
+175 val_175
+176 val_176
+177 val_177
+178 val_178
+179 val_179
+180 val_180
+181 val_181
+183 val_183
+186 val_186
+187 val_187
+189 val_189
+190 val_190
+191 val_191
+192 val_192
+193 val_193
+194 val_194
+195 val_195
+196 val_196
+197 val_197
+199 val_199
+200 val_200
+201 val_201
+202 val_202
+203 val_203
+205 val_205
+207 val_207
+208 val_208
+209 val_209
+213 val_213
+214 val_214
+216 val_216
+217 val_217
+218 val_218
+219 val_219
+221 val_221
+222 val_222
+223 val_223
+224 val_224
+226 val_226
+228 val_228
+229 val_229
+230 val_230
+233 val_233
+235 val_235
+237 val_237
+238 val_238
+239 val_239
+241 val_241
+242 val_242
+244 val_244
+247 val_247
+248 val_248
+249 val_249
+252 val_252
+255 val_255
+256 val_256
+257 val_257
+258 val_258
+260 val_260
+262 val_262
+263 val_263
+265 val_265
+266 val_266
+272 val_272
+273 val_273
+274 val_274
+275 val_275
+277 val_277
+278 val_278
+280 val_280
+281 val_281
+282 val_282
+283 val_283
+284 val_284
+285 val_285
+286 val_286
+287 val_287
+288 val_288
+289 val_289
+291 val_291
+292 val_292
+296 val_296
+298 val_298
+305 val_305
+306 val_306
+307 val_307
+308 val_308
+309 val_309
+310 val_310
+311 val_311
+315 val_315
+316 val_316
+317 val_317
+318 val_318
+321 val_321
+322 val_322
+323 val_323
+325 val_325
+327 val_327
+331 val_331
+332 val_332
+333 val_333
+335 val_335
+336 val_336
+338 val_338
+339 val_339
+341 val_341
+342 val_342
+344 val_344
+345 val_345
+348 val_348
+351 val_351
+353 val_353
+356 val_356
+360 val_360
+362 val_362
+364 val_364
+365 val_365
+366 val_366
+367 val_367
+368 val_368
+369 val_369
+373 val_373
+374 val_374
+375 val_375
+377 val_377
+378 val_378
+379 val_379
+382 val_382
+384 val_384
+386 val_386
+389 val_389
+392 val_392
+393 val_393
+394 val_394
+395 val_395
+396 val_396
+397 val_397
+399 val_399
+400 val_400
+401 val_401
+402 val_402
+403 val_403
+404 val_404
+406 val_406
+407 val_407
+409 val_409
+411 val_411
+413 val_413
+414 val_414
+417 val_417
+418 val_418
+419 val_419
+421 val_421
+424 val_424
+427 val_427
+429 val_429
+430 val_430
+431 val_431
+432 val_432
+435 val_435
+436 val_436
+437 val_437
+438 val_438
+439 val_439
+443 val_443
+444 val_444
+446 val_446
+448 val_448
+449 val_449
+452 val_452
+453 val_453
+454 val_454
+455 val_455
+457 val_457
+458 val_458
+459 val_459
+460 val_460
+462 val_462
+463 val_463
+466 val_466
+467 val_467
+468 val_468
+469 val_469
+470 val_470
+472 val_472
+475 val_475
+477 val_477
+478 val_478
+479 val_479
+480 val_480
+481 val_481
+482 val_482
+483 val_483
+484 val_484
+485 val_485
+487 val_487
+489 val_489
+490 val_490
+491 val_491
+492 val_492
+493 val_493
+494 val_494
+495 val_495
+496 val_496
+497 val_497
+498 val_498
diff --git a/sql/hive/src/test/resources/golden/having-4-e9918bd385cb35db4ebcbd4e398547f4 b/sql/hive/src/test/resources/golden/having-4-e9918bd385cb35db4ebcbd4e398547f4
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff b/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff
new file mode 100644
index 0000000000000..2d7022e386303
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/having-5-4a0c4e521b8a6f6146151c13a2715ff
@@ -0,0 +1,199 @@
+4
+5
+8
+9
+26
+27
+28
+30
+33
+34
+35
+37
+41
+42
+43
+44
+47
+51
+53
+54
+57
+58
+64
+65
+66
+67
+69
+70
+72
+74
+76
+77
+78
+80
+82
+83
+84
+85
+86
+87
+90
+92
+95
+96
+97
+98
+256
+257
+258
+260
+262
+263
+265
+266
+272
+273
+274
+275
+277
+278
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+291
+292
+296
+298
+302
+305
+306
+307
+308
+309
+310
+311
+315
+316
+317
+318
+321
+322
+323
+325
+327
+331
+332
+333
+335
+336
+338
+339
+341
+342
+344
+345
+348
+351
+353
+356
+360
+362
+364
+365
+366
+367
+368
+369
+373
+374
+375
+377
+378
+379
+382
+384
+386
+389
+392
+393
+394
+395
+396
+397
+399
+400
+401
+402
+403
+404
+406
+407
+409
+411
+413
+414
+417
+418
+419
+421
+424
+427
+429
+430
+431
+432
+435
+436
+437
+438
+439
+443
+444
+446
+448
+449
+452
+453
+454
+455
+457
+458
+459
+460
+462
+463
+466
+467
+468
+469
+470
+472
+475
+477
+478
+479
+480
+481
+482
+483
+484
+485
+487
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
diff --git a/sql/hive/src/test/resources/golden/having-6-9f50df5b5f31c7166b0396ab434dc095 b/sql/hive/src/test/resources/golden/having-6-9f50df5b5f31c7166b0396ab434dc095
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e b/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e
new file mode 100644
index 0000000000000..bd545ccf7430c
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/having-7-5ad96cb287df02080da1e2594f08d83e
@@ -0,0 +1,125 @@
+302
+305
+306
+307
+308
+309
+310
+311
+315
+316
+317
+318
+321
+322
+323
+325
+327
+331
+332
+333
+335
+336
+338
+339
+341
+342
+344
+345
+348
+351
+353
+356
+360
+362
+364
+365
+366
+367
+368
+369
+373
+374
+375
+377
+378
+379
+382
+384
+386
+389
+392
+393
+394
+395
+396
+397
+399
+400
+401
+402
+403
+404
+406
+407
+409
+411
+413
+414
+417
+418
+419
+421
+424
+427
+429
+430
+431
+432
+435
+436
+437
+438
+439
+443
+444
+446
+448
+449
+452
+453
+454
+455
+457
+458
+459
+460
+462
+463
+466
+467
+468
+469
+470
+472
+475
+477
+478
+479
+480
+481
+482
+483
+484
+485
+487
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
diff --git a/sql/hive/src/test/resources/golden/having-8-4aa7197e20b5a64461ca670a79488103 b/sql/hive/src/test/resources/golden/having-8-4aa7197e20b5a64461ca670a79488103
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff b/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff
new file mode 100644
index 0000000000000..d77586c12b6af
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/having-9-a79743372d86d77b0ff53a71adcb1cff
@@ -0,0 +1,199 @@
+4 val_4
+5 val_5
+8 val_8
+9 val_9
+26 val_26
+27 val_27
+28 val_28
+30 val_30
+33 val_33
+34 val_34
+35 val_35
+37 val_37
+41 val_41
+42 val_42
+43 val_43
+44 val_44
+47 val_47
+51 val_51
+53 val_53
+54 val_54
+57 val_57
+58 val_58
+64 val_64
+65 val_65
+66 val_66
+67 val_67
+69 val_69
+70 val_70
+72 val_72
+74 val_74
+76 val_76
+77 val_77
+78 val_78
+80 val_80
+82 val_82
+83 val_83
+84 val_84
+85 val_85
+86 val_86
+87 val_87
+90 val_90
+92 val_92
+95 val_95
+96 val_96
+97 val_97
+98 val_98
+256 val_256
+257 val_257
+258 val_258
+260 val_260
+262 val_262
+263 val_263
+265 val_265
+266 val_266
+272 val_272
+273 val_273
+274 val_274
+275 val_275
+277 val_277
+278 val_278
+280 val_280
+281 val_281
+282 val_282
+283 val_283
+284 val_284
+285 val_285
+286 val_286
+287 val_287
+288 val_288
+289 val_289
+291 val_291
+292 val_292
+296 val_296
+298 val_298
+302 val_302
+305 val_305
+306 val_306
+307 val_307
+308 val_308
+309 val_309
+310 val_310
+311 val_311
+315 val_315
+316 val_316
+317 val_317
+318 val_318
+321 val_321
+322 val_322
+323 val_323
+325 val_325
+327 val_327
+331 val_331
+332 val_332
+333 val_333
+335 val_335
+336 val_336
+338 val_338
+339 val_339
+341 val_341
+342 val_342
+344 val_344
+345 val_345
+348 val_348
+351 val_351
+353 val_353
+356 val_356
+360 val_360
+362 val_362
+364 val_364
+365 val_365
+366 val_366
+367 val_367
+368 val_368
+369 val_369
+373 val_373
+374 val_374
+375 val_375
+377 val_377
+378 val_378
+379 val_379
+382 val_382
+384 val_384
+386 val_386
+389 val_389
+392 val_392
+393 val_393
+394 val_394
+395 val_395
+396 val_396
+397 val_397
+399 val_399
+400 val_400
+401 val_401
+402 val_402
+403 val_403
+404 val_404
+406 val_406
+407 val_407
+409 val_409
+411 val_411
+413 val_413
+414 val_414
+417 val_417
+418 val_418
+419 val_419
+421 val_421
+424 val_424
+427 val_427
+429 val_429
+430 val_430
+431 val_431
+432 val_432
+435 val_435
+436 val_436
+437 val_437
+438 val_438
+439 val_439
+443 val_443
+444 val_444
+446 val_446
+448 val_448
+449 val_449
+452 val_452
+453 val_453
+454 val_454
+455 val_455
+457 val_457
+458 val_458
+459 val_459
+460 val_460
+462 val_462
+463 val_463
+466 val_466
+467 val_467
+468 val_468
+469 val_469
+470 val_470
+472 val_472
+475 val_475
+477 val_477
+478 val_478
+479 val_479
+480 val_480
+481 val_481
+482 val_482
+483 val_483
+484 val_484
+485 val_485
+487 val_487
+489 val_489
+490 val_490
+491 val_491
+492 val_492
+493 val_493
+494 val_494
+495 val_495
+496 val_496
+497 val_497
+498 val_498
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index eb7df717284ce..6f36a4f8cb905 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -30,6 +30,18 @@ case class TestData(a: Int, b: String)
*/
class HiveQuerySuite extends HiveComparisonTest {
+ createQueryTest("boolean = number",
+ """
+ |SELECT
+ | 1 = true, 1L = true, 1Y = true, true = 1, true = 1L, true = 1Y,
+ | 0 = true, 0L = true, 0Y = true, true = 0, true = 0L, true = 0Y,
+ | 1 = false, 1L = false, 1Y = false, false = 1, false = 1L, false = 1Y,
+ | 0 = false, 0L = false, 0Y = false, false = 0, false = 0L, false = 0Y,
+ | 2 = true, 2L = true, 2Y = true, true = 2, true = 2L, true = 2Y,
+ | 2 = false, 2L = false, 2Y = false, false = 2, false = 2L, false = 2Y
+ |FROM src LIMIT 1
+ """.stripMargin)
+
test("CREATE TABLE AS runs once") {
hql("CREATE TABLE foo AS SELECT 1 FROM src LIMIT 1").collect()
assert(hql("SELECT COUNT(*) FROM foo").collect().head.getLong(0) === 1,
diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
index 03a73f92b275e..566983675bff5 100644
--- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
@@ -99,9 +99,25 @@ object GenerateMIMAIgnore {
(ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet)
}
+ /** Scala reflection does not let us see inner function even if they are upgraded
+ * to public for some reason. So had to resort to java reflection to get all inner
+ * functions with $$ in there name.
+ */
+ def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = {
+ try {
+ Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName)
+ .filter(_.contains("$$")).map(classSymbol.fullName + "." + _)
+ } catch {
+ case t: Throwable =>
+ println("[WARN] Unable to detect inner functions for class:" + classSymbol.fullName)
+ Seq.empty[String]
+ }
+ }
+
private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = {
classSymbol.typeSignature.members
- .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName)
+ .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName) ++
+ getInnerFunctions(classSymbol)
}
def main(args: Array[String]) {
@@ -121,7 +137,8 @@ object GenerateMIMAIgnore {
name.endsWith("$class") ||
name.contains("$sp") ||
name.contains("hive") ||
- name.contains("Hive")
+ name.contains("Hive") ||
+ name.contains("repl")
}
/**
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 3ec36487dcd26..62b5c3bc5f0f3 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -60,6 +60,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
private var yarnAllocator: YarnAllocationHandler = _
private var isFinished: Boolean = false
private var uiAddress: String = _
+ private var uiHistoryAddress: String = _
private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES,
YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES)
private var isLastAMRetry: Boolean = true
@@ -237,6 +238,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
if (null != sparkContext) {
uiAddress = sparkContext.ui.appUIHostPort
+ uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf)
this.yarnAllocator = YarnAllocationHandler.newAllocator(
yarnConf,
resourceManager,
@@ -360,7 +362,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
finishReq.setAppAttemptId(appAttemptId)
finishReq.setFinishApplicationStatus(status)
finishReq.setDiagnostics(diagnostics)
- finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", ""))
+ finishReq.setTrackingUrl(uiHistoryAddress)
resourceManager.finishApplicationMaster(finishReq)
}
}
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
index a86ad256dfa39..184e2ad6c82cd 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
@@ -28,7 +28,6 @@ import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import akka.actor._
import akka.remote._
-import akka.actor.Terminated
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
@@ -57,10 +56,17 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
private var yarnAllocator: YarnAllocationHandler = _
- private var driverClosed:Boolean = false
+
+ private var driverClosed: Boolean = false
+ private var isFinished: Boolean = false
+ private var registered: Boolean = false
+
+ // Default to numExecutors * 2, with minimum of 3
+ private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures",
+ sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3)))
val securityManager = new SecurityManager(sparkConf)
- val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
+ val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
conf = sparkConf, securityManager = securityManager)._1
var actor: ActorRef = _
@@ -97,23 +103,26 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
appAttemptId = getApplicationAttemptId()
resourceManager = registerWithResourceManager()
- val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
-
- // Compute number of threads for akka
- val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
-
- if (minimumMemory > 0) {
- val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead",
- YarnAllocationHandler.MEMORY_OVERHEAD)
- val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
-
- if (numCore > 0) {
- // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
- // TODO: Uncomment when hadoop is on a version which has this fixed.
- // args.workerCores = numCore
+ synchronized {
+ if (!isFinished) {
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+ // Compute number of threads for akka
+ val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+ if (minimumMemory > 0) {
+ val mem = args.executorMemory + sparkConf.getInt("spark.yarn.executor.memoryOverhead",
+ YarnAllocationHandler.MEMORY_OVERHEAD)
+ val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+ if (numCore > 0) {
+ // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+ // TODO: Uncomment when hadoop is on a version which has this fixed.
+ // args.workerCores = numCore
+ }
+ }
+ registered = true
}
}
-
waitForSparkMaster()
addAmIpFilter()
// Allocate all containers
@@ -243,11 +252,17 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) {
yarnAllocator.allocateContainers(
math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0))
+ checkNumExecutorsFailed()
Thread.sleep(100)
}
logInfo("All executors have launched.")
-
+ }
+ private def checkNumExecutorsFailed() {
+ if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) {
+ finishApplicationMaster(FinalApplicationStatus.FAILED,
+ "max number of executor failures reached")
+ }
}
// TODO: We might want to extend this to allocate more containers in case they die !
@@ -257,6 +272,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
val t = new Thread {
override def run() {
while (!driverClosed) {
+ checkNumExecutorsFailed()
val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning
if (missingExecutorCount > 0) {
logInfo("Allocating " + missingExecutorCount +
@@ -282,15 +298,23 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
yarnAllocator.allocateContainers(0)
}
- def finishApplicationMaster(status: FinalApplicationStatus) {
-
- logInfo("finish ApplicationMaster with " + status)
- val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
- .asInstanceOf[FinishApplicationMasterRequest]
- finishReq.setAppAttemptId(appAttemptId)
- finishReq.setFinishApplicationStatus(status)
- finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", ""))
- resourceManager.finishApplicationMaster(finishReq)
+ def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") {
+ synchronized {
+ if (isFinished) {
+ return
+ }
+ logInfo("Unregistering ApplicationMaster with " + status)
+ if (registered) {
+ val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+ .asInstanceOf[FinishApplicationMasterRequest]
+ finishReq.setAppAttemptId(appAttemptId)
+ finishReq.setFinishApplicationStatus(status)
+ finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", ""))
+ finishReq.setDiagnostics(appMessage)
+ resourceManager.finishApplicationMaster(finishReq)
+ }
+ isFinished = true
+ }
}
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
index 4ac5ff5231d02..cb4cc7b119066 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
@@ -234,7 +234,8 @@ trait ClientBase extends Logging {
if (!ClientBase.LOCAL_SCHEME.equals(localURI.getScheme())) {
val setPermissions = if (destName.equals(ClientBase.APP_JAR)) true else false
val destPath = copyRemoteFile(dst, qualifyForLocal(localURI), replication, setPermissions)
- distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ val destFs = FileSystem.get(destPath.toUri(), conf)
+ distCacheMgr.addResource(destFs, conf, destPath, localResources, LocalResourceType.FILE,
destName, statCache)
} else if (confKey != null) {
sparkConf.set(confKey, localPath)
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 718cb19f57261..e98308cdbd74e 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -30,6 +30,9 @@ import org.apache.hadoop.util.StringInterner
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.api.ApplicationConstants
import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.SparkHadoopUtil
/**
@@ -132,4 +135,17 @@ object YarnSparkHadoopUtil {
}
}
+ def getUIHistoryAddress(sc: SparkContext, conf: SparkConf) : String = {
+ val eventLogDir = sc.eventLogger match {
+ case Some(logger) => logger.getApplicationLogDir()
+ case None => ""
+ }
+ val historyServerAddress = conf.get("spark.yarn.historyServer.address", "")
+ if (historyServerAddress != "" && eventLogDir != "") {
+ historyServerAddress + HistoryServer.UI_PATH_PREFIX + s"/$eventLogDir"
+ } else {
+ ""
+ }
+ }
+
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index d8266f7b0c9a7..f8fb96b312f23 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
import org.apache.spark.{SparkException, Logging, SparkContext}
-import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments, ExecutorLauncher, YarnSparkHadoopUtil}
import org.apache.spark.scheduler.TaskSchedulerImpl
import scala.collection.mutable.ArrayBuffer
@@ -37,6 +37,8 @@ private[spark] class YarnClientSchedulerBackend(
var client: Client = null
var appId: ApplicationId = null
+ var checkerThread: Thread = null
+ var stopping: Boolean = false
private[spark] def addArg(optionName: String, envVar: String, sysProp: String,
arrayBuf: ArrayBuffer[String]) {
@@ -54,6 +56,7 @@ private[spark] class YarnClientSchedulerBackend(
val driverPort = conf.get("spark.driver.port")
val hostport = driverHost + ":" + driverPort
conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort)
+ conf.set("spark.driver.appUIHistoryAddress", YarnSparkHadoopUtil.getUIHistoryAddress(sc, conf))
val argsArrayBuf = new ArrayBuffer[String]()
argsArrayBuf += (
@@ -85,6 +88,7 @@ private[spark] class YarnClientSchedulerBackend(
client = new Client(args, conf)
appId = client.runApp()
waitForApp()
+ checkerThread = yarnApplicationStateCheckerThread()
}
def waitForApp() {
@@ -115,7 +119,32 @@ private[spark] class YarnClientSchedulerBackend(
}
}
+ private def yarnApplicationStateCheckerThread(): Thread = {
+ val t = new Thread {
+ override def run() {
+ while (!stopping) {
+ val report = client.getApplicationReport(appId)
+ val state = report.getYarnApplicationState()
+ if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.KILLED
+ || state == YarnApplicationState.FAILED) {
+ logError(s"Yarn application already ended: $state")
+ sc.stop()
+ stopping = true
+ }
+ Thread.sleep(1000L)
+ }
+ checkerThread = null
+ Thread.currentThread().interrupt()
+ }
+ }
+ t.setName("Yarn Application State Checker")
+ t.setDaemon(true)
+ t.start()
+ t
+ }
+
override def stop() {
+ stopping = true
super.stop()
client.stop
logInfo("Stopped")
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index eaf594c8b49b9..035356d390c80 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -59,6 +59,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
private var yarnAllocator: YarnAllocationHandler = _
private var isFinished: Boolean = false
private var uiAddress: String = _
+ private var uiHistoryAddress: String = _
private val maxAppAttempts: Int = conf.getInt(
YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
private var isLastAMRetry: Boolean = true
@@ -216,6 +217,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
if (sparkContext != null) {
uiAddress = sparkContext.ui.appUIHostPort
+ uiHistoryAddress = YarnSparkHadoopUtil.getUIHistoryAddress(sparkContext, sparkConf)
this.yarnAllocator = YarnAllocationHandler.newAllocator(
yarnConf,
amClient,
@@ -312,8 +314,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
logInfo("Unregistering ApplicationMaster with " + status)
if (registered) {
- val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "")
- amClient.unregisterApplicationMaster(status, diagnostics, trackingUrl)
+ amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress)
}
}
}
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
index 5ac95f3798723..fc7b8320d734d 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
@@ -19,15 +19,12 @@ package org.apache.spark.deploy.yarn
import java.net.Socket
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.net.NetUtils
-import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import akka.actor._
import akka.remote._
-import akka.actor.Terminated
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
@@ -57,10 +54,16 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
private var yarnAllocator: YarnAllocationHandler = _
- private var driverClosed:Boolean = false
+ private var driverClosed: Boolean = false
+ private var isFinished: Boolean = false
+ private var registered: Boolean = false
private var amClient: AMRMClient[ContainerRequest] = _
+ // Default to numExecutors * 2, with minimum of 3
+ private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures",
+ sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3)))
+
val securityManager = new SecurityManager(sparkConf)
val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
conf = sparkConf, securityManager = securityManager)._1
@@ -101,7 +104,12 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
amClient.start()
appAttemptId = ApplicationMaster.getApplicationAttemptId()
- registerApplicationMaster()
+ synchronized {
+ if (!isFinished) {
+ registerApplicationMaster()
+ registered = true
+ }
+ }
waitForSparkMaster()
addAmIpFilter()
@@ -210,6 +218,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
yarnAllocator.addResourceRequests(args.numExecutors)
yarnAllocator.allocateResources()
while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) {
+ checkNumExecutorsFailed()
allocateMissingExecutor()
yarnAllocator.allocateResources()
Thread.sleep(100)
@@ -228,12 +237,20 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
}
}
+ private def checkNumExecutorsFailed() {
+ if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) {
+ finishApplicationMaster(FinalApplicationStatus.FAILED,
+ "max number of executor failures reached")
+ }
+ }
+
private def launchReporterThread(_sleepTime: Long): Thread = {
val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime
val t = new Thread {
override def run() {
while (!driverClosed) {
+ checkNumExecutorsFailed()
allocateMissingExecutor()
logDebug("Sending progress")
yarnAllocator.allocateResources()
@@ -248,10 +265,18 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
t
}
- def finishApplicationMaster(status: FinalApplicationStatus) {
- logInfo("Unregistering ApplicationMaster with " + status)
- val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "")
- amClient.unregisterApplicationMaster(status, "" /* appMessage */ , trackingUrl)
+ def finishApplicationMaster(status: FinalApplicationStatus, appMessage: String = "") {
+ synchronized {
+ if (isFinished) {
+ return
+ }
+ logInfo("Unregistering ApplicationMaster with " + status)
+ if (registered) {
+ val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "")
+ amClient.unregisterApplicationMaster(status, appMessage, trackingUrl)
+ }
+ isFinished = true
+ }
}
}