From d8c04cf2fb7599c993948df10f4746b70f8c52b9 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 5 Oct 2024 07:38:42 +0900 Subject: [PATCH] [SPARK-49836][SQL][SS] Fix possibly broken query when window is provided to window/session_window fn ### What changes were proposed in this pull request? This PR fixes the correctness issue about losing operators during analysis - it happens when window is provided to window()/session_window() function. The rule `TimeWindowing` and `SessionWindowing` are responsible to resolve the time window functions. When the window function has `window` as parameter (time column) (in other words, building time window from time window), the rule wraps window with WindowTime function so that the rule ResolveWindowTime will further resolve this. (And TimeWindowing/SessionWindowing will resolve this again against the result of ResolveWindowTime.) The issue is that the rule uses "return" for the above, which intends to have "early return" as the other branch is too long compared to this branch. This unfortunately does not work as intended - the intention is just to go out of current local scope (mostly end of curly brace), but it seems to break the loop of execution in "outer" side. (I haven't debugged further but it's simply clear that it doesn't work as intended.) Quoting from Scala doc: > Nonlocal returns are implemented by throwing and catching scala.runtime.NonLocalReturnException-s. It's not super clear where NonLocalReturnException is caught in the call stack; it might exit the execution for much broader scope (context) than expected. And it's finally deprecated in Scala 3.2 and likely be removed in future. https://dotty.epfl.ch/docs/reference/dropped-features/nonlocal-returns.html Interestingly it does not break every query for chained time window aggregations. Spark already has several tests with DataFrame API and they haven't failed. The reproducer in community report is using SQL statement - where each aggregation is considered as subquery. This PR fixes the rule to NOT use early return and instead have a huge if else. ### Why are the changes needed? Described in above. ### Does this PR introduce _any_ user-facing change? Yes, this fixes the possible query breakage. The impacted workloads may not be very huge as chained time window aggregations is an advanced usage, and it does not break every query for the usage. ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48309 from HeartSaVioR/SPARK-49836. Lead-authored-by: Jungtaek Lim Co-authored-by: Andrzej Zera Signed-off-by: Jungtaek Lim --- .../analysis/ResolveTimeWindows.scala | 255 +++++++++--------- .../sql/DataFrameSessionWindowingSuite.scala | 51 ++++ .../sql/DataFrameTimeWindowingSuite.scala | 53 ++++ 3 files changed, 232 insertions(+), 127 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala index e506a3629db17..a8680d0a01816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala @@ -87,85 +87,86 @@ object TimeWindowing extends Rule[LogicalPlan] { val window = windowExpressions.head + // time window is provided as time column of window function, replace it with WindowTime if (StructType.acceptsType(window.timeColumn.dataType)) { - return p.transformExpressions { + p.transformExpressions { case t: TimeWindow => t.copy(timeColumn = WindowTime(window.timeColumn)) } - } - - val metadata = window.timeColumn match { - case a: Attribute => a.metadata - case _ => Metadata.empty - } - - val newMetadata = new MetadataBuilder() - .withMetadata(metadata) - .putBoolean(TimeWindow.marker, true) - .build() + } else { + val metadata = window.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } - def getWindow(i: Int, dataType: DataType): Expression = { - val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) - val remainder = (timestamp - window.startTime) % window.slideDuration - val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0), - remainder + window.slideDuration)), Some(remainder)) - val windowStart = lastStart - i * window.slideDuration - val windowEnd = windowStart + window.windowDuration + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(TimeWindow.marker, true) + .build() - // We make sure value fields are nullable since the dataType of TimeWindow defines them - // as nullable. - CreateNamedStruct( - Literal(WINDOW_START) :: - PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: - Literal(WINDOW_END) :: - PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: - Nil) - } + def getWindow(i: Int, dataType: DataType): Expression = { + val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) + val remainder = (timestamp - window.startTime) % window.slideDuration + val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0), + remainder + window.slideDuration)), Some(remainder)) + val windowStart = lastStart - i * window.slideDuration + val windowEnd = windowStart + window.windowDuration + + // We make sure value fields are nullable since the dataType of TimeWindow defines them + // as nullable. + CreateNamedStruct( + Literal(WINDOW_START) :: + PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: + Literal(WINDOW_END) :: + PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: + Nil) + } - val windowAttr = AttributeReference( - WINDOW_COL_NAME, window.dataType, metadata = newMetadata)() + val windowAttr = AttributeReference( + WINDOW_COL_NAME, window.dataType, metadata = newMetadata)() - if (window.windowDuration == window.slideDuration) { - val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( - exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata)) + if (window.windowDuration == window.slideDuration) { + val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( + exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata)) - val replacedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } + val replacedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } - // For backwards compatibility we add a filter to filter out nulls - val filterExpr = IsNotNull(window.timeColumn) + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(window.timeColumn) - replacedPlan.withNewChildren( - Project(windowStruct +: child.output, - Filter(filterExpr, child)) :: Nil) - } else { - val overlappingWindows = - math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt - val windows = - Seq.tabulate(overlappingWindows)(i => - getWindow(i, window.timeColumn.dataType)) - - val projections = windows.map(_ +: child.output) - - // When the condition windowDuration % slideDuration = 0 is fulfilled, - // the estimation of the number of windows becomes exact one, - // which means all produced windows are valid. - val filterExpr = - if (window.windowDuration % window.slideDuration == 0) { - IsNotNull(window.timeColumn) + replacedPlan.withNewChildren( + Project(windowStruct +: child.output, + Filter(filterExpr, child)) :: Nil) } else { - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + val overlappingWindows = + math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = + Seq.tabulate(overlappingWindows)(i => + getWindow(i, window.timeColumn.dataType)) + + val projections = windows.map(_ +: child.output) + + // When the condition windowDuration % slideDuration = 0 is fulfilled, + // the estimation of the number of windows becomes exact one, + // which means all produced windows are valid. + val filterExpr = + if (window.windowDuration % window.slideDuration == 0) { + IsNotNull(window.timeColumn) + } else { + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) + } + + val substitutedPlan = Filter(filterExpr, + Expand(projections, windowAttr +: child.output, child)) + + val renamedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + renamedPlan.withNewChildren(substitutedPlan :: Nil) } - - val substitutedPlan = Filter(filterExpr, - Expand(projections, windowAttr +: child.output, child)) - - val renamedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } - - renamedPlan.withNewChildren(substitutedPlan :: Nil) } } else if (numWindowExpr > 1) { throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) @@ -210,74 +211,74 @@ object SessionWindowing extends Rule[LogicalPlan] { val session = sessionExpressions.head if (StructType.acceptsType(session.timeColumn.dataType)) { - return p transformExpressions { + p transformExpressions { case t: SessionWindow => t.copy(timeColumn = WindowTime(session.timeColumn)) } - } + } else { + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } - val metadata = session.timeColumn match { - case a: Attribute => a.metadata - case _ => Metadata.empty - } + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() - val newMetadata = new MetadataBuilder() - .withMetadata(metadata) - .putBoolean(SessionWindow.marker, true) - .build() - - val sessionAttr = AttributeReference( - SESSION_COL_NAME, session.dataType, metadata = newMetadata)() - - val sessionStart = - PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) - val gapDuration = session.gapDuration match { - case expr if expr.dataType == CalendarIntervalType => - expr - case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => - Cast(expr, CalendarIntervalType) - case other => - throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) - } - val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, - session.timeColumn.dataType, LongType) - - // We make sure value fields are nullable since the dataType of SessionWindow defines them - // as nullable. - val literalSessionStruct = CreateNamedStruct( - Literal(SESSION_START) :: - PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) - .castNullable() :: - Literal(SESSION_END) :: - PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) - .castNullable() :: - Nil) - - val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( - exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() + + val sessionStart = + PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) + val gapDuration = session.gapDuration match { + case expr if expr.dataType == CalendarIntervalType => + expr + case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => + Cast(expr, CalendarIntervalType) + case other => + throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) + } + val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, + session.timeColumn.dataType, LongType) - val replacedPlan = p transformExpressions { - case s: SessionWindow => sessionAttr - } + // We make sure value fields are nullable since the dataType of SessionWindow defines them + // as nullable. + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) + .castNullable() :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) + .castNullable() :: + Nil) - val filterByTimeRange = if (gapDuration.foldable) { - val interval = gapDuration.eval().asInstanceOf[CalendarInterval] - interval == null || interval.months + interval.days + interval.microseconds <= 0 - } else { - true - } + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) - // As same as tumbling window, we add a filter to filter out nulls. - // And we also filter out events with negative or zero or invalid gap duration. - val filterExpr = if (filterByTimeRange) { - IsNotNull(session.timeColumn) && - (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) - } else { - IsNotNull(session.timeColumn) - } + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } - replacedPlan.withNewChildren( - Filter(filterExpr, - Project(sessionStruct +: child.output, child)) :: Nil) + val filterByTimeRange = if (gapDuration.foldable) { + val interval = gapDuration.eval().asInstanceOf[CalendarInterval] + interval == null || interval.months + interval.days + interval.microseconds <= 0 + } else { + true + } + + // As same as tumbling window, we add a filter to filter out nulls. + // And we also filter out events with negative or zero or invalid gap duration. + val filterExpr = if (filterByTimeRange) { + IsNotNull(session.timeColumn) && + (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) + } else { + IsNotNull(session.timeColumn) + } + + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(sessionStruct +: child.output, child)) :: Nil) + } } else if (numWindowExpr > 1) { throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala index 1ac1dda374fa7..6c1ca94a03079 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -547,4 +547,55 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession } } } + + test("SPARK-49836 using window fn with window as parameter should preserve parent operator") { + withTempView("clicks") { + val df = Seq( + // small window: [00:00, 01:00), user1, 2 + ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"), + // small window: [01:00, 02:00), user2, 2 + ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"), + // small window: [03:00, 04:00), user1, 1 + ("2024-09-30 00:03:30", "user1"), + // small window: [11:00, 12:00), user1, 3 + ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"), + ("2024-09-30 00:11:45", "user1") + ).toDF("eventTime", "userId") + + // session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 / + // (12:00, 12:05), user1, 3 + + df.createOrReplaceTempView("clicks") + + val aggregatedData = spark.sql( + """ + |SELECT + | userId, + | avg(cpu_large.numClicks) AS clicksPerSession + |FROM + |( + | SELECT + | session_window(small_window, '5 minutes') AS session, + | userId, + | sum(numClicks) AS numClicks + | FROM + | ( + | SELECT + | window(eventTime, '1 minute') AS small_window, + | userId, + | count(*) AS numClicks + | FROM clicks + | GROUP BY window, userId + | ) cpu_small + | GROUP BY session_window, userId + |) cpu_large + |GROUP BY userId + |""".stripMargin) + + checkAnswer( + aggregatedData, + Seq(Row("user1", 3), Row("user2", 2)) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 6ee173bc6af67..c52d428cd5dd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.sql.Timestamp import java.time.LocalDateTime import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -714,4 +715,56 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { ) } } + + test("SPARK-49836 using window fn with window as parameter should preserve parent operator") { + withTempView("clicks") { + val df = Seq( + // small window: [00:00, 01:00), user1, 2 + ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"), + // small window: [01:00, 02:00), user2, 2 + ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"), + // small window: [07:00, 08:00), user1, 1 + ("2024-09-30 00:07:00", "user1"), + // small window: [11:00, 12:00), user1, 3 + ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"), + ("2024-09-30 00:11:45", "user1") + ).toDF("eventTime", "userId") + + // large window: [00:00, 10:00), user1, 3, [00:00, 10:00), user2, 2, [10:00, 20:00), user1, 3 + + df.createOrReplaceTempView("clicks") + + val aggregatedData = spark.sql( + """ + |SELECT + | cpu_large.large_window.end AS timestamp, + | avg(cpu_large.numClicks) AS avgClicksPerUser + |FROM + |( + | SELECT + | window(small_window, '10 minutes') AS large_window, + | userId, + | sum(numClicks) AS numClicks + | FROM + | ( + | SELECT + | window(eventTime, '1 minute') AS small_window, + | userId, + | count(*) AS numClicks + | FROM clicks + | GROUP BY window, userId + | ) cpu_small + | GROUP BY window, userId + |) cpu_large + |GROUP BY timestamp + |""".stripMargin) + + checkAnswer( + aggregatedData, + Seq( + Row(Timestamp.valueOf("2024-09-30 00:10:00"), 2.5), + Row(Timestamp.valueOf("2024-09-30 00:20:00"), 3)) + ) + } + } }