Skip to content

Commit

Permalink
Remove labels and visitors for AD and KMEANS command
Browse files Browse the repository at this point in the history
Signed-off-by: jackieyanghan <jkhanjob@gmail.com>
  • Loading branch information
jackiehanyang committed Apr 8, 2022
1 parent 465bdec commit 9c38c3c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.opensearch.sql.opensearch.planner.physical;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
Expand Down Expand Up @@ -48,11 +49,12 @@ public abstract class MLCommonsOperatorActions extends PhysicalPlan {
*/
protected DataFrame generateInputDataset(PhysicalPlan input) {
List<Map<String, Object>> inputData = new LinkedList<>();
ImmutableMap.Builder<String, Object> inputDataBuilder = new ImmutableMap.Builder<>();
while (input.hasNext()) {
input.next().tupleValue().forEach((key, value)
-> inputDataBuilder.put(key, value.value()));
inputData.add(inputDataBuilder.build());
inputData.add(new HashMap<String, Object>() {
{
input.next().tupleValue().forEach((key, value) -> put(key, value.value()));
}
});
}

return DataFrameBuilder.load(inputData);
Expand Down
28 changes: 14 additions & 14 deletions ppl/src/main/antlr/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,27 @@ kmeansCommand
;

kmeansParameter
: (CENTROIDS EQUAL centroids=integerLiteral) #centroids
| (ITERATIONS EQUAL iterations=integerLiteral) #iterations
| (DISTANCE_TYPE EQUAL distance_type=stringLiteral) #distance_type
: (CENTROIDS EQUAL centroids=integerLiteral)
| (ITERATIONS EQUAL iterations=integerLiteral)
| (DISTANCE_TYPE EQUAL distance_type=stringLiteral)
;

adCommand
: AD (adParameter)*
;

adParameter
: (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral) #number_of_trees
| (SHINGLE_SIZE EQUAL shingle_size=integerLiteral) #shingle_size
| (SAMPLE_SIZE EQUAL sample_size=integerLiteral) #sample_size
| (OUTPUT_AFTER EQUAL output_after=integerLiteral) #output_after
| (TIME_DECAY EQUAL time_decay=decimalLiteral) #time_decay
| (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral) #anomaly_rate
| (TIME_FIELD EQUAL time_field=stringLiteral) #time_field
| (DATE_FORMAT EQUAL date_format=stringLiteral) #date_format
| (TIME_ZONE EQUAL time_zone=stringLiteral) #time_zone
| (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral) #training_data_size
| (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral) #anomaly_score_threshold
: (NUMBER_OF_TREES EQUAL number_of_trees=integerLiteral)
| (SHINGLE_SIZE EQUAL shingle_size=integerLiteral)
| (SAMPLE_SIZE EQUAL sample_size=integerLiteral)
| (OUTPUT_AFTER EQUAL output_after=integerLiteral)
| (TIME_DECAY EQUAL time_decay=decimalLiteral)
| (ANOMALY_RATE EQUAL anomaly_rate=decimalLiteral)
| (TIME_FIELD EQUAL time_field=stringLiteral)
| (DATE_FORMAT EQUAL date_format=stringLiteral)
| (TIME_ZONE EQUAL time_zone=stringLiteral)
| (TRAINING_DATA_SIZE EQUAL training_data_size=integerLiteral)
| (ANOMALY_SCORE_THRESHOLD EQUAL anomaly_score_threshold=decimalLiteral)
;

/** clauses */
Expand Down
29 changes: 16 additions & 13 deletions ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WhereCommandContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
Expand All @@ -32,7 +33,6 @@
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.tree.ParseTree;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
Expand Down Expand Up @@ -313,25 +313,28 @@ protected UnresolvedPlan aggregateResult(UnresolvedPlan aggregate, UnresolvedPla
*/
@Override
public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) {
return new Kmeans(ctx.kmeansParameter().stream()
.map(p -> (Argument) internalVisitExpression(p))
.collect(Collectors.toMap(
Argument::getArgName, Argument::getValue,
(value1, value2) -> value2)
));
ImmutableMap.Builder<String, Literal> builder = ImmutableMap.builder();
ctx.kmeansParameter()
.forEach(x -> {
builder.put(x.children.get(0).toString(),
(Literal) internalVisitExpression(x.children.get(2)));
});
return new Kmeans(builder.build());
}

/**
* AD command.
*/
@Override
public UnresolvedPlan visitAdCommand(AdCommandContext ctx) {
return new AD(ctx.adParameter().stream()
.map(p -> (Argument) internalVisitExpression(p))
.collect(Collectors.toMap(
Argument::getArgName, Argument::getValue,
(value1, value2) -> value2)
));
ImmutableMap.Builder<String, Literal> builder = ImmutableMap.builder();
ctx.adParameter()
.forEach(x -> {
builder.put(x.children.get(0).toString(),
(Literal) internalVisitExpression(x.children.get(2)));
});

return new AD(builder.build());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,15 @@
import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Anomaly_rateContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Anomaly_score_thresholdContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CentroidsContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Date_formatContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Distance_typeContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext;
Expand All @@ -32,43 +27,19 @@
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntervalLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IterationsContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalAndContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalNotContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Number_of_treesContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Output_afterContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ParentheticBinaryArithmeticContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.PercentileAggFunctionContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RelevanceExpressionContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Sample_sizeContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Shingle_sizeContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TableSourceContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_decayContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_fieldContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Time_zoneContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.Training_data_sizeContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WcFieldExpressionContext;
import static org.opensearch.sql.ppl.utils.ArgumentFactory.getArgumentValue;
import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE;
import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD;
import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS;
import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT;
import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE;
import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS;
import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES;
import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER;
import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE;
import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -129,82 +100,6 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) {
return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression()));
}

