Skip to content

Commit

Permalink
Mbz static method validation (#1914)
Browse files Browse the repository at this point in the history
* Added validation for classes passed as TypeReference based on ModelDatat.typeDefinitions
  • Loading branch information
maciej-brzezinski authored Jul 21, 2021
1 parent 7789665 commit 984d30b
Show file tree
Hide file tree
Showing 21 changed files with 222 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ case class ExpressionConfig(globalProcessVariables: Map[String, WithCategories[A
// DictInstance have the same dictId as DictDefinition
dictionaries: Map[String, WithCategories[DictDefinition]] = Map.empty,
hideMetaVariable: Boolean = false,
strictMethodsChecking: Boolean = true
strictMethodsChecking: Boolean = true,
staticMethodInvocationsChecking: Boolean = false
)

object ExpressionConfig {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package pl.touk.nussknacker.engine.benchmarks.spel

import java.util.concurrent.TimeUnit

import cats.data.Validated.{Invalid, Valid}
import org.openjdk.jmh.annotations._
import pl.touk.nussknacker.engine.TypeDefinitionSet
import pl.touk.nussknacker.engine.api.Context
import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.NodeId
import pl.touk.nussknacker.engine.api.context.ValidationContext
Expand All @@ -19,10 +19,10 @@ class SpelBenchmarkSetup(expression: String, vars: Map[String, AnyRef]) {

private val expressionDefinition = ExpressionDefinition(globalVariables = Map(), globalImports = Nil, additionalClasses = List(),
languages = LanguageConfiguration.default, optimizeCompilation = true,
strictTypeChecking = true, dictionaries = Map.empty, hideMetaVariable = false, strictMethodsChecking = true)
strictTypeChecking = true, dictionaries = Map.empty, hideMetaVariable = false, strictMethodsChecking = true, staticMethodInvocationsChecking = false)

private val expressionCompiler = ExpressionCompiler.withOptimization(
getClass.getClassLoader, new SimpleDictRegistry(Map.empty), expressionDefinition, settings = ClassExtractionSettings.Default)
getClass.getClassLoader, new SimpleDictRegistry(Map.empty), expressionDefinition, settings = ClassExtractionSettings.Default, typeDefinitionSet = TypeDefinitionSet.empty)

private val validationContext = ValidationContext(vars.mapValues(Typed.fromInstance), Map.empty)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.apache.avro.Schema
import org.apache.avro.generic.GenericData.EnumSymbol
import org.apache.avro.generic.GenericRecord
import org.scalatest.{FunSpec, Matchers}
import pl.touk.nussknacker.engine.TypeDefinitionSet
import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.NodeId
import pl.touk.nussknacker.engine.api.context.ValidationContext
import pl.touk.nussknacker.engine.api.dict.DictInstance
Expand Down Expand Up @@ -218,7 +219,7 @@ class AvroSchemaSpelExpressionSpec extends FunSpec with Matchers {

private def parse[T:TypeTag](expr: String, validationCtx: ValidationContext) : ValidatedNel[ExpressionParseError, TypedExpression] = {
SpelExpressionParser.default(getClass.getClassLoader, new SimpleDictRegistry(Map(dictId -> EmbeddedDictDefinition(Map("key1" -> "value1")))), enableSpelForceCompile = true,
strictTypeChecking = true, Nil, Standard, strictMethodsChecking = true)(ClassExtractionSettings.Default).parse(expr, validationCtx, Typed.fromDetailedType[T])
strictTypeChecking = true, Nil, Standard, strictMethodsChecking = true, staticMethodInvocationsChecking = false, TypeDefinitionSet.empty)(ClassExtractionSettings.Default).parse(expr, validationCtx, Typed.fromDetailedType[T])
}

private def wrapWithRecordSchema(fieldsDefinition: String) =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package pl.touk.nussknacker.engine

import cats.data.{NonEmptyList, Validated}
import cats.data.Validated.{Invalid, Valid}
import org.apache.commons.lang3.ClassUtils
import org.springframework.expression.{EvaluationContext, EvaluationException}
import org.springframework.expression.spel.ExpressionState
import org.springframework.expression.spel.ast.TypeReference
import pl.touk.nussknacker.engine.api.Context
import pl.touk.nussknacker.engine.api.expression.ExpressionParseError
import pl.touk.nussknacker.engine.api.typed.typing.{TypedClass, TypingResult}
import pl.touk.nussknacker.engine.definition.{DefinitionExtractor, ProcessDefinitionExtractor, TypeInfos}
import pl.touk.nussknacker.engine.spel.TypedNode
import pl.touk.nussknacker.engine.spel.ast.SpelAst.RichSpelNode

import scala.util.{Failure, Success, Try}



object TypeDefinitionSet {

def empty: TypeDefinitionSet = TypeDefinitionSet(Set.empty)

}

case class TypeDefinitionSet(typeDefinitions: Set[TypeInfos.ClazzDefinition]) {

def validateTypeReference(typeReference: TypeReference, evaluationContext: EvaluationContext): Validated[NonEmptyList[ExpressionParseError], TypedClass] = {

/**
* getValue mutates TypeReference but is still safe
* it adds values to fields type and exitTypeDescriptor but field type is transient and exitTypeDescriptor is of a primitive type (String)
*/
val typeReferenceClazz = Try(typeReference.getValue(new ExpressionState(evaluationContext)))

typeReferenceClazz match {
case Success(typeReferenceClazz) =>
typeDefinitions.find(typeDefinition => typeDefinition.clazzName.klass.equals(typeReferenceClazz)) match {
case Some(clazzDefinition: TypeInfos.ClazzDefinition) => Valid(clazzDefinition.clazzName)
case None => Invalid(NonEmptyList.of(ExpressionParseError(s"${typeReferenceClazz} is not allowed to be passed as TypeReference")))
}
case Failure(_: EvaluationException) => Invalid(NonEmptyList.of(ExpressionParseError(s"Class ${typeReference.toStringAST} does not exist")))
case Failure(exception) => throw exception
}

}

}

Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,33 @@ import pl.touk.nussknacker.engine.spel.SpelExpressionParser
import pl.touk.nussknacker.engine.sql.SqlExpressionParser
import pl.touk.nussknacker.engine.util.Implicits._
import pl.touk.nussknacker.engine.util.validated.ValidatedSyntax
import pl.touk.nussknacker.engine.{ModelData, compiledgraph, graph}
import pl.touk.nussknacker.engine.{ModelData, TypeDefinitionSet, compiledgraph, graph}

object ExpressionCompiler {

def withOptimization(loader: ClassLoader, dictRegistry: DictRegistry, expressionConfig: ExpressionDefinition[ObjectMetadata], settings: ClassExtractionSettings): ExpressionCompiler
= default(loader, dictRegistry, expressionConfig, expressionConfig.optimizeCompilation, settings)
def withOptimization(loader: ClassLoader, dictRegistry: DictRegistry, expressionConfig: ExpressionDefinition[ObjectMetadata],
settings: ClassExtractionSettings, typeDefinitionSet: TypeDefinitionSet): ExpressionCompiler
= default(loader, dictRegistry, expressionConfig, expressionConfig.optimizeCompilation, settings, typeDefinitionSet)

def withoutOptimization(loader: ClassLoader, dictRegistry: DictRegistry, expressionConfig: ExpressionDefinition[ObjectMetadata], settings: ClassExtractionSettings): ExpressionCompiler
= default(loader, dictRegistry, expressionConfig, optimizeCompilation = false, settings)
def withoutOptimization(loader: ClassLoader, dictRegistry: DictRegistry, expressionConfig: ExpressionDefinition[ObjectMetadata],
settings: ClassExtractionSettings, typeDefinitionSet: TypeDefinitionSet): ExpressionCompiler
= default(loader, dictRegistry, expressionConfig, optimizeCompilation = false, settings, typeDefinitionSet)

def withoutOptimization(modelData: ModelData): ExpressionCompiler = {
withoutOptimization(modelData.modelClassLoader.classLoader,
modelData.dictServices.dictRegistry,
modelData.processDefinition.expressionConfig,
modelData.processDefinition.settings)
modelData.processDefinition.settings,
TypeDefinitionSet(modelData.typeDefinitions))
}

private def default(loader: ClassLoader, dictRegistry: DictRegistry, expressionConfig: ExpressionDefinition[ObjectMetadata],
optimizeCompilation: Boolean, settings: ClassExtractionSettings): ExpressionCompiler = {
optimizeCompilation: Boolean, settings: ClassExtractionSettings, typeDefinitionSet: TypeDefinitionSet): ExpressionCompiler = {
val defaultParsers = Seq(
SpelExpressionParser.default(loader, dictRegistry, optimizeCompilation, expressionConfig.strictTypeChecking, expressionConfig.globalImports, SpelExpressionParser.Standard, expressionConfig.strictMethodsChecking)(settings),
SpelExpressionParser.default(loader, dictRegistry, optimizeCompilation, expressionConfig.strictTypeChecking, expressionConfig.globalImports, SpelExpressionParser.Template, expressionConfig.strictMethodsChecking)(settings),
SpelExpressionParser.default(loader, dictRegistry, optimizeCompilation, expressionConfig.strictTypeChecking,
expressionConfig.globalImports, SpelExpressionParser.Standard, expressionConfig.strictMethodsChecking, expressionConfig.staticMethodInvocationsChecking, typeDefinitionSet)(settings),
SpelExpressionParser.default(loader, dictRegistry, optimizeCompilation, expressionConfig.strictTypeChecking,
expressionConfig.globalImports, SpelExpressionParser.Template, expressionConfig.strictMethodsChecking, expressionConfig.staticMethodInvocationsChecking, typeDefinitionSet)(settings),
SqlExpressionParser)
val parsersSeq = defaultParsers ++ expressionConfig.languages.expressionParsers
val parsers = parsersSeq.map(p => p.languageId -> p).toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import pl.touk.nussknacker.engine.compile.nodecompilation.NodeCompiler.NodeCompi
import pl.touk.nussknacker.engine.compiledgraph.part.{PotentiallyStartPart, TypedEnd}
import pl.touk.nussknacker.engine.compiledgraph.{CompiledProcessParts, part}
import pl.touk.nussknacker.engine.definition.DefinitionExtractor._
import pl.touk.nussknacker.engine.definition.ProcessDefinitionExtractor
import pl.touk.nussknacker.engine.definition.ProcessDefinitionExtractor.ProcessDefinition
import pl.touk.nussknacker.engine.graph.EspProcess
import pl.touk.nussknacker.engine.graph.node.{Sink, Source => _, _}
Expand Down Expand Up @@ -246,7 +247,9 @@ protected trait ProcessCompilerBase {
object ProcessValidator {

def default(definitions: ProcessDefinition[ObjectWithMethodDef], dictRegistry: DictRegistry, classLoader: ClassLoader = getClass.getClassLoader): ProcessValidator = {
val expressionCompiler = ExpressionCompiler.withoutOptimization(classLoader, dictRegistry, definitions.expressionConfig, definitions.settings)
val typeDefinitionSet = TypeDefinitionSet(ProcessDefinitionExtractor.extractTypes(definitions))

val expressionCompiler = ExpressionCompiler.withoutOptimization(classLoader, dictRegistry, definitions.expressionConfig, definitions.settings, typeDefinitionSet)
val nodeCompiler = new NodeCompiler(definitions, expressionCompiler, classLoader, PreventInvocationCollector, RunMode.Normal)
val sub = new PartSubGraphCompiler(expressionCompiler, nodeCompiler)
new ProcessCompiler(classLoader, sub, GlobalVariablesPreparer(definitions.expressionConfig), nodeCompiler)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package pl.touk.nussknacker.engine.compile

import java.util.concurrent.TimeUnit

import cats.data.ValidatedNel
import pl.touk.nussknacker.engine.Interpreter
import pl.touk.nussknacker.engine.{Interpreter, TypeDefinitionSet}
import pl.touk.nussknacker.engine.api.async.DefaultAsyncInterpretationValue
import pl.touk.nussknacker.engine.api.context.ProcessCompilationError
import pl.touk.nussknacker.engine.api.exception.EspExceptionHandler
Expand All @@ -12,7 +11,7 @@ import pl.touk.nussknacker.engine.api.{Lifecycle, MetaData, ProcessListener}
import pl.touk.nussknacker.engine.compile.nodecompilation.NodeCompiler
import pl.touk.nussknacker.engine.compiledgraph.CompiledProcessParts
import pl.touk.nussknacker.engine.definition.DefinitionExtractor.ObjectWithMethodDef
import pl.touk.nussknacker.engine.definition.LazyInterpreterDependencies
import pl.touk.nussknacker.engine.definition.{LazyInterpreterDependencies, ProcessDefinitionExtractor}
import pl.touk.nussknacker.engine.definition.ProcessDefinitionExtractor.ProcessDefinition
import pl.touk.nussknacker.engine.dict.DictServicesFactoryLoader
import pl.touk.nussknacker.engine.expression.ExpressionEvaluator
Expand Down Expand Up @@ -42,7 +41,8 @@ object ProcessCompilerData {
val dictRegistryFactory = loadDictRegistry(userCodeClassLoader)
val dictRegistry = dictRegistryFactory.createEngineDictRegistry(definitions.expressionConfig.dictionaries)

val expressionCompiler = ExpressionCompiler.withOptimization(userCodeClassLoader, dictRegistry, definitions.expressionConfig, definitions.settings)
val typeDefinitionSet = TypeDefinitionSet(ProcessDefinitionExtractor.extractTypes(definitions))
val expressionCompiler = ExpressionCompiler.withOptimization(userCodeClassLoader, dictRegistry, definitions.expressionConfig, definitions.settings, typeDefinitionSet)
//for testing environment it's important to take classloader from user jar
val nodeCompiler = new NodeCompiler(definitions, expressionCompiler, userCodeClassLoader, resultsCollector, runMode)
val subCompiler = new PartSubGraphCompiler(expressionCompiler, nodeCompiler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import pl.touk.nussknacker.engine.api.MethodToInvoke
import pl.touk.nussknacker.engine.api.context.transformation.{GenericNodeTransformation, OutputVariableNameValue, TypedNodeDependencyValue}
import pl.touk.nussknacker.engine.api.definition.{OutputVariableNameDependency, Parameter, TypedNodeDependency, WithExplicitMethodToInvoke, WithExplicitTypesToExtract}
import pl.touk.nussknacker.engine.api.process.{ClassExtractionSettings, SingleNodeConfig, WithCategories}
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypingResult, Unknown}
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedClass, TypingResult, Unknown}
import pl.touk.nussknacker.engine.api.util.ReflectUtils
import pl.touk.nussknacker.engine.definition.DefinitionExtractor.{ObjectMetadata, _}
import pl.touk.nussknacker.engine.definition.MethodDefinitionExtractor.MethodDefinition
Expand All @@ -31,13 +31,13 @@ class DefinitionExtractor[T](methodDefinitionExtractor: MethodDefinitionExtracto

(obj match {
//TODO: how validators/editors in NodeConfig should be handled for GenericNodeTransformation/WithExplicitMethodToInvoke?
case e:GenericNodeTransformation[_] =>
case e: GenericNodeTransformation[_] =>
// Here in general we do not have a specified "returnType", hence Undefined/Void
val returnType = if (e.nodeDependencies.contains(OutputVariableNameDependency)) Unknown else Typed[Void]
val parametersList = StandardParameterEnrichment.enrichParameterDefinitions(e.initialParameters, objWithCategories.nodeConfig)
val definition = ObjectDefinition(parametersList, returnType, objWithCategories.categories, objWithCategories.nodeConfig)
Right(GenericNodeTransformationMethodDef(e, definition))
case e:WithExplicitMethodToInvoke =>
case e: WithExplicitMethodToInvoke =>
WithExplicitMethodToInvokeMethodDefinitionExtractor.extractMethodDefinition(e,
classOf[WithExplicitMethodToInvoke].getMethods.find(_.getName == "invoke").get, nodeConfig).right.map(fromMethodDefinition)
case _ =>
Expand Down Expand Up @@ -73,7 +73,7 @@ object DefinitionExtractor {
def categories: List[String]

// TODO: Use ContextTransformation API to check if custom node is adding some output variable
def hasNoReturn : Boolean = Set[TypingResult](Typed[Void], Typed[Unit], Typed[BoxedUnit]).contains(returnType)
def hasNoReturn: Boolean = Set[TypingResult](Typed[Void], Typed[Unit], Typed[BoxedUnit]).contains(returnType)

}

Expand All @@ -83,7 +83,7 @@ object DefinitionExtractor {

def invokeMethod(params: Map[String, Any],
outputVariableNameOpt: Option[String],
additional: Seq[AnyRef]) : Any
additional: Seq[AnyRef]): Any

def objectDefinition: ObjectDefinition

Expand Down Expand Up @@ -139,11 +139,11 @@ object DefinitionExtractor {
case class FinalStateValue(value: Option[Any])

case class StandardObjectWithMethodDef(obj: Any,
methodDef: MethodDefinition,
objectDefinition: ObjectDefinition) extends ObjectWithMethodDef with LazyLogging {
methodDef: MethodDefinition,
objectDefinition: ObjectDefinition) extends ObjectWithMethodDef with LazyLogging {
def invokeMethod(params: Map[String, Any],
outputVariableNameOpt: Option[String],
additional: Seq[AnyRef]) : Any = {
additional: Seq[AnyRef]): Any = {
val values = methodDef.orderedDependencies.prepareValues(params, outputVariableNameOpt, additional)
try {
methodDef.invocation(obj, values)
Expand All @@ -166,16 +166,16 @@ object DefinitionExtractor {
}

case class ObjectDefinition(parameters: List[Parameter],
returnType: TypingResult,
categories: List[String],
nodeConfig: SingleNodeConfig) extends ObjectMetadata
returnType: TypingResult,
categories: List[String],
nodeConfig: SingleNodeConfig) extends ObjectMetadata


object ObjectWithMethodDef {

import cats.syntax.semigroup._

def forMap[T](objs: Map[String, WithCategories[_<:T]], methodExtractor: MethodDefinitionExtractor[T], externalConfig: Map[String, SingleNodeConfig]): Map[String, ObjectWithMethodDef] = {
def forMap[T](objs: Map[String, WithCategories[_ <: T]], methodExtractor: MethodDefinitionExtractor[T], externalConfig: Map[String, SingleNodeConfig]): Map[String, ObjectWithMethodDef] = {
objs.map { case (id, obj) =>
val config = externalConfig.getOrElse(id, SingleNodeConfig.zero) |+| obj.nodeConfig
id -> (obj, config)
Expand Down Expand Up @@ -211,7 +211,7 @@ object DefinitionExtractor {

def explicitTypes(obj: ObjectWithMethodDef): List[TypingResult] = {
obj.obj match {
case explicit : WithExplicitTypesToExtract => explicit.typesToExtract
case explicit: WithExplicitTypesToExtract => explicit.typesToExtract
case _ => Nil
}
}
Expand All @@ -238,12 +238,12 @@ object DefinitionExtractor {
}

object TypeInfos {

@JsonCodec(encodeOnly = true) case class Parameter(name: String, refClazz: TypingResult)

@JsonCodec(encodeOnly = true) case class MethodInfo(parameters: List[Parameter], refClazz: TypingResult, description: Option[String], varArgs: Boolean)

case class ClazzDefinition(clazzName: TypingResult, methods: Map[String, List[MethodInfo]]) {
case class ClazzDefinition(clazzName: TypedClass, methods: Map[String, List[MethodInfo]]) {

def getPropertyOrFieldType(methodName: String): Option[TypingResult] = {
val filtered = methods.get(methodName).toList
Expand Down
Loading

0 comments on commit 984d30b

Please sign in to comment.