Skip to content

Commit

Permalink
[SPARK-48974][SQL][SS][ML][MLLIB] Use SparkSession.implicits instea…
Browse files Browse the repository at this point in the history
…d of `SQLContext.implicits`

### What changes were proposed in this pull request?
This PR replaces `SQLContext.implicits`  with `SparkSession.implicits` in the Spark codebase.

### Why are the changes needed?
Reduce the usage of code from `SQLContext` within the internal code of Spark.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Pass GitHub Actions

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47457 from LuciferYang/use-sparksession-implicits.

Lead-authored-by: yangjie01 <yangjie01@baidu.com>
Co-authored-by: YangJie <yangjie01@baidu.com>
Signed-off-by: yangjie01 <yangjie01@baidu.com>
  • Loading branch information
LuciferYang committed Jul 24, 2024
1 parent fdcf975 commit 877c3f2
Show file tree
Hide file tree
Showing 16 changed files with 28 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ object MLUtils extends Logging {
).resolveRelation(checkFilesExist = false))
.select("value")

import lines.sqlContext.implicits._
import lines.sparkSession.implicits._

lines.select(trim($"value").as("line"))
.filter(not((length($"line") === 0).or($"line".startsWith("#"))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class FMClassifierSuite extends MLTest with DefaultReadWriteTest {
}

test("FMClassifier: Predictor, Classifier methods") {
val sqlContext = smallBinaryDataset.sqlContext
import sqlContext.implicits._
val session = smallBinaryDataset.sparkSession
import session.implicits._
val fm = new FMClassifier()

val model = fm.fit(smallBinaryDataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
}

test("multinomial logistic regression: Predictor, Classifier methods") {
val sqlContext = smallMultinomialDataset.sqlContext
import sqlContext.implicits._
val session = smallMultinomialDataset.sparkSession
import session.implicits._
val mlr = new LogisticRegression().setFamily("multinomial")

val model = mlr.fit(smallMultinomialDataset)
Expand Down Expand Up @@ -590,8 +590,8 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
}

test("binary logistic regression: Predictor, Classifier methods") {
val sqlContext = smallBinaryDataset.sqlContext
import sqlContext.implicits._
val session = smallBinaryDataset.sparkSession
import session.implicits._
val lr = new LogisticRegression().setFamily("binomial")

val model = lr.fit(smallBinaryDataset)
Expand Down Expand Up @@ -1427,8 +1427,8 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
val trainer2 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight")
.setElasticNetParam(1.0).setRegParam(6.0).setStandardization(false)

val sqlContext = multinomialDataset.sqlContext
import sqlContext.implicits._
val session = multinomialDataset.sparkSession
import session.implicits._
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class CollectTopKSuite extends MLTest {

override def beforeAll(): Unit = {
super.beforeAll()
val sqlContext = spark.sqlContext
import sqlContext.implicits._
val session = spark
import session.implicits._
dataFrame = Seq(
(0, 3, 54f),
(0, 4, 44f),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -962,8 +962,8 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
}

test("linear regression with weighted samples") {
val sqlContext = spark.sqlContext
import sqlContext.implicits._
val session = spark
import session.implicits._
val numClasses = 0
def modelEquals(m1: LinearRegressionModel, m2: LinearRegressionModel): Unit = {
assert(m1.coefficients ~== m2.coefficients relTol 0.01)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object MLTestingUtils extends SparkFunSuite {
numClasses: Int,
modelEquals: (M, M) => Unit,
outlierRatio: Int): Unit = {
import data.sqlContext.implicits._
import data.sparkSession.implicits._
val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap {
case Instance(l, w, f) =>
val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ object CSVUtils {
// Note that this was separately made by SPARK-18362. Logically, this should be the same
// with the one below, `filterCommentAndEmpty` but execution path is different. One of them
// might have to be removed in the near future if possible.
import lines.sqlContext.implicits._
import lines.sparkSession.implicits._
val aliased = lines.toDF("value")
val nonEmptyLines = aliased.filter(length(trim($"value")) > 0)
if (options.isCommentSet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
.contains(MyQueryStagePrepRule()))
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(MyNewQueryStageRule(), MyNewQueryStageRule())))
import session.sqlContext.implicits._
import session.implicits._
val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1)
val df = data.selectExpr("vals + 1")
df.collect()
Expand Down Expand Up @@ -225,7 +225,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
session.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
import session.implicits._
// perform a join to inject a shuffle exchange
val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2")
val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2")
Expand Down Expand Up @@ -283,7 +283,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
import session.implicits._
// perform a join to inject a broadcast exchange
val left = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("l1", "l2")
val right = Seq((1, 50L), (2, 100L), (3, 150L)).toDF("r1", "r2")
Expand Down Expand Up @@ -327,7 +327,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
try {
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
import session.implicits._

val input = Seq((100L), (200L), (300L))
val data = input.toDF("vals").repartition(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class BlockingSource extends StreamSourceProvider with StreamSinkProvider {
override def schema: StructType = fakeSchema
override def getOffset: Option[Offset] = Some(new LongOffset(0))
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
import spark.implicits._
import spark.sparkSession.implicits._
Seq[Int]().toDS().toDF()
}
override def stop(): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class HiveContextCompatibilitySuite extends SparkFunSuite {

test("basic operations") {
val _hc = hc
import _hc.implicits._
import _hc.sparkSession.implicits._
val df1 = (1 to 20).map { i => (i, i) }.toDF("a", "x")
val df2 = (1 to 100).map { i => (i, i % 10, i % 2 == 0) }.toDF("a", "b", "c")
.select($"a", $"b")
Expand All @@ -71,7 +71,7 @@ class HiveContextCompatibilitySuite extends SparkFunSuite {

test("basic DDLs") {
val _hc = hc
import _hc.implicits._
import _hc.sparkSession.implicits._
val databases = hc.sql("SHOW DATABASES").collect().map(_.getString(0))
assert(databases.toSeq == Seq("default"))
hc.sql("CREATE DATABASE mee_db")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ object SPARK_9757 extends QueryTest {

val hiveContext = new TestHiveContext(sparkContext)
spark = hiveContext.sparkSession
import hiveContext.implicits._
import hiveContext.sparkSession.implicits._

val dir = Utils.createTempDir()
dir.delete()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ListTablesSuite extends QueryTest
with TestHiveSingleton
with BeforeAndAfterAll {
import hiveContext._
import hiveContext.implicits._
import hiveContext.sparkSession.implicits._

val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ case class TestData(a: Int, b: String)
*/
@SlowHiveTest
class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.hive.test.TestHive.sparkSession.implicits._

private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, sql}
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.hive.test.TestHive.sparkSession.implicits._
import org.apache.spark.tags.SlowHiveTest

case class Nested(a: Int, B: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton}
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.hive.test.TestHive.sparkSession.implicits._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.tags.SlowHiveTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.io.File
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, Expression, HiveHash, Literal, Pmod}
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.hive.test.TestHive.sparkSession.implicits._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION

Expand Down

0 comments on commit 877c3f2

Please sign in to comment.