Skip to content

Commit

Permalink
Adds Compile-Time Thread.interrupted() checks (#398)
Browse files Browse the repository at this point in the history
* Add Thread.interrupted() checks
  • Loading branch information
dlurton authored May 5, 2021
1 parent eee99ae commit 7926b29
Show file tree
Hide file tree
Showing 23 changed files with 390 additions and 26 deletions.
14 changes: 10 additions & 4 deletions lang/src/org/partiql/lang/CompilerPipeline.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.partiql.lang.eval.*
import org.partiql.lang.eval.builtins.*
import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure
import org.partiql.lang.syntax.*
import org.partiql.lang.util.interruptibleFold

/**
* Contains all of the information needed for processing steps.
Expand Down Expand Up @@ -180,7 +181,7 @@ interface CompilerPipeline {
}
}

private class CompilerPipelineImpl(
internal class CompilerPipelineImpl(
override val valueFactory: ExprValueFactory,
private val parser: Parser,
override val compileOptions: CompileOptions,
Expand All @@ -198,10 +199,15 @@ private class CompilerPipelineImpl(
override fun compile(query: ExprNode): Expression {
val context = StepContext(valueFactory, compileOptions, functions, procedures)

val preProcessedQuery = preProcessingSteps.fold(query) { currentExprNode, step ->
step(currentExprNode, context)
}
val preProcessedQuery = executePreProcessingSteps(query, context)

return compiler.compile(preProcessedQuery)
}

internal fun executePreProcessingSteps(
query: ExprNode,
context: StepContext
) = preProcessingSteps.interruptibleFold(query) { currentExprNode, step ->
step(currentExprNode, context)
}
}
12 changes: 8 additions & 4 deletions lang/src/org/partiql/lang/ast/AstDeserialization.kt
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class AstDeserializerBuilder(val ion: IonSystem) {
}
}

private class AstDeserializerInternal(
internal class AstDeserializerInternal(
val astVersion: AstVersion,
val ion: IonSystem,
private val metaDeserializers: Map<String, MetaDeserializer>
Expand All @@ -262,7 +262,9 @@ private class AstDeserializerInternal(
return deserializeExprNode(sexp)
}

private fun validate(rootSexp: IonSexp) {
internal fun validate(rootSexp: IonSexp) {
checkThreadInterrupted()

val nodeTag = rootSexp.nodeTag // Throws if nodeTag is invalid for the current AstVersion
val nodeArgs = rootSexp.args

Expand Down Expand Up @@ -321,8 +323,9 @@ private class AstDeserializerInternal(
/**
* Given a serialized AST, return its [ExprNode] representation.
*/
private fun deserializeExprNode(metaOrTermOrExp: IonSexp): ExprNode =
deserializeSexpMetaOrTerm(metaOrTermOrExp) { target, metas ->
internal fun deserializeExprNode(metaOrTermOrExp: IonSexp): ExprNode {
checkThreadInterrupted()
return deserializeSexpMetaOrTerm(metaOrTermOrExp) { target, metas ->
val nodeTag = target.nodeTag
val targetArgs = target.args //args is an extension property--call it once for efficiency
//.toList() forces immutability
Expand Down Expand Up @@ -417,6 +420,7 @@ private class AstDeserializerInternal(
NodeTag.TYPE -> errInvalidContext(nodeTag)
}
}
}

private fun deserializeLit(targetArgs: List<IonValue>, metas: MetaContainer) = Literal(targetArgs.first(), metas)
private fun deserializeMissing(metas: MetaContainer) = LiteralMissing(metas)
Expand Down
2 changes: 2 additions & 0 deletions lang/src/org/partiql/lang/ast/AstSerialization.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.amazon.ion.IonSystem
import org.partiql.lang.util.IonWriterContext
import org.partiql.lang.util.asIonSexp
import org.partiql.lang.util.case
import org.partiql.lang.util.checkThreadInterrupted
import kotlin.UnsupportedOperationException

/**
Expand Down Expand Up @@ -73,6 +74,7 @@ private class AstSerializerImpl(val astVersion: AstVersion, val ion: IonSystem):

private fun IonWriterContext.writeExprNode(expr: ExprNode): Unit =
writeAsTerm(expr.metas) {
checkThreadInterrupted()
sexp {
when (expr) {
// Leaf nodes
Expand Down
2 changes: 2 additions & 0 deletions lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.partiql.lang.ast
import com.amazon.ionelement.api.emptyMetaContainer
import com.amazon.ionelement.api.toIonElement
import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.util.checkThreadInterrupted
import org.partiql.pig.runtime.SymbolPrimitive
import org.partiql.pig.runtime.asPrimitive

Expand Down Expand Up @@ -67,6 +68,7 @@ private fun ExprNode.toAstExec() : PartiqlAst.Statement {
}

fun ExprNode.toAstExpr(): PartiqlAst.Expr {
checkThreadInterrupted()
val node = this
val metas = this.metas.toIonElementMetaContainer()

Expand Down
2 changes: 2 additions & 0 deletions lang/src/org/partiql/lang/ast/StatementToExprNode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.amazon.ion.IonSystem
import com.amazon.ionelement.api.toIonValue
import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.domains.PartiqlAst.*
import org.partiql.lang.util.checkThreadInterrupted

import org.partiql.pig.runtime.SymbolPrimitive
import org.partiql.lang.ast.SetQuantifier as ExprNodeSetQuantifier // Conflicts with PartiqlAst.SetQuantifier
Expand Down Expand Up @@ -69,6 +70,7 @@ private class StatementTransformer(val ion: IonSystem) {
this.map { it.toExprNode() }

private fun Expr.toExprNode(): ExprNode {
checkThreadInterrupted()
val metas = this.metas.toPartiQlMetaContainer()
return when (this) {
is Expr.Missing -> LiteralMissing(metas)
Expand Down
20 changes: 18 additions & 2 deletions lang/src/org/partiql/lang/ast/ast.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,35 @@ import java.util.*
sealed class AstNode : Iterable<AstNode> {

/**
* returns all the children nodes.
* Returns all the children nodes.
*
* This property is [deprecated](see https://github.com/partiql/partiql-lang-kotlin/issues/396). Use
* one of the following PIG-generated classes to analyze AST nodes instead:
*
* - [org.partiql.lang.domains.PartiqlAst.Visitor]
* - [org.partiql.lang.domains.PartiqlAst.VisitorFold]
*/
@Deprecated("DO NOT USE - see kdoc, see https://github.com/partiql/partiql-lang-kotlin/issues/396")
abstract val children: List<AstNode>

/**
* Depth first iterator over all nodes.
*
* While collecting child nodes, throws [InterruptedException] if the [Thread.interrupted] flag has been set.
*
* This property is [deprecated](see https://github.com/partiql/partiql-lang-kotlin/issues/396). Use
* one of the following PIG-generated classes to analyze AST nodes instead:
*
* - [org.partiql.lang.domains.PartiqlAst.Visitor]
* - [org.partiql.lang.domains.PartiqlAst.VisitorFold]
*/
@Deprecated("DO NOT USE - see kdoc for alternatives")
override operator fun iterator(): Iterator<AstNode> {
val allNodes = mutableListOf<AstNode>()

fun depthFirstSequence(node: AstNode) {
allNodes.add(node)
node.children.map { depthFirstSequence(it) }
node.children.interruptibleMap { depthFirstSequence(it) }
}

depthFirstSequence(this)
Expand Down
9 changes: 6 additions & 3 deletions lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package org.partiql.lang.ast.passes

import org.partiql.lang.ast.*
import org.partiql.lang.util.checkThreadInterrupted

/**
* Provides a minimal interface for an AST rewriter implementation.
Expand All @@ -28,11 +29,12 @@ interface AstRewriter {
* This is the base-class for an AST rewriter which simply makes an exact copy of the original AST.
* Simple rewrites can be performed by inheritors.
*/
@Deprecated("New rewriters should implement PIG's PartiqlAst.VisitorTransform instead")
@Deprecated("New rewriters should implement PIG's VisitorTransformBase instead")
open class AstRewriterBase : AstRewriter {

override fun rewriteExprNode(node: ExprNode): ExprNode =
when (node) {
override fun rewriteExprNode(node: ExprNode): ExprNode {
checkThreadInterrupted()
return when (node) {
is Literal -> rewriteLiteral(node)
is LiteralMissing -> rewriteLiteralMissing(node)
is VariableReference -> rewriteVariableReference(node)
Expand All @@ -55,6 +57,7 @@ open class AstRewriterBase : AstRewriter {
is DateTimeType.Date -> rewriteDate(node)
is DateTimeType.Time -> rewriteTime(node)
}
}

open fun rewriteMetas(itemWithMetas: HasMetas): MetaContainer = itemWithMetas.metas

Expand Down
2 changes: 1 addition & 1 deletion lang/src/org/partiql/lang/ast/passes/AstVisitor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.partiql.lang.ast.*
*
* One `visit*` function is included for each base type in the AST.
*/
@Deprecated("Use AstNode#iterator() or AstNode#children()")
@Deprecated("Use org.lang.partiql.domains.PartiqlAst.Visitor instead")
interface AstVisitor {
/**
* Invoked by [AstWalker] for every instance of [ExprNode] encountered.
Expand Down
1 change: 1 addition & 0 deletions lang/src/org/partiql/lang/ast/passes/AstWalker.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ open class AstWalker(private val visitor: AstVisitor) {

protected open fun walkExprNode(vararg exprs: ExprNode?) {
exprs.filterNotNull().forEach { expr: ExprNode ->
checkThreadInterrupted()
visitor.visitExprNode(expr)

when (expr) {
Expand Down
11 changes: 10 additions & 1 deletion lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.partiql.lang.eval.visitors.PartiqlAstSanityValidator
import org.partiql.lang.syntax.SqlParser
import org.partiql.lang.util.*
import java.math.*
import java.time.LocalDate
import java.util.*
import kotlin.collections.*

Expand Down Expand Up @@ -207,6 +206,10 @@ internal class EvaluatingCompiler(

/**
* Compiles an [ExprNode] tree to an [Expression].
*
* Checks [Thread.interrupted] before every expression and sub-expression is compiled
* and throws [InterruptedException] if [Thread.interrupted] it has been set in the
* hope that long running compilations may be aborted by the caller.
*/
fun compile(originalAst: ExprNode): Expression {
val visitorTransformer = compileOptions.visitorTransformMode.createVisitorTransform()
Expand Down Expand Up @@ -257,7 +260,13 @@ internal class EvaluatingCompiler(
*/
fun eval(ast: ExprNode, session: EvaluationSession): ExprValue = compile(ast).eval(session)

/**
* Compiles the specified [ExprNode] into a [ThunkEnv].
*
* This function will [InterruptedException] if [Thread.interrupted] has been set.
*/
private fun compileExprNode(expr: ExprNode): ThunkEnv {
checkThreadInterrupted()
return when (expr) {
is Literal -> compileLiteral(expr)
is LiteralMissing -> compileLiteralMissing(expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import org.partiql.pig.runtime.SymbolPrimitive
*
* If provided with a query that has all of the from source aliases already specified, an exact clone is returned.
*/
class FromSourceAliasVisitorTransform : PartiqlAst.VisitorTransform() {
class FromSourceAliasVisitorTransform : VisitorTransformBase() {

private class InnerFromSourceAliasVisitorTransform : PartiqlAst.VisitorTransform() {
private class InnerFromSourceAliasVisitorTransform : VisitorTransformBase() {
private var fromSourceCounter = 0

override fun transformFromSourceScan_asAlias(node: PartiqlAst.FromSource.Scan): SymbolPrimitive? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.partiql.pig.runtime.SymbolPrimitive
*
* If provided with a query with all of the group by item aliases already specified, an exact clone is returned.
*/
class GroupByItemAliasVisitorTransform(var nestLevel: Int = 0) : PartiqlAst.VisitorTransform() {
class GroupByItemAliasVisitorTransform(var nestLevel: Int = 0) : VisitorTransformBase() {

override fun transformGroupBy(node: PartiqlAst.GroupBy): PartiqlAst.GroupBy {
return PartiqlAst.build {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.partiql.lang.eval.visitors

import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.util.interruptibleFold

/**
* A simple visitor transformer that provides a pipeline of transformers to be executed in sequential order.
Expand All @@ -11,7 +12,7 @@ class PipelinedVisitorTransform(vararg transformers: PartiqlAst.VisitorTransform
private val transformerList = transformers.toList()

override fun transformStatement(node: PartiqlAst.Statement): PartiqlAst.Statement =
transformerList.fold(node) {
transformerList.interruptibleFold(node) {
intermediateNode, transformer ->
transformer.transformStatement(intermediateNode)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.partiql.pig.runtime.SymbolPrimitive
*
* ```
*/
class SelectListItemAliasVisitorTransform : PartiqlAst.VisitorTransform() {
class SelectListItemAliasVisitorTransform : VisitorTransformBase() {

override fun transformProjectionProjectList(node: PartiqlAst.Projection.ProjectList): PartiqlAst.Projection {
return PartiqlAst.build {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.partiql.lang.ast.UniqueNameMeta
import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.eval.errNoContext

class SelectStarVisitorTransform : PartiqlAst.VisitorTransform() {
class SelectStarVisitorTransform : VisitorTransformBase() {

/**
* Copies all parts of [PartiqlAst.Expr.Select] except [newProjection] for [PartiqlAst.Projection].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ enum class StaticTypeVisitorTransformConstraints {
*/
class StaticTypeVisitorTransform(private val ion: IonSystem,
globalBindings: Bindings<StaticType>,
constraints: Set<StaticTypeVisitorTransformConstraints> = setOf()) : PartiqlAst.VisitorTransform() {
constraints: Set<StaticTypeVisitorTransformConstraints> = setOf()) : VisitorTransformBase() {

/** Used to allow certain binding lookups to occur directly in the global scope. */
private val globalEnv = wrapBindings(globalBindings, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ data class SubstitutionPair(val target: PartiqlAst.Expr, val replacement: Partiq
*
* This class is `open` to allow subclasses to restrict the nodes to which the substitution should occur.
*/
open class SubstitutionVisitorTransform(protected val substitutions: Map<PartiqlAst.Expr, SubstitutionPair>): PartiqlAst.VisitorTransform() {
open class SubstitutionVisitorTransform(protected val substitutions: Map<PartiqlAst.Expr, SubstitutionPair>): VisitorTransformBase() {

/**
* If [node] matches any of the target nodes in [substitutions], replaces the node with the replacement.
Expand All @@ -59,7 +59,7 @@ open class SubstitutionVisitorTransform(protected val substitutions: Map<Partiql
* After .copy() and copying metas is added to PIG (https://github.com/partiql/partiql-ir-generator/pull/53) change
* this and its usages to use .copy().
*/
inner class MetaVisitorTransform(private val newMetas: MetaContainer) : PartiqlAst.VisitorTransform() {
inner class MetaVisitorTransform(private val newMetas: MetaContainer) : VisitorTransformBase() {
override fun transformMetas(metas: MetaContainer): MetaContainer = newMetas
}

Expand Down
14 changes: 13 additions & 1 deletion lang/src/org/partiql/lang/eval/visitors/VisitorTransformBase.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
package org.partiql.lang.eval.visitors

import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.util.checkThreadInterrupted

/**
* Base-class for visitor transforms that provides additional functions outside of [PartiqlAst.VisitorTransform].
* Base-class for visitor transforms that provides additional `transform*` functions that outside of
* the PIG-generated [PartiqlAst.VisitorTransform] class and adds a [Thread.interrupted] check
* to [transformExpr].
*
* All transforms should derive from this class instead of [PartiqlAst.VisitorTransform] so that they can
* be interrupted if they take a long time to process large ASTs.
*/
abstract class VisitorTransformBase : PartiqlAst.VisitorTransform() {

override fun transformExpr(node: PartiqlAst.Expr): PartiqlAst.Expr {
checkThreadInterrupted()
return super.transformExpr(node)
}

/**
* Transforms the [PartiqlAst.Expr.Select] expression following the PartiQL evaluation order. That is:
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ fun basicVisitorTransforms() = PipelinedVisitorTransform(

/** A stateless visitor transform that returns the input. */
@JvmField
internal val IDENTITY_VISITOR_TRANSFORM: PartiqlAst.VisitorTransform = object : PartiqlAst.VisitorTransform() {
internal val IDENTITY_VISITOR_TRANSFORM: PartiqlAst.VisitorTransform = object : VisitorTransformBase() {
override fun transformStatement(node: PartiqlAst.Statement): PartiqlAst.Statement = node
}
6 changes: 6 additions & 0 deletions lang/src/org/partiql/lang/syntax/SqlParser.kt
Original file line number Diff line number Diff line change
Expand Up @@ -1039,12 +1039,17 @@ class SqlParser(private val ion: IonSystem) : Parser {
/**
* Parses the given token list.
*
* Throws [InterruptedException] if [Thread.interrupted] is set. This is the best place to do
* that for the parser because this is the main function called to parse an expression and so
* is called quite frequently during parsing by many parts of the parser.
*
* @param precedence The precedence of the current expression parsing.
* A negative value represents the "top-level" parsing.
*
* @return The parse tree for the given expression.
*/
internal fun List<Token>.parseExpression(precedence: Int = -1): ParseNode {
checkThreadInterrupted()
var expr = parseUnaryTerm()
var rem = expr.remaining

Expand Down Expand Up @@ -2815,6 +2820,7 @@ class SqlParser(private val ion: IonSystem) : Parser {
* If [dmlListTokenSeen] is true, it means it has been encountered at least once before while traversing the parse tree.
*/
private fun validateTopLevelNodes(node: ParseNode, level: Int, topLevelTokenSeen: Boolean, dmlListTokenSeen: Boolean) {
checkThreadInterrupted()
val isTopLevelType = when (node.type.isDml) {
// DML_LIST token type allows multiple DML keywords to be used in the same statement.
// Hence, DML keyword tokens are not treated as top level tokens if present with the DML_LIST token type
Expand Down
Loading

0 comments on commit 7926b29

Please sign in to comment.