From f73afbdf0a6adcb5e2d8a8ee8a52c922387570ab Mon Sep 17 00:00:00 2001 From: Sean Kao Date: Sun, 23 Jun 2024 14:52:00 -0700 Subject: [PATCH] unquote text and identifiers in PPL parsing Signed-off-by: Sean Kao --- .../spark/ppl/FlintSparkPPLBasicITSuite.scala | 51 +++++++++ .../sql/common/utils/StringUtils.java | 100 ++++++++++++++++++ .../sql/ppl/parser/AstExpressionBuilder.java | 12 ++- ...lPlanBasicQueriesTranslatorTestSuite.scala | 9 ++ 4 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/utils/StringUtils.java diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index ba925339e..4130a8c0b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -65,6 +65,33 @@ class FlintSparkPPLBasicITSuite assert(expectedPlan === logicalPlan) } + test("create ppl simple query with escaped identifiers test") { + val frame = sql("source = `spark_catalog`.`default`.`flint_ppl_test`") + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val expectedPlan: LogicalPlan = + Project( + Seq(UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + // Compare the two plans + assert(expectedPlan === logicalPlan) + } + test("create ppl simple query with head (limit) 3 test") { val frame = sql(s""" | source = $testTable| head 2 @@ -210,4 +237,28 @@ class FlintSparkPPLBasicITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test("create ppl simple query with quoted field names test") { + val frame = sql(s""" + | source = $testTable| fields `name`, `age` | head 1 | sort `age` + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val project = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) + // Define the expected logical plan + val limitPlan: LogicalPlan = Limit(Literal(1), project) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan) + + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan); + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/utils/StringUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/utils/StringUtils.java new file mode 100644 index 000000000..914bb1dfc --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/utils/StringUtils.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.utils; + +import com.google.common.base.Strings; + +import java.util.IllegalFormatException; +import java.util.Locale; + +public class StringUtils { + /** + * Unquote Identifier which has " or ' as mark. Strings quoted by ' or " with two of these quotes + * appearing next to each other in the quote acts as an escape
+ * Example: 'Test''s' will result in 'Test's', similar with those single quotes being replaced + * with double quote. Supports escaping quotes (single/double) and escape characters using the `\` + * characters. + * + * @param text string + * @return An unquoted string whose outer pair of (single/double) quotes have been removed + */ + public static String unquoteText(String text) { + if (text.length() < 2) { + return text; + } + + char enclosingQuote = 0; + char firstChar = text.charAt(0); + char lastChar = text.charAt(text.length() - 1); + + if (firstChar != lastChar) { + return text; + } + + if (firstChar == '`') { + return text.substring(1, text.length() - 1); + } + + if (firstChar == lastChar && (firstChar == '\'' || firstChar == '"')) { + enclosingQuote = firstChar; + } else { + return text; + } + + char currentChar; + char nextChar; + + StringBuilder textSB = new StringBuilder(); + + // Ignores first and last character as they are the quotes that should be removed + for (int chIndex = 1; chIndex < text.length() - 1; chIndex++) { + currentChar = text.charAt(chIndex); + nextChar = text.charAt(chIndex + 1); + + if ((currentChar == '\\' && (nextChar == '"' || nextChar == '\\' || nextChar == '\'')) + || (currentChar == nextChar && currentChar == enclosingQuote)) { + chIndex++; + currentChar = nextChar; + } + textSB.append(currentChar); + } + return textSB.toString(); + } + + /** + * Unquote Identifier which has ` as mark. + * + * @param identifier identifier that possibly enclosed by double quotes or back ticks + * @return An unquoted string whose outer pair of (double/back-tick) quotes have been removed + */ + public static String unquoteIdentifier(String identifier) { + if (isQuoted(identifier, "`")) { + return identifier.substring(1, identifier.length() - 1); + } else { + return identifier; + } + } + + /** + * Returns a formatted string using the specified format string and arguments, as well as the + * {@link Locale#ROOT} locale. + * + * @param format format string + * @param args arguments referenced by the format specifiers in the format string + * @return A formatted string + * @throws IllegalFormatException If a format string contains an illegal syntax, a format + * specifier that is incompatible with the given arguments, insufficient arguments given the + * format string, or other illegal conditions. + * @see String#format(Locale, String, Object...) + */ + public static String format(final String format, Object... args) { + return String.format(Locale.ROOT, format, args); + } + + private static boolean isQuoted(String text, String mark) { + return !Strings.isNullOrEmpty(text) && text.startsWith(mark) && text.endsWith(mark); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 3344cd7c2..b265047b5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -32,6 +32,7 @@ import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; import java.util.Arrays; @@ -322,7 +323,7 @@ public UnresolvedExpression visitIntervalLiteral(OpenSearchPPLParser.IntervalLit @Override public UnresolvedExpression visitStringLiteral(OpenSearchPPLParser.StringLiteralContext ctx) { - return new Literal(ctx.getText(), DataType.STRING); + return new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING); } @Override @@ -349,7 +350,7 @@ public UnresolvedExpression visitBySpanClause(OpenSearchPPLParser.BySpanClauseCo String name = ctx.spanClause().getText(); return ctx.alias != null ? new Alias( - name, visit(ctx.spanClause()), ctx.alias.getText()) + name, visit(ctx.spanClause()), StringUtils.unquoteIdentifier(ctx.alias.getText())) : new Alias(name, visit(ctx.spanClause())); } @@ -363,6 +364,7 @@ private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( ctx.stream() .map(RuleContext::getText) + .map(StringUtils::unquoteIdentifier) .collect(Collectors.toList())); } @@ -373,10 +375,10 @@ private List singleFieldRelevanceArguments( ImmutableList.Builder builder = ImmutableList.builder(); builder.add( new UnresolvedArgument( - "field", new QualifiedName(ctx.field.getText()))); + "field", new QualifiedName(StringUtils.unquoteText(ctx.field.getText())))); builder.add( new UnresolvedArgument( - "query", new Literal(ctx.query.getText(), DataType.STRING))); + "query", new Literal(StringUtils.unquoteText(ctx.query.getText()), DataType.STRING))); ctx.relevanceArg() .forEach( v -> @@ -384,7 +386,7 @@ private List singleFieldRelevanceArguments( new UnresolvedArgument( v.relevanceArgName().getText().toLowerCase(), new Literal( - v.relevanceArgValue().getText(), + StringUtils.unquoteText(v.relevanceArgValue().getText()), DataType.STRING)))); return builder.build(); } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index 1b04189db..b3e767d23 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -32,7 +32,16 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) assertEquals(expectedPlan, logPlan) + } + + test("test simple search with escaped table name") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTrnasformer.visit(plan(pplParser, "source=`table`", false), context) + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table"))) + assertEquals(expectedPlan, logPlan) } test("test simple search with schema.table and no explicit fields (defaults to all fields)") {