Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29870][SQL] Unify the logic of multi-units interval string to CalendarInterval #26491

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ singleTableSchema
: colTypeList EOF
;

singleInterval
: INTERVAL? multiUnitsInterval EOF
;

statement
: query #statementDefault
| ctes? dmlStatementNoWith #dmlStatement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
}

override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
withOrigin(ctx)(visitMultiUnitsInterval(ctx.multiUnitsInterval))
}

/* ********************************************************************************************
* Plan parsing
* ******************************************************************************************** */
Expand Down Expand Up @@ -2075,22 +2071,21 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
withOrigin(ctx) {
val units = ctx.intervalUnit().asScala.map { unit =>
val u = unit.getText.toLowerCase(Locale.ROOT)
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
if (u.endsWith("s")) u.substring(0, u.length - 1) else u
}.map(IntervalUtils.IntervalUnit.withName).toArray

val values = ctx.intervalValue().asScala.map { value =>
if (value.STRING() != null) {
string(value.STRING())
} else {
value.getText
}
}.toArray

val units = ctx.intervalUnit().asScala
val values = ctx.intervalValue().asScala
try {
IntervalUtils.fromUnitStrings(units, values)
assert(units.length == values.length)
val kvs = units.indices.map { i =>
val u = units(i).getText
val v = if (values(i).STRING() != null) {
string(values(i).STRING())
} else {
values(i).getText
}
v + " " + u
}
val str = kvs.mkString(" ")
IntervalUtils.fromString(str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we at here, the parsing is already done by antlr, so this may make things slower. But it's better to unify the code as the perf doesn't matter that much for interval literals.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the performance improvement is actually for type constructor to parse interval string literals, which can't be seen via code change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a particular modified IntervalBenchmark test, which mocks the type constructor logic, which is directly different with an IntervalUtils.fromString call only.

  private def addCase(benchmark: Benchmark, cardinality: Long, units: Seq[String]): Unit = {
    Seq(true, false).foreach { withPrefix =>
      val expr = buildString(withPrefix, units).cast("interval")
      val note = if (withPrefix) "w/ interval" else "w/o interval"
      benchmark.addCase(s"${units.length + 1} units $note", numIters = 3) { _ =>
//        doBenchmark(cardinality, expr)
        (0L until cardinality).foreach(_ => IntervalUtils.fromString(units.mkString(" ")))
      }
    }
  }

we can see huge perfomance improment here. Any way, this is just used to parse typed literals, not a big deal acturally.

info]   Running case: 1 units w/ interval
[info]   Stopped after 3 iterations, 98544 ms
[info]   Running case: 1 units w/o interval
[info]   Stopped after 3 iterations, 78871 ms
[info]   Running case: 2 units w/ interval
[info]   Stopped after 3 iterations, 72469 ms
[info]   Running case: 2 units w/o interval
[info]   Stopped after 3 iterations, 78753 ms
[info]   Running case: 1 units w/ interval
[info]   Stopped after 3 iterations, 8926 ms
[info]   Running case: 1 units w/o interval
[info]   Stopped after 3 iterations, 8881 ms
[info]   Running case: 2 units w/ interval
[info]   Stopped after 3 iterations, 8773 ms
[info]   Running case: 2 units w/o interval
[info]   Stopped after 3 iterations, 8815 ms

} catch {
case i: IllegalArgumentException =>
val e = new ParseException(i.getMessage, ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,12 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval

/**
* Base SQL parsing infrastructure.
*/
abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging {

/**
* Creates [[CalendarInterval]] for a given SQL String. Throws [[ParseException]] if the SQL
* string is not a valid interval format.
*/
def parseInterval(sqlText: String): CalendarInterval = parse(sqlText) { parser =>
astBuilder.visitSingleInterval(parser.singleInterval())
}

/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -108,25 +107,11 @@ object IntervalUtils {
*/
def fromString(str: String): CalendarInterval = {
if (str == null) throw new IllegalArgumentException("Interval string cannot be null")
try {
CatalystSqlParser.parseInterval(str)
} catch {
case e: ParseException =>
val ex = new IllegalArgumentException(s"Invalid interval string: $str\n" + e.message)
ex.setStackTrace(e.getStackTrace)
throw ex
}
}

/**
* A safe version of `fromString`. It returns null for invalid input string.
*/
def safeFromString(str: String): CalendarInterval = {
try {
fromString(str)
} catch {
case _: IllegalArgumentException => null
val interval = stringToInterval(UTF8String.fromString(str))
if (interval == null) {
throw new IllegalArgumentException(s"Invalid interval string: $str")
}
interval
}

private def toLongWithRange(
Expand Down Expand Up @@ -250,46 +235,6 @@ object IntervalUtils {
}
}

def fromUnitStrings(units: Array[IntervalUnit], values: Array[String]): CalendarInterval = {
assert(units.length == values.length)
var months: Int = 0
var days: Int = 0
var microseconds: Long = 0
var i = 0
while (i < units.length) {
try {
units(i) match {
case YEAR =>
months = Math.addExact(months, Math.multiplyExact(values(i).toInt, 12))
case MONTH =>
months = Math.addExact(months, values(i).toInt)
case WEEK =>
days = Math.addExact(days, Math.multiplyExact(values(i).toInt, 7))
case DAY =>
days = Math.addExact(days, values(i).toInt)
case HOUR =>
val hoursUs = Math.multiplyExact(values(i).toLong, MICROS_PER_HOUR)
microseconds = Math.addExact(microseconds, hoursUs)
case MINUTE =>
val minutesUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MINUTE)
microseconds = Math.addExact(microseconds, minutesUs)
case SECOND =>
microseconds = Math.addExact(microseconds, parseSecondNano(values(i)))
case MILLISECOND =>
val millisUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MILLIS)
microseconds = Math.addExact(microseconds, millisUs)
case MICROSECOND =>
microseconds = Math.addExact(microseconds, values(i).toLong)
}
} catch {
case e: Exception =>
throw new IllegalArgumentException(s"Error parsing interval string: ${e.getMessage}", e)
}
i += 1
}
new CalendarInterval(months, days, microseconds)
}

// Parses a string with nanoseconds, truncates the result and returns microseconds
private def parseNanos(nanosStr: String, isNegative: Boolean): Long = {
if (nanosStr != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class ExpressionParserSuite extends AnalysisTest {
MICROSECOND)

def intervalLiteral(u: IntervalUnit, s: String): Literal = {
Literal(IntervalUtils.fromUnitStrings(Array(u), Array(s)))
Literal(IntervalUtils.fromString(s + " " + u.toString))
}

test("intervals") {
Expand Down