From 74512f115b2b5e28dc87a4f45914f3244d36bce5 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 22 Oct 2024 17:54:54 +0800 Subject: [PATCH] Support Eventstats in PPL Signed-off-by: Lantao Jin --- .../ppl/FlintSparkPPLEventstatsITSuite.scala | 379 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 2 +- .../sql/ast/AbstractNodeVisitor.java | 4 + .../org/opensearch/sql/ast/tree/Window.java | 45 +++ .../sql/ppl/CatalystQueryPlanVisitor.java | 27 ++ .../opensearch/sql/ppl/parser/AstBuilder.java | 26 +- .../sql/ppl/utils/WindowSpecTransformer.java | 19 + ...calPlanEventstatsTranslatorTestSuite.scala | 256 ++++++++++++ 9 files changed, 750 insertions(+), 9 deletions(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala new file mode 100644 index 000000000..f1d287429 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEventstatsITSuite.scala @@ -0,0 +1,379 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Divide, Floor, Literal, Multiply, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Window} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLEventstatsITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test eventstats avg") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 36.25), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 36.25), + Row("Jake", 70, "California", "USA", 2023, 4, 36.25), + Row("Hello", 30, "New York", "USA", 2023, 4, 36.25)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), Nil, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 36.25, 70, 20, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 36.25, 70, 20, 4), + Row("Jake", 70, "California", "USA", 2023, 4, 36.25, 70, 20, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, 36.25, 70, 20, 4)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + Nil, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg by country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5), + Row("Jake", 70, "California", "USA", 2023, 4, 50), + Row("Hello", 30, "New York", "USA", 2023, 4, 50)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), partitionSpec, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 50, 70, 30, 2), + Row("Hello", 30, "New York", "USA", 2023, 4, 50, 70, 30, 2)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val partitionSpec = Seq(span) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span and country") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span, country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5, 25, 20, 2), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + } + + test("test eventstats avg, max, min, count by span and state") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span, state + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 25, 1), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 20, 1), + Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, 70, 1), + Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, 30, 1)) + assertSameRows(expected, frame) + } + + test("test eventstats stddev by span with filter") { + val frame = sql(s""" + | source = $testTable | where country != 'USA' | eventstats stddev_samp(age) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 3.5355339059327378), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 3.5355339059327378)) + assertSameRows(expected, frame) + } + + test("test eventstats stddev_pop by span with filter") { + val frame = sql(s""" + | source = $testTable | where state != 'California' | eventstats stddev_pop(age) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 2.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 2.5), + Row("Hello", 30, "New York", "USA", 2023, 4, 0.0)) + assertSameRows(expected, frame) + } + + test("test eventstats percentile by span with filter") { + val frame = sql(s""" + | source = $testTable | where state != 'California' | eventstats percentile_approx(age, 60) by span(age, 10) as age_span + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 25), + Row("Hello", 30, "New York", "USA", 2023, 4, 30)) + assertSameRows(expected, frame) + } + + test("test multiple eventstats") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age by state, country | eventstats avg(avg_age) as avg_state_age by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 22.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20.0, 22.5), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 50.0), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 50.0)) + assertSameRows(expected, frame) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val partitionSpec = Seq( + Alias(UnresolvedAttribute("state"), "state")(), + Alias(UnresolvedAttribute("country"), "country")()) + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val avgAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgAgeWindowExprAlias = Alias(avgAgeWindowExpression, "avg_age")() + val windowPlan1 = Window(Seq(avgAgeWindowExprAlias), partitionSpec, Nil, table) + + val countryPartitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgStateAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("avg_age")), isDistinct = false), + WindowSpecDefinition( + countryPartitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgStateAgeWindowExprAlias = Alias(avgStateAgeWindowExpression, "avg_state_age")() + val windowPlan2 = + Window(Seq(avgStateAgeWindowExprAlias), countryPartitionSpec, Nil, windowPlan1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan2) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple eventstats with eval") { + val frame = sql(s""" + | source = $testTable | eventstats avg(age) as avg_age by state, country | eval new_avg_age = avg_age - 10 | eventstats avg(new_avg_age) as avg_state_age by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 15.0, 12.5), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20.0, 10.0, 12.5), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 60.0, 40.0), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 20.0, 40.0)) + assertSameRows(expected, frame) + } + + test("test multiple eventstats with eval and filter") { + val frame = sql(s""" + | source = $testTable| eventstats avg(age) as avg_age by country, state, name | eval avg_age_divide_20 = avg_age - 20 | eventstats avg(avg_age_divide_20) + | as avg_state_age by country, state | where avg_state_age > 0 | eventstats count(avg_state_age) as count_country_age_greater_20 by country + | """.stripMargin) + val expected = Seq( + Row("John", 25, "Ontario", "Canada", 2023, 4, 25.0, 5.0, 5.0, 1), + Row("Jake", 70, "California", "USA", 2023, 4, 70.0, 50.0, 50.0, 2), + Row("Hello", 30, "New York", "USA", 2023, 4, 30.0, 10.0, 10.0, 2)) + assertSameRows(expected, frame) + } + + test("test eventstats distinct_count by span with filter") { + val exception = intercept[AnalysisException](sql(s""" + | source = $testTable | where state != 'California' | eventstats distinct_count(age) by span(age, 10) as age_span + | """.stripMargin)) + assert(exception.message.contains("Distinct window functions are not supported")) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 2b41530f0..d96ba5e18 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -18,6 +18,7 @@ WHERE: 'WHERE'; FIELDS: 'FIELDS'; RENAME: 'RENAME'; STATS: 'STATS'; +EVENTSTATS: 'EVENTSTATS'; DEDUP: 'DEDUP'; SORT: 'SORT'; EVAL: 'EVAL'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index e0672690d..6fd0c0c3a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -115,7 +115,7 @@ renameCommand ; statsCommand - : STATS (PARTITIONS EQUAL partitions = integerLiteral)? (ALLNUM EQUAL allnum = booleanLiteral)? (DELIM EQUAL delim = stringLiteral)? statsAggTerm (COMMA statsAggTerm)* (statsByClause)? (DEDUP_SPLITVALUES EQUAL dedupsplit = booleanLiteral)? + : (STATS | EVENTSTATS) (PARTITIONS EQUAL partitions = integerLiteral)? (ALLNUM EQUAL allnum = booleanLiteral)? (DELIM EQUAL delim = stringLiteral)? statsAggTerm (COMMA statsAggTerm)* (statsByClause)? (DEDUP_SPLITVALUES EQUAL dedupsplit = booleanLiteral)? ; dedupCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index c361ded08..4a27130ea 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -307,4 +307,8 @@ public T visitScalarSubquery(ScalarSubquery node, C context) { public T visitExistsSubquery(ExistsSubquery node, C context) { return visitChildren(node, context); } + + public T visitWindow(Window node, C context) { + return visitChildren(node, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java new file mode 100644 index 000000000..26cd08831 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Window.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class Window extends UnresolvedPlan { + private final List windowFunctionList; + private final List partExprList; + private final List sortExprList; + @Setter private UnresolvedExpression span; + private UnresolvedPlan child; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitWindow(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 902fc72e3..452f0ebb9 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -81,12 +81,14 @@ import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.ParseStrategy; import org.opensearch.sql.ppl.utils.SortUtils; +import org.opensearch.sql.ppl.utils.WindowSpecTransformer; import scala.Option; import scala.Tuple2; import scala.collection.IterableLike; @@ -115,6 +117,7 @@ import static org.opensearch.sql.ppl.utils.RelationUtils.getTableIdentifier; import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; +import static scala.collection.JavaConverters.seqAsJavaList; /** * Utility class to traverse PPL logical plan and translate it into catalyst logical plan @@ -326,6 +329,30 @@ private static LogicalPlan extractedAggregation(CatalystPlanContext context) { return context.apply(p -> new Aggregate(groupingExpression, aggregateExpressions, p)); } + @Override + public LogicalPlan visitWindow(Window node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + List windowFunctionExpList = visitExpressionList(node.getWindowFunctionList(), context); + Seq windowFunctionExpressions = context.retainAllNamedParseExpressions(p -> p); + List partitionExpList = visitExpressionList(node.getPartExprList(), context); + UnresolvedExpression span = node.getSpan(); + if (!Objects.isNull(span)) { + visitExpression(span, context); + } + Seq partitionSpec = context.retainAllNamedParseExpressions(p -> p); + Seq orderSpec = seq(new ArrayList()); + Seq aggregatorFunctions = seq( + seqAsJavaList(windowFunctionExpressions).stream() + .map(w -> WindowSpecTransformer.buildAggregateWindowFunction(w, partitionSpec, orderSpec)) + .collect(Collectors.toList())); + return context.apply(p -> + new org.apache.spark.sql.catalyst.plans.logical.Window( + aggregatorFunctions, + partitionSpec, + orderSpec, + p)); + } + @Override public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { expressionAnalyzer.visitAlias(node, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 1c0fe919f..ae9dd5a60 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -269,14 +269,24 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext .map(this::internalVisitExpression) .orElse(null); - Aggregation aggregation = - new Aggregation( - aggListBuilder.build(), - emptyList(), - groupList, - span, - ArgumentFactory.getArgumentList(ctx)); - return aggregation; + if (ctx.STATS() != null) { + Aggregation aggregation = + new Aggregation( + aggListBuilder.build(), + emptyList(), + groupList, + span, + ArgumentFactory.getArgumentList(ctx)); + return aggregation; + } else { + Window window = + new Window( + aggListBuilder.build(), + groupList, + emptyList()); + window.setSpan(span); + return window; + } } /** Dedup command. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java index 0e6ba2a1d..e6dd12032 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.expressions.Alias; import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Divide; import org.apache.spark.sql.catalyst.expressions.Expression; @@ -16,6 +17,7 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; import org.apache.spark.sql.catalyst.expressions.TimeWindow; +import org.apache.spark.sql.catalyst.expressions.UnboundedFollowing$; import org.apache.spark.sql.catalyst.expressions.UnboundedPreceding$; import org.apache.spark.sql.catalyst.expressions.WindowExpression; import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; @@ -79,4 +81,21 @@ static NamedExpression buildRowNumber(Seq partitionSpec, Seq())); } + + static NamedExpression buildAggregateWindowFunction(Expression aggregator, Seq partitionSpec, Seq orderSpec) { + Alias aggregatorAlias = (Alias) aggregator; + WindowExpression aggWindowExpression = new WindowExpression( + aggregatorAlias.child(), + new WindowSpecDefinition( + partitionSpec, + orderSpec, + new SpecifiedWindowFrame(RowFrame$.MODULE$, UnboundedPreceding$.MODULE$, UnboundedFollowing$.MODULE$))); + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + aggWindowExpression, + aggregatorAlias.name(), + NamedExpression.newExprId(), + seq(new ArrayList()), + Option.empty(), + seq(new ArrayList())); + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala new file mode 100644 index 000000000..53bb65950 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEventstatsTranslatorTestSuite.scala @@ -0,0 +1,256 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Divide, Floor, Literal, Multiply, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Window} + +class PPLLogicalPlanEventstatsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test eventstats avg") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source = table | eventstats avg(age)"), context) + + val table = UnresolvedRelation(Seq("table")) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), Nil, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test eventstats avg, max, min, count") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count"), + context) + + val table = UnresolvedRelation(Seq("table")) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + Nil, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + Nil, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg by country") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = table | eventstats avg(age) by country"), + context) + + val table = UnresolvedRelation(Seq("table")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg(age)")() + val windowPlan = Window(Seq(avgWindowExprAlias), partitionSpec, Nil, table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by country") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by country"), + context) + + val table = UnresolvedRelation(Seq("table")) + val partitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(avgWindowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test eventstats avg, max, min, count by span") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age, max(age) as max_age, min(age) as min_age, count(age) as count by span(age, 10) as age_span"), + context) + + val table = UnresolvedRelation(Seq("table")) + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val partitionSpec = Seq(span) + val windowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgWindowExprAlias = Alias(windowExpression, "avg_age")() + + val maxWindowExpression = WindowExpression( + UnresolvedFunction("MAX", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val maxWindowExprAlias = Alias(maxWindowExpression, "max_age")() + + val minWindowExpression = WindowExpression( + UnresolvedFunction("MIN", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val minWindowExprAlias = Alias(minWindowExpression, "min_age")() + + val countWindowExpression = WindowExpression( + UnresolvedFunction("COUNT", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val countWindowExprAlias = Alias(countWindowExpression, "count")() + val windowPlan = Window( + Seq(avgWindowExprAlias, maxWindowExprAlias, minWindowExprAlias, countWindowExprAlias), + partitionSpec, + Nil, + table) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test multiple eventstats") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eventstats avg(age) as avg_age by state, country | eventstats avg(avg_age) as avg_state_age by country"), + context) + + val partitionSpec = Seq( + Alias(UnresolvedAttribute("state"), "state")(), + Alias(UnresolvedAttribute("country"), "country")()) + val table = UnresolvedRelation(Seq("table")) + val avgAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("age")), isDistinct = false), + WindowSpecDefinition( + partitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgAgeWindowExprAlias = Alias(avgAgeWindowExpression, "avg_age")() + val windowPlan1 = Window(Seq(avgAgeWindowExprAlias), partitionSpec, Nil, table) + + val countryPartitionSpec = Seq(Alias(UnresolvedAttribute("country"), "country")()) + val avgStateAgeWindowExpression = WindowExpression( + UnresolvedFunction("AVG", Seq(UnresolvedAttribute("avg_age")), isDistinct = false), + WindowSpecDefinition( + countryPartitionSpec, + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing))) + val avgStateAgeWindowExprAlias = Alias(avgStateAgeWindowExpression, "avg_state_age")() + val windowPlan2 = + Window(Seq(avgStateAgeWindowExprAlias), countryPartitionSpec, Nil, windowPlan1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), windowPlan2) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}