diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index af4456c05b0a1..b153a7b08e590 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
val key = if (!isLocal && scheme == "file") {
- env.httpFileServer.addFile(new File(uri.getPath))
+ env.rpcEnv.fileServer.addFile(new File(uri.getPath))
} else {
schemeCorrectedPath
}
@@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
var key = ""
if (path.contains("\\")) {
// For local paths with backslashes on Windows, URI throws an exception
- key = env.httpFileServer.addJar(new File(path))
+ key = env.rpcEnv.fileServer.addJar(new File(path))
} else {
val uri = new URI(path)
key = uri.getScheme match {
@@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// of the AM to make it show up in the current working directory.
val fileName = new Path(uri.getPath).getName()
try {
- env.httpFileServer.addJar(new File(fileName))
+ env.rpcEnv.fileServer.addJar(new File(fileName))
} catch {
case e: Exception =>
// For now just log an error but allow to go through so spark examples work.
@@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
} else {
try {
- env.httpFileServer.addJar(new File(uri.getPath))
+ env.rpcEnv.fileServer.addJar(new File(uri.getPath))
} catch {
case exc: FileNotFoundException =>
logError(s"Jar not found at $path")
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 88df27f733f2a..84230e32a4462 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -66,7 +66,6 @@ class SparkEnv (
val blockTransferService: BlockTransferService,
val blockManager: BlockManager,
val securityManager: SecurityManager,
- val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
@@ -91,7 +90,6 @@ class SparkEnv (
if (!isStopped) {
isStopped = true
pythonWorkers.values.foreach(_.stop())
- Option(httpFileServer).foreach(_.stop())
mapOutputTracker.stop()
shuffleManager.stop()
broadcastManager.stop()
@@ -367,17 +365,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
- val httpFileServer =
- if (isDriver) {
- val fileServerPort = conf.getInt("spark.fileserver.port", 0)
- val server = new HttpFileServer(conf, securityManager, fileServerPort)
- server.initialize()
- conf.set("spark.fileserver.uri", server.serverUri)
- server
- } else {
- null
- }
-
val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
@@ -422,7 +409,6 @@ object SparkEnv extends Logging {
blockTransferService,
blockManager,
securityManager,
- httpFileServer,
sparkFilesDir,
metricsSystem,
memoryManager,
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index a560fd10cdf76..3d7d281b0dd66 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -17,6 +17,9 @@
package org.apache.spark.rpc
+import java.io.File
+import java.nio.channels.ReadableByteChannel
+
import scala.concurrent.Future
import org.apache.spark.{SecurityManager, SparkConf}
@@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
*/
def deserialize[T](deserializationAction: () => T): T
+
+ /**
+ * Return the instance of the file server used to serve files. This may be `null` if the
+ * RpcEnv is not operating in server mode.
+ */
+ def fileServer: RpcEnvFileServer
+
+ /**
+ * Open a channel to download a file from the given URI. If the URIs returned by the
+ * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to
+ * retrieve the files.
+ *
+ * @param uri URI with location of the file.
+ */
+ def openChannel(uri: String): ReadableByteChannel
+
}
+/**
+ * A server used by the RpcEnv to server files to other processes owned by the application.
+ *
+ * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or
+ * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`.
+ */
+private[spark] trait RpcEnvFileServer {
+
+ /**
+ * Adds a file to be served by this RpcEnv. This is used to serve files from the driver
+ * to executors when they're stored on the driver's local file system.
+ *
+ * @param file Local file to serve.
+ * @return A URI for the location of the file.
+ */
+ def addFile(file: File): String
+
+ /**
+ * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using
+ * `SparkContext.addJar`.
+ *
+ * @param file Local file to serve.
+ * @return A URI for the location of the file.
+ */
+ def addJar(file: File): String
+
+}
private[spark] case class RpcEnvConfig(
conf: SparkConf,
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 059a7e10ec12f..94dbec593c315 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rpc.akka
+import java.io.File
+import java.nio.channels.ReadableByteChannel
import java.util.concurrent.ConcurrentHashMap
import scala.concurrent.Future
@@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
import akka.serialization.JavaSerializer
-import org.apache.spark.{SparkException, Logging, SparkConf}
+import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.rpc._
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
@@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}
* remove Akka from the dependencies.
*/
private[spark] class AkkaRpcEnv private[akka] (
- val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int)
+ val actorSystem: ActorSystem,
+ val securityManager: SecurityManager,
+ conf: SparkConf,
+ boundPort: Int)
extends RpcEnv(conf) with Logging {
private val defaultAddress: RpcAddress = {
@@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] (
*/
private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]()
+ private val _fileServer = new AkkaFileServer(conf, securityManager)
+
private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = {
endpointToRef.put(endpoint, endpointRef)
refToEndpoint.put(endpointRef, endpoint)
@@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] (
override def shutdown(): Unit = {
actorSystem.shutdown()
+ _fileServer.shutdown()
}
override def stop(endpoint: RpcEndpointRef): Unit = {
@@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] (
deserializationAction()
}
}
+
+ override def openChannel(uri: String): ReadableByteChannel = {
+ throw new UnsupportedOperationException(
+ "AkkaRpcEnv's files should be retrieved using an HTTP client.")
+ }
+
+ override def fileServer: RpcEnvFileServer = _fileServer
+
+}
+
+private[akka] class AkkaFileServer(
+ conf: SparkConf,
+ securityManager: SecurityManager) extends RpcEnvFileServer {
+
+ @volatile private var httpFileServer: HttpFileServer = _
+
+ override def addFile(file: File): String = {
+ getFileServer().addFile(file)
+ }
+
+ override def addJar(file: File): String = {
+ getFileServer().addJar(file)
+ }
+
+ def shutdown(): Unit = {
+ if (httpFileServer != null) {
+ httpFileServer.stop()
+ }
+ }
+
+ private def getFileServer(): HttpFileServer = {
+ if (httpFileServer == null) synchronized {
+ if (httpFileServer == null) {
+ httpFileServer = startFileServer()
+ }
+ }
+ httpFileServer
+ }
+
+ private def startFileServer(): HttpFileServer = {
+ val fileServerPort = conf.getInt("spark.fileserver.port", 0)
+ val server = new HttpFileServer(conf, securityManager, fileServerPort)
+ server.initialize()
+ server
+ }
+
}
private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
@@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
config.name, config.host, config.port, config.conf, config.securityManager)
actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
- new AkkaRpcEnv(actorSystem, config.conf, boundPort)
+ new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 3ce359868039b..68701f609f77a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -20,6 +20,7 @@ import java.io._
import java.lang.{Boolean => JBoolean}
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
+import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.Nullable
@@ -45,27 +46,39 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
- private val transportConf = SparkTransportConf.fromSparkConf(
+ private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc",
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
+ private val streamManager = new NettyStreamManager(this)
+
private val transportContext = new TransportContext(transportConf,
- new NettyRpcHandler(dispatcher, this))
+ new NettyRpcHandler(dispatcher, this, streamManager))
- private val clientFactory = {
- val bootstraps: java.util.List[TransportClientBootstrap] =
- if (securityManager.isAuthenticationEnabled()) {
- java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
- securityManager.isSaslEncryptionEnabled()))
- } else {
- java.util.Collections.emptyList[TransportClientBootstrap]
- }
- transportContext.createClientFactory(bootstraps)
+ private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
+ if (securityManager.isAuthenticationEnabled()) {
+ java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
+ securityManager.isSaslEncryptionEnabled()))
+ } else {
+ java.util.Collections.emptyList[TransportClientBootstrap]
+ }
}
+ private val clientFactory = transportContext.createClientFactory(createClientBootstraps())
+
+ /**
+ * A separate client factory for file downloads. This avoids using the same RPC handler as
+ * the main RPC context, so that events caused by these clients are kept isolated from the
+ * main RPC traffic.
+ *
+ * It also allows for different configuration of certain properties, such as the number of
+ * connections per peer.
+ */
+ @volatile private var fileDownloadFactory: TransportClientFactory = _
+
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
@@ -292,6 +305,9 @@ private[netty] class NettyRpcEnv(
if (clientConnectionExecutor != null) {
clientConnectionExecutor.shutdownNow()
}
+ if (fileDownloadFactory != null) {
+ fileDownloadFactory.close()
+ }
}
override def deserialize[T](deserializationAction: () => T): T = {
@@ -300,6 +316,96 @@ private[netty] class NettyRpcEnv(
}
}
+ override def fileServer: RpcEnvFileServer = streamManager
+
+ override def openChannel(uri: String): ReadableByteChannel = {
+ val parsedUri = new URI(uri)
+ require(parsedUri.getHost() != null, "Host name must be defined.")
+ require(parsedUri.getPort() > 0, "Port must be defined.")
+ require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.")
+
+ val pipe = Pipe.open()
+ val source = new FileDownloadChannel(pipe.source())
+ try {
+ val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
+ val callback = new FileDownloadCallback(pipe.sink(), source, client)
+ client.stream(parsedUri.getPath(), callback)
+ } catch {
+ case e: Exception =>
+ pipe.sink().close()
+ source.close()
+ throw e
+ }
+
+ source
+ }
+
+ private def downloadClient(host: String, port: Int): TransportClient = {
+ if (fileDownloadFactory == null) synchronized {
+ if (fileDownloadFactory == null) {
+ val module = "files"
+ val prefix = "spark.rpc.io."
+ val clone = conf.clone()
+
+ // Copy any RPC configuration that is not overridden in the spark.files namespace.
+ conf.getAll.foreach { case (key, value) =>
+ if (key.startsWith(prefix)) {
+ val opt = key.substring(prefix.length())
+ clone.setIfMissing(s"spark.$module.io.$opt", value)
+ }
+ }
+
+ val ioThreads = clone.getInt("spark.files.io.threads", 1)
+ val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads)
+ val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true)
+ fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps())
+ }
+ }
+ fileDownloadFactory.createClient(host, port)
+ }
+
+ private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
+
+ @volatile private var error: Throwable = _
+
+ def setError(e: Throwable): Unit = error = e
+
+ override def read(dst: ByteBuffer): Int = {
+ if (error != null) {
+ throw error
+ }
+ source.read(dst)
+ }
+
+ override def close(): Unit = source.close()
+
+ override def isOpen(): Boolean = source.isOpen()
+
+ }
+
+ private class FileDownloadCallback(
+ sink: WritableByteChannel,
+ source: FileDownloadChannel,
+ client: TransportClient) extends StreamCallback {
+
+ override def onData(streamId: String, buf: ByteBuffer): Unit = {
+ while (buf.remaining() > 0) {
+ sink.write(buf)
+ }
+ }
+
+ override def onComplete(streamId: String): Unit = {
+ sink.close()
+ }
+
+ override def onFailure(streamId: String, cause: Throwable): Unit = {
+ logError(s"Error downloading stream $streamId.", cause)
+ source.setError(cause)
+ sink.close()
+ }
+
+ }
+
}
private[netty] object NettyRpcEnv extends Logging {
@@ -420,7 +526,7 @@ private[netty] class NettyRpcEndpointRef(
override def toString: String = s"NettyRpcEndpointRef(${_address})"
- def toURI: URI = new URI(s"spark://${_address}")
+ def toURI: URI = new URI(_address.toString)
final override def equals(that: Any): Boolean = that match {
case other: NettyRpcEndpointRef => _address == other._address
@@ -471,7 +577,9 @@ private[netty] case class RpcFailure(e: Throwable)
* with different `RpcAddress` information).
*/
private[netty] class NettyRpcHandler(
- dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
+ dispatcher: Dispatcher,
+ nettyEnv: NettyRpcEnv,
+ streamManager: StreamManager) extends RpcHandler with Logging {
// TODO: Can we add connection callback (channel registered) to the underlying framework?
// A variable to track whether we should dispatch the RemoteProcessConnected message.
@@ -498,7 +606,7 @@ private[netty] class NettyRpcHandler(
dispatcher.postRemoteMessage(messageToDispatch, callback)
}
- override def getStreamManager: StreamManager = new OneForOneStreamManager
+ override def getStreamManager: StreamManager = streamManager
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
@@ -516,8 +624,8 @@ private[netty] class NettyRpcHandler(
override def connectionTerminated(client: TransportClient): Unit = {
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
- val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
clients.remove(client)
+ val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
} else {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
new file mode 100644
index 0000000000000..eb1d2604fb235
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rpc.netty
+
+import java.io.File
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.server.StreamManager
+import org.apache.spark.rpc.RpcEnvFileServer
+
+/**
+ * StreamManager implementation for serving files from a NettyRpcEnv.
+ */
+private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
+ extends StreamManager with RpcEnvFileServer {
+
+ private val files = new ConcurrentHashMap[String, File]()
+ private val jars = new ConcurrentHashMap[String, File]()
+
+ override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {
+ throw new UnsupportedOperationException()
+ }
+
+ override def openStream(streamId: String): ManagedBuffer = {
+ val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2)
+ val file = ftype match {
+ case "files" => files.get(fname)
+ case "jars" => jars.get(fname)
+ case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype")
+ }
+
+ require(file != null, s"File not found: $streamId")
+ new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
+ }
+
+ override def addFile(file: File): String = {
+ require(files.putIfAbsent(file.getName(), file) == null,
+ s"File ${file.getName()} already registered.")
+ s"${rpcEnv.address.toSparkURL}/files/${file.getName()}"
+ }
+
+ override def addJar(file: File): String = {
+ require(jars.putIfAbsent(file.getName(), file) == null,
+ s"JAR ${file.getName()} already registered.")
+ s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}"
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 1b3acb8ef7f51..af632349c9cae 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,6 +21,7 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
+import java.nio.channels.Channels
import java.util.concurrent._
import java.util.{Locale, Properties, Random, UUID}
import javax.net.ssl.HttpsURLConnection
@@ -535,6 +536,14 @@ private[spark] object Utils extends Logging {
val uri = new URI(url)
val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
Option(uri.getScheme).getOrElse("file") match {
+ case "spark" =>
+ if (SparkEnv.get == null) {
+ throw new IllegalStateException(
+ "Cannot retrieve files with 'spark' scheme without an active SparkEnv.")
+ }
+ val source = SparkEnv.get.rpcEnv.openChannel(url)
+ val is = Channels.newInputStream(source)
+ downloadFile(url, is, targetFile, fileOverwrite)
case "http" | "https" | "ftp" =>
var uc: URLConnection = null
if (securityMgr.isAuthenticationEnabled()) {
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 2f55006420ce1..2b664c6313efa 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.rpc
-import java.io.NotSerializableException
+import java.io.{File, NotSerializableException}
+import java.util.UUID
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException}
import scala.collection.mutable
@@ -25,10 +27,14 @@ import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
+import com.google.common.io.Files
+import org.mockito.Mockito.{mock, when}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.Utils
/**
* Common tests for an RpcEnv implementation.
@@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override def beforeAll(): Unit = {
val conf = new SparkConf()
env = createRpcEnv(conf, "local", 0)
+
+ val sparkEnv = mock(classOf[SparkEnv])
+ when(sparkEnv.rpcEnv).thenReturn(env)
+ SparkEnv.set(sparkEnv)
}
override def afterAll(): Unit = {
if (env != null) {
env.shutdown()
}
+ SparkEnv.set(null)
}
def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv
@@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1)
}
+ test("file server") {
+ val conf = new SparkConf()
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir, "file")
+ Files.write(UUID.randomUUID().toString(), file, UTF_8)
+ val jar = new File(tempDir, "jar")
+ Files.write(UUID.randomUUID().toString(), jar, UTF_8)
+
+ val fileUri = env.fileServer.addFile(file)
+ val jarUri = env.fileServer.addJar(jar)
+
+ val destDir = Utils.createTempDir()
+ val destFile = new File(destDir, file.getName())
+ val destJar = new File(destDir, jar.getName())
+
+ val sm = new SecurityManager(conf)
+ val hc = SparkHadoopUtil.get.conf
+ Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false)
+ Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false)
+
+ assert(Files.equal(file, destFile))
+ assert(Files.equal(jar, destJar))
+ }
+
}
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index f9d8e80c98b66..ccca795683da3 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -25,17 +25,19 @@ import org.mockito.Matchers._
import org.apache.spark.SparkFunSuite
import org.apache.spark.network.client.{TransportResponseHandler, TransportClient}
+import org.apache.spark.network.server.StreamManager
import org.apache.spark.rpc._
class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv])
- when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())).
- thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+ val sm = mock(classOf[StreamManager])
+ when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
+ .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
test("receive") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
@@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
test("connectionTerminated") {
val dispatcher = mock(classOf[Dispatcher])
- val nettyRpcHandler = new NettyRpcHandler(dispatcher, env)
+ val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm)
val channel = mock(classOf[Channel])
val client = new TransportClient(channel, mock(classOf[TransportResponseHandler]))
diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml
new file mode 100644
index 0000000000000..dff3d33bf4ed1
--- /dev/null
+++ b/docs/_data/menu-ml.yaml
@@ -0,0 +1,10 @@
+- text: Feature extraction, transformation, and selection
+ url: ml-features.html
+- text: Decision trees for classification and regression
+ url: ml-decision-tree.html
+- text: Ensembles
+ url: ml-ensembles.html
+- text: Linear methods with elastic-net regularization
+ url: ml-linear-methods.html
+- text: Multilayer perceptron classifier
+ url: ml-ann.html
diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml
new file mode 100644
index 0000000000000..12d22abd52826
--- /dev/null
+++ b/docs/_data/menu-mllib.yaml
@@ -0,0 +1,75 @@
+- text: Data types
+ url: mllib-data-types.html
+- text: Basic statistics
+ url: mllib-statistics.html
+ subitems:
+ - text: Summary statistics
+ url: mllib-statistics.html#summary-statistics
+ - text: Correlations
+ url: mllib-statistics.html#correlations
+ - text: Stratified sampling
+ url: mllib-statistics.html#stratified-sampling
+ - text: Hypothesis testing
+ url: mllib-statistics.html#hypothesis-testing
+ - text: Random data generation
+ url: mllib-statistics.html#random-data-generation
+- text: Classification and regression
+ url: mllib-classification-regression.html
+ subitems:
+ - text: Linear models (SVMs, logistic regression, linear regression)
+ url: mllib-linear-methods.html
+ - text: Naive Bayes
+ url: mllib-naive-bayes.html
+ - text: decision trees
+ url: mllib-decision-tree.html
+ - text: ensembles of trees (Random Forests and Gradient-Boosted Trees)
+ url: mllib-ensembles.html
+ - text: isotonic regression
+ url: mllib-isotonic-regression.html
+- text: Collaborative filtering
+ url: mllib-collaborative-filtering.html
+ subitems:
+ - text: alternating least squares (ALS)
+ url: mllib-collaborative-filtering.html#collaborative-filtering
+- text: Clustering
+ url: mllib-clustering.html
+ subitems:
+ - text: k-means
+ url: mllib-clustering.html#k-means
+ - text: Gaussian mixture
+ url: mllib-clustering.html#gaussian-mixture
+ - text: power iteration clustering (PIC)
+ url: mllib-clustering.html#power-iteration-clustering-pic
+ - text: latent Dirichlet allocation (LDA)
+ url: mllib-clustering.html#latent-dirichlet-allocation-lda
+ - text: streaming k-means
+ url: mllib-clustering.html#streaming-k-means
+- text: Dimensionality reduction
+ url: mllib-dimensionality-reduction.html
+ subitems:
+ - text: singular value decomposition (SVD)
+ url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd
+ - text: principal component analysis (PCA)
+ url: mllib-dimensionality-reduction.html#principal-component-analysis-pca
+- text: Feature extraction and transformation
+ url: mllib-feature-extraction.html
+- text: Frequent pattern mining
+ url: mllib-frequent-pattern-mining.html
+ subitems:
+ - text: FP-growth
+ url: mllib-frequent-pattern-mining.html#fp-growth
+ - text: association rules
+ url: mllib-frequent-pattern-mining.html#association-rules
+ - text: PrefixSpan
+ url: mllib-frequent-pattern-mining.html#prefix-span
+- text: Evaluation metrics
+ url: mllib-evaluation-metrics.html
+- text: PMML model export
+ url: mllib-pmml-model-export.html
+- text: Optimization (developer)
+ url: mllib-optimization.html
+ subitems:
+ - text: stochastic gradient descent
+ url: mllib-optimization.html#stochastic-gradient-descent-sgd
+ - text: limited-memory BFGS (L-BFGS)
+ url: mllib-optimization.html#limited-memory-bfgs-l-bfgs
diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html
new file mode 100644
index 0000000000000..0103e890cc21a
--- /dev/null
+++ b/docs/_includes/nav-left-wrapper-ml.html
@@ -0,0 +1,8 @@
+
\ No newline at end of file
diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html
new file mode 100644
index 0000000000000..73176f4132554
--- /dev/null
+++ b/docs/_includes/nav-left.html
@@ -0,0 +1,17 @@
+{% assign navurl = page.url | remove: 'index.html' %}
+
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 467ff7a03fb70..1b09e2221e173 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -124,16 +124,24 @@
-
- {% if page.displayTitle %}
-
{{ page.displayTitle }}
- {% else %}
-
{{ page.title }}
- {% endif %}
+
- {{ content }}
+ {% if page.url contains "/ml" %}
+ {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %}
+ {% endif %}
-
+
+
+ {% if page.displayTitle %}
+
{{ page.displayTitle }}
+ {% else %}
+ {{ page.title }}
+ {% endif %}
+
+ {{ content }}
+
+
+
diff --git a/docs/configuration.md b/docs/configuration.md
index c496146e3ed63..4de202d7f7631 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1020,6 +1020,7 @@ Apart from these, the following properties are also available, and may be useful
(random) |
Port for the executor to listen on. This is used for communicating with the driver.
+ This is only relevant when using the Akka RPC backend.
|
@@ -1027,6 +1028,7 @@ Apart from these, the following properties are also available, and may be useful
(random) |
Port for the driver's HTTP file server to listen on.
+ This is only relevant when using the Akka RPC backend.
|
diff --git a/docs/css/main.css b/docs/css/main.css
index d770173be1014..356b324d6303b 100755
--- a/docs/css/main.css
+++ b/docs/css/main.css
@@ -39,8 +39,18 @@
margin-left: 10px;
}
+body .container-wrapper {
+ position: absolute;
+ width: 100%;
+ display: flex;
+}
+
body #content {
+ position: relative;
+
line-height: 1.6; /* Inspired by Github's wiki style */
+ background-color: white;
+ padding-left: 15px;
}
.title {
@@ -155,3 +165,30 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu {
* AnchorJS (anchor links when hovering over headers)
*/
a.anchorjs-link:hover { text-decoration: none; }
+
+
+/**
+ * The left navigation bar.
+ */
+.left-menu-wrapper {
+ position: absolute;
+ height: 100%;
+
+ width: 256px;
+ margin-top: -20px;
+ padding-top: 20px;
+ background-color: #F0F8FC;
+}
+
+.left-menu {
+ position: fixed;
+ max-width: 350px;
+
+ padding-right: 10px;
+ width: 256px;
+}
+
+.left-menu h3 {
+ margin-left: 10px;
+ line-height: 30px;
+}
\ No newline at end of file
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index db6bfa69ee0fe..925a1e0ba6fcf 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -326,6 +326,15 @@ If you need a reference to the proper location to put log files in the YARN so t
Otherwise, the client process will exit after submission.
+
+ spark.yarn.am.nodeLabelExpression |
+ (none) |
+
+ A YARN node label expression that restricts the set of nodes AM will be scheduled on.
+ Only versions of YARN greater than or equal to 2.6 support node label expressions, so when
+ running against earlier versions, this property will be ignored.
+ |
+
spark.yarn.executor.nodeLabelExpression |
(none) |
diff --git a/docs/security.md b/docs/security.md
index 177109415180b..e1af221d446b0 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -149,7 +149,8 @@ configure those ports.
(random) |
Schedule tasks |
spark.executor.port |
- Akka-based. Set to "0" to choose a port randomly. |
+ Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is
+ configured. |
Executor |
@@ -157,7 +158,7 @@ configure those ports.
(random) |
File server for files and jars |
spark.fileserver.port |
- Jetty-based |
+ Jetty-based. Only used if Akka RPC backend is configured. |
Executor |
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 96b36b7a73209..ed6b28c282135 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -723,7 +723,7 @@ Some of these advanced sources are as follows.
- **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details.
-- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using
+- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j to get the public stream of tweets using
[Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information
can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by
Twitter4J library. You can either get the public stream, or get the filtered stream based on a
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 9327e21e43db7..9fd652a3df4c4 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -595,7 +595,7 @@ def launch_cluster(conn, opts, cluster_name):
dev = BlockDeviceType()
dev.ephemeral_name = 'ephemeral%d' % i
# The first ephemeral drive is /dev/sdb.
- name = '/dev/sd' + string.letters[i + 1]
+ name = '/dev/sd' + string.ascii_letters[i + 1]
block_map[name] = dev
# Launch slaves
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
index 593f8fb3e9fe9..4ad7676c8d32b 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
@@ -37,7 +37,7 @@ public static void main(String[] args) {
// $example on$
// Load training data
DataFrame training = sqlContext.read().format("libsvm")
- .load("data/mllib/sample_libsvm_data.txt");
+ .load("data/mllib/sample_linear_regression_data.txt");
LinearRegression lr = new LinearRegression()
.setMaxIter(10)
diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py
index b0278276330c3..a4cd40cf26726 100644
--- a/examples/src/main/python/ml/linear_regression_with_elastic_net.py
+++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py
@@ -29,7 +29,8 @@
# $example on$
# Load training data
- training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
+ training = sqlContext.read.format("libsvm")\
+ .load("data/mllib/sample_linear_regression_data.txt")
lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
similarity index 51%
rename from examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
rename to examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
index dc13f82488af7..424f00158c2f2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
@@ -16,7 +16,7 @@
*/
// scalastyle:off println
-package org.apache.spark.examples.mllib
+package org.apache.spark.examples.ml
import java.io.File
@@ -24,25 +24,22 @@ import com.google.common.io.Files
import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
-import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext, DataFrame}
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
/**
- * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with
+ * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with
* {{{
- * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * ./bin/run-example ml.DataFrameExample [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
-object DatasetExample {
+object DataFrameExample {
- case class Params(
- input: String = "data/mllib/sample_libsvm_data.txt",
- dataFormat: String = "libsvm") extends AbstractParams[Params]
+ case class Params(input: String = "data/mllib/sample_libsvm_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -52,9 +49,6 @@ object DatasetExample {
opt[String]("input")
.text(s"input path to dataset")
.action((x, c) => c.copy(input = x))
- opt[String]("dataFormat")
- .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
- .action((x, c) => c.copy(input = x))
checkConfig { params =>
success
}
@@ -69,55 +63,42 @@ object DatasetExample {
def run(params: Params) {
- val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+ val conf = new SparkConf().setAppName(s"DataFrameExample with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
- import sqlContext.implicits._ // for implicit conversions
// Load input data
- val origData: RDD[LabeledPoint] = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
- }
- println(s"Loaded ${origData.count()} instances from file: ${params.input}")
-
- // Convert input data to DataFrame explicitly.
- val df: DataFrame = origData.toDF()
- println(s"Inferred schema:\n${df.schema.prettyJson}")
- println(s"Converted to DataFrame with ${df.count()} records")
-
- // Select columns
- val labelsDf: DataFrame = df.select("label")
- val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
- val numLabels = labels.count()
- val meanLabel = labels.fold(0.0)(_ + _) / numLabels
- println(s"Selected label column with average value $meanLabel")
-
- val featuresDf: DataFrame = df.select("features")
- val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
+ println(s"Loading LIBSVM file with UDT from ${params.input}.")
+ val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache()
+ println("Schema from LIBSVM:")
+ df.printSchema()
+ println(s"Loaded training data as a DataFrame with ${df.count()} records.")
+
+ // Show statistical summary of labels.
+ val labelSummary = df.describe("label")
+ labelSummary.show()
+
+ // Convert features column to an RDD of vectors.
+ val features = df.select("features").map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
+ // Save the records in a parquet file.
val tmpDir = Files.createTempDir()
tmpDir.deleteOnExit()
val outputDir = new File(tmpDir, "dataset").toString
println(s"Saving to $outputDir as Parquet file.")
df.write.parquet(outputDir)
+ // Load the records back.
println(s"Loading Parquet file with UDT from $outputDir.")
- val newDataset = sqlContext.read.parquet(outputDir)
-
- println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
- val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v }
- val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
- (summary, feat) => summary.add(feat),
- (sum1, sum2) => sum1.merge(sum2))
- println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
+ val newDF = sqlContext.read.parquet(outputDir)
+ println(s"Schema from Parquet:")
+ newDF.printSchema()
sc.stop()
}
-
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala
index 5a51ece6f9ba7..22c824cea84d3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala
@@ -33,7 +33,8 @@ object LinearRegressionWithElasticNetExample {
// $example on$
// Load training data
- val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
+ val training = sqlCtx.read.format("libsvm")
+ .load("data/mllib/sample_linear_regression_data.txt")
val lr = new LinearRegression()
.setMaxIter(10)
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index 3ee6bd92e47fc..55fe156cf665f 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -148,7 +148,7 @@ List buildClassPath(String appClassPath) throws IOException {
String scala = getScalaVersion();
List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx",
"streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver",
- "yarn", "launcher");
+ "yarn", "launcher", "network/common", "network/shuffle", "network/yarn");
if (prependClasses) {
if (!isTesting) {
System.err.println(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 6f15b37abcb30..4b2b3f8489fd0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] {
stages: Array[PipelineStage],
sc: SparkContext,
path: String): Unit = {
- // Copied and edited from DefaultParamsWriter.saveMetadata
- // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
- val uid = instance.uid
- val cls = instance.getClass.getName
val stageUids = stages.map(_.uid)
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
- val metadata = ("class" -> cls) ~
- ("timestamp" -> System.currentTimeMillis()) ~
- ("sparkVersion" -> sc.version) ~
- ("uid" -> uid) ~
- ("paramMap" -> jsonParams)
- val metadataPath = new Path(path, "metadata").toString
- val metadataJson = compact(render(metadata))
- sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))
// Save stages
val stagesDir = new Path(path, "stages").toString
@@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] {
implicit val format = DefaultFormats
val stagesDir = new Path(path, "stages").toString
- val stageUids: Array[String] = metadata.params match {
- case JObject(pairs) =>
- if (pairs.length != 1) {
- // Should not happen unless file is corrupted or we have a bug.
- throw new RuntimeException(
- s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
- }
- pairs.head match {
- case ("stageUids", jsonValue) =>
- jsonValue.extract[Seq[String]].toArray
- case (paramName, jsonValue) =>
- // Should not happen unless file is corrupted or we have a bug.
- throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
- s" in metadata: ${metadata.metadataStr}")
- }
- case _ =>
- throw new IllegalArgumentException(
- s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
- }
+ val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
- val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
- val cls = Utils.classForName(stageMetadata.className)
- cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath)
+ DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
}
(metadata.uid, stages)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 1fe3abaca81c3..bfb70963b151d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.2.0")
@Experimental
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
- extends Evaluator with HasRawPredictionCol with HasLabelCol {
+ extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.2.0")
def this() = this(Identifiable.randomUID("binEval"))
@@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.4.1")
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}
+
+@Since("1.6.0")
+object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] {
+
+ @Since("1.6.0")
+ override def load(path: String): BinaryClassificationEvaluator = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index df5f04ca5a8d9..c44db0ec595ea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
-import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.types.DoubleType
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.5.0")
@Experimental
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
- extends Evaluator with HasPredictionCol with HasLabelCol {
+ extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("mcEval"))
@@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
}
+
+@Since("1.6.0")
+object MulticlassClassificationEvaluator
+ extends DefaultParamsReadable[MulticlassClassificationEvaluator] {
+
+ @Since("1.6.0")
+ override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index ba012f444d3e0..daaa174a086e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
- extends Evaluator with HasPredictionCol with HasLabelCol {
+ extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
@@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.5.0")
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}
+
+@Since("1.6.0")
+object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] {
+
+ @Since("1.6.0")
+ override def load(path: String): RegressionEvaluator = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 32d7afee6e73b..aa88cb03d23c5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
val pca = new feature.PCA(k = $(k))
val pcaModel = pca.fit(input)
- copyValues(new PCAModel(uid, pcaModel).setParent(this))
+ copyValues(new PCAModel(uid, pcaModel.pc).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] {
/**
* :: Experimental ::
* Model fitted by [[PCA]].
+ *
+ * @param pc A principal components Matrix. Each column is one principal component.
*/
@Experimental
class PCAModel private[ml] (
override val uid: String,
- pcaModel: feature.PCAModel)
+ val pc: DenseMatrix)
extends Model[PCAModel] with PCAParams with MLWritable {
import PCAModel._
- /** a principal components Matrix. Each column is one principal component. */
- val pc: DenseMatrix = pcaModel.pc
-
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -124,6 +123,7 @@ class PCAModel private[ml] (
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
+ val pcaModel = new feature.PCAModel($(k), pc)
val pcaOp = udf { pcaModel.transform _ }
dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
}
@@ -139,7 +139,7 @@ class PCAModel private[ml] (
}
override def copy(extra: ParamMap): PCAModel = {
- val copied = new PCAModel(uid, pcaModel)
+ val copied = new PCAModel(uid, pc)
copyValues(copied, extra).setParent(parent)
}
@@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] {
private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
- private case class Data(k: Int, pc: DenseMatrix)
+ private case class Data(pc: DenseMatrix)
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
- val data = Data(instance.getK, instance.pc)
+ val data = Data(instance.pc)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
@@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
- .select("k", "pc")
+ val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
+ .select("pc")
.head()
- val oldModel = new feature.PCAModel(k, pc)
- val model = new PCAModel(metadata.uid, oldModel)
+ val model = new PCAModel(metadata.uid, pc)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 0feec0549852b..801096fed27bf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String)
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
Array.fill(numAttrs)(NumericAttribute.defaultAttr)
}
+ case otherType =>
+ throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 4d35177ad9b0f..b798aa1fab767 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.json4s.{DefaultFormats, JValue}
+import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
@@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] {
private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
- val extraMetadata = render("rank" -> instance.rank)
+ val extraMetadata = "rank" -> instance.rank
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val userPath = new Path(path, "userFactors").toString
instance.userFactors.write.format("parquet").save(userPath)
@@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] {
override def load(path: String): ALSModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
implicit val format = DefaultFormats
- val rank: Int = metadata.extraMetadata match {
- case Some(m: JValue) =>
- (m \ "rank").extract[Int]
- case None =>
- throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" +
- s" ${metadata.metadataStr}")
- }
-
+ val rank = (metadata.metadata \ "rank").extract[Int]
val userPath = new Path(path, "userFactors").toString
val userFactors = sqlContext.read.format("parquet").load(userPath)
val itemPath = new Path(path, "itemFactors").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 77d9948ed86b9..83a9048374267 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -18,17 +18,24 @@
package org.apache.spark.ml.tuning
import com.github.fommil.netlib.F2jBLAS
+import org.apache.hadoop.fs.Path
+import org.json4s.{JObject, DefaultFormats}
+import org.json4s.jackson.JsonMethods._
-import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.classification.OneVsRestParams
+import org.apache.spark.ml.feature.RFormulaModel
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
+
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
@@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
*/
@Experimental
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
- with CrossValidatorParams with Logging {
+ with CrossValidatorParams with MLWritable with Logging {
def this() = this(Identifiable.randomUID("cv"))
@@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
}
copied
}
+
+ // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types.
+ // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]].
+ // However, this case should be unusual.
+ @Since("1.6.0")
+ override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this)
+}
+
+@Since("1.6.0")
+object CrossValidator extends MLReadable[CrossValidator] {
+
+ @Since("1.6.0")
+ override def read: MLReader[CrossValidator] = new CrossValidatorReader
+
+ @Since("1.6.0")
+ override def load(path: String): CrossValidator = super.load(path)
+
+ private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
+
+ SharedReadWrite.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit =
+ SharedReadWrite.saveImpl(path, instance, sc)
+ }
+
+ private class CrossValidatorReader extends MLReader[CrossValidator] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[CrossValidator].getName
+
+ override def load(path: String): CrossValidator = {
+ val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
+ SharedReadWrite.load(path, sc, className)
+ new CrossValidator(metadata.uid)
+ .setEstimator(estimator)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(estimatorParamMaps)
+ .setNumFolds(numFolds)
+ }
+ }
+
+ private object CrossValidatorReader {
+ /**
+ * Examine the given estimator (which may be a compound estimator) and extract a mapping
+ * from UIDs to corresponding [[Params]] instances.
+ */
+ def getUidMap(instance: Params): Map[String, Params] = {
+ val uidList = getUidMapImpl(instance)
+ val uidMap = uidList.toMap
+ if (uidList.size != uidMap.size) {
+ throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
+ s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
+ }
+ uidMap
+ }
+
+ def getUidMapImpl(instance: Params): List[(String, Params)] = {
+ val subStages: Array[Params] = instance match {
+ case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+ case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+ case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+ case ovr: OneVsRestParams =>
+ // TODO: SPARK-11892: This case may require special handling.
+ throw new UnsupportedOperationException("CrossValidator write will fail because it" +
+ " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
+ case rform: RFormulaModel =>
+ // TODO: SPARK-11891: This case may require special handling.
+ throw new UnsupportedOperationException("CrossValidator write will fail because it" +
+ " cannot yet handle an estimator containing an RFormulaModel")
+ case _: Params => Array()
+ }
+ val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
+ List((instance.uid, instance)) ++ subStageMaps
+ }
+ }
+
+ private[tuning] object SharedReadWrite {
+
+ /**
+ * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
+ * This does not check [[CrossValidator.estimatorParamMaps]].
+ */
+ def validateParams(instance: ValidatorParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException("CrossValidator write will fail " +
+ s" because it contains $name which does not implement Writable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+ checkElement(instance.getEvaluator, "evaluator")
+ checkElement(instance.getEstimator, "estimator")
+ // Check to make sure all Params apply to this estimator. Throw an error if any do not.
+ // Extraneous Params would cause problems when loading the estimatorParamMaps.
+ val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
+ instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
+ s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
+ s" Evaluator. An extraneous Param was found: $p")
+ }
+ }
+ }
+
+ private[tuning] def saveImpl(
+ path: String,
+ instance: CrossValidatorParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+ import org.json4s.JsonDSL._
+
+ val estimatorParamMapsJson = compact(render(
+ instance.getEstimatorParamMaps.map { case paramMap =>
+ paramMap.toSeq.map { case ParamPair(p, v) =>
+ Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+ }
+ }.toSeq
+ ))
+ val jsonParams = List(
+ "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
+ "estimatorParamMaps" -> parse(estimatorParamMapsJson)
+ )
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val evaluatorPath = new Path(path, "evaluator").toString
+ instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+ val estimatorPath = new Path(path, "estimator").toString
+ instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+ }
+
+ private[tuning] def load[M <: Model[M]](
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+ implicit val format = DefaultFormats
+ val evaluatorPath = new Path(path, "evaluator").toString
+ val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+ val estimatorPath = new Path(path, "estimator").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
+
+ val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
+
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
+ val estimatorParamMaps: Array[ParamMap] =
+ (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
+ pMap =>
+ val paramPairs = pMap.map { case pInfo: Map[String, String] =>
+ val est = uidToParams(pInfo("parent"))
+ val param = est.getParam(pInfo("name"))
+ val value = param.jsonDecode(pInfo("value"))
+ param -> value
+ }
+ ParamMap(paramPairs: _*)
+ }.toArray
+ (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
+ }
+ }
}
/**
@@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
*
* @param bestModel The best model selected from k-fold cross validation.
* @param avgMetrics Average cross-validation metrics for each paramMap in
- * [[estimatorParamMaps]], in the corresponding order.
+ * [[CrossValidator.estimatorParamMaps]], in the corresponding order.
*/
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
val bestModel: Model[_],
val avgMetrics: Array[Double])
- extends Model[CrossValidatorModel] with CrossValidatorParams {
+ extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
override def validateParams(): Unit = {
bestModel.validateParams()
@@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] (
avgMetrics.clone())
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)
+}
+
+@Since("1.6.0")
+object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
+
+ import CrossValidator.SharedReadWrite
+
+ @Since("1.6.0")
+ override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): CrossValidatorModel = super.load(path)
+
+ private[CrossValidatorModel]
+ class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
+
+ SharedReadWrite.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ import org.json4s.JsonDSL._
+ val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
+ SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+ val bestModelPath = new Path(path, "bestModel").toString
+ instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+ }
+ }
+
+ private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[CrossValidatorModel].getName
+
+ override def load(path: String): CrossValidatorModel = {
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
+ SharedReadWrite.load(path, sc, className)
+ val bestModelPath = new Path(path, "bestModel").toString
+ val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+ val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
+ val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
+ cv.set(cv.estimator, estimator)
+ .set(cv.evaluator, evaluator)
+ .set(cv.estimatorParamMaps, estimatorParamMaps)
+ .set(cv.numFolds, numFolds)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index ff9322dba122a..8484b1f801066 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter {
* - timestamp
* - sparkVersion
* - uid
- * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
+ * - paramMap
+ * - (optionally, extra metadata)
+ * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
+ * @param paramMap If given, this is saved in the "paramMap" field.
+ * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
+ * [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
- extraMetadata: Option[JValue] = None): Unit = {
+ extraMetadata: Option[JObject] = None,
+ paramMap: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
- val jsonParams = params.map { case ParamPair(p, v) =>
+ val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
- }.toList
- val metadata = ("class" -> cls) ~
+ }.toList))
+ val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
- ("paramMap" -> jsonParams) ~
- ("extraMetadata" -> extraMetadata)
+ ("paramMap" -> jsonParams)
+ val metadata = extraMetadata match {
+ case Some(jObject) =>
+ basicMetadata ~ jObject
+ case None =>
+ basicMetadata
+ }
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
@@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
* @param params paramMap, as a [[JValue]]
- * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]]
- * @param metadataStr Full metadata file String (for debugging)
+ * @param metadata All metadata, including the other fields
+ * @param metadataJson Full metadata file String (for debugging)
*/
case class Metadata(
className: String,
@@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader {
timestamp: Long,
sparkVersion: String,
params: JValue,
- extraMetadata: Option[JValue],
- metadataStr: String)
+ metadata: JValue,
+ metadataJson: String)
/**
* Load metadata from file.
@@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader {
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
- val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
- Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr)
+ Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
}
/**
@@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader {
}
case _ =>
throw new IllegalArgumentException(
- s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
+ s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
}
+
+ /**
+ * Load a [[Params]] instance from the given path, and return it.
+ * This assumes the instance implements [[MLReadable]].
+ */
+ def loadParamsInstance[T](path: String, sc: SparkContext): T = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc)
+ val cls = Utils.classForName(metadata.className)
+ cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 12aba6bc6dbeb..8c86767456368 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,11 +17,9 @@
package org.apache.spark.ml
-import java.io.File
-
import scala.collection.JavaConverters._
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar.mock
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index def869fe66777..a535c1218ecfa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
-class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
+class BinaryClassificationEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new BinaryClassificationEvaluator)
}
+
+ test("read/write") {
+ val evaluator = new BinaryClassificationEvaluator()
+ .setRawPredictionCol("myRawPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("areaUnderPR")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
index 6d8412b0b3701..7ee65975d22f7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
-class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
+class MulticlassClassificationEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
}
+
+ test("read/write") {
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setPredictionCol("myPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("recall")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index aa722da323935..60886bf77d2f0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RegressionEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new RegressionEvaluator)
@@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
evaluator.setMetricName("mae")
assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
}
+
+ test("read/write") {
+ val evaluator = new RegressionEvaluator()
+ .setPredictionCol("myPrediction")
+ .setLabelCol("myLabel")
+ .setMetricName("r2")
+ testDefaultReadWrite(evaluator)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 5a21cd20ceede..edab21e6c3072 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
test("params") {
ParamsSuite.checkParams(new PCA)
val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
- val model = new PCAModel("pca", new OldPCAModel(2, mat))
+ val model = new PCAModel("pca", mat)
ParamsSuite.checkParams(model)
}
@@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
}
- test("read/write") {
+ test("PCA read/write") {
+ val t = new PCA()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setK(3)
+ testDefaultReadWrite(t)
+ }
- def checkModelData(model1: PCAModel, model2: PCAModel): Unit = {
- assert(model1.pc === model2.pc)
- }
- val allParams: Map[String, Any] = Map(
- "k" -> 3,
- "inputCol" -> "features",
- "outputCol" -> "pca_features"
- )
- val data = Seq(
- (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))),
- (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
- (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
- )
- val df = sqlContext.createDataFrame(data).toDF("id", "features")
- val pca = new PCA().setK(3)
- testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData)
+ test("PCAModel read/write") {
+ val instance = new PCAModel("myPCAModel",
+ Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix])
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.pc === instance.pc)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index fb21ab6b9bf2c..9c1c00f41ab1d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -69,6 +69,17 @@ class VectorAssemblerSuite
}
}
+ test("transform should throw an exception in case of unsupported type") {
+ val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c")
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("a", "b", "c"))
+ .setOutputCol("features")
+ val thrown = intercept[SparkException] {
+ assembler.transform(df)
+ }
+ assert(thrown.getMessage contains "VectorAssembler does not support the StringType type")
+ }
+
test("ML attributes") {
val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index cbe09292a0337..dd6366050c020 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -18,19 +18,22 @@
package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.feature.HashingTF
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.{Pipeline, Estimator, Model}
+import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{ParamPair, ParamMap}
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
-class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class CrossValidatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
@@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("validateParams should check estimatorParamMaps") {
- import CrossValidatorSuite._
+ import CrossValidatorSuite.{MyEstimator, MyEvaluator}
val est = new MyEstimator("est")
val eval = new MyEvaluator
@@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
cv.validateParams()
}
}
+
+ test("read/write: CrossValidator with simple estimator") {
+ val lr = new LogisticRegression().setMaxIter(3)
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(evaluator)
+ .setNumFolds(20)
+ .setEstimatorParamMaps(paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
+ assert(evaluator.uid === evaluator2.uid)
+ assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+ cv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+ }
+
+ test("read/write: CrossValidator with complex estimator") {
+ // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]]
+ val lrEvaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+
+ val lr = new LogisticRegression().setMaxIter(3)
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val lrcv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(lrEvaluator)
+ .setEstimatorParamMaps(lrParamMaps)
+
+ val hashingTF = new HashingTF()
+ val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv))
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(hashingTF.numFeatures, Array(10, 20))
+ .addGrid(lr.elasticNetParam, Array(0.0, 1.0))
+ .build()
+ val evaluator = new BinaryClassificationEvaluator()
+
+ val cv = new CrossValidator()
+ .setEstimator(pipeline)
+ .setEvaluator(evaluator)
+ .setNumFolds(20)
+ .setEstimatorParamMaps(paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+
+ cv2.getEstimator match {
+ case pipeline2: Pipeline =>
+ assert(pipeline.uid === pipeline2.uid)
+ pipeline2.getStages match {
+ case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) =>
+ assert(hashingTF.uid === hashingTF2.uid)
+ lrcv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ case other =>
+ throw new AssertionError(s"Loaded internal CrossValidator expected to be" +
+ s" LogisticRegression but found type ${other.getClass.getName}")
+ }
+ assert(lrcv.uid === lrcv2.uid)
+ assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ assert(lrEvaluator.uid === lrcv2.getEvaluator.uid)
+ CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
+ case other =>
+ throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" +
+ " but found: " + other.map(_.getClass.getName).mkString(", "))
+ }
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" CrossValidator but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: CrossValidator fails for extraneous Param") {
+ val lr = new LogisticRegression()
+ val lr2 = new LogisticRegression()
+ val evaluator = new BinaryClassificationEvaluator()
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .addGrid(lr2.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(paramMaps)
+ withClue("CrossValidator.write failed to catch extraneous Param error") {
+ intercept[IllegalArgumentException] {
+ cv.write
+ }
+ }
+ }
+
+ test("read/write: CrossValidatorModel") {
+ val lr = new LogisticRegression()
+ .setThreshold(0.6)
+ val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
+ .setThreshold(0.6)
+ val evaluator = new BinaryClassificationEvaluator()
+ .setMetricName("areaUnderPR") // not default metric
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6))
+ cv.set(cv.estimator, lr)
+ .set(cv.evaluator, evaluator)
+ .set(cv.numFolds, 20)
+ .set(cv.estimatorParamMaps, paramMaps)
+
+ val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+ assert(cv.uid === cv2.uid)
+ assert(cv.getNumFolds === cv2.getNumFolds)
+
+ assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+ val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
+ assert(evaluator.uid === evaluator2.uid)
+ assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+ cv2.getEstimator match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getThreshold === lr2.getThreshold)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+
+ cv2.bestModel match {
+ case lrModel2: LogisticRegressionModel =>
+ assert(lrModel.uid === lrModel2.uid)
+ assert(lrModel.getThreshold === lrModel2.getThreshold)
+ assert(lrModel.coefficients === lrModel2.coefficients)
+ assert(lrModel.intercept === lrModel2.intercept)
+ case other =>
+ throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
+ assert(cv.avgMetrics === cv2.avgMetrics)
+ }
}
-object CrossValidatorSuite {
+object CrossValidatorSuite extends SparkFunSuite {
+
+ /**
+ * Assert sequences of estimatorParamMaps are identical.
+ * Params must be simple types comparable with `===`.
+ */
+ def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
+ assert(pMaps.length === pMaps2.length)
+ pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
+ assert(pMap.size === pMap2.size)
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ assert(pMap2.contains(p))
+ assert(pMap2(p) === v)
+ }
+ }
+ }
abstract class MyModel extends Model[MyModel]
diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
index 02230a00e69fc..88ba3ccebdf20 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
@@ -30,13 +30,19 @@
*/
class StreamInterceptor implements TransportFrameDecoder.Interceptor {
+ private final TransportResponseHandler handler;
private final String streamId;
private final long byteCount;
private final StreamCallback callback;
private volatile long bytesRead;
- StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
+ StreamInterceptor(
+ TransportResponseHandler handler,
+ String streamId,
+ long byteCount,
+ StreamCallback callback) {
+ this.handler = handler;
this.streamId = streamId;
this.byteCount = byteCount;
this.callback = callback;
@@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {
@Override
public void exceptionCaught(Throwable cause) throws Exception {
+ handler.deactivateStream();
callback.onFailure(streamId, cause);
}
@Override
public void channelInactive() throws Exception {
+ handler.deactivateStream();
callback.onFailure(streamId, new ClosedChannelException());
}
@@ -65,8 +73,10 @@ public boolean handle(ByteBuf buf) throws Exception {
RuntimeException re = new IllegalStateException(String.format(
"Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
callback.onFailure(streamId, re);
+ handler.deactivateStream();
throw re;
} else if (bytesRead == byteCount) {
+ handler.deactivateStream();
callback.onComplete(streamId);
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index a0ba223e340a2..876fcd846791c 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -73,10 +73,12 @@ public class TransportClient implements Closeable {
private final Channel channel;
private final TransportResponseHandler handler;
@Nullable private String clientId;
+ private volatile boolean timedOut;
public TransportClient(Channel channel, TransportResponseHandler handler) {
this.channel = Preconditions.checkNotNull(channel);
this.handler = Preconditions.checkNotNull(handler);
+ this.timedOut = false;
}
public Channel getChannel() {
@@ -84,7 +86,7 @@ public Channel getChannel() {
}
public boolean isActive() {
- return channel.isOpen() || channel.isActive();
+ return !timedOut && (channel.isOpen() || channel.isActive());
}
public SocketAddress getSocketAddress() {
@@ -263,6 +265,11 @@ public void onFailure(Throwable e) {
}
}
+ /** Mark this channel as having timed out. */
+ public void timeOut() {
+ this.timedOut = true;
+ }
+
@Override
public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 42a4f664e697c..61bafc8380049 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -136,8 +136,19 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO
TransportClient cachedClient = clientPool.clients[clientIndex];
if (cachedClient != null && cachedClient.isActive()) {
- logger.trace("Returning cached connection to {}: {}", address, cachedClient);
- return cachedClient;
+ // Make sure that the channel will not timeout by updating the last use time of the
+ // handler. Then check that the client is still alive, in case it timed out before
+ // this code was able to update things.
+ TransportChannelHandler handler = cachedClient.getChannel().pipeline()
+ .get(TransportChannelHandler.class);
+ synchronized (handler) {
+ handler.getResponseHandler().updateTimeOfLastRequest();
+ }
+
+ if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ return cachedClient;
+ }
}
// If we reach here, we don't have an existing connection open. Let's create a new one.
@@ -159,8 +170,10 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO
}
/**
- * Create a completely new {@link TransportClient} to the given remote host / port
- * But this connection is not pooled.
+ * Create a completely new {@link TransportClient} to the given remote host / port.
+ * This connection is not pooled.
+ *
+ * As with {@link #createClient(String, int)}, this method is blocking.
*/
public TransportClient createUnmanagedClient(String remoteHost, int remotePort)
throws IOException {
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index ed3f36af58048..be181e0660826 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -24,6 +24,7 @@
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler {
private final Map outstandingRpcs;
private final Queue streamCallbacks;
+ private volatile boolean streamActive;
/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
private final AtomicLong timeOfLastRequestNs;
@@ -69,7 +71,7 @@ public TransportResponseHandler(Channel channel) {
}
public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
- timeOfLastRequestNs.set(System.nanoTime());
+ updateTimeOfLastRequest();
outstandingFetches.put(streamChunkId, callback);
}
@@ -78,7 +80,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) {
}
public void addRpcRequest(long requestId, RpcResponseCallback callback) {
- timeOfLastRequestNs.set(System.nanoTime());
+ updateTimeOfLastRequest();
outstandingRpcs.put(requestId, callback);
}
@@ -87,9 +89,15 @@ public void removeRpcRequest(long requestId) {
}
public void addStreamCallback(StreamCallback callback) {
+ timeOfLastRequestNs.set(System.nanoTime());
streamCallbacks.offer(callback);
}
+ @VisibleForTesting
+ public void deactivateStream() {
+ streamActive = false;
+ }
+
/**
* Fire the failure callback for all outstanding requests. This is called when we have an
* uncaught exception or pre-mature connection termination.
@@ -177,14 +185,16 @@ public void handle(ResponseMessage message) {
StreamResponse resp = (StreamResponse) message;
StreamCallback callback = streamCallbacks.poll();
if (callback != null) {
- StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount,
+ StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
callback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
+ streamActive = true;
} catch (Exception e) {
logger.error("Error installing stream handler.", e);
+ deactivateStream();
}
} else {
logger.error("Could not find callback for StreamResponse.");
@@ -208,7 +218,8 @@ public void handle(ResponseMessage message) {
/** Returns total number of outstanding requests (fetch requests + rpcs) */
public int numOutstandingRequests() {
- return outstandingFetches.size() + outstandingRpcs.size();
+ return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
+ (streamActive ? 1 : 0);
}
/** Returns the time in nanoseconds of when the last request was sent out. */
@@ -216,4 +227,9 @@ public long getTimeOfLastRequestNs() {
return timeOfLastRequestNs.get();
}
+ /** Updates the time of the last request to the current system time. */
+ public void updateTimeOfLastRequest() {
+ timeOfLastRequestNs.set(System.nanoTime());
+ }
+
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index f8fcd1c3d7d76..3164e00679035 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -116,20 +116,33 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
// there are outstanding requests, we also do a secondary consistency check to ensure
// there's no race between the idle timeout and incrementing the numOutstandingRequests
// (see SPARK-7003).
- boolean isActuallyOverdue =
- System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
- if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
- if (responseHandler.numOutstandingRequests() > 0) {
- String address = NettyUtils.getRemoteAddress(ctx.channel());
- logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
- "requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
- "is wrong.", address, requestTimeoutNs / 1000 / 1000);
- ctx.close();
- } else if (closeIdleConnections) {
- // While CloseIdleConnections is enable, we also close idle connection
- ctx.close();
+ //
+ // To avoid a race between TransportClientFactory.createClient() and this code which could
+ // result in an inactive client being returned, this needs to run in a synchronized block.
+ synchronized (this) {
+ boolean isActuallyOverdue =
+ System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
+ if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
+ if (responseHandler.numOutstandingRequests() > 0) {
+ String address = NettyUtils.getRemoteAddress(ctx.channel());
+ logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
+ "requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
+ "is wrong.", address, requestTimeoutNs / 1000 / 1000);
+ client.timeOut();
+ ctx.close();
+ } else if (closeIdleConnections) {
+ // While CloseIdleConnections is enable, we also close idle connection
+ client.timeOut();
+ ctx.close();
+ }
}
}
}
+ ctx.fireUserEventTriggered(evt);
}
+
+ public TransportResponseHandler getResponseHandler() {
+ return responseHandler;
+ }
+
}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
index 17a03ebe88a93..30144f4a9fc7a 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.network;
+import io.netty.channel.Channel;
import io.netty.channel.local.LocalChannel;
import org.junit.Test;
@@ -28,12 +29,16 @@
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamResponse;
+import org.apache.spark.network.util.TransportFrameDecoder;
public class TransportResponseHandlerSuite {
@Test
@@ -112,4 +117,26 @@ public void handleFailedRPC() {
verify(callback, times(1)).onFailure((Throwable) any());
assertEquals(0, handler.numOutstandingRequests());
}
+
+ @Test
+ public void testActiveStreams() {
+ Channel c = new LocalChannel();
+ c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
+ TransportResponseHandler handler = new TransportResponseHandler(c);
+
+ StreamResponse response = new StreamResponse("stream", 1234L, null);
+ StreamCallback cb = mock(StreamCallback.class);
+ handler.addStreamCallback(cb);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.handle(response);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.deactivateStream();
+ assertEquals(0, handler.numOutstandingRequests());
+
+ StreamFailure failure = new StreamFailure("stream", "uh-oh");
+ handler.addStreamCallback(cb);
+ assertEquals(1, handler.numOutstandingRequests());
+ handler.handle(failure);
+ assertEquals(0, handler.numOutstandingRequests());
+ }
}
diff --git a/pom.xml b/pom.xml
index ad849112ce76c..234fd5dea1a6e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1958,7 +1958,6 @@
false
false
false
- true
true
src
@@ -1997,7 +1996,6 @@
1
false
false
- true
true
__not_used__
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 67724c4e9e411..f575f0012d59e 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -632,7 +632,6 @@ object TestSettings {
javaOptions in Test += "-Dspark.master.rest.enabled=false",
javaOptions in Test += "-Dspark.ui.enabled=false",
javaOptions in Test += "-Dspark.ui.showConsoleProgress=false",
- javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test += "-Dderby.system.durability=test",
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 9ca8e1f264cfa..81fd4e782628a 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -346,9 +346,10 @@ def cast(self, dataType):
if isinstance(dataType, basestring):
jc = self._jc.cast(dataType)
elif isinstance(dataType, DataType):
- sc = SparkContext._active_spark_context
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(dataType.json())
+ from pyspark.sql import SQLContext
+ sc = SparkContext.getOrCreate()
+ ctx = SQLContext.getOrCreate(sc)
+ jdt = ctx._ssql_ctx.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
else:
raise TypeError("unexpected type: %s" % type(dataType))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index c3da513c13897..a1ca723bbd7ab 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1457,14 +1457,15 @@ def __init__(self, func, returnType, name=None):
self._judf = self._create_judf(name)
def _create_judf(self, name):
+ from pyspark.sql import SQLContext
f, returnType = self.func, self.returnType # put them in closure `func`
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
- sc = SparkContext._active_spark_context
+ sc = SparkContext.getOrCreate()
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(self.returnType.json())
+ ctx = SQLContext.getOrCreate(sc)
+ jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 5032f05c2edba..0c10a56c555f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -54,8 +54,13 @@ object ExpressionEncoder {
val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
val fromRowExpression = ScalaReflection.constructorFor[T]
+ val schema = ScalaReflection.schemaFor[T] match {
+ case ScalaReflection.Schema(s: StructType, _) => s
+ case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable)
+ }
+
new ExpressionEncoder[T](
- toRowExpression.dataType,
+ schema,
flat,
toRowExpression.flatten,
fromRowExpression,
@@ -71,7 +76,13 @@ object ExpressionEncoder {
encoders.foreach(_.assertUnresolved())
val schema = StructType(encoders.zipWithIndex.map {
- case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
+ case (e, i) =>
+ val (dataType, nullable) = if (e.flat) {
+ e.schema.head.dataType -> e.schema.head.nullable
+ } else {
+ e.schema -> true
+ }
+ StructField(s"_${i + 1}", dataType, nullable)
})
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index d5a0af3c1ffe5..62d09f0f55105 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -236,11 +236,6 @@ case class NewInstance(
}
if (propagateNull) {
- val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"${ev.isNull} = ${ev.value} == null;"
- } else {
- ""
- }
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala
index 2b83651f9086d..515c071c283b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala
@@ -52,7 +52,8 @@ private[sql] trait DataTypeParser extends StandardTokenParsers {
"(?i)decimal".r ^^^ DecimalType.USER_DEFAULT |
"(?i)date".r ^^^ DateType |
"(?i)timestamp".r ^^^ TimestampType |
- varchar
+ varchar |
+ char
protected lazy val fixedDecimalType: Parser[DataType] =
("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
@@ -60,6 +61,9 @@ private[sql] trait DataTypeParser extends StandardTokenParsers {
DecimalType(precision.toInt, scale.toInt)
}
+ protected lazy val char: Parser[DataType] =
+ "(?i)char".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType
+
protected lazy val varchar: Parser[DataType] =
"(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 76459b34a484f..d6ca138672ef1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
-import org.apache.spark.sql.types.ArrayType
+import org.apache.spark.sql.types.{StructType, ArrayType}
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite {
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
}
+ test("nullable of encoder schema") {
+ def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = {
+ assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)
+ }
+
+ // test for flat encoders
+ checkNullable[Int](false)
+ checkNullable[Option[Int]](true)
+ checkNullable[java.lang.Integer](true)
+ checkNullable[String](true)
+
+ // test for product encoders
+ checkNullable[(String, Int)](true, false)
+ checkNullable[(Int, java.lang.Long)](false, true)
+
+ // test for nested product encoders
+ {
+ val schema = ExpressionEncoder[(Int, (String, Int))].schema
+ assert(schema(0).nullable === false)
+ assert(schema(1).nullable === true)
+ assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true)
+ assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false)
+ }
+
+ // test for tupled encoders
+ {
+ val schema = ExpressionEncoder.tuple(
+ ExpressionEncoder[Int],
+ ExpressionEncoder[(String, Int)]).schema
+ assert(schema(0).nullable === false)
+ assert(schema(1).nullable === true)
+ assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true)
+ assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false)
+ }
+ }
+
private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap()
outers.put(getClass.getName, this)
private def encodeDecodeTest[T : ExpressionEncoder](
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala
index 1e3409a9db6eb..bebf708965474 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala
@@ -49,7 +49,9 @@ class DataTypeParserSuite extends SparkFunSuite {
checkDataType("DATE", DateType)
checkDataType("timestamp", TimestampType)
checkDataType("string", StringType)
+ checkDataType("ChaR(5)", StringType)
checkDataType("varchAr(20)", StringType)
+ checkDataType("cHaR(27)", StringType)
checkDataType("BINARY", BinaryType)
checkDataType("array", ArrayType(DoubleType, true))
@@ -83,7 +85,8 @@ class DataTypeParserSuite extends SparkFunSuite {
|struct<
| struct:struct,
| MAP:Map,
- | arrAy:Array>
+ | arrAy:Array,
+ | anotherArray:Array>
""".stripMargin,
StructType(
StructField("struct",
@@ -91,7 +94,8 @@ class DataTypeParserSuite extends SparkFunSuite {
StructField("deciMal", DecimalType.USER_DEFAULT, true) ::
StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) ::
StructField("MAP", MapType(TimestampType, StringType), true) ::
- StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil)
+ StructField("arrAy", ArrayType(DoubleType, true), true) ::
+ StructField("anotherArray", ArrayType(StringType, true), true) :: Nil)
)
// A column name can be a reserved word in our DDL parser and SqlParser.
checkDataType(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 6ce41aaf01e27..a9719128a626e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -23,9 +23,8 @@ import org.apache.spark.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -46,14 +45,12 @@ object TypedAggregateExpression {
/**
* This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
* the following limitations:
- * - It assumes the aggregator reduces and returns a single column of type `long`.
- * - It might only work when there is a single aggregator in the first column.
* - It assumes the aggregator has a zero, `0`.
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
- bEncoder: ExpressionEncoder[Any], // Should be bound.
+ unresolvedBEncoder: ExpressionEncoder[Any],
cEncoder: ExpressionEncoder[Any],
children: Seq[Attribute],
mutableAggBufferOffset: Int,
@@ -80,10 +77,14 @@ case class TypedAggregateExpression(
override lazy val inputTypes: Seq[DataType] = Nil
- override val aggBufferSchema: StructType = bEncoder.schema
+ override val aggBufferSchema: StructType = unresolvedBEncoder.schema
override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
+ val bEncoder = unresolvedBEncoder
+ .resolve(aggBufferAttributes, OuterScopes.outerScopes)
+ .bind(aggBufferAttributes)
+
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
@@ -93,12 +94,18 @@ case class TypedAggregateExpression(
lazy val boundA = aEncoder.get
private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
- // todo: need a more neat way to assign the value.
var i = 0
while (i < aggBufferAttributes.length) {
+ val offset = mutableAggBufferOffset + i
aggBufferSchema(i).dataType match {
- case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i))
- case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i))
+ case BooleanType => buffer.setBoolean(offset, value.getBoolean(i))
+ case ByteType => buffer.setByte(offset, value.getByte(i))
+ case ShortType => buffer.setShort(offset, value.getShort(i))
+ case IntegerType => buffer.setInt(offset, value.getInt(i))
+ case LongType => buffer.setLong(offset, value.getLong(i))
+ case FloatType => buffer.setFloat(offset, value.getFloat(i))
+ case DoubleType => buffer.setDouble(offset, value.getDouble(i))
+ case other => buffer.update(offset, value.get(i, other))
}
i += 1
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 9377589790011..19dce5d1e2f37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L
}
case class AggData(a: Int, b: String)
-object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
+object ClassInputAgg extends Aggregator[AggData, Int, Int] {
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: Int = 0
@@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
override def merge(b1: Int, b2: Int): Int = b1 + b2
}
+object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
+ /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
+ override def zero: (Int, AggData) = 0 -> AggData(0, "0")
+
+ /**
+ * Combine two values to produce a new value. For performance, the function may modify `b` and
+ * return it instead of constructing new object for b.
+ */
+ override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
+
+ /**
+ * Transform the output of the reduction.
+ */
+ override def finish(reduction: (Int, AggData)): Int = reduction._1
+
+ /**
+ * Merge two intermediate values
+ */
+ override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
+ (b1._1 + b2._1, b1._2)
+}
+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
("one", 1))
}
+
+ test("typed aggregation: complex input") {
+ val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
+
+ checkAnswer(
+ ds.select(ComplexBufferAgg.toColumn),
+ 2
+ )
+
+ checkAnswer(
+ ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
+ (1.5, 2))
+
+ checkAnswer(
+ ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn),
+ ("one", 1), ("two", 1))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 9da02550b39ce..cc8e4325fd2f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -386,7 +386,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Seq((JavaData(1), 1L), (JavaData(2), 1L)))
}
- ignore("Java encoder self join") {
+ test("Java encoder self join") {
implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
val ds = Seq(JavaData(1), JavaData(2)).toDS()
assert(ds.joinWith(ds, lit(true)).collect().toSet ==
@@ -396,6 +396,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(JavaData(2), JavaData(1)),
(JavaData(2), JavaData(2))))
}
+
+ test("SPARK-11894: Incorrect results are returned when using null") {
+ val nullInt = null.asInstanceOf[java.lang.Integer]
+ val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
+ val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, lit(true)),
+ ((nullInt, "1"), (nullInt, "1")),
+ ((new java.lang.Integer(22), "2"), (nullInt, "1")),
+ ((nullInt, "1"), (new java.lang.Integer(22), "2")),
+ ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2")))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 12af8068c398f..26c1ff520406c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -85,6 +85,7 @@ case class AllDataTypesScan(
Date.valueOf("1970-01-01"),
new Timestamp(20000 + i),
s"varchar_$i",
+ s"char_$i",
Seq(i, i + 1),
Seq(Map(s"str_$i" -> Row(i.toLong))),
Map(i -> i.toString),
@@ -115,6 +116,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext {
Date.valueOf("1970-01-01"),
new Timestamp(20000 + i),
s"varchar_$i",
+ s"char_$i",
Seq(i, i + 1),
Seq(Map(s"str_$i" -> Row(i.toLong))),
Map(i -> i.toString),
@@ -154,6 +156,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext {
|dateField dAte,
|timestampField tiMestamp,
|varcharField varchaR(12),
+ |charField ChaR(18),
|arrayFieldSimple Array,
|arrayFieldComplex Array