diff --git a/core/pom.xml b/core/pom.xml index 7c60cf10c3dc2..6d8be37037729 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -150,7 +150,7 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.6 + 3.2.10 colt diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index a76a070b5b863..8947e66f4577c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -96,17 +96,23 @@ class JdbcRDD[T: ClassTag]( override def close() { try { - if (null != rs && ! rs.isClosed()) rs.close() + if (null != rs && ! rs.isClosed()) { + rs.close() + } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) stmt.close() + if (null != stmt && ! stmt.isClosed()) { + stmt.close() + } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! stmt.isClosed()) conn.close() + if (null != conn && ! conn.isClosed()) { + conn.close() + } logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) @@ -120,3 +126,4 @@ object JdbcRDD { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } } + diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala index eb920ab0c0b67..f176d09816f5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi object TaskLocality extends Enumeration { // Process local is expected to be used ONLY within TaskSetManager for now. - val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value + val PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY = Value type TaskLocality = Value diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index d2f764fc22f54..6c0d1b2752a81 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -89,11 +89,11 @@ private[spark] class TaskSchedulerImpl( // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - private val executorsByHost = new HashMap[String, HashSet[String]] + protected val executorsByHost = new HashMap[String, HashSet[String]] protected val hostsByRack = new HashMap[String, HashSet[String]] - private val executorIdToHost = new HashMap[String, String] + protected val executorIdToHost = new HashMap[String, String] // Listener object to pass upcalls into var dagScheduler: DAGScheduler = null @@ -249,6 +249,7 @@ private[spark] class TaskSchedulerImpl( // Take each TaskSet in our scheduling order, and then offer it each node in increasing order // of locality levels so that it gets a chance to launch local tasks on all of them. + // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY var launchedTask = false for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { do { @@ -265,7 +266,7 @@ private[spark] class TaskSchedulerImpl( activeExecutorIds += execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK - assert (availableCpus(i) >= 0) + assert(availableCpus(i) >= 0) launchedTask = true } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8b5e8cb802a45..20a4bd12f93f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -79,6 +79,7 @@ private[spark] class TaskSetManager( private val numFailures = new Array[Int](numTasks) // key is taskId, value is a Map of executor id to when it failed private val failedExecutors = new HashMap[Int, HashMap[String, Long]]() + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksSuccessful = 0 @@ -179,26 +180,17 @@ private[spark] class TaskSetManager( } } - var hadAliveLocations = false for (loc <- tasks(index).preferredLocations) { for (execId <- loc.executorId) { addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) } - if (sched.hasExecutorsAliveOnHost(loc.host)) { - hadAliveLocations = true - } addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) for (rack <- sched.getRackForHost(loc.host)) { addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - if(sched.hasHostAliveOnRack(rack)){ - hadAliveLocations = true - } } } - if (!hadAliveLocations) { - // Even though the task might've had preferred locations, all of those hosts or executors - // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. + if (tasks(index).preferredLocations == Nil) { addTo(pendingTasksWithNoPrefs) } @@ -239,7 +231,6 @@ private[spark] class TaskSetManager( */ private def findTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = { var indexOffset = list.size - while (indexOffset > 0) { indexOffset -= 1 val index = list(indexOffset) @@ -288,12 +279,12 @@ private[spark] class TaskSetManager( !hasAttemptOnHost(index, host) && !executorIsBlacklisted(execId, index) if (!speculatableTasks.isEmpty) { - // Check for process-local or preference-less tasks; note that tasks can be process-local + // Check for process-local tasks; note that tasks can be process-local // on multiple nodes when we replicate cached blocks, as in Spark Streaming for (index <- speculatableTasks if canRunOnHost(index)) { val prefs = tasks(index).preferredLocations val executors = prefs.flatMap(_.executorId) - if (prefs.size == 0 || executors.contains(execId)) { + if (executors.contains(execId)) { speculatableTasks -= index return Some((index, TaskLocality.PROCESS_LOCAL)) } @@ -310,6 +301,17 @@ private[spark] class TaskSetManager( } } + // Check for no-preference tasks + if (TaskLocality.isAllowed(locality, TaskLocality.NO_PREF)) { + for (index <- speculatableTasks if canRunOnHost(index)) { + val locations = tasks(index).preferredLocations + if (locations.size == 0) { + speculatableTasks -= index + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + } + } + // Check for rack-local tasks if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { for (rack <- sched.getRackForHost(host)) { @@ -341,20 +343,27 @@ private[spark] class TaskSetManager( * * @return An option containing (task index within the task set, locality, is speculative?) */ - private def findTask(execId: String, host: String, locality: TaskLocality.Value) + private def findTask(execId: String, host: String, maxLocality: TaskLocality.Value) : Option[(Int, TaskLocality.Value, Boolean)] = { for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) { return Some((index, TaskLocality.PROCESS_LOCAL, false)) } - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) { for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) { return Some((index, TaskLocality.NODE_LOCAL, false)) } } - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) { + // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic + for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) { + return Some((index, TaskLocality.PROCESS_LOCAL, false)) + } + } + + if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) { for { rack <- sched.getRackForHost(host) index <- findTaskFromList(execId, getPendingTasksForRack(rack)) @@ -363,25 +372,27 @@ private[spark] class TaskSetManager( } } - // Look for no-pref tasks after rack-local tasks since they can run anywhere. - for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL, false)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) { for (index <- findTaskFromList(execId, allPendingTasks)) { return Some((index, TaskLocality.ANY, false)) } } - // Finally, if all else has failed, find a speculative task - findSpeculativeTask(execId, host, locality).map { case (taskIndex, allowedLocality) => - (taskIndex, allowedLocality, true) - } + // find a speculative task if all others tasks have been scheduled + findSpeculativeTask(execId, host, maxLocality).map { + case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)} } /** * Respond to an offer of a single executor from the scheduler by finding a task + * + * NOTE: this function is either called with a maxLocality which + * would be adjusted by delay scheduling algorithm or it will be with a special + * NO_PREF locality which will be not modified + * + * @param execId the executor Id of the offered resource + * @param host the host Id of the offered resource + * @param maxLocality the maximum locality we want to schedule the tasks at */ def resourceOffer( execId: String, @@ -392,9 +403,14 @@ private[spark] class TaskSetManager( if (!isZombie) { val curTime = clock.getTime() - var allowedLocality = getAllowedLocalityLevel(curTime) - if (allowedLocality > maxLocality) { - allowedLocality = maxLocality // We're not allowed to search for farther-away tasks + var allowedLocality = maxLocality + + if (maxLocality != TaskLocality.NO_PREF) { + allowedLocality = getAllowedLocalityLevel(curTime) + if (allowedLocality > maxLocality) { + // We're not allowed to search for farther-away tasks + allowedLocality = maxLocality + } } findTask(execId, host, allowedLocality) match { @@ -410,8 +426,11 @@ private[spark] class TaskSetManager( taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime + // NO_PREF will not affect the variables related to delay scheduling + if (maxLocality != TaskLocality.NO_PREF) { + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + } // Serialize and return the task val startTime = clock.getTime() // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here @@ -639,8 +658,7 @@ private[spark] class TaskSetManager( override def executorLost(execId: String, host: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a - // task that used to have locations on only this host might now go to the no-prefs list. Note + // Re-enqueue pending tasks for this host based on the status of the cluster. Note // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because findTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { @@ -671,6 +689,9 @@ private[spark] class TaskSetManager( for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) } + // recalculate valid locality levels and waits when executor is lost + myLocalityLevels = computeValidLocalityLevels() + localityWaits = myLocalityLevels.map(getLocalityWait) } /** @@ -722,17 +743,17 @@ private[spark] class TaskSetManager( conf.get("spark.locality.wait.node", defaultWait).toLong case TaskLocality.RACK_LOCAL => conf.get("spark.locality.wait.rack", defaultWait).toLong - case TaskLocality.ANY => - 0L + case _ => 0L } } /** * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been * added to queues using addPendingTask. + * */ private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { - import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} + import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY} val levels = new ArrayBuffer[TaskLocality.TaskLocality] if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 && pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) { @@ -742,6 +763,9 @@ private[spark] class TaskSetManager( pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } + if (!pendingTasksWithNoPrefs.isEmpty) { + levels += NO_PREF + } if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 && pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) { levels += RACK_LOCAL @@ -751,20 +775,7 @@ private[spark] class TaskSetManager( levels.toArray } - // Re-compute pendingTasksWithNoPrefs since new preferred locations may become available def executorAdded() { - def newLocAvail(index: Int): Boolean = { - for (loc <- tasks(index).preferredLocations) { - if (sched.hasExecutorsAliveOnHost(loc.host) || - (sched.getRackForHost(loc.host).isDefined && - sched.hasHostAliveOnRack(sched.getRackForHost(loc.host).get))) { - return true - } - } - false - } - logInfo("Re-computing pending task lists.") - pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_)) myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 9a356d0dbaf17..24db2f287a47b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -40,7 +40,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val ser = Serializer.getSerializer(dep.serializer.orNull) private val conf = SparkEnv.get.conf - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 private var sorter: ExternalSorter[K, V, _] = null private var outputFile: File = null diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 28aa35bc7e147..f9fdffae8bd8f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -73,7 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { val sortBasedShuffle = conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index cc0423856cefb..260a5c3888aa7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -101,7 +101,7 @@ class ExternalAppendOnlyMap[K, V, C]( private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 101c83b264f63..3f93afd57b3ad 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -84,7 +84,7 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 // Size of object batches when reading/writing from serializers. // diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index c52368b5514db..ffd23380a886f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -85,14 +85,31 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val finishedManagers = new ArrayBuffer[TaskSetManager] val taskSetsFailed = new ArrayBuffer[String] - val executors = new mutable.HashMap[String, String] ++ liveExecutors + val executors = new mutable.HashMap[String, String] + for ((execId, host) <- liveExecutors) { + addExecutor(execId, host) + } + for ((execId, host) <- liveExecutors; rack <- getRackForHost(host)) { hostsByRack.getOrElseUpdate(rack, new mutable.HashSet[String]()) += host } dagScheduler = new FakeDAGScheduler(sc, this) - def removeExecutor(execId: String): Unit = executors -= execId + def removeExecutor(execId: String) { + executors -= execId + val host = executorIdToHost.get(execId) + assert(host != None) + val hostId = host.get + val executorsOnHost = executorsByHost(hostId) + executorsOnHost -= execId + for (rack <- getRackForHost(hostId); hosts <- hostsByRack.get(rack)) { + hosts -= hostId + if (hosts.isEmpty) { + hostsByRack -= rack + } + } + } override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager @@ -100,8 +117,15 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) + override def hasHostAliveOnRack(rack: String): Boolean = { + hostsByRack.get(rack) != None + } + def addExecutor(execId: String, host: String) { executors.put(execId, host) + val executorsOnHost = executorsByHost.getOrElseUpdate(host, new mutable.HashSet[String]) + executorsOnHost += execId + executorIdToHost += execId -> host for (rack <- getRackForHost(host)) { hostsByRack.getOrElseUpdate(rack, new mutable.HashSet[String]()) += host } @@ -123,7 +147,7 @@ class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { } class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { - import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} + import TaskLocality.{ANY, PROCESS_LOCAL, NO_PREF, NODE_LOCAL, RACK_LOCAL} private val conf = new SparkConf @@ -134,18 +158,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - // Offer a host with process-local as the constraint; this should work because the TaskSet - // above won't have any locality preferences - val taskOption = manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL) + // Offer a host with NO_PREF as the constraint, + // we should get a nopref task immediately since that's what we only have + var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) - val task = taskOption.get - assert(task.executorId === "exec1") - assert(sched.startedTasks.contains(0)) - - // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) === None) // Tell it the task has finished manager.handleSuccessfulTask(0, createTaskResult(0)) @@ -161,7 +180,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // First three offers should all find tasks for (i <- 0 until 3) { - val taskOption = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) + var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) val task = taskOption.get assert(task.executorId === "exec1") @@ -169,7 +188,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(sched.startedTasks.toSet === Set(0, 1, 2)) // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) === None) + assert(manager.resourceOffer("exec1", "host1", NO_PREF) === None) // Finish the first two tasks manager.handleSuccessfulTask(0, createTaskResult(0)) @@ -211,37 +230,40 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { ) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - - // Offer host1, exec1 again: the last task, which has no prefs, should be chosen - assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 3) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) === None) + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) clock.advance(LOCALITY_WAIT) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) === None) - - // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 + // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 2) should + // get chosen before the noPref task assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index == 2) - // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL) === None) - - // Offer host1, exec1 again, at ANY level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", ANY) === None) + // Offer host2, exec3 again, at NODE_LOCAL level: we should choose task 2 + assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL).get.index == 1) + // Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task + // after failing to find a node_Local task + assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) clock.advance(LOCALITY_WAIT) + assert(manager.resourceOffer("exec2", "host2", NO_PREF).get.index == 3) + } - // Offer host1, exec1 again, at ANY level: task 1 should get chosen - assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) - - // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks - assert(manager.resourceOffer("exec1", "host1", ANY) === None) + test("we do not need to delay scheduling when we only have noPref tasks in the queue") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec3", "host2")) + val taskSet = FakeTask.createTaskSet(3, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2", "exec3")), + Seq() // Last task has no locality prefs + ) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // First offer host1, exec1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).get.index === 0) + assert(manager.resourceOffer("exec3", "host2", PROCESS_LOCAL).get.index === 1) + assert(manager.resourceOffer("exec3", "host2", NODE_LOCAL) == None) + assert(manager.resourceOffer("exec3", "host2", NO_PREF).get.index === 2) } test("delay scheduling with fallback") { @@ -298,20 +320,24 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - // Offer host1 again: third task should be chosen immediately because host3 is not up - assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 2) - - // After this, nothing should get chosen + // After this, nothing should get chosen, because we have separated tasks with unavailable preference + // from the noPrefPendingTasks assert(manager.resourceOffer("exec1", "host1", ANY) === None) // Now mark host2 as dead sched.removeExecutor("exec2") manager.executorLost("exec2", "host2") - // Task 1 should immediately be launched on host1 because its original host is gone + // nothing should be chosen + assert(manager.resourceOffer("exec1", "host1", ANY) === None) + + clock.advance(LOCALITY_WAIT * 2) + + // task 1 and 2 would be scheduled as nonLocal task assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) + assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 2) - // Now that all tasks have launched, nothing new should be launched anywhere else + // all finished assert(manager.resourceOffer("exec1", "host1", ANY) === None) assert(manager.resourceOffer("exec2", "host2", ANY) === None) } @@ -373,7 +399,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val manager = new TaskSetManager(sched, taskSet, 4, clock) { - val offerResult = manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL) + val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) assert(offerResult.isDefined, "Expect resource offer to return a task") assert(offerResult.get.index === 0) @@ -384,15 +410,15 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(!sched.taskSetsFailed.contains(taskSet.id)) // Ensure scheduling on exec1 fails after failure 1 due to blacklist - assert(manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", TaskLocality.NODE_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", TaskLocality.RACK_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", TaskLocality.ANY).isEmpty) + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", RACK_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", ANY).isEmpty) } // Run the task on exec1.1 - should work, and then fail it on exec1.1 { - val offerResult = manager.resourceOffer("exec1.1", "host1", TaskLocality.NODE_LOCAL) + val offerResult = manager.resourceOffer("exec1.1", "host1", NODE_LOCAL) assert(offerResult.isDefined, "Expect resource offer to return a task for exec1.1, offerResult = " + offerResult) @@ -404,12 +430,12 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(!sched.taskSetsFailed.contains(taskSet.id)) // Ensure scheduling on exec1.1 fails after failure 2 due to blacklist - assert(manager.resourceOffer("exec1.1", "host1", TaskLocality.NODE_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1.1", "host1", NODE_LOCAL).isEmpty) } // Run the task on exec2 - should work, and then fail it on exec2 { - val offerResult = manager.resourceOffer("exec2", "host2", TaskLocality.ANY) + val offerResult = manager.resourceOffer("exec2", "host2", ANY) assert(offerResult.isDefined, "Expect resource offer to return a task") assert(offerResult.get.index === 0) @@ -420,20 +446,20 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(!sched.taskSetsFailed.contains(taskSet.id)) // Ensure scheduling on exec2 fails after failure 3 due to blacklist - assert(manager.resourceOffer("exec2", "host2", TaskLocality.ANY).isEmpty) + assert(manager.resourceOffer("exec2", "host2", ANY).isEmpty) } // After reschedule delay, scheduling on exec1 should be possible. clock.advance(rescheduleDelay) { - val offerResult = manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL) + val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) assert(offerResult.isDefined, "Expect resource offer to return a task") assert(offerResult.get.index === 0) assert(offerResult.get.executorId === "exec1") - assert(manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty) // Cause exec1 to fail : failure 4 manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) @@ -443,7 +469,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(sched.taskSetsFailed.contains(taskSet.id)) } - test("new executors get added") { + test("new executors get added and lost") { // Assign host2 to rack2 FakeRackUtil.cleanUp() FakeRackUtil.assignHostToRack("host2", "rack2") @@ -456,26 +482,25 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq()) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - // All tasks added to no-pref list since no preferred location is available - assert(manager.pendingTasksWithNoPrefs.size === 4) // Only ANY is valid - assert(manager.myLocalityLevels.sameElements(Array(ANY))) + assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) // Add a new executor sched.addExecutor("execD", "host1") manager.executorAdded() - // Task 0 and 1 should be removed from no-pref list - assert(manager.pendingTasksWithNoPrefs.size === 2) // Valid locality should contain NODE_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, ANY))) + assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) // Add another executor sched.addExecutor("execC", "host2") manager.executorAdded() - // No-pref list now only contains task 3 - assert(manager.pendingTasksWithNoPrefs.size === 1) // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements( - Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) - FakeRackUtil.cleanUp() + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) + // test if the valid locality is recomputed when the executor is lost + sched.removeExecutor("execC") + manager.executorLost("execC", "host2") + assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) + sched.removeExecutor("execD") + manager.executorLost("execD", "host1") + assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) } test("test RACK_LOCAL tasks") { @@ -506,7 +531,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host2 // Task 1 can be scheduled with RACK_LOCAL assert(manager.resourceOffer("execB", "host2", RACK_LOCAL).get.index === 1) - FakeRackUtil.cleanUp() } test("do not emit warning when serialized task is small") { @@ -536,6 +560,53 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.emittedTaskSizeWarning) } + test("speculative and noPref task should be scheduled after node-local") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host2"), TaskLocation("host1")), + Seq(), + Seq(TaskLocation("host3", "execC"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 0) + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) + assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index == 1) + + manager.speculatableTasks += 1 + clock.advance(LOCALITY_WAIT) + // schedule the nonPref task + assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 2) + // schedule the speculative task + assert(manager.resourceOffer("execB", "host2", NO_PREF).get.index === 1) + clock.advance(LOCALITY_WAIT * 3) + // schedule non-local tasks + assert(manager.resourceOffer("execB", "host2", ANY).get.index === 3) + } + + test("node-local tasks should be scheduled right away when there are only node-local and no-preference tasks") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(), + Seq(TaskLocation("host3"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + // node-local tasks are scheduled without delay + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 0) + assert(manager.resourceOffer("execA", "host2", NODE_LOCAL).get.index === 1) + assert(manager.resourceOffer("execA", "host3", NODE_LOCAL).get.index === 3) + assert(manager.resourceOffer("execA", "host3", NODE_LOCAL) === None) + + // schedule no-preference after node local ones + assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) diff --git a/docs/configuration.md b/docs/configuration.md index 5e7556c08ee36..7cd7f4124db7e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -266,7 +266,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.file.buffer.kb - 100 + 32 Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index a4a4a6381a85c..9d88a38831d89 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -269,6 +269,7 @@ class PythonMLLibAPI extends Serializable { .setNumIterations(numIterations) .setRegParam(regParam) .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) if (regType == "l2") { lrAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { @@ -339,16 +340,27 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val SVMAlg = new SVMWithSGD() + SVMAlg.setIntercept(intercept) + SVMAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + SVMAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + SVMAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - SVMWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + SVMAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -361,15 +373,28 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regParam: Double, + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val LogRegAlg = new LogisticRegressionWithSGD() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + LogRegAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - LogisticRegressionWithSGD.train( - data, - numIterations, - stepSize, - miniBatchFraction, - initialWeights), + LogRegAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 87c81e7b0bd2f..3bf44ad7c44e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.{HashPartitioner, Logging} + +import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom /** * Entry in vocabulary @@ -58,29 +59,63 @@ private case class VocabWord( * Efficient Estimation of Word Representations in Vector Space * and * Distributed Representations of Words and Phrases and their Compositionality. - * @param size vector dimension - * @param startingAlpha initial learning rate - * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) - * @param numIterations number of iterations to run, should be smaller than or equal to parallelism */ @Experimental -class Word2Vec( - val size: Int, - val startingAlpha: Double, - val parallelism: Int, - val numIterations: Int) extends Serializable with Logging { +class Word2Vec extends Serializable with Logging { + + private var vectorSize = 100 + private var startingAlpha = 0.025 + private var numPartitions = 1 + private var numIterations = 1 + private var seed = Utils.random.nextLong() + + /** + * Sets vector size (default: 100). + */ + def setVectorSize(vectorSize: Int): this.type = { + this.vectorSize = vectorSize + this + } + + /** + * Sets initial learning rate (default: 0.025). + */ + def setLearningRate(learningRate: Double): this.type = { + this.startingAlpha = learningRate + this + } /** - * Word2Vec with a single thread. + * Sets number of partitions (default: 1). Use a small number for accuracy. */ - def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + def setNumPartitions(numPartitions: Int): this.type = { + require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") + this.numPartitions = numPartitions + this + } + + /** + * Sets number of iterations (default: 1), which should be smaller than or equal to number of + * partitions. + */ + def setNumIterations(numIterations: Int): this.type = { + this.numIterations = numIterations + this + } + + /** + * Sets random seed (default: a random long integer). + */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 - private val layer1Size = size - private val modelPartitionNum = 100 + private val layer1Size = vectorSize /** context words from [-window, window] */ private val window = 5 @@ -94,12 +129,12 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]): Unit = { + private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( - x._1, - x._2, + x._1, + x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) @@ -245,23 +280,24 @@ class Word2Vec( } } - val newSentences = sentences.repartition(parallelism).cache() + val newSentences = sentences.repartition(numPartitions).cache() + val initRandom = new XORShiftRandom(seed) var syn0Global = - Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) + Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) var syn1Global = new Array[Float](vocabSize * layer1Size) - - for(iter <- 1 to numIterations) { - val (aggSyn0, aggSyn1, _, _) = - // TODO: broadcast temp instead of serializing it directly - // or initialize the model in each executor - newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( - seqOp = (c, v) => (c, v) match { + + for (k <- 1 to numIterations) { + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => + val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount - var wc = wordCount + var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) + // TODO: discount by iteration? + alpha = + startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } @@ -269,8 +305,7 @@ class Word2Vec( var pos = 0 while (pos < sentence.size) { val word = sentence(pos) - // TODO: fix random seed - val b = Random.nextInt(window) + val b = random.nextInt(window) // Train Skip-gram var a = b while (a < window * 2 + 1 - b) { @@ -280,7 +315,7 @@ class Word2Vec( val lastWord = sentence(c) val l1 = lastWord * layer1Size val neu1e = new Array[Float](layer1Size) - // Hierarchical softmax + // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { val l2 = bcVocab.value(word).point(d) * layer1Size @@ -303,44 +338,44 @@ class Word2Vec( pos += 1 } (syn0, syn1, lwc, wc) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) - }) + } + Iterator(model) + } + val (aggSyn0, aggSyn1, _, _) = + partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + } syn0Global = aggSyn0 syn1Global = aggSyn1 } newSentences.unpersist() - val wordMap = new Array[(String, Array[Float])](vocabSize) + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word val vector = new Array[Float](layer1Size) Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) - wordMap(i) = (word, vector) + word2VecMap += word -> vector i += 1 } - val modelRDD = sc.parallelize(wordMap, modelPartitionNum) - .partitionBy(new HashPartitioner(modelPartitionNum)) - .persist(StorageLevel.MEMORY_AND_DISK) - - new Word2VecModel(modelRDD) + + new Word2VecModel(word2VecMap.toMap) } } /** * Word2Vec model -*/ -class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { + */ +class Word2VecModel private[mllib] ( + private val model: Map[String, Array[Float]]) extends Serializable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri * @return vector representation of word */ def transform(word: String): Vector = { - val result = model.lookup(word) - if (result.isEmpty) { - throw new IllegalStateException(s"$word not in vocabulary") + model.get(word) match { + case Some(vec) => + Vectors.dense(vec.map(_.toDouble)) + case None => + throw new IllegalStateException(s"$word not in vocabulary") } - else Vectors.dense(result(0).map(_.toDouble)) } /** @@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - val topK = model.map { case(w, vec) => - (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } - .sortByKey(ascending = false) - .take(num + 1) - .map(_.swap) - .tail - - topK - } -} - -object Word2Vec{ - /** - * Train Word2Vec model - * @param input RDD of words - * @param size vector dimension - * @param startingAlpha initial learning rate - * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) - * @param numIterations number of iterations, should be smaller than or equal to parallelism - * @return Word2Vec model - */ - def train[S <: Iterable[String]]( - input: RDD[S], - size: Int, - startingAlpha: Double, - parallelism: Int = 1, - numIterations:Int = 1): Word2VecModel = { - new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) + // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) + model.mapValues(vec => cosineSimilarity(fVector, vec)) + .toSeq + .sortBy(- _._2) + .take(num + 1) + .tail + .toArray } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index b5db39b68a223..e34335d89eb75 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -30,29 +30,22 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { val localDoc = Seq(sentence, sentence) val doc = sc.parallelize(localDoc) .map(line => line.split(" ").toSeq) - val size = 10 - val startingAlpha = 0.025 - val window = 2 - val minCount = 2 - val num = 2 - - val model = Word2Vec.train(doc, size, startingAlpha) + val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) val syms = model.findSynonyms("a", 2) - assert(syms.length == num) + assert(syms.length == 2) assert(syms(0)._1 == "b") assert(syms(1)._1 == "c") } - test("Word2VecModel") { val num = 2 - val localModel = Seq( + val word2VecMap = Map( ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) ) - val model = new Word2VecModel(sc.parallelize(localModel, 2)) + val model = new Word2VecModel(word2VecMap) val syms = model.findSynonyms("china", num) assert(syms.length == num) assert(syms(0)._1 == "taiwan") diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 2bbb9c3fca315..5ec1a8084d269 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -73,11 +73,36 @@ def predict(self, x): class LogisticRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None): - """Train a logistic regression model on the given data.""" + def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, + initialWeights=None, regParam=1.0, regType=None, intercept=False): + """ + Train a logistic regression model on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regParam: The regularizer parameter (default: 1.0). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i) + d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data, initialWeights) @@ -115,11 +140,35 @@ def predict(self, x): class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None): - """Train a support vector machine on the given data.""" + miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): + """ + Train a support vector machine on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param regParam: The regularizer parameter (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - d._jrdd, iterations, step, regParam, miniBatchFraction, i) + d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1a829c6fafe03..f1093701ddc89 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -672,12 +672,12 @@ def _infer_schema_type(obj, dataType): ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), - LongType: (int, long), + LongType: (long,), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), - TimestampType: (datetime.datetime, datetime.time, datetime.date), + TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), StructType: (tuple, list), @@ -1042,12 +1042,15 @@ def applySchema(self, rdd, schema): [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import datetime - >>> rdd = sc.parallelize([(127, -32768, 1.0, + >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ - ... StructField("byte", ByteType(), False), - ... StructField("short", ShortType(), False), + ... StructField("byte1", ByteType(), False), + ... StructField("byte2", ByteType(), False), + ... StructField("short1", ShortType(), False), + ... StructField("short2", ShortType(), False), + ... StructField("int", IntegerType(), False), ... StructField("float", FloatType(), False), ... StructField("time", TimestampType(), False), ... StructField("map", @@ -1056,11 +1059,19 @@ def applySchema(self, rdd, schema): ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema).map( - ... lambda x: (x.byte, x.short, x.float, x.time, + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> results = srdd.map( + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, ... x.map["a"], x.struct.b, x.list, x.null)) - >>> srdd.collect()[0] - (127, -32768, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + >>> results.collect()[0] + (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> srdd.registerTempTable("table2") + >>> sqlCtx.sql( + ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + + ... "float + 1.1 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2ba68cab115fb..c18d7858f0a43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: + ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: StarExpansion :: @@ -113,13 +114,58 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolve(name).getOrElse(u) + val result = q.resolveChildren(name).getOrElse(u) logDebug(s"Resolving $u to $result") result } } } + /** + * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT + * clause. This rule detects such queries and adds the required attributes to the original + * projection, so that they will be available during sorting. Another projection is added to + * remove these attributes after sorting. + */ + object ResolveSortReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + val resolved = unresolved.flatMap(child.resolveChildren) + val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet + + val missingInProject = requiredAttributes -- p.output + if (missingInProject.nonEmpty) { + // Add missing attributes and then project them away after the sort. + Project(projectList, + Sort(ordering, + Project(projectList ++ missingInProject, child))) + } else { + s // Nothing we can do here. Return original plan. + } + case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + // A small hack to create an object that will allow us to resolve any references that + // refer to named expressions that are present in the grouping expressions. + val groupingRelation = LocalRelation( + grouping.collect { case ne: NamedExpression => ne.toAttribute } + ) + + logDebug(s"Grouping expressions: $groupingRelation") + val resolved = unresolved.flatMap(groupingRelation.resolve).toSet + val missingInAggs = resolved -- a.outputSet + logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") + if (missingInAggs.nonEmpty) { + // Add missing grouping exprs and then project them away after the sort. + Project(a.output, + Sort(ordering, + Aggregate(grouping, aggs ++ missingInAggs, child))) + } else { + s // Nothing we can do here. Return original plan. + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 888cb08e95f06..278569f0cb14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -72,16 +72,29 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as + * Optionally resolves the given string to a [[NamedExpression]] using the input from all child + * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String): Option[NamedExpression] = { + def resolveChildren(name: String): Option[NamedExpression] = + resolve(name, children.flatMap(_.output)) + + /** + * Optionally resolves the given string to a [[NamedExpression]] based on the output of this + * LogicalPlan. The attribute is expressed as string in the following form: + * `[scope].AttributeName.[nested].[fields]...`. + */ + def resolve(name: String): Option[NamedExpression] = + resolve(name, output) + + /** Performs attribute resolution given a name and a sequence of possible attributes. */ + protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { val parts = name.split("\\.") // Collect all attributes that are output by this nodes children where either the first part // matches the name or where the first part matches the scope and the second part matches the // name. Return these matches along with any remaining parts, which represent dotted access to // struct fields. - val options = children.flatMap(_.output).flatMap { option => + val options = input.flatMap { option => // If the first part of the desired name matches a qualifier for this possible match, drop it. val remainingParts = if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts @@ -89,15 +102,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { } options.distinct match { - case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it. + case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. // One match, but we also need to extract the requested nested field. - case (a, nestedFields) :: Nil => + case Seq((a, nestedFields)) => a.dataType match { case StructType(fields) => Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) case _ => None // Don't know how to resolve these field references } - case Nil => None // No matches. + case Seq() => None // No matches. case ambiguousReferences => throw new TreeNodeException( this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 40bfd55e95a12..0fd7aaaa36eb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql +import scala.collection.immutable +import scala.collection.JavaConversions._ + import java.util.Properties -import scala.collection.JavaConverters._ -object SQLConf { +private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" - val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" - val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" @@ -66,13 +66,13 @@ trait SQLConf { * Note that the choice of dialect does not affect things like what tables are available or * how query execution is performed. */ - private[spark] def dialect: String = get(DIALECT, "sql") + private[spark] def dialect: String = getConf(DIALECT, "sql") /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = get(COMPRESS_CACHED, "false").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt + private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode @@ -84,7 +84,7 @@ trait SQLConf { * Defaults to false as this feature is currently experimental. */ private[spark] def codegenEnabled: Boolean = - if (get(CODEGEN_ENABLED, "false") == "true") true else false + if (getConf(CODEGEN_ENABLED, "false") == "true") true else false /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -94,7 +94,7 @@ trait SQLConf { * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. */ private[spark] def autoBroadcastJoinThreshold: Int = - get(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt + getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, @@ -102,41 +102,40 @@ trait SQLConf { * properly implemented estimation of this statistic will not be incorrectly broadcasted in joins. */ private[spark] def defaultSizeInBytes: Long = - getOption(DEFAULT_SIZE_IN_BYTES).map(_.toLong).getOrElse(autoBroadcastJoinThreshold + 1) + getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong /** ********************** SQLConf functionality methods ************ */ - def set(props: Properties): Unit = { - settings.synchronized { - props.asScala.foreach { case (k, v) => settings.put(k, v) } - } + /** Set Spark SQL configuration properties. */ + def setConf(props: Properties): Unit = settings.synchronized { + props.foreach { case (k, v) => settings.put(k, v) } } - def set(key: String, value: String): Unit = { + /** Set the given Spark SQL configuration property. */ + def setConf(key: String, value: String): Unit = { require(key != null, "key cannot be null") require(value != null, s"value cannot be null for key: $key") settings.put(key, value) } - def get(key: String): String = { + /** Return the value of Spark SQL configuration property for the given key. */ + def getConf(key: String): String = { Option(settings.get(key)).getOrElse(throw new NoSuchElementException(key)) } - def get(key: String, defaultValue: String): String = { + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. + */ + def getConf(key: String, defaultValue: String): String = { Option(settings.get(key)).getOrElse(defaultValue) } - def getAll: Array[(String, String)] = settings.synchronized { settings.asScala.toArray } - - def getOption(key: String): Option[String] = Option(settings.get(key)) - - def contains(key: String): Boolean = settings.containsKey(key) - - def toDebugString: String = { - settings.synchronized { - settings.asScala.toArray.sorted.map{ case (k, v) => s"$k=$v" }.mkString("\n") - } - } + /** + * Return all the configuration properties that have been set (i.e. not the default). + * This creates a new copy of the config properties in the form of a Map. + */ + def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } private[spark] def clear() { settings.clear() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ecd5fbaa0b094..71d338d21d0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -491,7 +491,10 @@ class SQLContext(@transient val sparkContext: SparkContext) new java.sql.Timestamp(c.getTime().getTime()) case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + case (c: Long, IntegerType) => c.toInt case (c: Double, FloatType) => c.toFloat case (c, StringType) if !c.isInstanceOf[String] => c.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index c416a745739b3..7e7bb2859bbcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -118,7 +118,7 @@ private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) private[sql] object ColumnBuilder { - val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104 + val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { if (orig.remaining >= size) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index d008806eedbe1..f631ee76fcd78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -36,9 +36,9 @@ import org.apache.spark.sql.Row * }}} */ private[sql] trait NullableColumnBuilder extends ColumnBuilder { - private var nulls: ByteBuffer = _ + protected var nulls: ByteBuffer = _ + protected var nullCount: Int = _ private var pos: Int = _ - private var nullCount: Int = _ abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { nulls = ByteBuffer.allocate(1024) @@ -78,4 +78,9 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { buffer.rewind() buffer } + + protected def buildNonNulls(): ByteBuffer = { + nulls.limit(nulls.position()).rewind() + super.build() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 6ad12a0dcb64d..a5826bb033e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -46,8 +46,6 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] this: NativeColumnBuilder[T] with WithCompressionSchemes => - import CompressionScheme._ - var compressionEncoders: Seq[Encoder[T]] = _ abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { @@ -81,28 +79,32 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] } } - abstract override def build() = { - val rawBuffer = super.build() + override def build() = { + val nonNullBuffer = buildNonNulls() + val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) if (isWorthCompressing(candidate)) candidate else PassThrough.encoder } - val headerSize = columnHeaderSize(rawBuffer) + // Header = column type ID + null count + null positions + val headerSize = 4 + 4 + nulls.limit() val compressedSize = if (encoder.compressedSize == 0) { - rawBuffer.limit - headerSize + nonNullBuffer.remaining() } else { encoder.compressedSize } - // Reserves 4 bytes for compression scheme ID val compressedBuffer = ByteBuffer + // Reserves 4 bytes for compression scheme ID .allocate(headerSize + 4 + compressedSize) .order(ByteOrder.nativeOrder) - - copyColumnHeader(rawBuffer, compressedBuffer) + // Write the header + .putInt(typeId) + .putInt(nullCount) + .put(nulls) logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") - encoder.compress(rawBuffer, compressedBuffer, columnType) + encoder.compress(nonNullBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index ba1810dd2ae66..7797f75177893 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -67,22 +67,6 @@ private[sql] object CompressionScheme { s"Unrecognized compression scheme type ID: $typeId")) } - def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) { - // Writes column type ID - to.putInt(from.getInt()) - - // Writes null count - val nullCount = from.getInt() - to.putInt(nullCount) - - // Writes null positions - var i = 0 - while (i < nullCount) { - to.putInt(from.getInt()) - i += 1 - } - } - def columnHeaderSize(columnBuffer: ByteBuffer): Int = { val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder) val nullCount = header.getInt(4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 9293239131d52..38f37564f1788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -53,10 +53,10 @@ case class SetCommand( if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.set(SQLConf.SHUFFLE_PARTITIONS, v) + context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") } else { - context.set(k, v) + context.setConf(k, v) Array(s"$k=$v") } @@ -77,14 +77,14 @@ case class SetCommand( "system:sun.java.command=shark.SharkServer2") } else { - Array(s"$k=${context.getOption(k).getOrElse("")}") + Array(s"$k=${context.getConf(k, "")}") } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - context.getAll.map { case (k, v) => + context.getAllConfs.map { case (k, v) => s"$k=$v" - } + }.toSeq case _ => throw new IllegalArgumentException() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 1a58d73d9e7f4..584f71b3c13d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -29,21 +29,18 @@ class SQLConfSuite extends QueryTest { test("programmatic ways of basic setting and getting") { clear() - assert(getOption(testKey).isEmpty) - assert(getAll.toSet === Set()) + assert(getAllConfs.size === 0) - set(testKey, testVal) - assert(get(testKey) == testVal) - assert(get(testKey, testVal + "_") == testVal) - assert(getOption(testKey) == Some(testVal)) - assert(contains(testKey)) + setConf(testKey, testVal) + assert(getConf(testKey) == testVal) + assert(getConf(testKey, testVal + "_") == testVal) + assert(getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.get(testKey) == testVal) - assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getOption(testKey) == Some(testVal)) - assert(TestSQLContext.contains(testKey)) + assert(TestSQLContext.getConf(testKey) == testVal) + assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getAllConfs.contains(testKey)) clear() } @@ -51,21 +48,21 @@ class SQLConfSuite extends QueryTest { test("parse SQL set commands") { clear() sql(s"set $testKey=$testVal") - assert(get(testKey, testVal + "_") == testVal) - assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + assert(getConf(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) sql("set some.property=20") - assert(get("some.property", "0") == "20") + assert(getConf("some.property", "0") == "20") sql("set some.property = 40") - assert(get("some.property", "0") == "40") + assert(getConf("some.property", "0") == "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(get(key, "0") == vs) + assert(getConf(key, "0") == vs) sql(s"set $key=") - assert(get(key, "0") == "") + assert(getConf(key, "0") == "") clear() } @@ -73,6 +70,6 @@ class SQLConfSuite extends QueryTest { test("deprecated property") { clear() sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10") + assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 6d688ea95cfc0..72c19fa31d980 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -42,4 +42,3 @@ object TestCompressibleColumnBuilder { builder } } - diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 7fac90fdc596d..c6f60c18804a4 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-hive-thriftserver_2.10 jar - Spark Project Hive + Spark Project Hive Thrift Server http://spark.apache.org/ hive-thriftserver diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d8e7a5943daa5..53f3dc11dbb9f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -60,9 +60,9 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Sets up the system initially or after a RESET command */ protected def configure() { - set("javax.jdo.option.ConnectionURL", + setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") - set("hive.metastore.warehouse.dir", warehousePath) + setConf("hive.metastore.warehouse.dir", warehousePath) } configure() // Must be called before initializing the catalog below. @@ -76,7 +76,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => // Change the default SQL dialect to HiveQL - override private[spark] def dialect: String = get(SQLConf.DIALECT, "hiveql") + override private[spark] def dialect: String = getConf(SQLConf.DIALECT, "hiveql") override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -224,15 +224,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) - set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. ss } sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") - override def set(key: String, value: String): Unit = { - super.set(key, value) + override def setConf(key: String, value: String): Unit = { + super.setConf(key, value) runSqlHive(s"SET $key=$value") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index c605e8adcfb0f..d890df866fbe5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -65,9 +65,9 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Sets up the system initially or after a RESET command */ protected def configure() { - set("javax.jdo.option.ConnectionURL", + setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") - set("hive.metastore.warehouse.dir", warehousePath) + setConf("hive.metastore.warehouse.dir", warehousePath) } configure() // Must be called before initializing the catalog below. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2f0be49b6a6d7..fdb2f41f5a5b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -75,9 +75,9 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") test("Query expressed in SQL") { - set("spark.sql.dialect", "sql") + setConf("spark.sql.dialect", "sql") assert(sql("SELECT 1").collect() === Array(Seq(1))) - set("spark.sql.dialect", "hiveql") + setConf("spark.sql.dialect", "hiveql") } @@ -436,18 +436,18 @@ class HiveQuerySuite extends HiveComparisonTest { val testVal = "val0,val_1,val2.3,my_table" sql(s"set $testKey=$testVal") - assert(get(testKey, testVal + "_") == testVal) + assert(getConf(testKey, testVal + "_") == testVal) sql("set some.property=20") - assert(get("some.property", "0") == "20") + assert(getConf("some.property", "0") == "20") sql("set some.property = 40") - assert(get("some.property", "0") == "40") + assert(getConf("some.property", "0") == "40") sql(s"set $testKey=$testVal") - assert(get(testKey, "0") == testVal) + assert(getConf(testKey, "0") == testVal) sql(s"set $testKey=") - assert(get(testKey, "0") == "") + assert(getConf(testKey, "0") == "") } test("SET commands semantics for a HiveContext") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala new file mode 100644 index 0000000000000..635a9fb0d56cb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * A collection of hive query tests where we generate the answers ourselves instead of depending on + * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is + * valid, but Hive currently cannot execute it. + */ +class SQLQuerySuite extends QueryTest { + test("ordering not in select") { + checkAnswer( + sql("SELECT key FROM src ORDER BY value"), + sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) + } + + test("ordering not in agg") { + checkAnswer( + sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), + sql(""" + SELECT key + FROM ( + SELECT key, value + FROM src + GROUP BY key, value + ORDER BY value) a""").collect().toSeq) + } +}