Skip to content

Commit

Permalink
[SPARK-29870][SQL] Unify the logic of multi-units interval string to …
Browse files Browse the repository at this point in the history
…CalendarInterval

### What changes were proposed in this pull request?

We now have two different implementation for multi-units interval strings to CalendarInterval type values.

One is used to covert interval string literals to CalendarInterval. This approach will re-delegate the interval string to spark parser which handles the string as a `singleInterval` -> `multiUnitsInterval` -> eventually call `IntervalUtils.fromUnitStrings`

The other is used in `Cast`, which eventually calls `IntervalUtils.stringToInterval`. This approach is ~10 times faster than the other.

We should unify these two for better performance and simple logic. this pr uses the 2nd approach.

### Why are the changes needed?

We should unify these two for better performance and simple logic.

### Does this PR introduce any user-facing change?

no

### How was this patch tested?

we shall not fail on existing uts

Closes #26491 from yaooqinn/SPARK-29870.

Authored-by: Kent Yao <yaooqinn@hotmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
yaooqinn authored and cloud-fan committed Nov 18, 2019
1 parent 5cebe58 commit 50f6d93
Show file tree
Hide file tree
Showing 19 changed files with 169 additions and 222 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ singleTableSchema
: colTypeList EOF
;

singleInterval
: INTERVAL? multiUnitsInterval EOF
;

statement
: query #statementDefault
| ctes? dmlStatementNoWith #dmlStatement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => IntervalUtils.stringToInterval(s))
buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s))
}

