diff --git a/lang/src/org/partiql/lang/syntax/SqlParser.kt b/lang/src/org/partiql/lang/syntax/SqlParser.kt index 33642aec78..b25da01a4a 100644 --- a/lang/src/org/partiql/lang/syntax/SqlParser.kt +++ b/lang/src/org/partiql/lang/syntax/SqlParser.kt @@ -1118,41 +1118,39 @@ class SqlParser(private val ion: IonSystem) : Parser { return expr } - private fun List.parseUnaryTerm(): ParseNode = - when (head?.isUnaryOperator) { + private fun List.parseUnaryTerm(): ParseNode { + return when (head?.isUnaryOperator) { true -> { val op = head!! - - val term = tail.parseUnaryTerm() - var expr: ParseNode? = null + fun makeUnaryParseNode(term: ParseNode) = + ParseNode(UNARY, op, listOf(term), term.remaining) // constant fold unary plus/minus into constant literals when (op.keywordText) { - "+" -> when { - term.isNumericLiteral -> { - // unary plus is a NO-OP - expr = term + "+" -> { + val term = tail.parseUnaryTerm() + when { + // unary plus is a no-op on numeric literals. + term.isNumericLiteral -> term + else -> makeUnaryParseNode(term) } } - "-" -> when { - term.isNumericLiteral -> { - val num = -term.numberValue() - expr = ParseNode(ATOM, - term.token!!.copy(value = num.ionValue(ion)), - emptyList(), - term.remaining) + "-" -> { + val term = tail.parseUnaryTerm() + when { + // for numbers, drop the minus sign but also negate the value + term.isNumericLiteral -> + term.copy(token = term.token!!.copy(value = (-term.numberValue()).ionValue(ion))) + else -> makeUnaryParseNode(term) } } - "not" -> { - val children = tail.parseExpression(op.prefixPrecedence) - expr = ParseNode(UNARY, op, listOf(children), children.remaining) - } + else -> makeUnaryParseNode(tail.parseExpression(op.prefixPrecedence)) } - - expr ?: ParseNode(UNARY, op, listOf(term), term.remaining) } else -> parsePathTerm() } + } + private fun List.parsePathTerm(pathMode: PathMode = PathMode.FULL_PATH): ParseNode { val term = when (pathMode) { diff --git a/lang/test/org/partiql/lang/syntax/SqlParserTest.kt b/lang/test/org/partiql/lang/syntax/SqlParserTest.kt index dfc3c97921..e18b01f201 100644 --- a/lang/test/org/partiql/lang/syntax/SqlParserTest.kt +++ b/lang/test/org/partiql/lang/syntax/SqlParserTest.kt @@ -25,6 +25,7 @@ import org.partiql.lang.ast.SourceLocationMeta import org.partiql.lang.ast.sourceLocation import org.partiql.lang.domains.PartiqlAst import org.partiql.lang.domains.id +import kotlin.concurrent.thread /** * Originally just meant to test the parser, this class now tests several different things because @@ -3998,4 +3999,35 @@ class SqlParserTest : SqlParserTestBase() { from = scan(id("bar")) ))) } + + @Test + fun manyNestedNotPerformanceRegressionTest() { + val startTime = System.currentTimeMillis() + val t = thread { + parse( + """ + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not + not not not not not not not not not not not not not not not not not not not not not not not not false + """) + } + val maxParseTime: Long = 5000 + t.join(maxParseTime) + t.interrupt() + + assertTrue( + "parsing many nested unary nots should take less than $maxParseTime", + System.currentTimeMillis() - startTime < maxParseTime) + } }