/**
* Kmeans arguments.
*/
@Override
public UnresolvedExpression visitCentroids(CentroidsContext ctx) {
return new Argument(CENTROIDS, getArgumentValue(ctx.centroids));
}

@Override
public UnresolvedExpression visitIterations(IterationsContext ctx) {
return new Argument(ITERATIONS, getArgumentValue(ctx.iterations));
}

@Override
public UnresolvedExpression visitDistance_type(Distance_typeContext ctx) {
return new Argument(DISTANCE_TYPE, getArgumentValue(ctx.distance_type));
}

/**
* AD arguments.
*/
@Override
public UnresolvedExpression visitNumber_of_trees(Number_of_treesContext ctx) {
return new Argument(NUMBER_OF_TREES, getArgumentValue(ctx.number_of_trees));
}

@Override
public UnresolvedExpression visitShingle_size(Shingle_sizeContext ctx) {
return new Argument(SHINGLE_SIZE, getArgumentValue(ctx.shingle_size));
}

@Override
public UnresolvedExpression visitSample_size(Sample_sizeContext ctx) {
return new Argument(SAMPLE_SIZE, getArgumentValue(ctx.sample_size));
}

@Override
public UnresolvedExpression visitOutput_after(Output_afterContext ctx) {
return new Argument(OUTPUT_AFTER, getArgumentValue(ctx.output_after));
}

@Override
public UnresolvedExpression visitTime_decay(Time_decayContext ctx) {
return new Argument(TIME_DECAY, getArgumentValue(ctx.time_decay));
}

@Override
public UnresolvedExpression visitAnomaly_rate(Anomaly_rateContext ctx) {
return new Argument(ANOMALY_RATE, getArgumentValue(ctx.anomaly_rate));
}

@Override
public UnresolvedExpression visitTime_field(Time_fieldContext ctx) {
return new Argument(TIME_FIELD, getArgumentValue(ctx.time_field));
}

@Override
public UnresolvedExpression visitDate_format(Date_formatContext ctx) {
return new Argument(DATE_FORMAT, getArgumentValue(ctx.date_format));
}

@Override
public UnresolvedExpression visitTime_zone(Time_zoneContext ctx) {
return new Argument(TIME_ZONE, getArgumentValue(ctx.time_zone));
}

@Override
public UnresolvedExpression visitTraining_data_size(Training_data_sizeContext ctx) {
return new Argument(TRAINING_DATA_SIZE, getArgumentValue(ctx.training_data_size));
}

@Override
public UnresolvedExpression visitAnomaly_score_threshold(Anomaly_score_thresholdContext ctx) {
return new Argument(ANOMALY_SCORE_THRESHOLD, getArgumentValue(ctx.anomaly_score_threshold));
}

/**
* Logical expression excluding boolean, comparison.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,12 @@ public static List<Argument> getArgumentList(RareCommandContext ctx) {
* @param ctx ParserRuleContext instance
* @return Literal
*/
public static Literal getArgumentValue(ParserRuleContext ctx) {
private static Literal getArgumentValue(ParserRuleContext ctx) {
return ctx instanceof IntegerLiteralContext
? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER)
: ctx instanceof BooleanLiteralContext
? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN)
: ctx instanceof DecimalLiteralContext
? new Literal(Double.valueOf(ctx.getText()), DataType.DOUBLE)
: new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING);
? new Literal(Integer.parseInt(ctx.getText()), DataType.INTEGER)
: ctx instanceof BooleanLiteralContext
? new Literal(Boolean.valueOf(ctx.getText()), DataType.BOOLEAN)
: new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -591,30 +591,12 @@ public void testKmeansCommand() {
));
}

@Test
public void testKmeansCommand_withDuplicateParameters() {
assertEqual("source=t | kmeans centroids=3 centroids=2",
new Kmeans(relation("t"), ImmutableMap.<String, Literal>builder()
.put("centroids", new Literal(2, DataType.INTEGER))
.build()
));
}

@Test
public void testKmeansCommandWithoutParameter() {
assertEqual("source=t | kmeans",
new Kmeans(relation("t"), ImmutableMap.of()));
}

@Test
public void test_fitRCFADCommand_withDuplicateParameters() {
assertEqual("source=t | AD shingle_size=10 shingle_size=8",
new AD(relation("t"), ImmutableMap.<String, Literal>builder()
.put("shingle_size", new Literal(8, DataType.INTEGER))
.build()
));
}

@Test
public void test_fitRCFADCommand_withoutDataFormat() {
assertEqual("source=t | AD shingle_size=10 time_decay=0.0001 time_field='timestamp' "
Expand Down

0 comments on commit 9c38c3c

Please sign in to comment.