Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Literals foldable, ensure Parquet predicates pushdown #721

Merged
merged 27 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3bbdb9c
#343 - unpack to Literals
chris-twiner Jun 5, 2023
3df02ec
#343 - unpack to Literals - more test
chris-twiner Jun 5, 2023
c8ecea8
#343 - unpack to Literals - comment
chris-twiner Jun 5, 2023
b7c3132
#343 - per review - docs missing
chris-twiner Jun 5, 2023
81d9315
#343 - per review - docs missing - fix reflection for all versions
chris-twiner Jun 5, 2023
a3567c2
#343 - add struct test showing difference between extension and exper…
chris-twiner Jun 6, 2023
bee3cd0
#343 - toString test to stop the patch complaint
chris-twiner Jun 6, 2023
bba92cb
#343 - sample docs
chris-twiner Jun 6, 2023
28bde88
#343 - package rename and adding logging that the extension is injected
chris-twiner Jun 6, 2023
f4e99b5
#343 - doc fixes
chris-twiner Jun 6, 2023
0cbe684
#343 - doc fixes
chris-twiner Jun 6, 2023
c308241
#343 - can't run that code
chris-twiner Jun 6, 2023
381931c
#343 - didn't stop the new sparkSession
chris-twiner Jun 6, 2023
3df725f
Apply suggestions from code review
chris-twiner Jun 6, 2023
23c3eb7
#343 - more z's, debug removal, comment adjust and brackets around ex…
chris-twiner Jun 6, 2023
2a83510
Refactor LitRule and LitRules tests by making them slightly more gene…
pomadchin Jun 7, 2023
e7ba599
Fix mdoc compilation
pomadchin Jun 7, 2023
e9999c1
#343 - added the struct test back
chris-twiner Jun 7, 2023
4e7bee3
#343 - disable the rule, foldable and eval evals
chris-twiner Jun 7, 2023
27e7c25
#343 - cleaned up
chris-twiner Jun 7, 2023
18f2bc6
More code cleanup
pomadchin Jun 7, 2023
82bf013
#343 - true with link for 3.2 support
chris-twiner Jun 7, 2023
c6bbe2c
#343 - bring back code gen with lazy to stop recompiles
chris-twiner Jun 7, 2023
31a023f
#343 - disable tests on 3.2, document why and renable the proper fold…
chris-twiner Jun 8, 2023
d7db649
#343 - more compat and a foldable only backport of SPARK-39106 and SP…
chris-twiner Jun 8, 2023
0e6c561
#343 - option 3 - let 3.2 fail as per oss impl, seperate tests
chris-twiner Jun 8, 2023
411871b
#343 - option 3 - better dir names
chris-twiner Jun 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ val shapeless = "2.3.10"
val scalacheck = "1.17.0"
val scalacheckEffect = "1.0.4"
val refinedVersion = "0.10.3"
val nakedFSVersion = "0.1.0"

val Scala212 = "2.12.17"
val Scala213 = "2.13.10"
Expand Down Expand Up @@ -66,6 +67,7 @@ lazy val `cats-spark32` = project
lazy val dataset = project
.settings(name := "frameless-dataset")
.settings(Compile / unmanagedSourceDirectories += baseDirectory.value / "src" / "main" / "spark-3.4+")
.settings(Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "spark-3.3+")
.settings(datasetSettings)
.settings(sparkDependencies(sparkVersion))
.dependsOn(core % "test->test;compile->compile")
Expand All @@ -74,6 +76,7 @@ lazy val `dataset-spark33` = project
.settings(name := "frameless-dataset-spark33")
.settings(sourceDirectory := (dataset / sourceDirectory).value)
.settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3")
.settings(Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.3+")
.settings(datasetSettings)
.settings(sparkDependencies(spark33Version))
.settings(spark33Settings)
Expand All @@ -83,6 +86,7 @@ lazy val `dataset-spark32` = project
.settings(name := "frameless-dataset-spark32")
.settings(sourceDirectory := (dataset / sourceDirectory).value)
.settings(Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3")
.settings(Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.2")
.settings(datasetSettings)
.settings(sparkDependencies(spark32Version))
.settings(spark32Settings)
Expand Down Expand Up @@ -192,7 +196,9 @@ lazy val datasetSettings = framelessSettings ++ framelessTypedDatasetREPL ++ Seq
dmm("org.apache.spark.sql.FramelessInternals.column")
)
},
coverageExcludedPackages := "org.apache.spark.sql.reflection"
coverageExcludedPackages := "org.apache.spark.sql.reflection",

libraryDependencies += "com.globalmentor" % "hadoop-bare-naked-local-fs" % nakedFSVersion % Test exclude("org.apache.hadoop", "hadoop-commons")
)

