-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-20331][SQL] Enhanced Hive partition pruning predicate pushdown #17633
Changes from all commits
8a5d7f5
3e469cf
51551af
f8f0bd5
2165c87
a087a0f
af3065a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit | |
|
||
import scala.collection.JavaConverters._ | ||
import scala.util.control.NonFatal | ||
import scala.util.Try | ||
|
||
import org.apache.hadoop.fs.Path | ||
import org.apache.hadoop.hive.conf.HiveConf | ||
|
@@ -46,6 +47,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTableParti | |
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types.{IntegralType, StringType} | ||
import org.apache.spark.unsafe.types.UTF8String | ||
import org.apache.spark.util.Utils | ||
|
||
/** | ||
|
@@ -589,18 +591,67 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { | |
col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) | ||
.map(col => col.getName).toSet | ||
|
||
filters.collect { | ||
case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => | ||
s"${a.name} ${op.symbol} $v" | ||
case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => | ||
s"$v ${op.symbol} ${a.name}" | ||
case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) | ||
object ExtractableLiteral { | ||
def unapply(expr: Expression): Option[String] = expr match { | ||
case Literal(value, _: IntegralType) => Some(value.toString) | ||
case Literal(value, _: StringType) => Some(quoteStringLiteral(value.toString)) | ||
case _ => None | ||
} | ||
} | ||
|
||
object ExtractableLiterals { | ||
def unapply(exprs: Seq[Expression]): Option[Seq[String]] = { | ||
exprs.map(ExtractableLiteral.unapply).foldLeft(Option(Seq.empty[String])) { | ||
case (Some(accum), Some(value)) => Some(accum :+ value) | ||
case _ => None | ||
} | ||
} | ||
} | ||
|
||
object ExtractableValues { | ||
private lazy val valueToLiteralString: PartialFunction[Any, String] = { | ||
case value: Byte => value.toString | ||
case value: Short => value.toString | ||
case value: Int => value.toString | ||
case value: Long => value.toString | ||
case value: UTF8String => quoteStringLiteral(value.toString) | ||
} | ||
|
||
def unapply(values: Set[Any]): Option[Seq[String]] = { | ||
values.toSeq.foldLeft(Option(Seq.empty[String])) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
case (Some(accum), value) if valueToLiteralString.isDefinedAt(value) => | ||
Some(accum :+ valueToLiteralString(value)) | ||
case _ => None | ||
} | ||
} | ||
} | ||
|
||
def convertInToOr(a: Attribute, values: Seq[String]): String = { | ||
values.map(value => s"${a.name} = $value").mkString("(", " or ", ")") | ||
} | ||
|
||
lazy val convert: PartialFunction[Expression, String] = { | ||
case In(a: Attribute, ExtractableLiterals(values)) | ||
if !varcharKeys.contains(a.name) && values.nonEmpty => | ||
convertInToOr(a, values) | ||
case InSet(a: Attribute, ExtractableValues(values)) | ||
if !varcharKeys.contains(a.name) && values.nonEmpty => | ||
convertInToOr(a, values) | ||
case op @ BinaryComparison(a: Attribute, ExtractableLiteral(value)) | ||
if !varcharKeys.contains(a.name) => | ||
s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" | ||
case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) | ||
s"${a.name} ${op.symbol} $value" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a problem with leaving them out? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm, realized that |
||
case op @ BinaryComparison(ExtractableLiteral(value), a: Attribute) | ||
if !varcharKeys.contains(a.name) => | ||
s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" | ||
}.mkString(" and ") | ||
s"$value ${op.symbol} ${a.name}" | ||
case op @ And(expr1, expr2) | ||
if convert.isDefinedAt(expr1) || convert.isDefinedAt(expr2) => | ||
(convert.lift(expr1) ++ convert.lift(expr2)).mkString("(", " and ", ")") | ||
case op @ Or(expr1, expr2) | ||
if convert.isDefinedAt(expr1) && convert.isDefinedAt(expr2) => | ||
s"(${convert(expr1)} or ${convert(expr2)})" | ||
} | ||
|
||
filters.map(convert.lift).collect { case Some(filterString) => filterString }.mkString(" and ") | ||
} | ||
|
||
private def quoteStringLiteral(str: String): String = { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,21 +19,25 @@ package org.apache.spark.sql.hive.client | |
|
||
import org.apache.hadoop.conf.Configuration | ||
import org.apache.hadoop.hive.conf.HiveConf | ||
import org.scalatest.BeforeAndAfterAll | ||
|
||
import org.apache.spark.SparkFunSuite | ||
import org.apache.spark.sql.catalyst.catalog._ | ||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} | ||
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EmptyRow, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, InSet, LessThan, LessThanOrEqual, Like, Literal, Or} | ||
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser | ||
import org.apache.spark.sql.hive.HiveUtils | ||
import org.apache.spark.sql.types.IntegerType | ||
import org.apache.spark.sql.types.{ByteType, IntegerType, StringType} | ||
|
||
class HiveClientSuite extends SparkFunSuite { | ||
private val clientBuilder = new HiveClientBuilder | ||
// TODO: Refactor this to `HivePartitionFilteringSuite` | ||
class HiveClientSuite(version: String) | ||
extends HiveVersionSuite(version) with BeforeAndAfterAll { | ||
import CatalystSqlParser._ | ||
|
||
private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname | ||
|
||
test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { | ||
val testPartitionCount = 5 | ||
private val testPartitionCount = 3 * 24 * 4 | ||
|
||
private def init(tryDirectSql: Boolean): HiveClient = { | ||
val storageFormat = CatalogStorageFormat( | ||
locationUri = None, | ||
inputFormat = None, | ||
|
@@ -43,19 +47,214 @@ class HiveClientSuite extends SparkFunSuite { | |
properties = Map.empty) | ||
|
||
val hadoopConf = new Configuration() | ||
hadoopConf.setBoolean(tryDirectSqlKey, false) | ||
val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf) | ||
client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)") | ||
hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) | ||
val client = buildClient(hadoopConf) | ||
client | ||
.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") | ||
|
||
val partitions = | ||
for { | ||
ds <- 20170101 to 20170103 | ||
h <- 0 to 23 | ||
chunk <- Seq("aa", "ab", "ba", "bb") | ||
} yield CatalogTablePartition(Map( | ||
"ds" -> ds.toString, | ||
"h" -> h.toString, | ||
"chunk" -> chunk | ||
), storageFormat) | ||
assert(partitions.size == testPartitionCount) | ||
|
||
val partitions = (1 to testPartitionCount).map { part => | ||
CatalogTablePartition(Map("part" -> part.toString), storageFormat) | ||
} | ||
client.createPartitions( | ||
"default", "test", partitions, ignoreIfExists = false) | ||
client | ||
} | ||
|
||
override def beforeAll() { | ||
client = init(true) | ||
} | ||
|
||
test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { | ||
val client = init(false) | ||
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), | ||
Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3)))) | ||
Seq(parseExpression("ds=20170101"))) | ||
|
||
assert(filteredPartitions.size == testPartitionCount) | ||
} | ||
|
||
test("getPartitionsByFilter: ds=20170101") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will these tests be executed on all supported hive versions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No. That's something I'll look into. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They will now. See |
||
testMetastorePartitionFiltering( | ||
"ds=20170101", | ||
20170101 to 20170101, | ||
0 to 23, | ||
"aa" :: "ab" :: "ba" :: "bb" :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: ds=(20170101 + 1) and h=0") { | ||
// Should return all partitions where h=0 because getPartitionsByFilter does not support | ||
// comparisons to non-literal values | ||
testMetastorePartitionFiltering( | ||
"ds=(20170101 + 1) and h=0", | ||
20170101 to 20170103, | ||
0 to 0, | ||
"aa" :: "ab" :: "ba" :: "bb" :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: chunk='aa'") { | ||
testMetastorePartitionFiltering( | ||
"chunk='aa'", | ||
20170101 to 20170103, | ||
0 to 23, | ||
"aa" :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: 20170101=ds") { | ||
testMetastorePartitionFiltering( | ||
"20170101=ds", | ||
20170101 to 20170101, | ||
0 to 23, | ||
"aa" :: "ab" :: "ba" :: "bb" :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: ds=20170101 and h=10") { | ||
testMetastorePartitionFiltering( | ||
"ds=20170101 and h=10", | ||
20170101 to 20170101, | ||
10 to 10, | ||
"aa" :: "ab" :: "ba" :: "bb" :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: ds=20170101 or ds=20170102") { | ||
testMetastorePartitionFiltering( | ||
"ds=20170101 or ds=20170102", | ||
20170101 to 20170102, | ||
0 to 23, | ||
"aa" :: "ab" :: "ba" :: "bb" :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { | ||
testMetastorePartitionFiltering( | ||
"ds in (20170102, 20170103)", | ||
20170102 to 20170103, | ||
0 to 23, | ||
"aa" :: "ab" :: "ba" :: "bb" :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { | ||
testMetastorePartitionFiltering( | ||
"ds in (20170102, 20170103)", | ||
20170102 to 20170103, | ||
0 to 23, | ||
"aa" :: "ab" :: "ba" :: "bb" :: Nil, { | ||
case expr @ In(v, list) if expr.inSetConvertible => | ||
InSet(v, Set() ++ list.map(_.eval(EmptyRow))) | ||
}) | ||
} | ||
|
||
test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { | ||
testMetastorePartitionFiltering( | ||
"chunk in ('ab', 'ba')", | ||
20170101 to 20170103, | ||
0 to 23, | ||
"ab" :: "ba" :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { | ||
testMetastorePartitionFiltering( | ||
"chunk in ('ab', 'ba')", | ||
20170101 to 20170103, | ||
0 to 23, | ||
"ab" :: "ba" :: Nil, { | ||
case expr @ In(v, list) if expr.inSetConvertible => | ||
InSet(v, Set() ++ list.map(_.eval(EmptyRow))) | ||
}) | ||
} | ||
|
||
test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { | ||
val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) | ||
val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) | ||
testMetastorePartitionFiltering( | ||
"(ds=20170101 and h>=8) or (ds=20170102 and h<8)", | ||
day1 :: day2 :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { | ||
val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) | ||
// Day 2 should include all hours because we can't build a filter for h<(7+1) | ||
val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) | ||
testMetastorePartitionFiltering( | ||
"(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", | ||
day1 :: day2 :: Nil) | ||
} | ||
|
||
test("getPartitionsByFilter: " + | ||
"chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { | ||
val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) | ||
val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) | ||
testMetastorePartitionFiltering( | ||
"chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", | ||
day1 :: day2 :: Nil) | ||
} | ||
|
||
private def testMetastorePartitionFiltering( | ||
filterString: String, | ||
expectedDs: Seq[Int], | ||
expectedH: Seq[Int], | ||
expectedChunks: Seq[String]): Unit = { | ||
testMetastorePartitionFiltering( | ||
filterString, | ||
(expectedDs, expectedH, expectedChunks) :: Nil, | ||
identity) | ||
} | ||
|
||
private def testMetastorePartitionFiltering( | ||
filterString: String, | ||
expectedDs: Seq[Int], | ||
expectedH: Seq[Int], | ||
expectedChunks: Seq[String], | ||
transform: Expression => Expression): Unit = { | ||
testMetastorePartitionFiltering( | ||
filterString, | ||
(expectedDs, expectedH, expectedChunks) :: Nil, | ||
identity) | ||
} | ||
|
||
private def testMetastorePartitionFiltering( | ||
filterString: String, | ||
expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { | ||
testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) | ||
} | ||
|
||
private def testMetastorePartitionFiltering( | ||
filterString: String, | ||
expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], | ||
transform: Expression => Expression): Unit = { | ||
val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), | ||
Seq( | ||
transform(parseExpression(filterString)) | ||
)) | ||
|
||
val expectedPartitionCount = expectedPartitionCubes.map { | ||
case (expectedDs, expectedH, expectedChunks) => | ||
expectedDs.size * expectedH.size * expectedChunks.size | ||
}.sum | ||
|
||
val expectedPartitions = expectedPartitionCubes.map { | ||
case (expectedDs, expectedH, expectedChunks) => | ||
for { | ||
ds <- expectedDs | ||
h <- expectedH | ||
chunk <- expectedChunks | ||
} yield Set( | ||
"ds" -> ds.toString, | ||
"h" -> h.toString, | ||
"chunk" -> chunk | ||
) | ||
}.reduce(_ ++ _) | ||
|
||
val actualFilteredPartitionCount = filteredPartitions.size | ||
|
||
assert(actualFilteredPartitionCount == expectedPartitionCount, | ||
s"Expected $expectedPartitionCount partitions but got $actualFilteredPartitionCount") | ||
assert(filteredPartitions.map(_.spec.toSet).toSet == expectedPartitions.toSet) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.hive.client | ||
|
||
import scala.collection.immutable.IndexedSeq | ||
|
||
import org.scalatest.Suite | ||
|
||
class HiveClientSuites extends Suite with HiveClientVersions { | ||
override def nestedSuites: IndexedSeq[Suite] = { | ||
// Hive 0.12 does not provide the partition filtering API we call | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then we will lost test coverage because we don't test hive 0.12 for basic operations... How about we create a new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No test coverage has been lost. I added "TODO" comments suggesting better names for these classes, for renaming later. |
||
versions.filterNot(_ == "0.12").map(new HiveClientSuite(_)) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like it to be more java style:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there something wrong with the way it is now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
foldLeft
may be not friendly to some Spark developers, but it's not a big deal.