diff --git a/lang/resources/org/partiql/type-domains/partiql.ion b/lang/resources/org/partiql/type-domains/partiql.ion index 8224ca3948..32658a807e 100644 --- a/lang/resources/org/partiql/type-domains/partiql.ion +++ b/lang/resources/org/partiql/type-domains/partiql.ion @@ -210,12 +210,68 @@ may then be further optimized by selecting better implementations of each operat // UNPIVOT [AS ] [AT ] [BY ] (unpivot expr::expr as_alias::(? symbol) at_alias::(? symbol) by_alias::(? symbol)) - // JOIN [INNER | LEFT | RIGHT | FULL] ON - (join type::join_type left::from_source right::from_source predicate::(? expr))) + // JOIN [INNER | LEFT | RIGHT | FULL] ON + (join type::join_type left::from_source right::from_source predicate::(? expr)) + + // MATCH + (graph_match expr::expr graph_expr::graph_match_expr)) // Indicates the logical type of join. (sum join_type (inner) (left) (right) (full)) + // The direction of an edge + // | Orientation | Edge pattern | Abbreviation | + // |---------------------------+--------------+--------------| + // | Pointing left | <−[ spec ]− | <− | + // | Undirected | ~[ spec ]~ | ~ | + // | Pointing right | −[ spec ]−> | −> | + // | Left or undirected | <~[ spec ]~ | <~ | + // | Undirected or right | ~[ spec ]~> | ~> | + // | Left or right | <−[ spec ]−> | <−> | + // | Left, undirected or right | −[ spec ]− | − | + // + // Fig. 5. Table of edge patterns: + // https://arxiv.org/abs/2112.06217 + (sum graph_match_direction + (edge_left) + (edge_undirected) + (edge_right) + (edge_left_or_undirected) + (edge_undirected_or_right) + (edge_left_or_right) + (edge_left_or_undirected_or_right)) + + // A part of a graph pattern + (sum graph_match_pattern_part + // A single node in a graph pattern. + (node + predicate::(? expr) // an optional node pre-filter, e.g.: `WHERE c.name='Alarm'` in `MATCH (c WHERE c.name='Alarm')` + variable::(? symbol) // the optional element variable of the node match, e.g.: `x` in `MATCH (x)` + label::(* symbol 0)) // the optional label(s) to match for the node, e.g.: `Entity` in `MATCH (x:Entity)` + + // A single edge in a graph pattern. + (edge + direction::graph_match_direction // edge direction + quantifier::(? graph_match_quantifier) // an optional quantifier for the edge match + predicate::(? expr) // an optional edge pre-filter, e.g.: `WHERE t.capacity>100` in `MATCH −[t:hasSupply WHERE t.capacity>100]−>` + variable::(? symbol) // the optional element variable of the edge match, e.g.: `t` in `MATCH −[t]−>` + label::(* symbol 0)) // the optional label(s) to match for the edge. e.g.: `Target` in `MATCH −[t:Target]−>` + // A sub-pattern. + (pattern pattern::graph_match_pattern)) + + // A quantifier for graph edges or patterns. (e.g., the `{2,5}` in `MATCH (x)->{2,5}(y)`) + (product graph_match_quantifier lower::int upper::(? int)) + + // A single graph match pattern. + (product graph_match_pattern + quantifier::(? graph_match_quantifier) // an optional quantifier for the entire pattern match + parts::(* graph_match_pattern_part 1)) // the ordered pattern parts + + // A graph match clause as defined in GPML + // See https://arxiv.org/abs/2112.06217 + (product graph_match_expr patterns::(*graph_match_pattern 1)) + + // A generic pair of expressions. Used in the `struct`, `searched_case` and `simple_case` expr variants above. (product expr_pair first::expr second::expr) diff --git a/lang/src/org/partiql/lang/ast/StatementToExprNode.kt b/lang/src/org/partiql/lang/ast/StatementToExprNode.kt index 920b3e3b7d..e6a44e2d97 100644 --- a/lang/src/org/partiql/lang/ast/StatementToExprNode.kt +++ b/lang/src/org/partiql/lang/ast/StatementToExprNode.kt @@ -266,6 +266,7 @@ private class StatementTransformer(val ion: IonSystem) { condition = predicate?.toExprNode() ?: Literal(ion.newBool(true), metaContainerOf(StaticTypeMeta(StaticType.BOOL))), metas = metas ) + is PartiqlAst.FromSource.GraphMatch -> error("$this node has no representation in prior ASTs.") } } diff --git a/lang/src/org/partiql/lang/errors/ErrorCode.kt b/lang/src/org/partiql/lang/errors/ErrorCode.kt index 61cc8e8e59..c4518b9c35 100644 --- a/lang/src/org/partiql/lang/errors/ErrorCode.kt +++ b/lang/src/org/partiql/lang/errors/ErrorCode.kt @@ -393,6 +393,42 @@ enum class ErrorCode( "expected identifier for alias" ), + PARSE_EXPECTED_IDENT_FOR_MATCH( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected identifier for match" + ), + + PARSE_EXPECTED_LEFT_PAREN_FOR_MATCH_NODE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected left parenthesis for match node" + ), + + PARSE_EXPECTED_RIGHT_PAREN_FOR_MATCH_NODE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected right parenthesis for match node" + ), + + PARSE_EXPECTED_LEFT_BRACKET_FOR_MATCH_EDGE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected left bracket for match edge" + ), + + PARSE_EXPECTED_RIGHT_BRACKET_FOR_MATCH_EDGE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected right bracket for match edge" + ), + + PARSE_EXPECTED_EDGE_PATTERN_MATCH_EDGE( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected edge pattern for match edge" + ), + PARSE_EXPECTED_AS_FOR_LET( ErrorCategory.PARSER, LOC_TOKEN, diff --git a/lang/src/org/partiql/lang/eval/visitors/GroupByPathExpressionVisitorTransform.kt b/lang/src/org/partiql/lang/eval/visitors/GroupByPathExpressionVisitorTransform.kt index 1687bdefab..9e53169414 100644 --- a/lang/src/org/partiql/lang/eval/visitors/GroupByPathExpressionVisitorTransform.kt +++ b/lang/src/org/partiql/lang/eval/visitors/GroupByPathExpressionVisitorTransform.kt @@ -76,6 +76,9 @@ class GroupByPathExpressionVisitorTransform( is PartiqlAst.FromSource.Unpivot -> listOfNotNull(fromSource.asAlias?.text, fromSource.atAlias?.text) + + is PartiqlAst.FromSource.GraphMatch -> + TODO("Handle MATCH for GROUP BY") } } diff --git a/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt b/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt index 8b179008fe..29a072163e 100644 --- a/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt +++ b/lang/src/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt @@ -165,4 +165,8 @@ private object FromSourceToBexpr : PartiqlAst.FromSource.Converter = DateTimePart.values() "+", "-", "not" ) +/** Operators specific to the `MATCH` clause. */ +@JvmField internal val MATCH_OPERATORS = setOf( + "~" +) + /** All operators with special parsing rules. */ @JvmField internal val SPECIAL_OPERATORS = SPECIAL_INFIX_OPERATORS + setOf( "@" ) @JvmField internal val ALL_SINGLE_LEXEME_OPERATORS = - SINGLE_LEXEME_BINARY_OPERATORS + UNARY_OPERATORS + SPECIAL_OPERATORS + SINGLE_LEXEME_BINARY_OPERATORS + UNARY_OPERATORS + SPECIAL_OPERATORS + MATCH_OPERATORS @JvmField internal val ALL_OPERATORS = - BINARY_OPERATORS + UNARY_OPERATORS + SPECIAL_OPERATORS + BINARY_OPERATORS + UNARY_OPERATORS + SPECIAL_OPERATORS + MATCH_OPERATORS /** * Operator precedence groups @@ -585,7 +590,7 @@ internal const val DIGIT_CHARS = "0" + NON_ZERO_DIGIT_CHARS @JvmField internal val E_NOTATION_CHARS = allCase("E") -internal const val NON_OVERLOADED_OPERATOR_CHARS = "^%=@+" +internal const val NON_OVERLOADED_OPERATOR_CHARS = "^%=@+~" internal const val OPERATOR_CHARS = NON_OVERLOADED_OPERATOR_CHARS + "-*/<>|!" @JvmField internal val ALPHA_CHARS = allCase("ABCDEFGHIJKLMNOPQRSTUVWXYZ") diff --git a/lang/src/org/partiql/lang/syntax/SqlParser.kt b/lang/src/org/partiql/lang/syntax/SqlParser.kt index a8ca86ba08..4ed58a3c4b 100644 --- a/lang/src/org/partiql/lang/syntax/SqlParser.kt +++ b/lang/src/org/partiql/lang/syntax/SqlParser.kt @@ -175,6 +175,13 @@ class SqlParser( FROM, FROM_CLAUSE, FROM_SOURCE_JOIN, + MATCH, + MATCH_EXPR, + MATCH_EXPR_NODE, + MATCH_EXPR_EDGE, + MATCH_EXPR_EDGE_DIRECTION, + MATCH_EXPR_NAME, + MATCH_EXPR_LABEL, CHECK, ON_CONFLICT, CONFLICT_ACTION, @@ -912,11 +919,130 @@ class SqlParser( if (isCrossJoin) metas + metaContainerOf(IsImplictJoinMeta.instance) else metas ) } + ParseType.MATCH -> toGraphMatch() else -> unwrapAliasesAndUnpivot() } } } + private fun ParseNode.toGraphMatch(): PartiqlAst.FromSource { + val metas = getMetas() + + return PartiqlAst.build { + val expr = children[0].toAstExpr() + val patterns = children.tail.map { + if (it.type != ParseType.MATCH_EXPR) error("Invalid parse tree: expecting match expression in MATCH") + it.toGraphMatchPattern() + } + + val matchExpr = PartiqlAst.GraphMatchExpr(patterns, metas = metas) + PartiqlAst.FromSource.GraphMatch(expr, matchExpr, metas) + } + } + + private fun ParseNode.toGraphMatchPattern(): PartiqlAst.GraphMatchPattern { + val metas = getMetas() + + return PartiqlAst.build { + val parts = children.map { + when (it.type) { + ParseType.MATCH_EXPR_NODE -> it.toGraphMatchNode() + ParseType.MATCH_EXPR_EDGE -> it.toGraphMatchEdge() + else -> { + TODO("Handle pattern part other than node&edge") + } + } + } + + // TODO quantifier + PartiqlAst.GraphMatchPattern(quantifier = null, parts = parts, metas = metas) + } + } + + private fun ParseNode.toGraphMatchNode(): PartiqlAst.GraphMatchPatternPart.Node { + val metas = getMetas() + + var name: SymbolPrimitive? = null + val label = mutableListOf() + val predicate = null + + for (child in children) { + when (child.type) { + ParseType.MATCH_EXPR_NAME -> { + if (name != null) error("Invalid parse tree: name encountered more than once in MATCH") + name = SymbolPrimitive(child.children[0].token!!.text!!, child.getMetas()) + } + ParseType.MATCH_EXPR_LABEL -> { + label.add(SymbolPrimitive(child.children[0].token!!.text!!, child.getMetas())) + } + else -> { + TODO("Unhandled case for graph match node") + } + } + } + + return PartiqlAst.build { + PartiqlAst.GraphMatchPatternPart.Node( + variable = name, + label = label, + predicate = predicate, + metas = metas + ) + } + } + + private fun ParseNode.toGraphMatchEdge(): PartiqlAst.GraphMatchPatternPart.Edge { + val metas = getMetas() + + var direction: PartiqlAst.GraphMatchDirection? = null + var name: SymbolPrimitive? = null + val label = mutableListOf() + val predicate = null + + for (child in children) { + when (child.type) { + ParseType.MATCH_EXPR_NAME -> { + if (name != null) error("Invalid parse tree: name encountered more than once in MATCH") + name = SymbolPrimitive(child.children[0].token!!.text!!, child.getMetas()) + } + ParseType.MATCH_EXPR_LABEL -> { + label.add(SymbolPrimitive(child.children[0].token!!.text!!, child.getMetas())) + } + ParseType.MATCH_EXPR_EDGE_DIRECTION -> { + direction = when (child.token!!.text!!) { + "<-" -> PartiqlAst.GraphMatchDirection.EdgeLeft() + "~" -> PartiqlAst.GraphMatchDirection.EdgeUndirected() + "->" -> PartiqlAst.GraphMatchDirection.EdgeRight() + "<~" -> PartiqlAst.GraphMatchDirection.EdgeLeftOrUndirected() + "~>" -> PartiqlAst.GraphMatchDirection.EdgeUndirectedOrRight() + "<->" -> PartiqlAst.GraphMatchDirection.EdgeLeftOrRight() + "-" -> PartiqlAst.GraphMatchDirection.EdgeLeftOrUndirectedOrRight() + else -> error("Invalid parse tree: unknown edge direction ${child.token.text!!}") + } + } + else -> { + TODO("Unhandled case for graph match edge") + } + } + } + + if (direction == null) { + error("Invalid parse tree: null edge direction") + } + + // TODO quantifier + return PartiqlAst.build { + PartiqlAst.GraphMatchPatternPart.Edge( + direction = direction, + quantifier = null, + variable = name, + label = label, + predicate = predicate, + metas = metas + ) + } + } + private fun ParseNode.unwrapAliasesAndUnpivot(): PartiqlAst.FromSource { val (aliases, unwrappedParseNode) = unwrapAliases() @@ -2928,6 +3054,10 @@ class SqlParser( } rem = child.remaining + child = rem.parseOptionalMatchClause(child).also { + rem = it.remaining + } + child = rem.parseOptionalAsAlias(child).also { rem = it.remaining } @@ -3153,6 +3283,324 @@ class SqlParser( ErrorCode.PARSE_UNEXPECTED_TERM, PropertyValueMap() ) + private fun List.parseOptionalMatchClause(child: ParseNode): ParseNode { + val rem = this + return when (rem.head?.keywordText) { + "match" -> { + rem.parseMatch(child) + } + else -> { + child + } + } + } + + private fun List.parseMatch(expr: ParseNode): ParseNode { + val matches = ArrayList() + var rem = this.tail + + fun consume(type: TokenType): Boolean { + if (rem.head?.type == type) { + rem = rem.tail + return true + } + return false + } + + do { + val pattern = rem.parseMatchPattern() + matches.add(pattern) + rem = pattern.remaining + } while (consume(TokenType.COMMA)) + + return ParseNode(ParseType.MATCH, this.head, listOf(expr) + matches, rem) + } + + private fun List.parseMatchPattern(): ParseNode { + // left/right/undirected edge directions essentially form a 3-bit flag + // represent here the 7 (non-zero) 3-bit combinations to their abbreviation for lookup + val abbreviationMap = Array(2) { Array(2) { Array(2) { "" } } }.also { + it[0][0][1] = "~" + it[0][1][0] = "->" + it[0][1][1] = "~>" + it[1][0][0] = "<-" + it[1][0][1] = "<~" + it[1][1][0] = "<->" + it[1][1][1] = "-" + } + + data class EdgeType(val left: Boolean, val right: Boolean, val undirected: Boolean) { + // Just union all the left/right/undirected flags + fun union(other: EdgeType): EdgeType { + val left = this.left.or(other.left) + val right = this.right.or(other.right) + val undirected = this.undirected.or(other.undirected) + val union = EdgeType(left = left, right = right, undirected = undirected) + return union + } + + // Combine leading and trailing edge detection. + // If leading thinks right and trailing thinks left + // then left + right + undirected + // else + // union(leading, trailing) + fun combine(other: EdgeType): EdgeType { + return if (this == EdgeType(left = false, right = true, undirected = false) && + other == EdgeType(left = true, right = false, undirected = false) + ) { + EdgeType(left = true, right = true, undirected = true) + } else { + this.union(other) + } + } + } + + fun EdgeType.abbreviation(): String { + return abbreviationMap[if (left) 1 else 0][if (right) 1 else 0][if (undirected) 1 else 0] + } + + val abbreviations = HashMap().also { + for (lIdx in 0..1) { + for (rIdx in 0..1) { + for (unIdx in 0..1) { + val abbreviation = abbreviationMap[lIdx][rIdx][unIdx] + if (abbreviation.isNotEmpty()) { + it[abbreviation] = EdgeType(left = (lIdx > 0), right = (rIdx > 0), undirected = (unIdx > 0)) + } + } + } + } + } + + var rem = this + + fun parseName(): ParseNode? { + if (rem.head?.type?.isIdentifier() == true) { + val name = rem.atomFromHead() + return ParseNode(ParseType.MATCH_EXPR_NAME, null, listOf(name), name.remaining) + } else { + return null + } + } + + fun parseLabel(): ParseNode? { + return when (rem.head?.type) { + TokenType.COLON -> { + rem = rem.tail + if (rem.head?.type?.isIdentifier() == true) { + val name = rem.atomFromHead() + return ParseNode(ParseType.MATCH_EXPR_LABEL, null, listOf(name), name.remaining) + } else { + rem.head.err( + "Expected identifier for", + ErrorCode.PARSE_EXPECTED_IDENT_FOR_MATCH + ) + } + } + else -> null + } + } + + // 'consume' a single token matching the specified type and optionally matching a specified keyword + fun consume(type: TokenType, keywordText: String? = null): Boolean { + if (rem.head?.type == type) { + if (keywordText == null || rem.head?.keywordText == keywordText) { + rem = rem.tail + return true + } + } + return false + } + + // `consume` a single token matching the specified type or throw an error if not possible + fun expect(type: TokenType, errorCode: ErrorCode, errorMsg: String) { + if (!consume(type)) { + rem.head.err( + "Expected ${type.name} for $errorMsg", errorCode + ) + } + } + + fun parseNode(): ParseNode { + expect(TokenType.LEFT_PAREN, ErrorCode.PARSE_EXPECTED_LEFT_PAREN_FOR_MATCH_NODE, "match node") + val name = parseName() + rem = name?.remaining ?: rem + val label = parseLabel() + rem = label?.remaining ?: rem + expect(TokenType.RIGHT_PAREN, ErrorCode.PARSE_EXPECTED_RIGHT_PAREN_FOR_MATCH_NODE, "match node") + + return ParseNode(ParseType.MATCH_EXPR_NODE, null, listOfNotNull(name, label), rem) + } + + fun errorEdgeParse(): Nothing { + rem.head.err("Expected edge pattern for match", ErrorCode.PARSE_EXPECTED_EDGE_PATTERN_MATCH_EDGE) + } + + fun parseLeftEdgePattern(): EdgeType { + val direction = if (rem.head?.type == TokenType.OPERATOR) { + when (rem.head!!.keywordText) { + "<" -> { + rem = rem.tail + if (rem.head?.type == TokenType.OPERATOR) { + when (rem.head!!.keywordText) { + "-" -> EdgeType(left = true, right = false, undirected = false) + "~" -> EdgeType(left = true, right = false, undirected = true) + else -> errorEdgeParse() + } + } else { + errorEdgeParse() + } + } + "-" -> EdgeType(left = false, right = true, undirected = false) + "~" -> EdgeType(left = false, right = false, undirected = true) + else -> errorEdgeParse() + } + } else { + errorEdgeParse() + } + rem = rem.tail + + return direction + } + + fun parseRightEdgePattern(): EdgeType { + val direction = if (rem.head?.type == TokenType.OPERATOR) { + when (rem.head!!.keywordText) { + "-" -> { + if (rem.tail.head?.type == TokenType.OPERATOR && rem.tail.head?.keywordText == ">") { + rem = rem.tail + EdgeType(left = false, right = true, undirected = false) + } else { + EdgeType(left = true, right = false, undirected = false) + } + } + "~" -> { + if (rem.tail.head?.type == TokenType.OPERATOR && rem.tail.head?.keywordText == ">") { + rem = rem.tail + EdgeType(left = false, right = true, undirected = true) + } else { + EdgeType(left = false, right = false, undirected = true) + } + } + else -> errorEdgeParse() + } + } else { + errorEdgeParse() + } + rem = rem.tail + + return direction + } + + // Parses an edge pattern containing a spec as defined by + // + // | Orientation | Edge pattern | Abbreviation | + // |---------------------------+--------------+--------------| + // | Pointing left | <−[ spec ]− | <− | + // | Undirected | ~[ spec ]~ | ~ | + // | Pointing right | −[ spec ]−> | −> | + // | Left or undirected | <~[ spec ]~ | <~ | + // | Undirected or right | ~[ spec ]~> | ~> | + // | Left or right | <−[ spec ]−> | <−> | + // | Left, undirected or right | −[ spec ]− | − | + // + // Fig. 5. Table of edge patterns: + // https://arxiv.org/abs/2112.06217 + fun parseEdgeWithSpec(): ParseNode { + val dir1 = parseLeftEdgePattern() + expect(TokenType.LEFT_BRACKET, ErrorCode.PARSE_EXPECTED_LEFT_BRACKET_FOR_MATCH_EDGE, "match edge") + val name = parseName() + rem = name?.remaining ?: rem + val label = parseLabel() + rem = label?.remaining ?: rem + expect(TokenType.RIGHT_BRACKET, ErrorCode.PARSE_EXPECTED_RIGHT_BRACKET_FOR_MATCH_EDGE, "match edge") + val dir2 = parseRightEdgePattern() + + val dir = dir1.combine(dir2) + + val directionToken = + Token(TokenType.OPERATOR, ion.newSymbol(dir.abbreviation()), SourceSpan(0, 0, 0)) + val direction = ParseNode(ParseType.MATCH_EXPR_EDGE_DIRECTION, directionToken, emptyList(), rem) + + return ParseNode(ParseType.MATCH_EXPR_EDGE, null, listOfNotNull(direction, name, label), rem) + } + + // Parses an abbreviated edge pattern (i.e, no label, no variable, no predicate) as defined by + // + // | Orientation | Edge pattern | Abbreviation | + // |---------------------------+--------------+--------------| + // | Pointing left | <−[ spec ]− | <− | + // | Undirected | ~[ spec ]~ | ~ | + // | Pointing right | −[ spec ]−> | −> | + // | Left or undirected | <~[ spec ]~ | <~ | + // | Undirected or right | ~[ spec ]~> | ~> | + // | Left or right | <−[ spec ]−> | <−> | + // | Left, undirected or right | −[ spec ]− | − | + // + // Fig. 5. Table of edge patterns: + // https://arxiv.org/abs/2112.06217 + fun parseEdgeAbbreviated(): ParseNode { + var candidates: Map = abbreviations + do { + if (rem.head?.type == TokenType.OPERATOR) { + val char = rem.head!!.keywordText!! + rem = rem.tail + candidates = candidates.filterKeys { it.startsWith(char) }.mapKeys { it.key.removePrefix(char) } + if (candidates.size == 1) { + val edge = candidates.values.first() + val directionToken = + Token(TokenType.OPERATOR, ion.newSymbol(edge.abbreviation()), SourceSpan(0, 0, 0)) + val direction = ParseNode(ParseType.MATCH_EXPR_EDGE_DIRECTION, directionToken, emptyList(), rem) + return ParseNode(ParseType.MATCH_EXPR_EDGE, null, listOf(direction), rem) + } + } else if (candidates.contains("")) { + val edge = candidates[""]!! + val directionToken = + Token(TokenType.OPERATOR, ion.newSymbol(edge.abbreviation()), SourceSpan(0, 0, 0)) + val direction = ParseNode(ParseType.MATCH_EXPR_EDGE_DIRECTION, directionToken, emptyList(), rem) + return ParseNode(ParseType.MATCH_EXPR_EDGE, null, listOf(direction), rem) + } else { + errorEdgeParse() + } + } while (candidates.isNotEmpty()) + errorEdgeParse() + } + + fun parseEdge(): ParseNode { + var preRem = rem + return try { + parseEdgeWithSpec() + } catch (e: ParserException) { + rem = preRem + parseEdgeAbbreviated() + } + } + + val patterns = ArrayList() + val nodeLeft = try { + parseNode() + } catch (e: ParserException) { + null + } + patterns.add(nodeLeft) + do { + val edge = try { + parseEdge() + } catch (e: ParserException) { + null + } + patterns.add(edge) + val nodeRight = try { + parseNode() + } catch (e: ParserException) { + null + } + patterns.add(nodeRight) + } while (edge != null && nodeRight != null) + + return ParseNode(ParseType.MATCH_EXPR, null, patterns.filterNotNull(), rem) + } + /** * Validates tree to make sure that the top level tokens are not found below the top level. * Top level tokens are the tokens or keywords which are valid to be used only at the top level in the query. diff --git a/lang/test/org/partiql/lang/syntax/SqlParserMatchTest.kt b/lang/test/org/partiql/lang/syntax/SqlParserMatchTest.kt new file mode 100644 index 0000000000..cb243a7d31 --- /dev/null +++ b/lang/test/org/partiql/lang/syntax/SqlParserMatchTest.kt @@ -0,0 +1,516 @@ +package org.partiql.lang.syntax + +import com.amazon.ionelement.api.ionBool +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionString +import org.junit.Ignore +import org.junit.Test +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.id + +class SqlParserMatchTest : SqlParserTestBase() { + @Test + fun allNodesNoLabel() = assertExpressionNoRoundTrip( + "SELECT 1 FROM my_graph MATCH ()" + ) { + select( + project = projectList(projectExpr(lit(ionInt(1)))), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + node( + predicate = null, + variable = null, + label = listOf() + ) + ) + ) + ) + ) + ), + where = null + ) + } + + @Test + fun allNodesNoLabelFilter() = assertExpressionNoRoundTrip( + "SELECT 1 FROM my_graph MATCH () WHERE contains_value('1')", + ) { + select( + project = projectList(projectExpr(lit(ionInt(1)))), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + node( + predicate = null, + variable = null, + label = listOf() + ) + ) + ) + ) + ) + ), + where = call(funcName = "contains_value", args = listOf(lit(ionString("1")))) + ) + } + + @Test + fun allNodes() = assertExpressionNoRoundTrip( + "SELECT x.info AS info FROM my_graph MATCH (x) WHERE x.name LIKE 'foo'", + ) { + select( + project = projectList( + projectExpr( + expr = path(id("x"), pathExpr(lit(ionString("info")), caseInsensitive())), + asAlias = "info" + ) + ), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + node( + predicate = null, + variable = "x", + label = listOf() + ) + ) + ) + ) + ) + ), + where = like( + value = path(id("x"), pathExpr(lit(ionString("name")), caseInsensitive())), + pattern = lit(ionString("foo")) + ) + ) + } + + @Test + fun labelledNodes() = assertExpressionNoRoundTrip( + "SELECT x AS target FROM my_graph MATCH (x:Label) WHERE x.has_data = true", + ) { + select( + project = projectList(projectExpr(expr = id("x"), asAlias = "target")), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + node( + predicate = null, + variable = "x", + label = listOf("Label") + ) + ) + ) + ) + ) + ), + where = eq( + listOf( + path(id("x"), pathExpr(lit(ionString("has_data")), caseInsensitive())), + lit(ionBool(true)) + ) + ) + ) + } + + @Test + fun allEdges() = assertExpressionNoRoundTrip( + "SELECT 1 FROM g MATCH -[]-> ", + ) { + select( + project = projectList(projectExpr(lit(ionInt(1)))), + from = graphMatch( + expr = id("g"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + edge( + direction = edgeRight(), + quantifier = null, + predicate = null, + variable = null, + label = listOf() + ) + ) + ) + ) + ) + ), + where = null + ) + } + + val simpleGraphAST = { direction: PartiqlAst.GraphMatchDirection, variable: String?, label: List? -> + PartiqlAst.build { + select( + project = projectList(projectExpr(id("a")), projectExpr(id("b"))), + from = graphMatch( + expr = id("g"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + quantifier = null, + parts = listOf( + node( + predicate = null, + variable = "a", + label = listOf("A") + ), + edge( + direction = direction, + quantifier = null, + predicate = null, + variable = variable, + label = label ?: emptyList() + ), + node( + predicate = null, + variable = "b", + label = listOf("B") + ), + ) + ) + ) + ) + ), + where = null + ) + } + } + + @Test + fun rightDirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) -[e:E]-> (b:B)", + ) { + simpleGraphAST(edgeRight(), "e", listOf("E")) + } + + @Test + fun rightDirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) -> (b:B)", + ) { + simpleGraphAST(edgeRight(), null, null) + } + + @Test + fun leftDirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <-[e:E]- (b:B)", + ) { + simpleGraphAST(edgeLeft(), "e", listOf("E")) + } + + @Test + fun leftDirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <- (b:B)", + ) { + simpleGraphAST(edgeLeft(), null, null) + } + + @Test + fun undirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) ~[e:E]~ (b:B)", + ) { + simpleGraphAST(edgeUndirected(), "e", listOf("E")) + } + + @Test + fun undirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) ~ (b:B)", + ) { + simpleGraphAST(edgeUndirected(), null, null) + } + + @Test + fun rightOrUnDirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) ~[e:E]~> (b:B)", + ) { + simpleGraphAST(edgeUndirectedOrRight(), "e", listOf("E")) + } + + @Test + fun rightOrUnDirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) ~> (b:B)", + ) { + simpleGraphAST(edgeUndirectedOrRight(), null, null) + } + + @Test + fun leftOrUnDirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <~[e:E]~ (b:B)", + ) { + simpleGraphAST(edgeLeftOrUndirected(), "e", listOf("E")) + } + + @Test + fun leftOrUnDirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <~ (b:B)", + ) { + simpleGraphAST(edgeLeftOrUndirected(), null, null) + } + + @Test + fun leftOrRight() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <-[e:E]-> (b:B)", + ) { + simpleGraphAST(edgeLeftOrRight(), "e", listOf("E")) + } + + @Test + fun leftOrRightAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) <-> (b:B)", + ) { + simpleGraphAST(edgeLeftOrRight(), null, null) + } + + @Test + fun leftOrRightOrUndirected() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) -[e:E]- (b:B)", + ) { + simpleGraphAST(edgeLeftOrUndirectedOrRight(), "e", listOf("E")) + } + + @Test + fun leftOrRightOrUndirectedAbbreviated() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a:A) - (b:B)", + ) { + simpleGraphAST(edgeLeftOrUndirectedOrRight(), null, null) + } + + @Test + fun singleEdgeMatch() = assertExpressionNoRoundTrip( + "SELECT the_a.name AS src, the_b.name AS dest FROM my_graph MATCH (the_a:a) -[the_y:y]-> (the_b:b) WHERE the_y.score > 10", + ) { + select( + project = projectList( + projectExpr( + expr = path(id("the_a"), pathExpr(lit(ionString("name")), caseInsensitive())), + asAlias = "src" + ), + projectExpr( + expr = path(id("the_b"), pathExpr(lit(ionString("name")), caseInsensitive())), + asAlias = "dest" + ) + ), + from = graphMatch( + expr = id("my_graph"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + node( + predicate = null, + variable = "the_a", + label = listOf("a") + ), + edge( + direction = edgeRight(), + quantifier = null, + predicate = null, + variable = "the_y", + label = listOf("y") + ), + node( + predicate = null, + variable = "the_b", + label = listOf("b") + ), + ) + ) + ) + ) + ), + where = gt( + listOf( + path(id("the_y"), pathExpr(lit(ionString("score")), caseInsensitive())), + lit(ionInt(10)) + ) + ) + ) + } + + @Test + fun twoHopTriples() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a) -[:has]-> (x), (x)-[:contains]->(b)", + ) { + select( + project = projectList( + projectExpr(expr = id("a")), + projectExpr(expr = id("b")) + ), + from = graphMatch( + expr = id("g"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + node( + predicate = null, + variable = "a", + label = listOf() + ), + edge( + direction = edgeRight(), + quantifier = null, + predicate = null, + variable = null, + label = listOf("has") + ), + node( + predicate = null, + variable = "x", + label = listOf() + ), + ) + ), + graphMatchPattern( + parts = listOf( + node( + predicate = null, + variable = "x", + label = listOf() + ), + edge( + direction = edgeRight(), + quantifier = null, + predicate = null, + variable = null, + label = listOf("contains") + ), + node( + predicate = null, + variable = "b", + label = listOf() + ), + ) + ) + ) + ) + ) + ) + } + + @Test + fun twoHopPattern() = assertExpressionNoRoundTrip( + "SELECT a,b FROM g MATCH (a)-[:has]->()-[:contains]->(b)", + ) { + select( + project = projectList( + projectExpr(expr = id("a")), + projectExpr(expr = id("b")) + ), + from = graphMatch( + expr = id("g"), + graphExpr = graphMatchExpr( + patterns = listOf( + graphMatchPattern( + parts = listOf( + node( + predicate = null, + variable = "a", + label = listOf() + ), + edge( + direction = edgeRight(), + quantifier = null, + predicate = null, + variable = null, + label = listOf("has") + ), + node( + predicate = null, + variable = null, + label = listOf() + ), + edge( + direction = edgeRight(), + quantifier = null, + predicate = null, + variable = null, + label = listOf("contains") + ), + node( + predicate = null, + variable = "b", + label = listOf() + ), + ) + ) + ) + ) + ) + ) + } + + // TODO prefilters + @Test + @Ignore + fun prefilters() = assertExpressionNoRoundTrip( + "SELECT u as banCandidate FROM g MATCH (p:Post Where p.isFlagged = true) ~[ep:createdPost]~ (u:User WHERE u.isBanned = false AND u.karma < 20) -[ec:createdComment]->(c:Comment WHERE c.isFlagged = true)", + ) { + TODO() + } + + // TODO label combinators + @Test + @Ignore + fun labelDisjunction() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x:Label|OtherLabel)", + ) { + TODO() + } + + @Test + @Ignore + fun labelConjunction() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x:Label&OtherLabel)", + ) { + TODO() + } + + @Test + @Ignore + fun labelNegation() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x:!Label)", + ) { + TODO() + } + + @Test + @Ignore + fun labelWildcard() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x:%)", + ) { + TODO() + } + + @Test + @Ignore + fun labelCombo() = assertExpressionNoRoundTrip( + "SELECT x FROM g MATCH (x: L1|L2&L3|!L4|(L5&%)", + ) { + TODO() + } + + // TODO path variable (e.g., `MATCH p = (x) -> (y)` + // TODO quantifiers (e.g., `MATCH (a:Node)−[:Edge]−>{2,5}(b:Node)`, `*`, `+`) + // TODO group variables (e.g., `MATCH ... WHERE SUM()...`) + // TODO union & multiset (e.g., `MATCH (a:Label) | (a:Label2)` , `MATCH (a:Label) |+| (a:Label2)` + // TODO conditional variables + // TODO graphical predicates (i.e., `IS DIRECTED`, `IS SOURCE OF`, `IS DESTINATION OF`, `SAME`, `ALL DIFFERENT`) + // TODO restrictors & selectors (i.e., `TRAIL`|`ACYCLIC`|`SIMPLE` & ANY SHORTEST, ALL SHORTEST, ANY, ANY k, SHORTEST k, SHORTEST k GROUP) + // TODO selector filters +} diff --git a/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt b/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt index 2c90d3ec2c..d53958728c 100644 --- a/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt +++ b/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt @@ -66,7 +66,8 @@ abstract class SqlParserTestBase : TestBase() { */ protected fun assertExpression( source: String, - expectedPigAst: String + expectedPigAst: String, + roundTrip: Boolean = true, ) { val actualStatement = parse(source) val expectedIonSexp = loadIonSexp(expectedPigAst) @@ -80,7 +81,25 @@ abstract class SqlParserTestBase : TestBase() { pigDomainAssert(actualStatement, expectedElement) // Check equals for actual value after round trip transformation: astStatement -> ExprNode -> astStatement - assertRoundTripPigAstToExprNode(actualStatement) + if (roundTrip) { + assertRoundTripPigAstToExprNode(actualStatement) + } + } + + /** + * This method is used by test cases for parsing a string. + * The test are performed with only PIG AST. + * The expected PIG AST is a PIG builder. + * No ExprNode <-> PIG AST round trip is performed. + */ + protected fun assertExpressionNoRoundTrip( + source: String, + expectedPigBuilder: PartiqlAst.Builder.() -> PartiqlAst.PartiqlAstNode + ) { + val expectedPigAst = PartiqlAst.build { expectedPigBuilder() }.toIonElement().toString() + + // Refer to comments inside the main body of the following function to see what checks are performed. + assertExpression(source, expectedPigAst, roundTrip = false) } /** @@ -95,7 +114,7 @@ abstract class SqlParserTestBase : TestBase() { val expectedPigAst = PartiqlAst.build { expectedPigBuilder() }.toIonElement().toString() // Refer to comments inside the main body of the following function to see what checks are performed. - assertExpression(source, expectedPigAst) + assertExpression(source, expectedPigAst, roundTrip = true) } /**