Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unquote text and identifiers in PPL parsing #393

Merged
merged 6 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class FlintSparkPPLAggregationsITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row(36.25))

// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))
Expand Down Expand Up @@ -76,7 +75,6 @@ class FlintSparkPPLAggregationsITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row(25))

// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,34 @@ class FlintSparkPPLBasicITSuite
}

test("create ppl simple query test") {
val frame = sql(s"""
| source = $testTable
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4),
Row("John", 25, "Ontario", "Canada", 2023, 4),
Row("Jane", 20, "Quebec", "Canada", 2023, 4))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(
Seq(UnresolvedStar(None)),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Compare the two plans
assert(expectedPlan === logicalPlan)
val testTableQuoted = "`spark_catalog`.`default`.`flint_ppl_test`"
Seq(testTable, testTableQuoted).foreach { table =>
val frame = sql(s"""
| source = $table
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4),
Row("John", 25, "Ontario", "Canada", 2023, 4),
Row("Jane", 20, "Quebec", "Canada", 2023, 4))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(
Seq(UnresolvedStar(None)),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Compare the two plans
assert(expectedPlan === logicalPlan)
}
}

test("create ppl simple query with head (limit) 3 test") {
Expand Down Expand Up @@ -90,7 +92,6 @@ class FlintSparkPPLBasicITSuite
| source = $testTable| sort name | head 2
| """.stripMargin)

// Retrieve the results
// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 2)
Expand Down Expand Up @@ -187,27 +188,29 @@ class FlintSparkPPLBasicITSuite
}

test("create ppl simple query two with fields and head (limit) with sorting test") {
val frame = sql(s"""
| source = $testTable| fields name, age | head 1 | sort age
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val project = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Define the expected logical plan
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan);
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
Seq(("name, age", "age"), ("`name`, `age`", "`age`")).foreach {
case (selectFields, sortField) =>
val frame = sql(s"""
| source = $testTable| fields $selectFields | head 1 | sort $sortField
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val project = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Define the expected logical plan
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan);
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("John", 25))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand All @@ -72,7 +71,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand Down Expand Up @@ -182,7 +180,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand All @@ -209,7 +206,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand Down Expand Up @@ -287,7 +283,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20))

// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class FlintSparkPPLTimeWindowITSuite
override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
// Update table creation
createTimeSeriesTransactionTable(testTable)
}

Expand All @@ -39,16 +38,6 @@ class FlintSparkPPLTimeWindowITSuite
}

test("create ppl query count sales by days window test") {
/*
val dataFrame = spark.read.table(testTable)
val query = dataFrame
.groupBy(
window(
col("transactionDate"), " 1 days")
).agg(sum(col("productsAmount")))

query.show(false)
*/
val frame = sql(s"""
| source = $testTable| stats sum(productsAmount) by span(transactionDate, 1d) as age_date
| """.stripMargin)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.common.utils;

import com.google.common.base.Strings;

import java.util.IllegalFormatException;
import java.util.Locale;

public class StringUtils {
seankao-az marked this conversation as resolved.
Show resolved Hide resolved
seankao-az marked this conversation as resolved.
Show resolved Hide resolved
/**
* Unquote Identifier which has " or ' as mark. Strings quoted by ' or " with two of these quotes
* appearing next to each other in the quote acts as an escape<br>
* Example: 'Test''s' will result in 'Test's', similar with those single quotes being replaced
* with double quote. Supports escaping quotes (single/double) and escape characters using the `\`
* characters.
*
* @param text string
* @return An unquoted string whose outer pair of (single/double) quotes have been removed
*/
public static String unquoteText(String text) {
if (text.length() < 2) {
return text;
}

char enclosingQuote = 0;
char firstChar = text.charAt(0);
char lastChar = text.charAt(text.length() - 1);

if (firstChar != lastChar) {
return text;
}

if (firstChar == '`') {
return text.substring(1, text.length() - 1);
}

if (firstChar == lastChar && (firstChar == '\'' || firstChar == '"')) {
enclosingQuote = firstChar;
} else {
return text;
}

char currentChar;
char nextChar;

StringBuilder textSB = new StringBuilder();

// Ignores first and last character as they are the quotes that should be removed
for (int chIndex = 1; chIndex < text.length() - 1; chIndex++) {
currentChar = text.charAt(chIndex);
nextChar = text.charAt(chIndex + 1);

if ((currentChar == '\\' && (nextChar == '"' || nextChar == '\\' || nextChar == '\''))
|| (currentChar == nextChar && currentChar == enclosingQuote)) {
chIndex++;
currentChar = nextChar;
}
textSB.append(currentChar);
}
return textSB.toString();
}

/**
* Unquote Identifier which has ` as mark.
*
* @param identifier identifier that possibly enclosed by backticks
* @return An unquoted string whose outer pair of backticks have been removed
*/
public static String unquoteIdentifier(String identifier) {
if (isQuoted(identifier, "`")) {
return identifier.substring(1, identifier.length() - 1);
} else {
return identifier;
}
}

/**
* Returns a formatted string using the specified format string and arguments, as well as the
* {@link Locale#ROOT} locale.
*
* @param format format string
* @param args arguments referenced by the format specifiers in the format string
* @return A formatted string
* @throws IllegalFormatException If a format string contains an illegal syntax, a format
* specifier that is incompatible with the given arguments, insufficient arguments given the
* format string, or other illegal conditions.
* @see String#format(Locale, String, Object...)
*/
public static String format(final String format, Object... args) {
return String.format(Locale.ROOT, format, args);
}

private static boolean isQuoted(String text, String mark) {
return !Strings.isNullOrEmpty(text) && text.startsWith(mark) && text.endsWith(mark);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.ppl.utils.ArgumentFactory;

import java.util.Arrays;
Expand Down Expand Up @@ -322,7 +323,7 @@ public UnresolvedExpression visitIntervalLiteral(OpenSearchPPLParser.IntervalLit

@Override
public UnresolvedExpression visitStringLiteral(OpenSearchPPLParser.StringLiteralContext ctx) {
return new Literal(ctx.getText(), DataType.STRING);
return new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING);
}

@Override
Expand All @@ -349,7 +350,7 @@ public UnresolvedExpression visitBySpanClause(OpenSearchPPLParser.BySpanClauseCo
String name = ctx.spanClause().getText();
return ctx.alias != null
? new Alias(
name, visit(ctx.spanClause()), ctx.alias.getText())
name, visit(ctx.spanClause()), StringUtils.unquoteIdentifier(ctx.alias.getText()))
: new Alias(name, visit(ctx.spanClause()));
}

Expand All @@ -363,6 +364,7 @@ private QualifiedName visitIdentifiers(List<? extends ParserRuleContext> ctx) {
return new QualifiedName(
ctx.stream()
.map(RuleContext::getText)
.map(StringUtils::unquoteIdentifier)
.collect(Collectors.toList()));
}

Expand All @@ -373,18 +375,18 @@ private List<UnresolvedExpression> singleFieldRelevanceArguments(
ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder();
builder.add(
new UnresolvedArgument(
"field", new QualifiedName(ctx.field.getText())));
"field", new QualifiedName(StringUtils.unquoteText(ctx.field.getText()))));
builder.add(
new UnresolvedArgument(
"query", new Literal(ctx.query.getText(), DataType.STRING)));
"query", new Literal(StringUtils.unquoteText(ctx.query.getText()), DataType.STRING)));
ctx.relevanceArg()
.forEach(
v ->
builder.add(
new UnresolvedArgument(
v.relevanceArgName().getText().toLowerCase(),
new Literal(
v.relevanceArgValue().getText(),
StringUtils.unquoteText(v.relevanceArgValue().getText()),
DataType.STRING))));
return builder.build();
}
Expand Down
Loading
Loading