Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29870][SQL] Unify the logic of multi-units interval string to CalendarInterval #26491

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ singleTableSchema
: colTypeList EOF
;

singleInterval
: INTERVAL? multiUnitsInterval EOF
;

statement
: query #statementDefault
| ctes? dmlStatementNoWith #dmlStatement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => IntervalUtils.stringToInterval(s))
buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s))
}

// LongConverter
Expand Down Expand Up @@ -1216,7 +1216,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, evNull) =>
code"""$evPrim = $util.stringToInterval($c);
code"""$evPrim = $util.safeStringToInterval($c);
if(${evPrim} == null) {
${evNull} = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class TimeWindow(
timeColumn: Expression,
Expand Down Expand Up @@ -103,7 +104,7 @@ object TimeWindow {
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
val cal = IntervalUtils.fromString(interval)
val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
}

override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
withOrigin(ctx)(visitMultiUnitsInterval(ctx.multiUnitsInterval))
}

/* ********************************************************************************************
* Plan parsing
* ******************************************************************************************** */
Expand Down Expand Up @@ -1870,7 +1866,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
toLiteral(stringToTimestamp(_, zoneId), TimestampType)
case "INTERVAL" =>
val interval = try {
IntervalUtils.fromString(value)
IntervalUtils.stringToInterval(UTF8String.fromString(value))
} catch {
case e: IllegalArgumentException =>
val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx)
Expand Down Expand Up @@ -2069,22 +2065,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
withOrigin(ctx) {
val units = ctx.intervalUnit().asScala.map { unit =>
val u = unit.getText.toLowerCase(Locale.ROOT)
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
if (u.endsWith("s")) u.substring(0, u.length - 1) else u
}.map(IntervalUtils.IntervalUnit.withName).toArray

val values = ctx.intervalValue().asScala.map { value =>
if (value.STRING() != null) {
string(value.STRING())
} else {
value.getText
}
}.toArray

val units = ctx.intervalUnit().asScala
val values = ctx.intervalValue().asScala
try {
IntervalUtils.fromUnitStrings(units, values)
assert(units.length == values.length)
val kvs = units.indices.map { i =>
val u = units(i).getText
val v = if (values(i).STRING() != null) {
string(values(i).STRING())
} else {
values(i).getText
}
UTF8String.fromString(" " + v + " " + u)
}
IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
} catch {
case i: IllegalArgumentException =>
val e = new ParseException(i.getMessage, ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,12 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval

/**
* Base SQL parsing infrastructure.
*/
abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging {

/**
* Creates [[CalendarInterval]] for a given SQL String. Throws [[ParseException]] if the SQL
* string is not a valid interval format.
*/
def parseInterval(sqlText: String): CalendarInterval = parse(sqlText) { parser =>
astBuilder.visitSingleInterval(parser.singleInterval())
}

/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -101,34 +100,6 @@ object IntervalUtils {
Decimal(result, 18, 6)
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
* @throws IllegalArgumentException if the input string is not in valid interval format.
*/
def fromString(str: String): CalendarInterval = {
if (str == null) throw new IllegalArgumentException("Interval string cannot be null")
try {
CatalystSqlParser.parseInterval(str)
} catch {
case e: ParseException =>
val ex = new IllegalArgumentException(s"Invalid interval string: $str\n" + e.message)
ex.setStackTrace(e.getStackTrace)
throw ex
}
}

/**
* A safe version of `fromString`. It returns null for invalid input string.
*/
def safeFromString(str: String): CalendarInterval = {
try {
fromString(str)
} catch {
case _: IllegalArgumentException => null
}
}

private def toLongWithRange(
fieldName: IntervalUnit,
s: String,
Expand Down Expand Up @@ -250,46 +221,6 @@ object IntervalUtils {
}
}

def fromUnitStrings(units: Array[IntervalUnit], values: Array[String]): CalendarInterval = {
assert(units.length == values.length)
var months: Int = 0
var days: Int = 0
var microseconds: Long = 0
var i = 0
while (i < units.length) {
try {
units(i) match {
case YEAR =>
months = Math.addExact(months, Math.multiplyExact(values(i).toInt, 12))
case MONTH =>
months = Math.addExact(months, values(i).toInt)
case WEEK =>
days = Math.addExact(days, Math.multiplyExact(values(i).toInt, 7))
case DAY =>
days = Math.addExact(days, values(i).toInt)
case HOUR =>
val hoursUs = Math.multiplyExact(values(i).toLong, MICROS_PER_HOUR)
microseconds = Math.addExact(microseconds, hoursUs)
case MINUTE =>
val minutesUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MINUTE)
microseconds = Math.addExact(microseconds, minutesUs)
case SECOND =>
microseconds = Math.addExact(microseconds, parseSecondNano(values(i)))
case MILLISECOND =>
val millisUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MILLIS)
microseconds = Math.addExact(microseconds, millisUs)
case MICROSECOND =>
microseconds = Math.addExact(microseconds, values(i).toLong)
}
} catch {
case e: Exception =>
throw new IllegalArgumentException(s"Error parsing interval string: ${e.getMessage}", e)
}
i += 1
}
new CalendarInterval(months, days, microseconds)
}

// Parses a string with nanoseconds, truncates the result and returns microseconds
private def parseNanos(nanosStr: String, isNegative: Boolean): Long = {
if (nanosStr != null) {
Expand All @@ -305,30 +236,6 @@ object IntervalUtils {
}
}

/**
* Parse second_nano string in ss.nnnnnnnnn format to microseconds
*/
private def parseSecondNano(secondNano: String): Long = {
def parseSeconds(secondsStr: String): Long = {
toLongWithRange(
SECOND,
secondsStr,
Long.MinValue / MICROS_PER_SECOND,
Long.MaxValue / MICROS_PER_SECOND) * MICROS_PER_SECOND
}

secondNano.split("\\.") match {
case Array(secondsStr) => parseSeconds(secondsStr)
case Array("", nanosStr) => parseNanos(nanosStr, false)
case Array(secondsStr, nanosStr) =>
val seconds = parseSeconds(secondsStr)
Math.addExact(seconds, parseNanos(nanosStr, seconds < 0))
case _ =>
throw new IllegalArgumentException(
"Interval string does not match second-nano format of ss.nnnnnnnnn")
}
}

/**
* Gets interval duration
*
Expand Down Expand Up @@ -452,20 +359,40 @@ object IntervalUtils {
private final val millisStr = unitToUtf8(MILLISECOND)
private final val microsStr = unitToUtf8(MICROSECOND)

/**
* A safe version of `stringToInterval`. It returns null for invalid input string.
*/
def safeStringToInterval(input: UTF8String): CalendarInterval = {
try {
stringToInterval(input)
} catch {
case _: IllegalArgumentException => null
}
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
* @throws IllegalArgumentException if the input string is not in valid interval format.
*/
def stringToInterval(input: UTF8String): CalendarInterval = {
import ParseState._
var state = PREFIX
def throwIAE(msg: String, e: Exception = null) = {
throw new IllegalArgumentException(s"Error parsing '$input' to interval, $msg", e)
}

if (input == null) {
return null
throwIAE("interval string cannot be null")
}
// scalastyle:off caselocale .toLowerCase
val s = input.trim.toLowerCase
// scalastyle:on
val bytes = s.getBytes
if (bytes.isEmpty) {
return null
throwIAE("interval string cannot be empty")
}
var state = PREFIX

var i = 0
var currentValue: Long = 0
var isNegative: Boolean = false
Expand All @@ -482,13 +409,19 @@ object IntervalUtils {
}
}

def currentWord: String = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan this method should be able to extract the error word

val strings = s.toString.split("\\s+")
val lenLeft = s.substring(i, s.numBytes()).toString.split("\\s+").length
strings(strings.length - lenLeft)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC lenLeft should be lenRight?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I guess left is a bit ambiguous here, right is better

}

while (i < bytes.length) {
val b = bytes(i)
state match {
case PREFIX =>
if (s.startsWith(intervalStr)) {
if (s.numBytes() == intervalStr.numBytes()) {
return null
throwIAE("interval string cannot be empty")
} else {
i += intervalStr.numBytes()
}
Expand Down Expand Up @@ -521,7 +454,7 @@ object IntervalUtils {
fractionScale = (NANOS_PER_SECOND / 10).toInt
i += 1
state = VALUE_FRACTIONAL_PART
case _ => return null
case _ => throwIAE( s"unrecognized number '$currentWord'")
}
case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE)
case VALUE =>
Expand All @@ -530,13 +463,13 @@ object IntervalUtils {
try {
currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0'))
} catch {
case _: ArithmeticException => return null
case e: ArithmeticException => throwIAE(e.getMessage, e)
}
case ' ' => state = TRIM_BEFORE_UNIT
case '.' =>
fractionScale = (NANOS_PER_SECOND / 10).toInt
state = VALUE_FRACTIONAL_PART
case _ => return null
case _ => throwIAE(s"invalid value '$currentWord'")
}
i += 1
case VALUE_FRACTIONAL_PART =>
Expand All @@ -547,14 +480,17 @@ object IntervalUtils {
case ' ' =>
fraction /= NANOS_PER_MICROS.toInt
state = TRIM_BEFORE_UNIT
case _ => return null
case _ if '0' <= b && b <= '9' =>
throwIAE(s"interval can only support nanosecond precision, '$currentWord' is out" +
s" of range")
case _ => throwIAE(s"invalid value '$currentWord' in fractional part")
}
i += 1
case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN)
case UNIT_BEGIN =>
// Checks that only seconds can have the fractional part
if (b != 's' && fractionScale >= 0) {
return null
throwIAE(s"'$currentWord' cannot have fractional part")
}
if (isNegative) {
currentValue = -currentValue
Expand Down Expand Up @@ -598,26 +534,26 @@ object IntervalUtils {
} else if (s.matchAt(microsStr, i)) {
microseconds = Math.addExact(microseconds, currentValue)
i += microsStr.numBytes()
} else return null
case _ => return null
} else throwIAE(s"invalid unit '$currentWord'")
case _ => throwIAE(s"invalid unit '$currentWord'")
}
} catch {
case _: ArithmeticException => return null
case e: ArithmeticException => throwIAE(e.getMessage, e)
}
state = UNIT_SUFFIX
case UNIT_SUFFIX =>
b match {
case 's' => state = UNIT_END
case ' ' => state = TRIM_BEFORE_SIGN
case _ => return null
case _ => throwIAE(s"invalid unit '$currentWord'")
}
i += 1
case UNIT_END =>
b match {
case ' ' =>
i += 1
state = TRIM_BEFORE_SIGN
case _ => return null
case _ => throwIAE(s"invalid unit '$currentWord'")
}
}
}
Expand Down
Loading