Skip to content

Commit

Permalink
Fix bag constructor parsing (#1500)
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 authored Jul 3, 2024
1 parent 4dd0972 commit 5b86afc
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ internal class PartiQLPigParser(val customTypes: List<CustomType> = listOf()) :
val tokenStream = createTokenStream(queryStream)
val parser = parserInit(tokenStream)
val tree = parser.root()
val visitor = PartiQLPigVisitor(customTypes, tokenStream.parameterIndexes)
val visitor = PartiQLPigVisitor(tokenStream, customTypes, tokenStream.parameterIndexes)
return visitor.visit(tree) as PartiqlAst.Statement
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.amazon.ionelement.api.ionNull
import com.amazon.ionelement.api.ionString
import com.amazon.ionelement.api.ionSymbol
import com.amazon.ionelement.api.loadSingleElement
import org.antlr.v4.runtime.CommonTokenStream
import org.antlr.v4.runtime.ParserRuleContext
import org.antlr.v4.runtime.Token
import org.antlr.v4.runtime.tree.TerminalNode
Expand Down Expand Up @@ -62,6 +63,7 @@ import org.partiql.lang.util.getPrecisionFromTimeString
import org.partiql.lang.util.unaryMinus
import org.partiql.parser.internal.antlr.PartiQLParser
import org.partiql.parser.internal.antlr.PartiQLParserBaseVisitor
import org.partiql.parser.internal.antlr.PartiQLTokens
import org.partiql.pig.runtime.SymbolPrimitive
import org.partiql.value.datetime.DateTimeException
import org.partiql.value.datetime.TimeZone
Expand Down Expand Up @@ -116,6 +118,7 @@ import java.time.format.DateTimeParseException
* There could be clever ways of exploiting this, to avoid the dispatch via `visit()`.
*/
internal class PartiQLPigVisitor(
private val tokens: CommonTokenStream,
val customTypes: List<CustomType> = listOf(),
private val parameterIndexes: Map<Int, Int> = mapOf(),
) :
Expand Down Expand Up @@ -1507,6 +1510,12 @@ internal class PartiQLPigVisitor(
*/

override fun visitBag(ctx: PartiQLParser.BagContext) = PartiqlAst.build {
// Prohibit hidden characters between angle brackets
val startTokenIndex = ctx.start.tokenIndex
val endTokenIndex = ctx.stop.tokenIndex
if (tokens.getHiddenTokensToRight(startTokenIndex, PartiQLTokens.HIDDEN) != null || tokens.getHiddenTokensToLeft(endTokenIndex, PartiQLTokens.HIDDEN) != null) {
throw ParserException("Invalid bag expression", ErrorCode.PARSE_INVALID_QUERY)
}
val exprList = ctx.expr().map { visitExpr(it) }
bag(exprList, ctx.ANGLE_LEFT(0).getSourceMetaContainer())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5019,4 +5019,38 @@ class PartiQLParserTest : PartiQLParserTestBase() {
lit(ionInt(1))
)
}

// regression tests for bag constructor angle bracket
@Test
fun testBagConstructor() = assertExpression("<<<<1>>>>") {
bag(
bag(
lit(ionInt(1))
)
)
}

@Test
fun testSpacesInBagConstructor() = checkInputThrowingParserException(
"< < < < 1 > > > >",
ErrorCode.PARSE_UNEXPECTED_TOKEN, // partiql-ast parser ErrorCode
expectErrorContextValues = mapOf(
Property.LINE_NUMBER to 1L,
Property.COLUMN_NUMBER to 1L,
Property.TOKEN_DESCRIPTION to PartiQLParser.ANGLE_LEFT.getAntlrDisplayString(),
Property.TOKEN_VALUE to ION.newSymbol("<")
)
)

@Test
fun testCommentsInBagConstructor() = checkInputThrowingParserException(
"</* some comment */<<<1>>>>",
ErrorCode.PARSE_UNEXPECTED_TOKEN, // partiql-ast parser ErrorCode
expectErrorContextValues = mapOf(
Property.LINE_NUMBER to 1L,
Property.COLUMN_NUMBER to 1L,
Property.TOKEN_DESCRIPTION to PartiQLParser.ANGLE_LEFT.getAntlrDisplayString(),
Property.TOKEN_VALUE to ION.newSymbol("<")
)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ abstract class PartiQLParserTestBase : TestBase() {
input: String,
errorCode: ErrorCode,
expectErrorContextValues: Map<Property, Any>,
targets: Array<ParserTarget> = arrayOf(ParserTarget.DEFAULT),
assertContext: Boolean = true,
): Unit = forEachTarget {
softAssert {
try {
parser.parseAstStatement(input)
fail("Expected ParserException but there was no Exception")
} catch (ex: ParserException) {
// split parser target does not use ErrorCode
// NOTE: only perform error code and error context checks for `ParserTarget.EXPERIMENTAL` (partiql-ast
// parser).
if (assertContext && (this@forEachTarget == ParserTarget.EXPERIMENTAL)) {
checkErrorAndErrorContext(errorCode, ex, expectErrorContextValues)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ internal class PartiQLParserDefault : PartiQLParser {
*/
@OptIn(PartiQLValueExperimental::class)
private class Visitor(
private val tokens: CommonTokenStream,
private val locations: SourceLocations.Mutable,
private val parameters: Map<Int, Int> = mapOf(),
) : PartiQLParserBaseVisitor<AstNode>() {
Expand All @@ -442,7 +443,7 @@ internal class PartiQLParserDefault : PartiQLParser {
tree: GeneratedParser.RootContext,
): PartiQLParser.Result {
val locations = SourceLocations.Mutable()
val visitor = Visitor(locations, tokens.parameterIndexes)
val visitor = Visitor(tokens, locations, tokens.parameterIndexes)
val root = visitor.visitAs<AstNode>(tree) as Statement
return PartiQLParser.Result(
source = source,
Expand Down Expand Up @@ -2022,6 +2023,12 @@ internal class PartiQLParserDefault : PartiQLParser {
*/

override fun visitBag(ctx: GeneratedParser.BagContext) = translate(ctx) {
// Prohibit hidden characters between angle brackets
val startTokenIndex = ctx.start.tokenIndex
val endTokenIndex = ctx.stop.tokenIndex
if (tokens.getHiddenTokensToRight(startTokenIndex, GeneratedLexer.HIDDEN) != null || tokens.getHiddenTokensToLeft(endTokenIndex, GeneratedLexer.HIDDEN) != null) {
throw error(ctx, "Invalid bag expression")
}
val expressions = visitOrEmpty<Expr>(ctx.expr())
exprCollection(Expr.Collection.Type.BAG, expressions)
}
Expand Down

0 comments on commit 5b86afc

Please sign in to comment.