From 9c38c3c97de9bc6539c95c44d166984a3d395e68 Mon Sep 17 00:00:00 2001 From: jackieyanghan Date: Thu, 7 Apr 2022 17:35:36 -0700 Subject: [PATCH] Remove labels and visitors for AD and KMEANS command Signed-off-by: jackieyanghan --- .../physical/MLCommonsOperatorActions.java | 10 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 28 ++--- .../opensearch/sql/ppl/parser/AstBuilder.java | 29 ++--- .../sql/ppl/parser/AstExpressionBuilder.java | 105 ------------------ .../sql/ppl/utils/ArgumentFactory.java | 12 +- .../sql/ppl/parser/AstBuilderTest.java | 18 --- 6 files changed, 41 insertions(+), 161 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java index 3574630452..21b232c031 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/MLCommonsOperatorActions.java @@ -7,6 +7,7 @@ package org.opensearch.sql.opensearch.planner.physical; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -48,11 +49,12 @@ public abstract class MLCommonsOperatorActions extends PhysicalPlan { */ protected DataFrame generateInputDataset(PhysicalPlan input) { List> inputData = new LinkedList<>(); - ImmutableMap.Builder inputDataBuilder = new ImmutableMap.Builder<>(); while (input.hasNext()) { - input.next().tupleValue().forEach((key, value) - -> inputDataBuilder.put(key, value.value())); - inputData.add(inputDataBuilder.build()); + inputData.add(new HashMap() { + { + input.next().tupleValue().forEach((key, value) -> put(key, value.value())); + } + }); } return DataFrameBuilder.load(inputData); diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index b4124b0c73..da37f8e22b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -93,9 +93,9 @@ kmeansCommand ; kmeansParameter - : (CENTROIDS EQUAL centroids=integerLiteral) #centroids - | (ITERATIONS EQUAL iterations=integerLiteral) #iterations - | (DISTANCE_TYPE EQUAL distance_type=stringLiteral) #distance_type + : (CENTROIDS EQUAL centroids=integerLiteral) + | (ITERATIONS EQUAL iterations=integerLiteral) + | (DISTANCE_TYPE EQUAL distance_type=stringLiteral) ; adCommand @@ -103,17 +103,17 @@ adCommand ; adParameter - : (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) #number_of_trees - | (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) #shingle_size - | (SAMPLE_SIZE EQUAL sample_size=integerLiteral) #sample_size - | (OUTPUT_AFTER EQUAL output_after=integerLiteral) #output_after - | (TIME_DECAY EQUAL time_decay=decimalLiteral) #time_decay - | (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) #anomaly_rate - | (TIME_FIELD EQUAL time_field=stringLiteral) #time_field - | (DATE_FORMAT EQUAL date_format=stringLiteral) #date_format - | (TIME_ZONE EQUAL time_zone=stringLiteral) #time_zone - | (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) #training_data_size - | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) #anomaly_score_threshold + : (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) + | (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) + | (SAMPLE_SIZE EQUAL sample_size=integerLiteral) + | (OUTPUT_AFTER EQUAL output_after=integerLiteral) + | (TIME_DECAY EQUAL time_decay=decimalLiteral) + | (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) + | (TIME_FIELD EQUAL time_field=stringLiteral) + | (DATE_FORMAT EQUAL date_format=stringLiteral) + | (TIME_ZONE EQUAL time_zone=stringLiteral) + | (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) + | (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) ; /** clauses */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index e7cc81ed71..2b25004f15 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -23,6 +23,7 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WhereCommandContext; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -32,7 +33,6 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; @@ -313,12 +313,13 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla */ @Override public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { - return new Kmeans(ctx.kmeansParameter().stream() - .map(p -> (Argument) internalVisitExpression(p)) - .collect(Collectors.toMap( - Argument::getArgName, Argument::getValue, - (value1, value2) -> value2) - )); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.kmeansParameter() + .forEach(x -> { + builder.put(x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + return new Kmeans(builder.build()); } /** @@ -326,12 +327,14 @@ public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { */ @Override public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { - return new AD(ctx.adParameter().stream() - .map(p -> (Argument) internalVisitExpression(p)) - .collect(Collectors.toMap( - Argument::getArgName, Argument::getValue, - (value1, value2) -> value2) - )); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.adParameter() + .forEach(x -> { + builder.put(x.children.get(0).toString(), + (Literal) internalVisitExpression(x.children.get(2))); + }); + + return new AD(builder.build()); } /** diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index c7f24e20b6..79612ff2cb 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -9,20 +9,15 @@ import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Anomaly_rateContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Anomaly_score_thresholdContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CentroidsContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Date_formatContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Distance_typeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; @@ -32,43 +27,19 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntervalLiteralContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IterationsContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalAndContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalNotContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Number_of_treesContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Output_afterContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ParentheticBinaryArithmeticContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.PercentileAggFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RelevanceExpressionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Sample_sizeContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Shingle_sizeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TableSourceContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_decayContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_fieldContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_zoneContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Training_data_sizeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WcFieldExpressionContext; -import static org.opensearch.sql.ppl.utils.ArgumentFactory.getArgumentValue; -import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE; -import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD; -import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS; -import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT; -import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE; -import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS; -import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES; -import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER; -import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE; -import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD; -import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE; -import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -129,82 +100,6 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression())); } - /** - * Kmeans arguments. - */ - @Override - public UnresolvedExpression visitCentroids(CentroidsContext ctx) { - return new Argument(CENTROIDS, getArgumentValue(ctx.centroids)); - } - - @Override - public UnresolvedExpression visitIterations(IterationsContext ctx) { - return new Argument(ITERATIONS, getArgumentValue(ctx.iterations)); - } - - @Override - public UnresolvedExpression visitDistance_type(Distance_typeContext ctx) { - return new Argument(DISTANCE_TYPE, getArgumentValue(ctx.distance_type)); - } - - /** - * AD arguments. - */ - @Override - public UnresolvedExpression visitNumber_of_trees(Number_of_treesContext ctx) { - return new Argument(NUMBER_OF_TREES, getArgumentValue(ctx.number_of_trees)); - } - - @Override - public UnresolvedExpression visitShingle_size(Shingle_sizeContext ctx) { - return new Argument(SHINGLE_SIZE, getArgumentValue(ctx.shingle_size)); - } - - @Override - public UnresolvedExpression visitSample_size(Sample_sizeContext ctx) { - return new Argument(SAMPLE_SIZE, getArgumentValue(ctx.sample_size)); - } - - @Override - public UnresolvedExpression visitOutput_after(Output_afterContext ctx) { - return new Argument(OUTPUT_AFTER, getArgumentValue(ctx.output_after)); - } - - @Override - public UnresolvedExpression visitTime_decay(Time_decayContext ctx) { - return new Argument(TIME_DECAY, getArgumentValue(ctx.time_decay)); - } - - @Override - public UnresolvedExpression visitAnomaly_rate(Anomaly_rateContext ctx) { - return new Argument(ANOMALY_RATE, getArgumentValue(ctx.anomaly_rate)); - } - - @Override - public UnresolvedExpression visitTime_field(Time_fieldContext ctx) { - return new Argument(TIME_FIELD, getArgumentValue(ctx.time_field)); - } - - @Override - public UnresolvedExpression visitDate_format(Date_formatContext ctx) { - return new Argument(DATE_FORMAT, getArgumentValue(ctx.date_format)); - } - - @Override - public UnresolvedExpression visitTime_zone(Time_zoneContext ctx) { - return new Argument(TIME_ZONE, getArgumentValue(ctx.time_zone)); - } - - @Override - public UnresolvedExpression visitTraining_data_size(Training_data_sizeContext ctx) { - return new Argument(TRAINING_DATA_SIZE, getArgumentValue(ctx.training_data_size)); - } - - @Override - public UnresolvedExpression visitAnomaly_score_threshold(Anomaly_score_thresholdContext ctx) { - return new Argument(ANOMALY_SCORE_THRESHOLD, getArgumentValue(ctx.anomaly_score_threshold)); - } - /** * Logical expression excluding boolean, comparison. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index a635013b1e..09afd2075f 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -140,14 +140,12 @@ public static List getArgumentList(RareCommandContext ctx) { * @param ctx ParserRuleContext instance * @return Literal */ - public static Literal getArgumentValue(ParserRuleContext ctx) { + private static Literal getArgumentValue(ParserRuleContext ctx) { return ctx instanceof IntegerLiteralContext - ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) - : ctx instanceof BooleanLiteralContext - ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) - : ctx instanceof DecimalLiteralContext - ? new Literal(Double.valueOf(ctx.getText()), DataType.DOUBLE) - : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); + ? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER) + : ctx instanceof BooleanLiteralContext + ? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN) + : new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 453ef4490e..5ee0e2be6b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -591,30 +591,12 @@ public void testKmeansCommand() { )); } - @Test - public void testKmeansCommand_withDuplicateParameters() { - assertEqual("source=t | kmeans centroids=3 centroids=2", - new Kmeans(relation("t"), ImmutableMap.builder() - .put("centroids", new Literal(2, DataType.INTEGER)) - .build() - )); - } - @Test public void testKmeansCommandWithoutParameter() { assertEqual("source=t | kmeans", new Kmeans(relation("t"), ImmutableMap.of())); } - @Test - public void test_fitRCFADCommand_withDuplicateParameters() { - assertEqual("source=t | AD shingle_size=10 shingle_size=8", - new AD(relation("t"), ImmutableMap.builder() - .put("shingle_size", new Literal(8, DataType.INTEGER)) - .build() - )); - } - @Test public void test_fitRCFADCommand_withoutDataFormat() { assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' "