lazy val refinedSettings = framelessSettings ++ framelessTypedDatasetREPL ++ Seq(
Expand Down
18 changes: 11 additions & 7 deletions dataset/src/main/scala/frameless/functions/Lit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ import org.apache.spark.sql.types.DataType
private[frameless] case class Lit[T <: AnyVal](
dataType: DataType,
nullable: Boolean,
toCatalyst: CodegenContext => ExprCode,
show: () => String
show: () => String,
catalystExpr: Expression // must be a generated Expression from a literal TypedEncoder's toCatalyst function
) extends Expression with NonSQLExpression {
override def toString: String = s"FramelessLit(${show()})"

def eval(input: InternalRow): Any = {
lazy val codegen = {
pomadchin marked this conversation as resolved.
Show resolved Hide resolved
val ctx = new CodegenContext()
val eval = genCode(ctx)

val codeBody = s"""
val codeBody =
s"""
public scala.Function1<InternalRow, Object> generate(Object[] references) {
return new LiteralEvalImpl(references);
}
Expand Down Expand Up @@ -47,13 +48,16 @@ private[frameless] case class Lit[T <: AnyVal](
val (clazz, _) = CodeGenerator.compile(code)
val codegen =
clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef]

codegen(input)
codegen
}

def eval(input: InternalRow): Any = codegen(input)

def children: Seq[Expression] = Nil

protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = toCatalyst(ctx)
protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = catalystExpr.genCode(ctx)

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = this

override val foldable: Boolean = catalystExpr.foldable
}
12 changes: 6 additions & 6 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ package object functions extends Udf with UnaryFunctions {
Lit(
dataType = encoder.catalystRepr,
nullable = encoder.nullable,
toCatalyst = encoder.toCatalyst(expr).genCode(_),
show = () => value.toString
show = () => value.toString,
catalystExpr = encoder.toCatalyst(expr)
)
)
}
Expand Down Expand Up @@ -84,8 +84,8 @@ package object functions extends Udf with UnaryFunctions {
Lit(
dataType = i7.catalystRepr,
nullable = i7.nullable,
toCatalyst = i7.toCatalyst(expr).genCode(_),
show = () => value.toString
show = () => value.toString,
i7.toCatalyst(expr)
)
)
}
Expand Down Expand Up @@ -127,8 +127,8 @@ package object functions extends Udf with UnaryFunctions {
Lit(
dataType = i7.catalystRepr,
nullable = true,
toCatalyst = i7.toCatalyst(expr).genCode(_),
show = () => value.toString
show = () => value.toString,
i7.toCatalyst(expr)
)
)
}
Expand Down
7 changes: 6 additions & 1 deletion dataset/src/test/scala/frameless/LitTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ class LitTests extends TypedDatasetSuite with Matchers {

val someIpsum: Option[Name] = Some(new Name("Ipsum"))

ds.withColumnReplaced('alias, functions.litValue(someIpsum)).
val lit = functions.litValue(someIpsum)
val tds = ds.withColumnReplaced('alias, functions.litValue(someIpsum))

tds.queryExecution.toString() should include (lit.toString)
chris-twiner marked this conversation as resolved.
Show resolved Hide resolved

tds.
collect.run() shouldBe initial.map(_.copy(alias = someIpsum))

ds.withColumnReplaced('alias, functions.litValue(Option.empty[Name])).
Expand Down
22 changes: 21 additions & 1 deletion dataset/src/test/scala/frameless/TypedDatasetSuite.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
package frameless

import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem
pomadchin marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.hadoop.fs.local.StreamingFS
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalactic.anyvals.PosZInt
import org.scalatest.BeforeAndAfterAll
import org.scalatestplus.scalacheck.Checkers
import org.scalacheck.Prop
import org.scalacheck.Prop._

import scala.util.{Properties, Try}
import org.scalatest.funsuite.AnyFunSuite

trait SparkTesting { self: BeforeAndAfterAll =>

val appID: String = new java.util.Date().toString + math.floor(math.random * 10E4).toLong.toString

val conf: SparkConf = new SparkConf()
/**
* Allows bare naked to be used instead of winutils for testing / dev
*/
def registerFS(sparkConf: SparkConf): SparkConf = {
if (System.getProperty("os.name").startsWith("Windows"))
sparkConf.set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName).
set("spark.hadoop.fs.AbstractFileSystem.file.impl", classOf[StreamingFS].getName)
pomadchin marked this conversation as resolved.
Show resolved Hide resolved
else
sparkConf
}

val conf: SparkConf = registerFS(new SparkConf())
.setMaster("local[*]")
.setAppName("test")
.set("spark.ui.enabled", "false")
Expand All @@ -26,9 +40,15 @@ trait SparkTesting { self: BeforeAndAfterAll =>
implicit def sc: SparkContext = session.sparkContext
implicit def sqlContext: SQLContext = session.sqlContext

def registerOptimizations(sqlContext: SQLContext): Unit = { }

def addSparkConfigProperties(config: SparkConf): Unit = { }

override def beforeAll(): Unit = {
assert(s == null)
addSparkConfigProperties(conf)
s = SparkSession.builder().config(conf).getOrCreate()
registerOptimizations(sqlContext)
}

override def afterAll(): Unit = {
Expand Down
20 changes: 20 additions & 0 deletions dataset/src/test/scala/frameless/sql/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package frameless

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{And, Or}

package object sql {
implicit class ExpressionOps(val self: Expression) extends AnyVal {
def toList: List[Expression] = {
def rec(expr: Expression, acc: List[Expression]): List[Expression] = {
expr match {
case And(left, right) => rec(left, rec(right, acc))
case Or(left, right) => rec(left, rec(right, acc))
case e => e +: acc
}
}

rec(self, Nil)
}
}
}
74 changes: 74 additions & 0 deletions dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package frameless.sql.rules

import frameless._
import frameless.sql._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers

trait SQLRulesSuite extends TypedDatasetSuite with Matchers { self =>
protected lazy val path: String = {
val tmpDir = System.getProperty("java.io.tmpdir")
s"$tmpDir/${self.getClass.getName}"
}

def withDataset[A: TypedEncoder: CatalystOrdered](payload: A)(f: TypedDataset[A] => Assertion): Assertion = {
TypedDataset.create(Seq(payload)).write.mode("overwrite").parquet(path)
f(TypedDataset.createUnsafe[A](session.read.parquet(path)))
}

def predicatePushDownTest[A: TypedEncoder: CatalystOrdered](
expected: X1[A],
expectedPushDownFilters: List[Filter],
planShouldNotContain: PartialFunction[Expression, Expression],
op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean]
): Assertion = {
withDataset(expected) { dataset =>
val ds = dataset.filter(op(dataset('a)))
val actualPushDownFilters = pushDownFilters(ds)

val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList)

// check the optimized plan
optimizedPlan.collectFirst(planShouldNotContain) should be (empty)

// compare filters
actualPushDownFilters shouldBe expectedPushDownFilters

val actual = ds.collect().run().toVector.headOption

// ensure serialization is not broken
pomadchin marked this conversation as resolved.
Show resolved Hide resolved
actual should be(Some(expected))
}
}

protected def pushDownFilters[T](ds: TypedDataset[T]): List[Filter] = {
val sparkPlan = ds.queryExecution.executedPlan

val initialPlan =
if (sparkPlan.children.isEmpty) // assume it's AQE
sparkPlan match {
case aq: AdaptiveSparkPlanExec => aq.initialPlan
case _ => sparkPlan
}
else
sparkPlan

initialPlan.collect {
case fs: FileSourceScanExec =>
import scala.reflect.runtime.{universe => ru}

val runtimeMirror = ru.runtimeMirror(getClass.getClassLoader)
val instanceMirror = runtimeMirror.reflect(fs)
val getter = ru.typeOf[FileSourceScanExec].member(ru.TermName("pushedDownFilters")).asTerm.getter
val m = instanceMirror.reflectMethod(getter.asMethod)
val res = m.apply(fs).asInstanceOf[Seq[Filter]]

res
}.flatten.toList
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.apache.hadoop.fs.local

import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem
import org.apache.hadoop.fs.DelegateToFileSystem

class StreamingFS(uri: java.net.URI, conf: org.apache.hadoop.conf.Configuration) extends
DelegateToFileSystem(uri, new BareLocalFileSystem(), conf, "file", false) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package frameless.sql.rules

import frameless._
import frameless.sql._
import frameless.functions.Lit
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant}
import org.apache.spark.sql.sources.{Filter, IsNotNull}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericRowWithSchema}
import java.time.Instant

import org.apache.spark.sql.catalyst.plans.logical
import org.scalatest.Assertion

//Note as InvokeLike and "ConditionalExpression" don't have SPARK-40380 and SPARK-39106 no predicate pushdowns can happen in 3.2.4
class FramelessLitPushDownTests extends SQLRulesSuite {
private val now: Long = currentTimestamp()

test("java.sql.Timestamp push-down") {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(SQLTimestamp(now))
val expectedPushDownFilters = List(IsNotNull("a"))

predicatePushDownTest[SQLTimestamp](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e },
_ >= expectedStructure.a
)
}

test("java.time.Instant push-down") {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(microsToInstant(now))
val expectedPushDownFilters = List(IsNotNull("a"))

predicatePushDownTest[Instant](
expectedStructure,
expectedPushDownFilters,
{ case e @ expressions.GreaterThanOrEqual(_, _: Lit[_]) => e },
_ >= expectedStructure.a
)
}

test("struct push-down") {
type Payload = X4[Int, Int, Int, Int]
val expectedStructure = X1(X4(1, 2, 3, 4))
val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema)
val expectedPushDownFilters = List(IsNotNull("a"))

predicatePushDownTest[Payload](
expectedStructure,
expectedPushDownFilters,
// Cast not Lit because of SPARK-40380
{ case e @ expressions.EqualTo(_, _: Cast) => e },
_ === expectedStructure.a
)
}

override def predicatePushDownTest[A: TypedEncoder: CatalystOrdered](
expected: X1[A],
expectedPushDownFilters: List[Filter],
planShouldContain: PartialFunction[Expression, Expression],
op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean]
): Assertion = {
withDataset(expected) { dataset =>
val ds = dataset.filter(op(dataset('a)))
val actualPushDownFilters = pushDownFilters(ds)

val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList)

// check the optimized plan
optimizedPlan.collectFirst(planShouldContain) should not be (empty)

// compare filters
actualPushDownFilters shouldBe expectedPushDownFilters

val actual = ds.collect().run().toVector.headOption

// ensure serialization is not broken
actual should be(Some(expected))
}
}

}
Loading