From 714bfcf8ce963c6bb489de6fada59f2aa27159a0 Mon Sep 17 00:00:00 2001 From: Kacper Trochimiak Date: Fri, 27 Sep 2024 07:16:56 +0200 Subject: [PATCH] WIP trendline command Signed-off-by: Kacper Trochimiak --- .../ppl/FlintSparkPPLTrendlineITSuite.scala | 62 ++++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 5 ++ .../src/main/antlr4/OpenSearchPPLParser.g4 | 14 ++++ .../sql/ast/AbstractNodeVisitor.java | 8 +++ .../opensearch/sql/ast/tree/Trendline.java | 65 +++++++++++++++++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 53 ++++++++++++++- .../opensearch/sql/ppl/parser/AstBuilder.java | 9 +++ .../sql/ppl/parser/AstExpressionBuilder.java | 10 +++ 8 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala new file mode 100644 index 000000000..ace5d75f8 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{QueryTest, Row} + +class FlintSparkPPLTrendlineITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedStateCountryTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("trendline sma") { + val frame = sql(s""" + | source = $testTable | trendline sma(2, age) as first_age_sma sma(3, age) as second_age_sma | fields name, first_age_sma, second_age_sma + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array() + + // Convert actual results to a Set for quick lookup + val resultsSet: Set[Row] = results.toSet + // Check that each expected row is present in the actual results + expectedResults.foreach { expectedRow => + assert(resultsSet.contains(expectedRow), s"Expected row $expectedRow not found in results") + } + // Retrieve the logical plan + val logicalPlan: LogicalPlan = + frame.queryExecution.commandExecuted.asInstanceOf[CommandResult].commandLogicalPlan + // Define the expected logical plan + val expectedPlan: LogicalPlan = Project(Seq(), new UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index dd43007f4..991539623 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -36,6 +36,7 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +TRENDLINE: 'TRENDLINE'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -86,6 +87,10 @@ STR: 'STR'; IP: 'IP'; NUM: 'NUM'; +//TRENDLINE KEYWORDS +SMA: 'SMA'; +WMA: 'WMA'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index fb1c79bd2..b82e00e4b 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -52,6 +52,7 @@ commands | lookupCommand | renameCommand | fillnullCommand + | trendlineCommand ; searchCommand @@ -208,6 +209,19 @@ fillnullCommand ; +trendlineCommand + : TRENDLINE trendlineClause (trendlineClause)* + ; + +trendlineClause + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS AS alias = fieldExpression + ; + +trendlineType + : SMA + | WMA + ; + kmeansCommand : KMEANS (kmeansParameter)* ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index e42306965..07d5dc9da 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -106,6 +106,14 @@ public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } + public T visitTrendline(Trendline node, C context) { + return visitChildren(node, context); + } + + public T visitTrendlineComputation(Trendline.TrendlineComputation node, C context) { + return visitChildren(node, context); + } + public T visitCorrelation(Correlation node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java new file mode 100644 index 000000000..27cabf568 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -0,0 +1,65 @@ +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +@ToString +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Trendline extends UnresolvedPlan { + + private UnresolvedPlan child; + private final List computations; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Getter + public static class TrendlineComputation extends UnresolvedExpression { + + private final Integer numberOfDataPoints; + private final UnresolvedExpression dataField; + private final String alias; + private final TrendlineType computationType; + + public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, String computationType) { + this.numberOfDataPoints = numberOfDataPoints; + this.dataField = dataField; + this.alias = alias; + this.computationType = Trendline.TrendlineType.valueOf(computationType.toUpperCase()); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitTrendlineComputation(this, context); + } + + } + + public enum TrendlineType { + SMA, + WMA + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index e6ab083ee..1e973af15 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -12,15 +12,25 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Ascending$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; +import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.InSubquery$; import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RowFrame$; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; -import org.apache.spark.sql.catalyst.plans.logical.*; +import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; +import org.apache.spark.sql.catalyst.expressions.WindowExpression; +import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project$; import org.apache.spark.sql.execution.ExplainMode; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.execution.command.ExplainCommand; @@ -35,6 +45,7 @@ import org.opensearch.sql.ast.expression.BinaryExpression; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; @@ -76,8 +87,10 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; @@ -89,7 +102,11 @@ import scala.collection.IterableLike; import scala.collection.Seq; -import java.util.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Stack; import java.util.function.BiFunction; import java.util.stream.Collectors; @@ -250,6 +267,13 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) { }); } + @Override + public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + visitExpressionList(node.getComputations(), context); + return child; + } + @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { node.getChild().get(0).accept(this, context); @@ -612,6 +636,31 @@ public Expression visitSpan(Span node, CatalystPlanContext context) { return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); } + @Override + public Expression visitTrendlineComputation(Trendline.TrendlineComputation node, CatalystPlanContext context) { + this.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + this.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); + Expression windowLowerBoundary = context.popNamedParseExpressions().get(); + if (node.getComputationType() == Trendline.TrendlineType.SMA) { + WindowExpression sma = new WindowExpression( + avgFunction, + new WindowSpecDefinition( + seq(), + seq(), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$))); + return context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(sma, + node.getAlias(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } else { + throw new IllegalArgumentException("WMA is not supported"); + } + } + @Override public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { node.getField().accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 8673b1582..0c7abbb33 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -371,6 +371,15 @@ private java.util.Map buildLookupPair(List (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new)); } + @Override + public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { + List trendlineComputations = ctx.trendlineClause() + .stream() + .map(expressionBuilder::visit) + .collect(Collectors.toList()); + return new Trendline(trendlineComputations); + } + /** Top command. */ @Override public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index f5e9269be..b835a7fca 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -37,6 +37,7 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -112,6 +113,15 @@ public UnresolvedExpression visitEvalClause(OpenSearchPPLParser.EvalClauseContex return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression())); } + @Override + public UnresolvedExpression visitTrendlineClause(OpenSearchPPLParser.TrendlineClauseContext ctx) { + Integer numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + Field dataField = (Field) this.visitFieldExpression(ctx.field); + String alias = ctx.alias.getText(); + String computationType = ctx.trendlineType().getText(); + return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, computationType); + } + /** * Logical expression excluding boolean, comparison. */