Skip to content

Commit

Permalink
WIP trendline command
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Trochimiak <kacper.trochimiak@eliatra.com>
  • Loading branch information
kt-eliatra committed Oct 9, 2024
1 parent 2e16414 commit 714bfcf
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.{QueryTest, Row}

class FlintSparkPPLTrendlineITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("trendline sma") {
val frame = sql(s"""
| source = $testTable | trendline sma(2, age) as first_age_sma sma(3, age) as second_age_sma | fields name, first_age_sma, second_age_sma
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array()

// Convert actual results to a Set for quick lookup
val resultsSet: Set[Row] = results.toSet
// Check that each expected row is present in the actual results
expectedResults.foreach { expectedRow =>
assert(resultsSet.contains(expectedRow), s"Expected row $expectedRow not found in results")
}
// Retrieve the logical plan
val logicalPlan: LogicalPlan =
frame.queryExecution.commandExecuted.asInstanceOf[CommandResult].commandLogicalPlan
// Define the expected logical plan
val expectedPlan: LogicalPlan = Project(Seq(), new UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
5 changes: 5 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ KMEANS: 'KMEANS';
AD: 'AD';
ML: 'ML';
FILLNULL: 'FILLNULL';
TRENDLINE: 'TRENDLINE';

//Native JOIN KEYWORDS
JOIN: 'JOIN';
Expand Down Expand Up @@ -86,6 +87,10 @@ STR: 'STR';
IP: 'IP';
NUM: 'NUM';

//TRENDLINE KEYWORDS
SMA: 'SMA';
WMA: 'WMA';

// ARGUMENT KEYWORDS
KEEPEMPTY: 'KEEPEMPTY';
CONSECUTIVE: 'CONSECUTIVE';
Expand Down
14 changes: 14 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ commands
| lookupCommand
| renameCommand
| fillnullCommand
| trendlineCommand
;

searchCommand
Expand Down Expand Up @@ -208,6 +209,19 @@ fillnullCommand
;


trendlineCommand
: TRENDLINE trendlineClause (trendlineClause)*
;

trendlineClause
: trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS AS alias = fieldExpression
;

trendlineType
: SMA
| WMA
;

kmeansCommand
: KMEANS (kmeansParameter)*
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ public T visitLookup(Lookup node, C context) {
return visitChildren(node, context);
}

public T visitTrendline(Trendline node, C context) {
return visitChildren(node, context);
}

public T visitTrendlineComputation(Trendline.TrendlineComputation node, C context) {
return visitChildren(node, context);
}

public T visitCorrelation(Correlation node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.List;

@ToString
@Getter
@RequiredArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Trendline extends UnresolvedPlan {

private UnresolvedPlan child;
private final List<UnresolvedExpression> computations;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public List<? extends Node> getChild() {
return ImmutableList.of(child);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> visitor, C context) {
return visitor.visitTrendline(this, context);
}

@Getter
public static class TrendlineComputation extends UnresolvedExpression {

private final Integer numberOfDataPoints;
private final UnresolvedExpression dataField;
private final String alias;
private final TrendlineType computationType;

public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, String computationType) {
this.numberOfDataPoints = numberOfDataPoints;
this.dataField = dataField;
this.alias = alias;
this.computationType = Trendline.TrendlineType.valueOf(computationType.toUpperCase());
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitTrendlineComputation(this, context);
}

}

public enum TrendlineType {
SMA,
WMA
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,25 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.Ascending$;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.CurrentRow$;
import org.apache.spark.sql.catalyst.expressions.Descending$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.RowFrame$;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.*;
import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame;
import org.apache.spark.sql.catalyst.expressions.WindowExpression;
import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns$;
import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$;
import org.apache.spark.sql.catalyst.plans.logical.Limit;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project$;
import org.apache.spark.sql.execution.ExplainMode;
import org.apache.spark.sql.execution.command.DescribeTableCommand;
import org.apache.spark.sql.execution.command.ExplainCommand;
Expand All @@ -35,6 +45,7 @@
import org.opensearch.sql.ast.expression.BinaryExpression;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -76,8 +87,10 @@
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.TopAggregation;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
Expand All @@ -89,7 +102,11 @@
import scala.collection.IterableLike;
import scala.collection.Seq;

import java.util.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -250,6 +267,13 @@ public LogicalPlan visitLookup(Lookup node, CatalystPlanContext context) {
});
}

@Override
public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
visitExpressionList(node.getComputations(), context);
return child;
}

@Override
public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
Expand Down Expand Up @@ -612,6 +636,31 @@ public Expression visitSpan(Span node, CatalystPlanContext context) {
return context.getNamedParseExpressions().push(window(field, value, node.getUnit()));
}

@Override
public Expression visitTrendlineComputation(Trendline.TrendlineComputation node, CatalystPlanContext context) {
this.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context);
Expression avgFunction = context.popNamedParseExpressions().get();
this.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context);
Expression windowLowerBoundary = context.popNamedParseExpressions().get();
if (node.getComputationType() == Trendline.TrendlineType.SMA) {
WindowExpression sma = new WindowExpression(
avgFunction,
new WindowSpecDefinition(
seq(),
seq(),
new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)));
return context.getNamedParseExpressions().push(
org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(sma,
node.getAlias(),
NamedExpression.newExprId(),
seq(new java.util.ArrayList<String>()),
Option.empty(),
seq(new java.util.ArrayList<String>())));
} else {
throw new IllegalArgumentException("WMA is not supported");
}
}

@Override
public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) {
node.getField().accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,15 @@ private java.util.Map<Alias, Field> buildLookupPair(List<OpenSearchPPLParser.Loo
.collect(Collectors.toMap(and -> (Alias) and.getLeft(), and -> (Field) and.getRight(), (x, y) -> y, LinkedHashMap::new));
}

@Override
public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) {
List<UnresolvedExpression> trendlineComputations = ctx.trendlineClause()
.stream()
.map(expressionBuilder::visit)
.collect(Collectors.toList());
return new Trendline(trendlineComputations);
}

/** Top command. */
@Override
public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.ppl.utils.ArgumentFactory;

Expand Down Expand Up @@ -112,6 +113,15 @@ public UnresolvedExpression visitEvalClause(OpenSearchPPLParser.EvalClauseContex
return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression()));
}

@Override
public UnresolvedExpression visitTrendlineClause(OpenSearchPPLParser.TrendlineClauseContext ctx) {
Integer numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText());
Field dataField = (Field) this.visitFieldExpression(ctx.field);
String alias = ctx.alias.getText();
String computationType = ctx.trendlineType().getText();
return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, computationType);
}

/**
* Logical expression excluding boolean, comparison.
*/
Expand Down

0 comments on commit 714bfcf

Please sign in to comment.