Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Compile-Time Thread.interrupted() checks #398

Merged
merged 2 commits into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions lang/src/org/partiql/lang/CompilerPipeline.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.partiql.lang.eval.*
import org.partiql.lang.eval.builtins.*
import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure
import org.partiql.lang.syntax.*
import org.partiql.lang.util.interruptibleFold

/**
* Contains all of the information needed for processing steps.
Expand Down Expand Up @@ -180,7 +181,7 @@ interface CompilerPipeline {
}
}

private class CompilerPipelineImpl(
internal class CompilerPipelineImpl(
override val valueFactory: ExprValueFactory,
private val parser: Parser,
override val compileOptions: CompileOptions,
Expand All @@ -198,10 +199,15 @@ private class CompilerPipelineImpl(
override fun compile(query: ExprNode): Expression {
val context = StepContext(valueFactory, compileOptions, functions, procedures)

val preProcessedQuery = preProcessingSteps.fold(query) { currentExprNode, step ->
step(currentExprNode, context)
}
val preProcessedQuery = executePreProcessingSteps(query, context)

return compiler.compile(preProcessedQuery)
}

internal fun executePreProcessingSteps(
query: ExprNode,
context: StepContext
) = preProcessingSteps.interruptibleFold(query) { currentExprNode, step ->
step(currentExprNode, context)
}
}
12 changes: 8 additions & 4 deletions lang/src/org/partiql/lang/ast/AstDeserialization.kt
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class AstDeserializerBuilder(val ion: IonSystem) {
}
}

private class AstDeserializerInternal(
internal class AstDeserializerInternal(
val astVersion: AstVersion,
val ion: IonSystem,
private val metaDeserializers: Map<String, MetaDeserializer>
Expand All @@ -262,7 +262,9 @@ private class AstDeserializerInternal(
return deserializeExprNode(sexp)
}

private fun validate(rootSexp: IonSexp) {
internal fun validate(rootSexp: IonSexp) {
checkThreadInterrupted()

val nodeTag = rootSexp.nodeTag // Throws if nodeTag is invalid for the current AstVersion
val nodeArgs = rootSexp.args

Expand Down Expand Up @@ -321,8 +323,9 @@ private class AstDeserializerInternal(
/**
* Given a serialized AST, return its [ExprNode] representation.
*/
private fun deserializeExprNode(metaOrTermOrExp: IonSexp): ExprNode =
deserializeSexpMetaOrTerm(metaOrTermOrExp) { target, metas ->
internal fun deserializeExprNode(metaOrTermOrExp: IonSexp): ExprNode {
checkThreadInterrupted()
return deserializeSexpMetaOrTerm(metaOrTermOrExp) { target, metas ->
val nodeTag = target.nodeTag
val targetArgs = target.args //args is an extension property--call it once for efficiency
//.toList() forces immutability
Expand Down Expand Up @@ -417,6 +420,7 @@ private class AstDeserializerInternal(
NodeTag.TYPE -> errInvalidContext(nodeTag)
}
}
}

private fun deserializeLit(targetArgs: List<IonValue>, metas: MetaContainer) = Literal(targetArgs.first(), metas)
private fun deserializeMissing(metas: MetaContainer) = LiteralMissing(metas)
Expand Down
2 changes: 2 additions & 0 deletions lang/src/org/partiql/lang/ast/AstSerialization.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.amazon.ion.IonSystem
import org.partiql.lang.util.IonWriterContext
import org.partiql.lang.util.asIonSexp
import org.partiql.lang.util.case
import org.partiql.lang.util.checkThreadInterrupted
import kotlin.UnsupportedOperationException

/**
Expand Down Expand Up @@ -73,6 +74,7 @@ private class AstSerializerImpl(val astVersion: AstVersion, val ion: IonSystem):

private fun IonWriterContext.writeExprNode(expr: ExprNode): Unit =
writeAsTerm(expr.metas) {
checkThreadInterrupted()
sexp {
when (expr) {
// Leaf nodes
Expand Down
2 changes: 2 additions & 0 deletions lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.partiql.lang.ast
import com.amazon.ionelement.api.emptyMetaContainer
import com.amazon.ionelement.api.toIonElement
import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.util.checkThreadInterrupted
import org.partiql.pig.runtime.SymbolPrimitive
import org.partiql.pig.runtime.asPrimitive

Expand Down Expand Up @@ -67,6 +68,7 @@ private fun ExprNode.toAstExec() : PartiqlAst.Statement {
}

fun ExprNode.toAstExpr(): PartiqlAst.Expr {
checkThreadInterrupted()
val node = this
val metas = this.metas.toIonElementMetaContainer()

Expand Down
2 changes: 2 additions & 0 deletions lang/src/org/partiql/lang/ast/StatementToExprNode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.amazon.ion.IonSystem
import com.amazon.ionelement.api.toIonValue
import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.domains.PartiqlAst.*
import org.partiql.lang.util.checkThreadInterrupted

import org.partiql.pig.runtime.SymbolPrimitive
import org.partiql.lang.ast.SetQuantifier as ExprNodeSetQuantifier // Conflicts with PartiqlAst.SetQuantifier
Expand Down Expand Up @@ -69,6 +70,7 @@ private class StatementTransformer(val ion: IonSystem) {
this.map { it.toExprNode() }

private fun Expr.toExprNode(): ExprNode {
checkThreadInterrupted()
val metas = this.metas.toPartiQlMetaContainer()
return when (this) {
is Expr.Missing -> LiteralMissing(metas)
Expand Down
20 changes: 18 additions & 2 deletions lang/src/org/partiql/lang/ast/ast.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,35 @@ import java.util.*
sealed class AstNode : Iterable<AstNode> {

/**
* returns all the children nodes.
* Returns all the children nodes.
*
* This property is [deprecated](see https://github.com/partiql/partiql-lang-kotlin/issues/396). Use
* one of the following PIG-generated classes to analyze AST nodes instead:
*
* - [org.partiql.lang.domains.PartiqlAst.Visitor]
* - [org.partiql.lang.domains.PartiqlAst.VisitorFold]
*/
@Deprecated("DO NOT USE - see kdoc, see https://github.com/partiql/partiql-lang-kotlin/issues/396")
abstract val children: List<AstNode>

/**
* Depth first iterator over all nodes.
*
* While collecting child nodes, throws [InterruptedException] if the [Thread.interrupted] flag has been set.
*
* This property is [deprecated](see https://github.com/partiql/partiql-lang-kotlin/issues/396). Use
* one of the following PIG-generated classes to analyze AST nodes instead:
*
* - [org.partiql.lang.domains.PartiqlAst.Visitor]
* - [org.partiql.lang.domains.PartiqlAst.VisitorFold]
*/
@Deprecated("DO NOT USE - see kdoc for alternatives")
override operator fun iterator(): Iterator<AstNode> {
val allNodes = mutableListOf<AstNode>()

fun depthFirstSequence(node: AstNode) {
allNodes.add(node)
node.children.map { depthFirstSequence(it) }
node.children.interruptibleMap { depthFirstSequence(it) }
}

depthFirstSequence(this)
Expand Down
9 changes: 6 additions & 3 deletions lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package org.partiql.lang.ast.passes

import org.partiql.lang.ast.*
import org.partiql.lang.util.checkThreadInterrupted

/**
* Provides a minimal interface for an AST rewriter implementation.
Expand All @@ -28,11 +29,12 @@ interface AstRewriter {
* This is the base-class for an AST rewriter which simply makes an exact copy of the original AST.
* Simple rewrites can be performed by inheritors.
*/
@Deprecated("New rewriters should implement PIG's PartiqlAst.VisitorTransform instead")
@Deprecated("New rewriters should implement PIG's VisitorTransformBase instead")
open class AstRewriterBase : AstRewriter {

override fun rewriteExprNode(node: ExprNode): ExprNode =
when (node) {
override fun rewriteExprNode(node: ExprNode): ExprNode {
checkThreadInterrupted()
return when (node) {
is Literal -> rewriteLiteral(node)
is LiteralMissing -> rewriteLiteralMissing(node)
is VariableReference -> rewriteVariableReference(node)
Expand All @@ -55,6 +57,7 @@ open class AstRewriterBase : AstRewriter {
is DateTimeType.Date -> rewriteDate(node)
is DateTimeType.Time -> rewriteTime(node)
}
}

open fun rewriteMetas(itemWithMetas: HasMetas): MetaContainer = itemWithMetas.metas

Expand Down
2 changes: 1 addition & 1 deletion lang/src/org/partiql/lang/ast/passes/AstVisitor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.partiql.lang.ast.*
*
* One `visit*` function is included for each base type in the AST.
*/
@Deprecated("Use AstNode#iterator() or AstNode#children()")
@Deprecated("Use org.lang.partiql.domains.PartiqlAst.Visitor instead")
interface AstVisitor {
/**
* Invoked by [AstWalker] for every instance of [ExprNode] encountered.
Expand Down
1 change: 1 addition & 0 deletions lang/src/org/partiql/lang/ast/passes/AstWalker.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ open class AstWalker(private val visitor: AstVisitor) {

protected open fun walkExprNode(vararg exprs: ExprNode?) {
exprs.filterNotNull().forEach { expr: ExprNode ->
checkThreadInterrupted()
visitor.visitExprNode(expr)

when (expr) {
Expand Down
11 changes: 10 additions & 1 deletion lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.partiql.lang.eval.visitors.PartiqlAstSanityValidator
import org.partiql.lang.syntax.SqlParser
import org.partiql.lang.util.*
import java.math.*
import java.time.LocalDate
import java.util.*
import kotlin.collections.*

Expand Down Expand Up @@ -207,6 +206,10 @@ internal class EvaluatingCompiler(

/**
* Compiles an [ExprNode] tree to an [Expression].
*
* Checks [Thread.interrupted] before every expression and sub-expression is compiled
* and throws [InterruptedException] if [Thread.interrupted] it has been set in the
* hope that long running compilations may be aborted by the caller.
*/
fun compile(originalAst: ExprNode): Expression {
val visitorTransformer = compileOptions.visitorTransformMode.createVisitorTransform()
Expand Down Expand Up @@ -257,7 +260,13 @@ internal class EvaluatingCompiler(
*/
fun eval(ast: ExprNode, session: EvaluationSession): ExprValue = compile(ast).eval(session)

/**
* Compiles the specified [ExprNode] into a [ThunkEnv].
*
* This function will [InterruptedException] if [Thread.interrupted] has been set.
*/
private fun compileExprNode(expr: ExprNode): ThunkEnv {
checkThreadInterrupted()
return when (expr) {
is Literal -> compileLiteral(expr)
is LiteralMissing -> compileLiteralMissing(expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import org.partiql.pig.runtime.SymbolPrimitive
*
* If provided with a query that has all of the from source aliases already specified, an exact clone is returned.
*/
class FromSourceAliasVisitorTransform : PartiqlAst.VisitorTransform() {
class FromSourceAliasVisitorTransform : VisitorTransformBase() {

private class InnerFromSourceAliasVisitorTransform : PartiqlAst.VisitorTransform() {
private class InnerFromSourceAliasVisitorTransform : VisitorTransformBase() {
private var fromSourceCounter = 0

override fun transformFromSourceScan_asAlias(node: PartiqlAst.FromSource.Scan): SymbolPrimitive? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.partiql.pig.runtime.SymbolPrimitive
*
* If provided with a query with all of the group by item aliases already specified, an exact clone is returned.
*/
class GroupByItemAliasVisitorTransform(var nestLevel: Int = 0) : PartiqlAst.VisitorTransform() {
class GroupByItemAliasVisitorTransform(var nestLevel: Int = 0) : VisitorTransformBase() {

override fun transformGroupBy(node: PartiqlAst.GroupBy): PartiqlAst.GroupBy {
return PartiqlAst.build {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.partiql.lang.eval.visitors

import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.util.interruptibleFold

/**
* A simple visitor transformer that provides a pipeline of transformers to be executed in sequential order.
Expand All @@ -11,7 +12,7 @@ class PipelinedVisitorTransform(vararg transformers: PartiqlAst.VisitorTransform
private val transformerList = transformers.toList()

override fun transformStatement(node: PartiqlAst.Statement): PartiqlAst.Statement =
transformerList.fold(node) {
transformerList.interruptibleFold(node) {
intermediateNode, transformer ->
transformer.transformStatement(intermediateNode)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.partiql.pig.runtime.SymbolPrimitive
*
* ```
*/
class SelectListItemAliasVisitorTransform : PartiqlAst.VisitorTransform() {
class SelectListItemAliasVisitorTransform : VisitorTransformBase() {

override fun transformProjectionProjectList(node: PartiqlAst.Projection.ProjectList): PartiqlAst.Projection {
return PartiqlAst.build {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.partiql.lang.ast.UniqueNameMeta
import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.eval.errNoContext

class SelectStarVisitorTransform : PartiqlAst.VisitorTransform() {
class SelectStarVisitorTransform : VisitorTransformBase() {

/**
* Copies all parts of [PartiqlAst.Expr.Select] except [newProjection] for [PartiqlAst.Projection].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ enum class StaticTypeVisitorTransformConstraints {
*/
class StaticTypeVisitorTransform(private val ion: IonSystem,
globalBindings: Bindings<StaticType>,
constraints: Set<StaticTypeVisitorTransformConstraints> = setOf()) : PartiqlAst.VisitorTransform() {
constraints: Set<StaticTypeVisitorTransformConstraints> = setOf()) : VisitorTransformBase() {

/** Used to allow certain binding lookups to occur directly in the global scope. */
private val globalEnv = wrapBindings(globalBindings, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ data class SubstitutionPair(val target: PartiqlAst.Expr, val replacement: Partiq
*
* This class is `open` to allow subclasses to restrict the nodes to which the substitution should occur.
*/
open class SubstitutionVisitorTransform(protected val substitutions: Map<PartiqlAst.Expr, SubstitutionPair>): PartiqlAst.VisitorTransform() {
open class SubstitutionVisitorTransform(protected val substitutions: Map<PartiqlAst.Expr, SubstitutionPair>): VisitorTransformBase() {

/**
* If [node] matches any of the target nodes in [substitutions], replaces the node with the replacement.
Expand All @@ -59,7 +59,7 @@ open class SubstitutionVisitorTransform(protected val substitutions: Map<Partiql
* After .copy() and copying metas is added to PIG (https://github.com/partiql/partiql-ir-generator/pull/53) change
* this and its usages to use .copy().
*/
inner class MetaVisitorTransform(private val newMetas: MetaContainer) : PartiqlAst.VisitorTransform() {
inner class MetaVisitorTransform(private val newMetas: MetaContainer) : VisitorTransformBase() {
override fun transformMetas(metas: MetaContainer): MetaContainer = newMetas
}

Expand Down
14 changes: 13 additions & 1 deletion lang/src/org/partiql/lang/eval/visitors/VisitorTransformBase.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
package org.partiql.lang.eval.visitors

import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.util.checkThreadInterrupted

/**
* Base-class for visitor transforms that provides additional functions outside of [PartiqlAst.VisitorTransform].
* Base-class for visitor transforms that provides additional `transform*` functions that outside of
* the PIG-generated [PartiqlAst.VisitorTransform] class and adds a [Thread.interrupted] check
* to [transformExpr].
*
* All transforms should derive from this class instead of [PartiqlAst.VisitorTransform] so that they can
* be interrupted if they take a long time to process large ASTs.
*/
abstract class VisitorTransformBase : PartiqlAst.VisitorTransform() {

override fun transformExpr(node: PartiqlAst.Expr): PartiqlAst.Expr {
checkThreadInterrupted()
return super.transformExpr(node)
}

/**
* Transforms the [PartiqlAst.Expr.Select] expression following the PartiQL evaluation order. That is:
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ fun basicVisitorTransforms() = PipelinedVisitorTransform(

/** A stateless visitor transform that returns the input. */
@JvmField
internal val IDENTITY_VISITOR_TRANSFORM: PartiqlAst.VisitorTransform = object : PartiqlAst.VisitorTransform() {
internal val IDENTITY_VISITOR_TRANSFORM: PartiqlAst.VisitorTransform = object : VisitorTransformBase() {
override fun transformStatement(node: PartiqlAst.Statement): PartiqlAst.Statement = node
}
6 changes: 6 additions & 0 deletions lang/src/org/partiql/lang/syntax/SqlParser.kt
Original file line number Diff line number Diff line change
Expand Up @@ -1039,12 +1039,17 @@ class SqlParser(private val ion: IonSystem) : Parser {
/**
* Parses the given token list.
*
* Throws [InterruptedException] if [Thread.interrupted] is set. This is the best place to do
* that for the parser because this is the main function called to parse an expression and so
* is called quite frequently during parsing by many parts of the parser.
*
* @param precedence The precedence of the current expression parsing.
* A negative value represents the "top-level" parsing.
*
* @return The parse tree for the given expression.
*/
internal fun List<Token>.parseExpression(precedence: Int = -1): ParseNode {
checkThreadInterrupted()
var expr = parseUnaryTerm()
var rem = expr.remaining

Expand Down Expand Up @@ -2815,6 +2820,7 @@ class SqlParser(private val ion: IonSystem) : Parser {
* If [dmlListTokenSeen] is true, it means it has been encountered at least once before while traversing the parse tree.
*/
private fun validateTopLevelNodes(node: ParseNode, level: Int, topLevelTokenSeen: Boolean, dmlListTokenSeen: Boolean) {
checkThreadInterrupted()
val isTopLevelType = when (node.type.isDml) {
// DML_LIST token type allows multiple DML keywords to be used in the same statement.
// Hence, DML keyword tokens are not treated as top level tokens if present with the DML_LIST token type
Expand Down
Loading