From d6264c1cfddc38fd35cb4d0dfa4a744319a790df Mon Sep 17 00:00:00 2001 From: Josh Pschorr Date: Thu, 23 Jun 2022 12:24:57 -0700 Subject: [PATCH] Add experimental draft support for GPML-style graph query --- .../org/partiql/type-domains/partiql.ion | 48 +- .../partiql/lang/ast/StatementToExprNode.kt | 1 + lang/src/org/partiql/lang/errors/ErrorCode.kt | 36 ++ .../GroupByPathExpressionVisitorTransform.kt | 3 + .../AstToLogicalVisitorTransform.kt | 4 + .../org/partiql/lang/syntax/LexerConstants.kt | 11 +- lang/src/org/partiql/lang/syntax/SqlParser.kt | 411 ++++++++++++++++ .../partiql/lang/syntax/SqlParserMatchTest.kt | 462 ++++++++++++++++++ .../partiql/lang/syntax/SqlParserTestBase.kt | 25 +- 9 files changed, 993 insertions(+), 8 deletions(-) create mode 100644 lang/test/org/partiql/lang/syntax/SqlParserMatchTest.kt diff --git a/lang/resources/org/partiql/type-domains/partiql.ion b/lang/resources/org/partiql/type-domains/partiql.ion index 8224ca3948..def42f1324 100644 --- a/lang/resources/org/partiql/type-domains/partiql.ion +++ b/lang/resources/org/partiql/type-domains/partiql.ion @@ -210,12 +210,56 @@ may then be further optimized by selecting better implementations of each operat // UNPIVOT [AS ] [AT ] [BY ] (unpivot expr::expr as_alias::(? symbol) at_alias::(? symbol) by_alias::(? symbol)) - // JOIN [INNER | LEFT | RIGHT | FULL] ON - (join type::join_type left::from_source right::from_source predicate::(? expr))) + // JOIN [INNER | LEFT | RIGHT | FULL] ON + (join type::join_type left::from_source right::from_source predicate::(? expr)) + + // MATCH + (graph_match expr::expr graph_expr::graph_match_expr)) // Indicates the logical type of join. (sum join_type (inner) (left) (right) (full)) + // A quantifier for graph edges or patterns. (e.g., the `{2,5}` in `MATCH (x)->{2,5}->(y)`) + (product graph_match_quantifier lower::int upper::(? int)) + + // A single node in a graph pattern. + (product graph_match_node + predicate::(? expr) // an optional node pre-filter + variable::(? symbol) // the optional element variable of the node match + label::(* symbol 0)) // the optional label(s) to match for the node + + // A single edge in a graph pattern. + (product graph_match_edge + direction::graph_match_direction // edge direction + quantifier::(? graph_match_quantifier) // an optional quantifier for the edge match + predicate::(? expr) // an optional edge pre-filter + variable::(? symbol) // the optional element variable of the edge match + label::(* symbol 0)) // the optional label(s) to match for the edge + + // The direction of an edge + (sum graph_match_direction + (edge_left) + (edge_undirected) + (edge_right) + (edge_left_or_undirected) + (edge_undirected_or_right) + (edge_left_or_right) + (edge_left_or_undirected_or_right)) + + (sum graph_match_pattern_part + (node node::graph_match_node) + (edge edge::graph_match_edge) + (pattern pattern::graph_match_pattern)) + + // A single graph match pattern. + (product graph_match_pattern + quantifier::(? graph_match_quantifier) // an optional quantifier for the entire pattern match + parts::(* graph_match_pattern_part 1)) // the ordered pattern parts + + // A graph match clause. + (product graph_match_expr patterns::(*graph_match_pattern 1)) + + // A generic pair of expressions. Used in the `struct`, `searched_case` and `simple_case` expr variants above. (product expr_pair first::expr second::expr) diff --git a/lang/src/org/partiql/lang/ast/StatementToExprNode.kt b/lang/src/org/partiql/lang/ast/StatementToExprNode.kt index 920b3e3b7d..e6a44e2d97 100644 --- a/lang/src/org/partiql/lang/ast/StatementToExprNode.kt +++ b/lang/src/org/partiql/lang/ast/StatementToExprNode.kt @@ -266,6 +266,7 @@ private class StatementTransformer(val ion: IonSystem) { condition = predicate?.toExprNode() ?: Literal(ion.newBool(true), metaContainerOf(StaticTypeMeta(StaticType.BOOL))), metas = metas ) + is PartiqlAst.FromSource.GraphMatch -> error("$this node has no representation in prior ASTs.") } } diff --git a/lang/src/org/partiql/lang/errors/ErrorCode.kt b/lang/src/org/partiql/lang/errors/ErrorCode.kt index 61cc8e8e59..c4518b9c35 100644 --- a/lang/src/org/partiql/lang/errors/ErrorCode.kt +++ b/lang/src/org/partiql/lang/errors/ErrorCode.kt @@ -393,6 +393,42 @@ enum class ErrorCode( "expected identifier for alias" ), + PARSE_EXPECTED_IDENT_FOR_MATCH( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected identifier for match" + ), + + PARSE_EXPECTED_LEFT_PAREN_FOR_MATCH_NODE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected left parenthesis for match node" + ), + + PARSE_EXPECTED_RIGHT_PAREN_FOR_MATCH_NODE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected right parenthesis for match node" + ), + + PARSE_EXPECTED_LEFT_BRACKET_FOR_MATCH_EDGE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected left bracket for match edge" + ), + + PARSE_EXPECTED_RIGHT_BRACKET_FOR_MATCH_EDGE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected right bracket for match edge" + ), + + PARSE_EXPECTED_EDGE_PATTERN_MATCH_EDGE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected edge pattern for match edge" + ), + PARSE_EXPECTED_AS_FOR_LET( ErrorCategory.PARSER, LOC_TOKEN, diff --git a/lang/src/org/partiql/lang/eval/visitors/GroupByPathExpressionVisitorTransform.kt b/lang/src/org/partiql/lang/eval/visitors/GroupByPathExpressionVisitorTransform.kt index 1687bdefab..9e53169414 100644 --- a/lang/src/org/partiql/lang/eval/visitors/GroupByPathExpressionVisitorTransform.kt +++ b/lang/src/org/partiql/lang/eval/visitors/GroupByPathExpressionVisitorTransform.kt @@ -76,6 +76,9 @@ class GroupByPathExpressionVisitorTransform( is PartiqlAst.FromSource.Unpivot -> listOfNotNull(fromSource.asAlias?.text, fromSource.atAlias?.text) + + is PartiqlAst.FromSource.GraphMatch -> + TODO("Handle MATCH for GROUP BY") } } diff --git a/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt index 8b179008fe..29a072163e 100644 --- a/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt +++ b/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt @@ -165,4 +165,8 @@ private object FromSourceToBexpr : PartiqlAst.FromSource.Converter = DateTimePart.values() "+", "-", "not" ) +/** All operators with special parsing rules. */ +@JvmField internal val MATCH_OPERATORS = setOf( + "~" +) + /** All operators with special parsing rules. */ @JvmField internal val SPECIAL_OPERATORS = SPECIAL_INFIX_OPERATORS + setOf( "@" ) @JvmField internal val ALL_SINGLE_LEXEME_OPERATORS = - SINGLE_LEXEME_BINARY_OPERATORS + UNARY_OPERATORS + SPECIAL_OPERATORS + SINGLE_LEXEME_BINARY_OPERATORS + UNARY_OPERATORS + SPECIAL_OPERATORS + MATCH_OPERATORS @JvmField internal val ALL_OPERATORS = - BINARY_OPERATORS + UNARY_OPERATORS + SPECIAL_OPERATORS + BINARY_OPERATORS + UNARY_OPERATORS + SPECIAL_OPERATORS + MATCH_OPERATORS /** * Operator precedence groups @@ -585,7 +590,7 @@ internal const val DIGIT_CHARS = "0" + NON_ZERO_DIGIT_CHARS @JvmField internal val E_NOTATION_CHARS = allCase("E") -internal const val NON_OVERLOADED_OPERATOR_CHARS = "^%=@+" +internal const val NON_OVERLOADED_OPERATOR_CHARS = "^%=@+~" internal const val OPERATOR_CHARS = NON_OVERLOADED_OPERATOR_CHARS + "-*/<>|!" @JvmField internal val ALPHA_CHARS = allCase("ABCDEFGHIJKLMNOPQRSTUVWXYZ") diff --git a/lang/src/org/partiql/lang/syntax/SqlParser.kt b/lang/src/org/partiql/lang/syntax/SqlParser.kt index a8ca86ba08..ad694cdfbe 100644 --- a/lang/src/org/partiql/lang/syntax/SqlParser.kt +++ b/lang/src/org/partiql/lang/syntax/SqlParser.kt @@ -175,6 +175,13 @@ class SqlParser( FROM, FROM_CLAUSE, FROM_SOURCE_JOIN, + MATCH, + MATCH_EXPR, + MATCH_EXPR_NODE, + MATCH_EXPR_EDGE, + MATCH_EXPR_EDGE_DIRECTION, + MATCH_EXPR_NAME, + MATCH_EXPR_LABEL, CHECK, ON_CONFLICT, CONFLICT_ACTION, @@ -912,11 +919,134 @@ class SqlParser( if (isCrossJoin) metas + metaContainerOf(IsImplictJoinMeta.instance) else metas ) } + ParseType.MATCH -> toGraphMatch() else -> unwrapAliasesAndUnpivot() } } } + private fun ParseNode.toGraphMatch(): PartiqlAst.FromSource { + val metas = getMetas() + + return PartiqlAst.build { + val expr = children[0].toAstExpr() + val patterns = children.tail.map { + if (it.type != ParseType.MATCH_EXPR) error("Invalid parse tree: expecting match expression in MATCH") + it.toGraphMatchPattern() + } + + val matchExpr = PartiqlAst.GraphMatchExpr(patterns, metas = metas) + PartiqlAst.FromSource.GraphMatch(expr, matchExpr, metas) + } + } + + private fun ParseNode.toGraphMatchPattern(): PartiqlAst.GraphMatchPattern { + val metas = getMetas() + + return PartiqlAst.build { + val parts = children.map { + when (it.type) { + ParseType.MATCH_EXPR_NODE -> { + PartiqlAst.GraphMatchPatternPart.Node(it.toGraphMatchNode(), it.getMetas()) + } + ParseType.MATCH_EXPR_EDGE -> { + PartiqlAst.GraphMatchPatternPart.Edge(it.toGraphMatchEdge(), it.getMetas()) + } + else -> { + TODO("Handle pattern part other than node&edge") + } + } + } + + // TODO quantifier + PartiqlAst.GraphMatchPattern(quantifier = null, parts = parts, metas = metas) + } + } + + private fun ParseNode.toGraphMatchNode(): PartiqlAst.GraphMatchNode { + val metas = getMetas() + + var name: SymbolPrimitive? = null + var label = mutableListOf() + var predicate = null + + for (child in children) { + when (child.type) { + ParseType.MATCH_EXPR_NAME -> { + if (name != null) error("Invalid parse tree: name encountered more than once in MATCH") + name = SymbolPrimitive(child.children[0].token!!.text!!, child.getMetas()) + } + ParseType.MATCH_EXPR_LABEL -> { + label.add(SymbolPrimitive(child.children[0].token!!.text!!, child.getMetas())) + } + else -> { + TODO("Unhandled case for graph match node") + } + } + } + + return PartiqlAst.build { + PartiqlAst.GraphMatchNode( + variable = name, + label = label, + predicate = predicate, + metas = metas + ) + } + } + + private fun ParseNode.toGraphMatchEdge(): PartiqlAst.GraphMatchEdge { + val metas = getMetas() + + var direction: PartiqlAst.GraphMatchDirection? = null + var name: SymbolPrimitive? = null + var label = mutableListOf() + var predicate = null + + for (child in children) { + when (child.type) { + ParseType.MATCH_EXPR_NAME -> { + if (name != null) error("Invalid parse tree: name encountered more than once in MATCH") + name = SymbolPrimitive(child.children[0].token!!.text!!, child.getMetas()) + } + ParseType.MATCH_EXPR_LABEL -> { + label.add(SymbolPrimitive(child.children[0].token!!.text!!, child.getMetas())) + } + ParseType.MATCH_EXPR_EDGE_DIRECTION -> { + direction = when (child.token!!.text!!) { + "<-" -> PartiqlAst.GraphMatchDirection.EdgeLeft() + "~" -> PartiqlAst.GraphMatchDirection.EdgeUndirected() + "->" -> PartiqlAst.GraphMatchDirection.EdgeRight() + "<~" -> PartiqlAst.GraphMatchDirection.EdgeLeftOrUndirected() + "~>" -> PartiqlAst.GraphMatchDirection.EdgeUndirectedOrRight() + "<->" -> PartiqlAst.GraphMatchDirection.EdgeLeftOrRight() + "-" -> PartiqlAst.GraphMatchDirection.EdgeLeftOrUndirectedOrRight() + else -> error("Invalid parse tree: unknown edge direction ${child.token.text!!}") + } + } + else -> { + TODO("Unhandled case for graph match edge") + } + } + } + + if (direction == null) { + error("Invalid parse tree: null edge direction") + } + + // TODO quantifier + return PartiqlAst.build { + PartiqlAst.GraphMatchEdge( + direction = direction, + quantifier = null, + variable = name, + label = label, + predicate = predicate, + metas = metas + ) + } + } + private fun ParseNode.unwrapAliasesAndUnpivot(): PartiqlAst.FromSource { val (aliases, unwrappedParseNode) = unwrapAliases() @@ -2928,6 +3058,10 @@ class SqlParser( } rem = child.remaining + child = rem.parseOptionalMatchClause(child).also { + rem = it.remaining + } + child = rem.parseOptionalAsAlias(child).also { rem = it.remaining } @@ -3153,6 +3287,283 @@ class SqlParser( ErrorCode.PARSE_UNEXPECTED_TERM, PropertyValueMap() ) + private fun List.parseOptionalMatchClause(child: ParseNode): ParseNode { + var rem = this + return when (rem.head?.keywordText) { + "match" -> { + rem.parseMatch(child) + } + else -> { + child + } + } + } + + private fun List.parseMatch(expr: ParseNode): ParseNode { + val matches = ArrayList() + var rem = this.tail + + fun consume(type: TokenType): Boolean { + if (rem.head?.type == type) { + rem = rem.tail + return true + } + return false + } + + do { + val pattern = rem.parseMatchPattern() + matches.add(pattern) + rem = pattern.remaining + } while (consume(TokenType.COMMA)) + + return ParseNode(ParseType.MATCH, this.head, listOf(expr) + matches, rem) + } + + private fun List.parseMatchPattern(): ParseNode { + + val abbreviationMap = Array(2) { Array(2) { Array(2) { "" } } }.also { + it[0][0][1] = "~" + it[0][1][0] = "->" + it[0][1][1] = "~>" + it[1][0][0] = "<-" + it[1][0][1] = "<~" + it[1][1][0] = "<->" + it[1][1][1] = "-" + } + + data class EdgeType(val left: Boolean, val right: Boolean, val undirected: Boolean) { + fun union(other: EdgeType): EdgeType { + val left = this.left.or(other.left) + val right = this.right.or(other.right) + val undirected = this.undirected.or(other.undirected) + val union = EdgeType(left = left, right = right, undirected = undirected) + return union + } + } + + fun EdgeType.abbreviation(): String { + return abbreviationMap[if (left) 1 else 0][if (right) 1 else 0][if (undirected) 1 else 0] + } + + val abbreviations = HashMap().also { + for (lIdx in 0..1) { + for (rIdx in 0..1) { + for (unIdx in 0..1) { + val abbreviation = abbreviationMap[lIdx][rIdx][unIdx] + if (abbreviation.isNotEmpty()) { + it[abbreviation] = EdgeType(left = (lIdx > 0), right = (rIdx > 0), undirected = (unIdx > 0)) + } + } + } + } + } + + var rem = this + + fun parseName(): ParseNode? { + if (rem.head?.type?.isIdentifier() == true) { + val name = rem.atomFromHead() + return ParseNode(ParseType.MATCH_EXPR_NAME, null, listOf(name), name.remaining) + } else { + return null + } + } + + fun parseLabel(): ParseNode? { + return when (rem.head?.type) { + TokenType.COLON -> { + rem = rem.tail + if (rem.head?.type?.isIdentifier() == true) { + val name = rem.atomFromHead() + return ParseNode(ParseType.MATCH_EXPR_LABEL, null, listOf(name), name.remaining) + } else { + rem.head.err( + "Expected identifier for", + ErrorCode.PARSE_EXPECTED_IDENT_FOR_MATCH + ) + } + } + else -> null + } + } + + fun consume(type: TokenType, keywordText: String? = null): Boolean { + if (rem.head?.type == type) { + if (keywordText == null || rem.head?.keywordText == keywordText) { + rem = rem.tail + return true + } + } + return false + } + + fun expect1(type: TokenType, errorCode: ErrorCode, errorMsg: String) { + if (!consume(type)) { + rem.head.err( + "Expected ${type.name} for $errorMsg", errorCode + ) + } + } + + fun parseNode(): ParseNode { + expect1(TokenType.LEFT_PAREN, ErrorCode.PARSE_EXPECTED_LEFT_PAREN_FOR_MATCH_NODE, "match node") + val name = parseName() + rem = name?.remaining ?: rem + val label = parseLabel() + rem = label?.remaining ?: rem + expect1(TokenType.RIGHT_PAREN, ErrorCode.PARSE_EXPECTED_RIGHT_PAREN_FOR_MATCH_NODE, "match node") + + return ParseNode(ParseType.MATCH_EXPR_NODE, null, listOfNotNull(name, label), rem) + } + + fun errorEdgeParse(): Nothing { + rem.head.err("Expected edge pattern for match", ErrorCode.PARSE_EXPECTED_EDGE_PATTERN_MATCH_EDGE) + } + + fun parseLeftEdgePattern(): EdgeType { + val direction = if (rem.head?.type == TokenType.OPERATOR) { + when (rem.head!!.keywordText) { + "<" -> { + rem = rem.tail + if (rem.head?.type == TokenType.OPERATOR) { + when (rem.head!!.keywordText) { + "-" -> EdgeType(left = true, right = false, undirected = false) + "~" -> EdgeType(left = true, right = false, undirected = true) + else -> errorEdgeParse() + } + } else { + errorEdgeParse() + } + } + "-" -> EdgeType(left = false, right = true, undirected = false) + "~" -> EdgeType(left = false, right = false, undirected = true) + else -> errorEdgeParse() + } + } else { + errorEdgeParse() + } + rem = rem.tail + + return direction + } + + fun parseRightEdgePattern(): EdgeType { + val direction = if (rem.head?.type == TokenType.OPERATOR) { + when (rem.head!!.keywordText) { + "-" -> { + if (rem.tail.head?.type == TokenType.OPERATOR && rem.tail.head?.keywordText == ">") { + rem = rem.tail + EdgeType(left = false, right = true, undirected = false) + } else { + EdgeType(left = true, right = false, undirected = false) + } + } + "~" -> { + if (rem.tail.head?.type == TokenType.OPERATOR && rem.tail.head?.keywordText == ">") { + rem = rem.tail + EdgeType(left = false, right = true, undirected = true) + } else { + EdgeType(left = false, right = false, undirected = true) + } + } + else -> errorEdgeParse() + } + } else { + errorEdgeParse() + } + rem = rem.tail + + return direction + } + + fun parseEdgeWithSpec(): ParseNode { + val dir1 = parseLeftEdgePattern() + expect1(TokenType.LEFT_BRACKET, ErrorCode.PARSE_EXPECTED_LEFT_BRACKET_FOR_MATCH_EDGE, "match edge") + val name = parseName() + rem = name?.remaining ?: rem + val label = parseLabel() + rem = label?.remaining ?: rem + expect1(TokenType.RIGHT_BRACKET, ErrorCode.PARSE_EXPECTED_RIGHT_BRACKET_FOR_MATCH_EDGE, "match edge") + val dir2 = parseRightEdgePattern() + + val dir = if (dir1 == EdgeType(left = false, right = true, undirected = false) && + dir2 == EdgeType(left = true, right = false, undirected = false) + ) { + EdgeType(left = true, right = true, undirected = true) + } else { + dir1.union(dir2) + } + + val directionToken = + Token(TokenType.OPERATOR, ion.newSymbol(dir.abbreviation()), SourceSpan(0, 0, 0)) + val direction = ParseNode(ParseType.MATCH_EXPR_EDGE_DIRECTION, directionToken, emptyList(), rem) + + return ParseNode(ParseType.MATCH_EXPR_EDGE, null, listOfNotNull(direction, name, label), rem) + } + + fun parseEdgeAbbreviated(): ParseNode { + var candidates: Map = abbreviations + do { + if (rem.head?.type == TokenType.OPERATOR) { + val char = rem.head!!.keywordText!! + rem = rem.tail + candidates = candidates.filterKeys { it.startsWith(char) }.mapKeys { it.key.removePrefix(char) } + if (candidates.size == 1) { + val edge = candidates.values.first() + val directionToken = + Token(TokenType.OPERATOR, ion.newSymbol(edge.abbreviation()), SourceSpan(0, 0, 0)) + val direction = ParseNode(ParseType.MATCH_EXPR_EDGE_DIRECTION, directionToken, emptyList(), rem) + return ParseNode(ParseType.MATCH_EXPR_EDGE, null, listOf(direction), rem) + } + } else if (candidates.contains("")) { + val edge = candidates[""]!! + val directionToken = + Token(TokenType.OPERATOR, ion.newSymbol(edge.abbreviation()), SourceSpan(0, 0, 0)) + val direction = ParseNode(ParseType.MATCH_EXPR_EDGE_DIRECTION, directionToken, emptyList(), rem) + return ParseNode(ParseType.MATCH_EXPR_EDGE, null, listOf(direction), rem) + } else { + errorEdgeParse() + } + } while (candidates.isNotEmpty()) + errorEdgeParse() + } + + fun parseEdge(): ParseNode { + var preRem = rem + return try { + parseEdgeWithSpec() + } catch (e: ParserException) { + rem = preRem + parseEdgeAbbreviated() + } + } + + val patterns = ArrayList() + val nodeLeft = try { + parseNode() + } catch (e: ParserException) { + null + } + patterns.add(nodeLeft) + do { + val edge = try { + parseEdge() + } catch (e: ParserException) { + null + } + patterns.add(edge) + val nodeRight = try { + parseNode() + } catch (e: ParserException) { + null + } + patterns.add(nodeRight) + } while (edge != null && nodeRight != null) + + return ParseNode(ParseType.MATCH_EXPR, null, patterns.filterNotNull(), rem) + } + /** * Validates tree to make sure that the top level tokens are not found below the top level. * Top level tokens are the tokens or keywords which are valid to be used only at the top level in the query. diff --git a/lang/test/org/partiql/lang/syntax/SqlParserMatchTest.kt b/lang/test/org/partiql/lang/syntax/SqlParserMatchTest.kt new file mode 100644 index 0000000000..ac3f5ad934 --- /dev/null +++ b/lang/test/org/partiql/lang/syntax/SqlParserMatchTest.kt @@ -0,0 +1,462 @@ +package org.partiql.lang.syntax + +import com.amazon.ionelement.api.ionBool +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionString +import org.junit.Ignore +import org.junit.Test +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.id + +class SqlParserMatchTest : SqlParserTestBase() { + @Test + fun allNodesNoLabel() = assertExpressionNoRoundTrip( + "SELECT 1 FROM my_graph MATCH ()" + ) { + select( + project = projectList(projectExpr(lit(ionInt(1)))), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf(PartiqlAst.GraphMatchPatternPart.Node((graphMatchNode()))) + ) + ) + ) + ), + where = null + ) + } + + @Test + fun allNodesNoLabelFilter() = assertExpressionNoRoundTrip( + "SELECT 1 FROM my_graph MATCH () WHERE contains_value('1')", + ) { + select( + project = projectList(projectExpr(lit(ionInt(1)))), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf(PartiqlAst.GraphMatchPatternPart.Node((graphMatchNode()))) + ) + ) + ) + ), + where = call(funcName = "contains_value", args = listOf(lit(ionString("1")))) + ) + } + + @Test + fun allNodes() = assertExpressionNoRoundTrip( + "SELECT x.info AS info FROM my_graph MATCH (x) WHERE x.name LIKE 'foo'", + ) { + select( + project = projectList( + projectExpr( + expr = path(id("x"), pathExpr(lit(ionString("info")), caseInsensitive())), + asAlias = "info" + ) + ), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf(PartiqlAst.GraphMatchPatternPart.Node((graphMatchNode(variable = "x")))) + ) + ) + ) + ), + where = like( + value = path(id("x"), pathExpr(lit(ionString("name")), caseInsensitive())), + pattern = lit(ionString("foo")) + ) + ) + } + + @Test + fun labelledNodes() = assertExpressionNoRoundTrip( + "SELECT x AS target FROM my_graph MATCH (x:Label) WHERE x.has_data = true", + ) { + select( + project = projectList(projectExpr(expr = id("x"), asAlias = "target")), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + PartiqlAst.GraphMatchPatternPart.Node( + graphMatchNode( + variable = "x", + label = listOf("Label") + ) + ) + ) + ) + ) + ) + ), + where = eq( + listOf( + path(id("x"), pathExpr(lit(ionString("has_data")), caseInsensitive())), + lit(ionBool(true)) + ) + ) + ) + } + + @Test + fun allEdges() = assertExpressionNoRoundTrip( + "SELECT 1 FROM g MATCH -[]-> ", + ) { + select( + project = projectList(projectExpr(lit(ionInt(1)))), + from = graphMatch( + expr = id("g"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf(PartiqlAst.GraphMatchPatternPart.Edge((graphMatchEdge(direction = edgeRight())))) + ) + ) + ) + ), + where = null + ) + } + + val simpleGraphAST = { direction: PartiqlAst.GraphMatchDirection, variable: String?, label: List? -> + PartiqlAst.build { + select( + project = projectList(projectExpr(id("a")), projectExpr(id("b"))), + from = graphMatch( + expr = id("g"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + PartiqlAst.GraphMatchPatternPart.Node( + graphMatchNode( + variable = "a", + label = listOf("A") + ) + ), + PartiqlAst.GraphMatchPatternPart.Edge( + graphMatchEdge( + direction = direction, + variable = variable, + label = label ?: emptyList() + ) + ), + PartiqlAst.GraphMatchPatternPart.Node( + graphMatchNode( + variable = "b", + label = listOf("B") + ) + ), + ) + ) + ) + ) + ), + where = null + ) + } + } + + @Test + fun rightDirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) -[e:E]-> (b:B)", + ) { + simpleGraphAST(edgeRight(), "e", listOf("E")) + } + + @Test + fun rightDirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) -> (b:B)", + ) { + simpleGraphAST(edgeRight(), null, null) + } + + @Test + fun leftDirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <-[e:E]- (b:B)", + ) { + simpleGraphAST(edgeLeft(), "e", listOf("E")) + } + + @Test + fun leftDirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <- (b:B)", + ) { + simpleGraphAST(edgeLeft(), null, null) + } + + @Test + fun undirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) ~[e:E]~ (b:B)", + ) { + simpleGraphAST(edgeUndirected(), "e", listOf("E")) + } + + @Test + fun undirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) ~ (b:B)", + ) { + simpleGraphAST(edgeUndirected(), null, null) + } + + @Test + fun rightOrUnDirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) ~[e:E]~> (b:B)", + ) { + simpleGraphAST(edgeUndirectedOrRight(), "e", listOf("E")) + } + + @Test + fun rightOrUnDirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) ~> (b:B)", + ) { + simpleGraphAST(edgeUndirectedOrRight(), null, null) + } + + @Test + fun leftOrUnDirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <~[e:E]~ (b:B)", + ) { + simpleGraphAST(edgeLeftOrUndirected(), "e", listOf("E")) + } + + @Test + fun leftOrUnDirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <~ (b:B)", + ) { + simpleGraphAST(edgeLeftOrUndirected(), null, null) + } + + @Test + fun leftOrRight() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <-[e:E]-> (b:B)", + ) { + simpleGraphAST(edgeLeftOrRight(), "e", listOf("E")) + } + + @Test + fun leftOrRightAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <-> (b:B)", + ) { + simpleGraphAST(edgeLeftOrRight(), null, null) + } + + @Test + fun leftOrRightOrUndirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) -[e:E]- (b:B)", + ) { + simpleGraphAST(edgeLeftOrUndirectedOrRight(), "e", listOf("E")) + } + + @Test + fun leftOrRightOrUndirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) - (b:B)", + ) { + simpleGraphAST(edgeLeftOrUndirectedOrRight(), null, null) + } + + @Test + fun singleEdgeMatch() = assertExpressionNoRoundTrip( + "SELECT the_a.name AS src, the_b.name AS dest FROM my_graph MATCH (the_a:a) -[the_y:y]-> (the_b:b) WHERE the_y.score > 10", + ) { + select( + project = projectList( + projectExpr( + expr = path(id("the_a"), pathExpr(lit(ionString("name")), caseInsensitive())), + asAlias = "src" + ), + projectExpr( + expr = path(id("the_b"), pathExpr(lit(ionString("name")), caseInsensitive())), + asAlias = "dest" + ) + ), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + PartiqlAst.GraphMatchPatternPart.Node( + graphMatchNode( + variable = "the_a", + label = listOf("a") + ) + ), + PartiqlAst.GraphMatchPatternPart.Edge( + graphMatchEdge( + direction = edgeRight(), + variable = "the_y", + label = listOf("y") + ) + ), + PartiqlAst.GraphMatchPatternPart.Node( + graphMatchNode( + variable = "the_b", + label = listOf("b") + ) + ), + ) + ) + ) + ) + ), + where = gt( + listOf( + path(id("the_y"), pathExpr(lit(ionString("score")), caseInsensitive())), + lit(ionInt(10)) + ) + ) + ) + } + + @Test + fun twoHopTriples() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a) -[:has]-> (x), (x)-[:contains]->(b)", + ) { + select( + project = projectList( + projectExpr(expr = id("a")), + projectExpr(expr = id("b")) + ), + from = graphMatch( + expr = id("g"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + PartiqlAst.GraphMatchPatternPart.Node(graphMatchNode(variable = "a")), + PartiqlAst.GraphMatchPatternPart.Edge( + graphMatchEdge( + direction = edgeRight(), + label = listOf("has") + ) + ), + PartiqlAst.GraphMatchPatternPart.Node(graphMatchNode(variable = "x")), + ) + ), + graphMatchPattern( + parts = listOf( + PartiqlAst.GraphMatchPatternPart.Node(graphMatchNode(variable = "x")), + PartiqlAst.GraphMatchPatternPart.Edge( + graphMatchEdge( + direction = edgeRight(), + label = listOf("contains") + ) + ), + PartiqlAst.GraphMatchPatternPart.Node(graphMatchNode(variable = "b")), + ) + ) + ) + ) + ) + ) + } + + @Test + fun twoHopPattern() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a)-[:has]->()-[:contains]->(b)", + ) { + select( + project = projectList( + projectExpr(expr = id("a")), + projectExpr(expr = id("b")) + ), + from = graphMatch( + expr = id("g"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + PartiqlAst.GraphMatchPatternPart.Node(graphMatchNode(variable = "a")), + PartiqlAst.GraphMatchPatternPart.Edge( + graphMatchEdge( + direction = edgeRight(), + label = listOf("has") + ) + ), + PartiqlAst.GraphMatchPatternPart.Node(graphMatchNode()), + PartiqlAst.GraphMatchPatternPart.Edge( + graphMatchEdge( + direction = edgeRight(), + label = listOf("contains") + ) + ), + PartiqlAst.GraphMatchPatternPart.Node(graphMatchNode(variable = "b")), + ) + ) + ) + ) + ) + ) + } + + // TODO prefilters + @Test + @Ignore + fun prefilters() = assertExpressionNoRoundTrip( + "SELECT u as banCandidate FROM g MATCH (p:Post Where p.isFlagged = true) ~[ep:createdPost]~ (u:User WHERE u.isBanned = false AND u.karma < 20) -[ec:createdComment]->(c:Comment WHERE c.isFlagged = true)", + ) { + TODO() + } + + // TODO label combinators + @Test + @Ignore + fun labelDisjunction() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x:Label|OtherLabel)", + ) { + TODO() + } + + @Test + @Ignore + fun labelConjunction() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x:Label&OtherLabel)", + ) { + TODO() + } + + @Test + @Ignore + fun labelNegation() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x:!Label)", + ) { + TODO() + } + + @Test + @Ignore + fun labelWildcard() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x:%)", + ) { + TODO() + } + + @Test + @Ignore + fun labelCombo() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x: L1|L2&L3|!L4|(L5&%)", + ) { + TODO() + } + + // TODO path variable (e.g., `MATCH p = (x) -> (y)` + // TODO quantifiers (e.g., `MATCH (a:Node)−[:Edge]−>{2,5}(b:Node)`, `*`, `+`) + // TODO group variables (e.g., `MATCH ... WHERE SUM()...`) + // TODO union & multiset (e.g., `MATCH (a:Label) | (a:Label2)` , `MATCH (a:Label) |+| (a:Label2)` + // TODO conditional variables + // TODO graphical predicates (i.e., `IS DIRECTED`, `IS SOURCE OF`, `IS DESTINATION OF`, `SAME`, `ALL DIFFERENT`) + // TODO restrictors & selectors (i.e., `TRAIL`|`ACYCLIC`|`SIMPLE` & ANY SHORTEST, ALL SHORTEST, ANY, ANY k, SHORTEST k, SHORTEST k GROUP) + // TODO selector filters +} diff --git a/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt b/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt index 2c90d3ec2c..4315347c65 100644 --- a/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt +++ b/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt @@ -66,7 +66,8 @@ abstract class SqlParserTestBase : TestBase() { */ protected fun assertExpression( source: String, - expectedPigAst: String + expectedPigAst: String, + roundTrip: Boolean = true, ) { val actualStatement = parse(source) val expectedIonSexp = loadIonSexp(expectedPigAst) @@ -80,7 +81,25 @@ abstract class SqlParserTestBase : TestBase() { pigDomainAssert(actualStatement, expectedElement) // Check equals for actual value after round trip transformation: astStatement -> ExprNode -> astStatement - assertRoundTripPigAstToExprNode(actualStatement) + if (roundTrip) { + assertRoundTripPigAstToExprNode(actualStatement) + } + } + + /** + * This method is used by test cases for parsing a string. + * The test are performed with only PIG AST. + * The expected PIG AST is a PIG builder. + * No ExprNode <-> PIG AST round trip is performed. + */ + protected fun assertExpressionNoRoundTrip( + source: String, + expectedPigBuilder: PartiqlAst.Builder.() -> PartiqlAst.PartiqlAstNode + ) { + val expectedPigAst = PartiqlAst.build { expectedPigBuilder() }.toIonElement().toString() + + // Refer to comments inside the main body of the following function to see what checks are performed. + assertExpression(source, expectedPigAst, false) } /** @@ -95,7 +114,7 @@ abstract class SqlParserTestBase : TestBase() { val expectedPigAst = PartiqlAst.build { expectedPigBuilder() }.toIonElement().toString() // Refer to comments inside the main body of the following function to see what checks are performed. - assertExpression(source, expectedPigAst) + assertExpression(source, expectedPigAst, true) } /**