// LongConverter
Expand Down Expand Up @@ -1234,7 +1234,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, evNull) =>
code"""$evPrim = $util.stringToInterval($c);
code"""$evPrim = $util.safeStringToInterval($c);
if(${evPrim} == null) {
${evNull} = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class TimeWindow(
timeColumn: Expression,
Expand Down Expand Up @@ -103,7 +104,7 @@ object TimeWindow {
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
val cal = IntervalUtils.fromString(interval)
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
}

override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
withOrigin(ctx)(visitMultiUnitsInterval(ctx.multiUnitsInterval))
}

/* ********************************************************************************************
* Plan parsing
* ******************************************************************************************** */
Expand Down Expand Up @@ -1870,7 +1866,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
toLiteral(stringToTimestamp(_, zoneId), TimestampType)
case "INTERVAL" =>
val interval = try {
IntervalUtils.fromString(value)
IntervalUtils.stringToInterval(UTF8String.fromString(value))
} catch {
case e: IllegalArgumentException =>
val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx)
Expand Down Expand Up @@ -2069,22 +2065,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
withOrigin(ctx) {
val units = ctx.intervalUnit().asScala.map { unit =>
val u = unit.getText.toLowerCase(Locale.ROOT)
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
if (u.endsWith("s")) u.substring(0, u.length - 1) else u
}.map(IntervalUtils.IntervalUnit.withName).toArray

val values = ctx.intervalValue().asScala.map { value =>
if (value.STRING() != null) {
string(value.STRING())
} else {
value.getText
}
}.toArray

val units = ctx.intervalUnit().asScala
val values = ctx.intervalValue().asScala
try {
IntervalUtils.fromUnitStrings(units, values)
assert(units.length == values.length)
val kvs = units.indices.map { i =>
val u = units(i).getText
val v = if (values(i).STRING() != null) {
string(values(i).STRING())
} else {
values(i).getText
}
UTF8String.fromString(" " + v + " " + u)
}
IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
} catch {
case i: IllegalArgumentException =>
val e = new ParseException(i.getMessage, ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,12 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval

/**
* Base SQL parsing infrastructure.
*/
abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging {

/**
* Creates [[CalendarInterval]] for a given SQL String. Throws [[ParseException]] if the SQL
* string is not a valid interval format.
*/
def parseInterval(sqlText: String): CalendarInterval = parse(sqlText) { parser =>
astBuilder.visitSingleInterval(parser.singleInterval())
}

/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -102,34 +101,6 @@ object IntervalUtils {
Decimal(result, 18, 6)
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
* @throws IllegalArgumentException if the input string is not in valid interval format.
*/
def fromString(str: String): CalendarInterval = {
if (str == null) throw new IllegalArgumentException("Interval string cannot be null")
try {
CatalystSqlParser.parseInterval(str)
} catch {
case e: ParseException =>
val ex = new IllegalArgumentException(s"Invalid interval string: $str\n" + e.message)
ex.setStackTrace(e.getStackTrace)
throw ex
}
}

/**
* A safe version of `fromString`. It returns null for invalid input string.
*/
def safeFromString(str: String): CalendarInterval = {
try {
fromString(str)
} catch {
case _: IllegalArgumentException => null
}
}

private def toLongWithRange(
fieldName: IntervalUnit,
s: String,
Expand Down Expand Up @@ -251,46 +222,6 @@ object IntervalUtils {
}
}

def fromUnitStrings(units: Array[IntervalUnit], values: Array[String]): CalendarInterval = {
assert(units.length == values.length)
var months: Int = 0
var days: Int = 0
var microseconds: Long = 0
var i = 0
while (i < units.length) {
try {
units(i) match {
case YEAR =>
months = Math.addExact(months, Math.multiplyExact(values(i).toInt, 12))
case MONTH =>
months = Math.addExact(months, values(i).toInt)
case WEEK =>
days = Math.addExact(days, Math.multiplyExact(values(i).toInt, 7))
case DAY =>
days = Math.addExact(days, values(i).toInt)
case HOUR =>
val hoursUs = Math.multiplyExact(values(i).toLong, MICROS_PER_HOUR)
microseconds = Math.addExact(microseconds, hoursUs)
case MINUTE =>
val minutesUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MINUTE)
microseconds = Math.addExact(microseconds, minutesUs)
case SECOND =>
microseconds = Math.addExact(microseconds, parseSecondNano(values(i)))
case MILLISECOND =>
val millisUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MILLIS)
microseconds = Math.addExact(microseconds, millisUs)
case MICROSECOND =>
microseconds = Math.addExact(microseconds, values(i).toLong)
}
} catch {
case e: Exception =>
throw new IllegalArgumentException(s"Error parsing interval string: ${e.getMessage}", e)
}
i += 1
}
new CalendarInterval(months, days, microseconds)
}

// Parses a string with nanoseconds, truncates the result and returns microseconds
private def parseNanos(nanosStr: String, isNegative: Boolean): Long = {
if (nanosStr != null) {
Expand All @@ -306,30 +237,6 @@ object IntervalUtils {
}
}

/**
* Parse second_nano string in ss.nnnnnnnnn format to microseconds
*/
private def parseSecondNano(secondNano: String): Long = {
def parseSeconds(secondsStr: String): Long = {
toLongWithRange(
SECOND,
secondsStr,
Long.MinValue / MICROS_PER_SECOND,
Long.MaxValue / MICROS_PER_SECOND) * MICROS_PER_SECOND
}

secondNano.split("\\.") match {
case Array(secondsStr) => parseSeconds(secondsStr)
case Array("", nanosStr) => parseNanos(nanosStr, false)
case Array(secondsStr, nanosStr) =>
val seconds = parseSeconds(secondsStr)
Math.addExact(seconds, parseNanos(nanosStr, seconds < 0))
case _ =>
throw new IllegalArgumentException(
"Interval string does not match second-nano format of ss.nnnnnnnnn")
}
}

/**
* Gets interval duration
*
Expand Down Expand Up @@ -558,18 +465,37 @@ object IntervalUtils {
private final val millisStr = unitToUtf8(MILLISECOND)
private final val microsStr = unitToUtf8(MICROSECOND)

/**
* A safe version of `stringToInterval`. It returns null for invalid input string.
*/
def safeStringToInterval(input: UTF8String): CalendarInterval = {
try {
stringToInterval(input)
} catch {
case _: IllegalArgumentException => null
}
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
* @throws IllegalArgumentException if the input string is not in valid interval format.
*/
def stringToInterval(input: UTF8String): CalendarInterval = {
import ParseState._
def throwIAE(msg: String, e: Exception = null) = {
throw new IllegalArgumentException(s"Error parsing '$input' to interval, $msg", e)
}

if (input == null) {
return null
throwIAE("interval string cannot be null")
}
// scalastyle:off caselocale .toLowerCase
val s = input.trim.toLowerCase
// scalastyle:on
val bytes = s.getBytes
if (bytes.isEmpty) {
return null
throwIAE("interval string cannot be empty")
}
var state = PREFIX
var i = 0
Expand All @@ -588,13 +514,19 @@ object IntervalUtils {
}
}

def currentWord: UTF8String = {
val strings = s.split(UTF8String.blankString(1), -1)
val lenRight = s.substring(i, s.numBytes()).split(UTF8String.blankString(1), -1).length
strings(strings.length - lenRight)
}

while (i < bytes.length) {
val b = bytes(i)
state match {
case PREFIX =>
if (s.startsWith(intervalStr)) {
if (s.numBytes() == intervalStr.numBytes()) {
return null
throwIAE("interval string cannot be empty")
} else {
i += intervalStr.numBytes()
}
Expand Down Expand Up @@ -627,7 +559,7 @@ object IntervalUtils {
fractionScale = (NANOS_PER_SECOND / 10).toInt
i += 1
state = VALUE_FRACTIONAL_PART
case _ => return null
case _ => throwIAE( s"unrecognized number '$currentWord'")
}
case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE)
case VALUE =>
Expand All @@ -636,13 +568,13 @@ object IntervalUtils {
try {
currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0'))
} catch {
case _: ArithmeticException => return null
case e: ArithmeticException => throwIAE(e.getMessage, e)
}
case ' ' => state = TRIM_BEFORE_UNIT
case '.' =>
fractionScale = (NANOS_PER_SECOND / 10).toInt
state = VALUE_FRACTIONAL_PART
case _ => return null
case _ => throwIAE(s"invalid value '$currentWord'")
}
i += 1
case VALUE_FRACTIONAL_PART =>
Expand All @@ -653,14 +585,17 @@ object IntervalUtils {
case ' ' =>
fraction /= NANOS_PER_MICROS.toInt
state = TRIM_BEFORE_UNIT
case _ => return null
case _ if '0' <= b && b <= '9' =>
throwIAE(s"interval can only support nanosecond precision, '$currentWord' is out" +
s" of range")
case _ => throwIAE(s"invalid value '$currentWord'")
}
i += 1
case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN)
case UNIT_BEGIN =>
// Checks that only seconds can have the fractional part
if (b != 's' && fractionScale >= 0) {
return null
throwIAE(s"'$currentWord' cannot have fractional part")
}
if (isNegative) {
currentValue = -currentValue
Expand Down Expand Up @@ -704,26 +639,26 @@ object IntervalUtils {
} else if (s.matchAt(microsStr, i)) {
microseconds = Math.addExact(microseconds, currentValue)
i += microsStr.numBytes()
} else return null
case _ => return null
} else throwIAE(s"invalid unit '$currentWord'")
case _ => throwIAE(s"invalid unit '$currentWord'")
}
} catch {
case _: ArithmeticException => return null
case e: ArithmeticException => throwIAE(e.getMessage, e)
}
state = UNIT_SUFFIX
case UNIT_SUFFIX =>
b match {
case 's' => state = UNIT_END
case ' ' => state = TRIM_BEFORE_SIGN
case _ => return null
case _ => throwIAE(s"invalid unit '$currentWord'")
}
i += 1
case UNIT_END =>
b match {
case ' ' =>
i += 1
state = TRIM_BEFORE_SIGN
case _ => return null
case _ => throwIAE(s"invalid unit '$currentWord'")
}
}
}
Expand Down
Loading

0 comments on commit 50f6d93

Please sign in to comment.