diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index b5cb3f0a0f9dc..c1a91c27eef2d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,7 +24,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable @@ -34,6 +34,16 @@ private[deploy] object DeployMessages { // Worker to Master + /** + * @param id the worker id + * @param host the worker host + * @param port the worker post + * @param worker the worker endpoint ref + * @param cores the core number of worker + * @param memory the memory size of worker + * @param workerWebUiUrl the worker Web UI address + * @param masterAddress the master address used by the worker to connect + */ case class RegisterWorker( id: String, host: String, @@ -41,7 +51,8 @@ private[deploy] object DeployMessages { worker: RpcEndpointRef, cores: Int, memory: Int, - workerWebUiUrl: String) + workerWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage { Utils.checkHost(host) assert (port > 0) @@ -80,8 +91,16 @@ private[deploy] object DeployMessages { sealed trait RegisterWorkerResponse - case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage - with RegisterWorkerResponse + /** + * @param master the master ref + * @param masterWebUiUrl the master Web UI address + * @param masterAddress the master address used by the worker to connect. It should be + * [[RegisterWorker.masterAddress]]. + */ + case class RegisteredWorker( + master: RpcEndpointRef, + masterWebUiUrl: String, + masterAddress: RpcAddress) extends DeployMessage with RegisterWorkerResponse case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e061939623cbb..53384e7373252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -231,7 +231,8 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) - case RegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -243,7 +244,7 @@ private[deploy] class Master( workerRef, workerWebUiUrl) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl)) + workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress)) schedule() } else { val workerAddress = worker.endpoint.address diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 34e3a4c020c80..1198e3cb05eaa 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -99,6 +99,20 @@ private[deploy] class Worker( private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None + + /** + * Whether to use the master address in `masterRpcAddresses` if possible. If it's disabled, Worker + * will just use the address received from Master. + */ + private val preferConfiguredMasterAddress = + conf.getBoolean("spark.worker.preferConfiguredMasterAddress", false) + /** + * The master address to connect in case of failure. When the connection is broken, worker will + * use this address to connect. This is usually just one of `masterRpcAddresses`. However, when + * a master is restarted or takes over leadership, it will be an address sent from master, which + * may not be in `masterRpcAddresses`. + */ + private var masterAddressToConnect: Option[RpcAddress] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" private var workerWebUiUrl: String = "" @@ -196,10 +210,19 @@ private[deploy] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { + /** + * Change to use the new master. + * + * @param masterRef the new master ref + * @param uiUrl the new master Web UI address + * @param masterAddress the new master address which the worker should use to connect in case of + * failure + */ + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) { // activeMasterUrl it's a valid Spark url since we receive it from master. activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl + masterAddressToConnect = Some(masterAddress) master = Some(masterRef) connected = true if (conf.getBoolean("spark.ui.reverseProxy", false)) { @@ -266,7 +289,8 @@ private[deploy] class Worker( if (registerMasterFutures != null) { registerMasterFutures.foreach(_.cancel(true)) } - val masterAddress = masterRef.address + val masterAddress = + if (preferConfiguredMasterAddress) masterAddressToConnect.get else masterRef.address registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { override def run(): Unit = { try { @@ -342,15 +366,27 @@ private[deploy] class Worker( } private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = { - masterEndpoint.send(RegisterWorker(workerId, host, port, self, cores, memory, workerWebUiUrl)) + masterEndpoint.send(RegisterWorker( + workerId, + host, + port, + self, + cores, + memory, + workerWebUiUrl, + masterEndpoint.address)) } private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { msg match { - case RegisteredWorker(masterRef, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + case RegisteredWorker(masterRef, masterWebUiUrl, masterAddress) => + if (preferConfiguredMasterAddress) { + logInfo("Successfully registered with master " + masterAddress.toSparkURL) + } else { + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + } registered = true - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterAddress) forwordMessageScheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { self.send(SendHeartbeat) @@ -419,7 +455,7 @@ private[deploy] class Worker( case MasterChanged(masterRef, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) - changeMaster(masterRef, masterWebUiUrl) + changeMaster(masterRef, masterWebUiUrl, masterRef.address) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) @@ -561,7 +597,8 @@ private[deploy] class Worker( } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (master.exists(_.address == remoteAddress)) { + if (master.exists(_.address == remoteAddress) || + masterAddressToConnect.exists(_ == remoteAddress)) { logInfo(s"$remoteAddress Disassociated !") masterDisconnected() } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 2127da48ece49..539264652d7d5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -34,7 +34,7 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv} class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { @@ -447,8 +447,15 @@ class MasterSuite extends SparkFunSuite } }) - master.self.send( - RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080")) + master.self.send(RegisterWorker( + "1", + "localhost", + 9999, + fakeWorker, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost", 9999))) val executors = (0 until 3).map { i => new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING) } @@ -459,4 +466,37 @@ class MasterSuite extends SparkFunSuite assert(killedDrivers.asScala.toList.sorted === List("0", "1", "2")) } } + + test("SPARK-20529: Master should reply the address received from worker") { + val master = makeMaster() + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + + @volatile var receivedMasterAddress: RpcAddress = null + val fakeWorker = master.rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = master.rpcEnv + + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(_, _, masterAddress) => + receivedMasterAddress = masterAddress + } + }) + + master.self.send(RegisterWorker( + "1", + "localhost", + 9999, + fakeWorker, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost2", 10000))) + + eventually(timeout(10.seconds)) { + assert(receivedMasterAddress === RpcAddress("localhost2", 10000)) + } + } }