Skip to content

Commit

Permalink
Enable local auth
Browse files Browse the repository at this point in the history
  • Loading branch information
Kimahriman committed Feb 13, 2025
1 parent e92e12a commit f02423a
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 7 deletions.
11 changes: 10 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ class DefaultChannelBuilder(ChannelBuilder):

@staticmethod
def default_port() -> int:
if "SPARK_TESTING" in os.environ and not is_remote_only():
print("Env check", "SPARK_LOCAL_CONNECT" in os.environ)
if "SPARK_LOCAL_REMOTE" in os.environ and not is_remote_only():
from pyspark.sql.session import SparkSession as PySparkSession

# In the case when Spark Connect uses the local mode, it starts the regular Spark
Expand Down Expand Up @@ -437,6 +438,7 @@ def toChannel(self) -> grpc.Channel:
-------
GRPC Channel instance.
"""
print('USING ENDPOINT', self.endpoint)

if not self.secure:
return self._insecure_channel(self.endpoint)
Expand Down Expand Up @@ -637,6 +639,13 @@ def __init__(
if isinstance(connection, ChannelBuilder)
else DefaultChannelBuilder(connection, channel_options)
)

if "SPARK_LOCAL_CONNECT_TOKEN" in os.environ:
self._builder.set(
ChannelBuilder.PARAM_TOKEN,
os.environ["SPARK_LOCAL_CONNECT_TOKEN"]
)

self._user_id = None
self._retry_policies: List[RetryPolicy] = []

Expand Down
14 changes: 9 additions & 5 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import uuid
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)
Expand Down Expand Up @@ -277,6 +278,7 @@ def __init__(self, connection: Union[str, DefaultChannelBuilder], userId: Option
the $USER environment. Defining the user ID as part of the connection string
takes precedence.
"""
print("Creating connection", connection)
self._client = SparkConnectClient(connection=connection, user_id=userId)
self._session_id = self._client._session_id

Expand Down Expand Up @@ -1041,11 +1043,18 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
init_opts.update(opts)
opts = init_opts

token = str(uuid.uuid4())

# Configurations to be overwritten
overwrite_conf = opts
overwrite_conf["spark.master"] = master
overwrite_conf["spark.local.connect"] = "1"
# When running a local server, always use an ephemeral port
overwrite_conf["spark.connect.grpc.binding.port"] = "0"
overwrite_conf["spark.connect.authenticate.secret"] = token
os.environ["SPARK_LOCAL_CONNECT"] = "1"
os.environ["SPARK_LOCAL_CONNECT_TOKEN"] = token
print("CONNECTING LOCALLY")

# Configurations to be set if unset.
default_conf = {
Expand All @@ -1054,11 +1063,6 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
"spark.sql.artifact.isolation.alwaysApplyClassloader": "true",
}

if "SPARK_TESTING" in os.environ:
# For testing, we use 0 to use an ephemeral port to allow parallel testing.
# See also SPARK-42272.
overwrite_conf["spark.connect.grpc.binding.port"] = "0"

origin_remote = os.environ.get("SPARK_REMOTE", None)
try:
if origin_remote is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,13 @@ object Connect {
.internal()
.booleanConf
.createWithDefault(true)

val CONNECT_AUTHENTICATE_SECRET =
buildConf("spark.connect.authenticate.secret")
.doc("A pre-shared key that will be used to authenticate clients. This secret must be" +
" passed as a bearer token by for clients to connect.")
.version("4.0.0")
.internal()
.stringConf
.createOptional
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor, Status}


class PreSharedKeyAuthenticationInterceptor(token: String) extends ServerInterceptor {

override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
metadata: Metadata,
next: ServerCallHandler[ReqT, RespT]
): ServerCall.Listener[ReqT] = {
val authHeaderValue =
metadata.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER))

if (authHeaderValue == null) {
val status = Status.UNAUTHENTICATED.withDescription("No auth token provided")
call.close(status, new Metadata())
} else if (authHeaderValue != s"Bearer $token") {
val status = Status.UNAUTHENTICATED.withDescription("Invalid auth token")
call.close(status, new Metadata())
} else {
return next.startCall(call, metadata)
}
// No-op for close calls
new ServerCall.Listener[ReqT]() {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_AUTHENTICATE_SECRET, CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
Expand Down Expand Up @@ -381,6 +381,10 @@ object SparkConnectService extends Logging {
sb.maxInboundMessageSize(SparkEnv.get.conf.get(CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE).toInt)
.addService(sparkConnectService)

SparkEnv.get.conf.get(CONNECT_AUTHENTICATE_SECRET).foreach { token =>
sb.intercept(new PreSharedKeyAuthenticationInterceptor(token))
}

// Add all registered interceptors to the server builder.
SparkConnectInterceptorRegistry.chainInterceptors(sb, configuredInterceptors)

Expand Down

0 comments on commit f02423a

Please sign in to comment.