diff --git a/build.sbt b/build.sbt index 54924d9d..1c5d7756 100644 --- a/build.sbt +++ b/build.sbt @@ -195,7 +195,11 @@ lazy val core = project ProblemFilters.exclude[IncompatibleResultTypeProblem]( "sangria.schema.WithInputTypeRendering.deprecationTracker"), ProblemFilters.exclude[ReversedMissingMethodProblem]( - "sangria.schema.WithInputTypeRendering.deprecationTracker") + "sangria.schema.WithInputTypeRendering.deprecationTracker"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.validation.RuleBasedQueryValidator.validateInputDocument"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "sangria.validation.RuleBasedQueryValidator.validateInputDocument") ), Test / testOptions += Tests.Argument(TestFrameworks.ScalaTest, "-oF"), libraryDependencies ++= Seq( diff --git a/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala b/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala index f0169695..592cc8c5 100644 --- a/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala +++ b/modules/benchmarks/src/main/scala/sangria/benchmarks/OverlappingFieldsCanBeMergedBenchmark.scala @@ -98,7 +98,7 @@ class OverlappingFieldsCanBeMergedBenchmark { bh.consume(doValidate(validator, deepAbstractConcrete)) private def doValidate(validator: QueryValidator, document: Document): Vector[Violation] = { - val result = validator.validateQuery(schema, document, None) + val result = validator.validateQuery(schema, document, Map.empty, None) require(result.isEmpty) result } diff --git a/modules/core/src/main/scala/sangria/execution/Executor.scala b/modules/core/src/main/scala/sangria/execution/Executor.scala index aa166324..95daf472 100644 --- a/modules/core/src/main/scala/sangria/execution/Executor.scala +++ b/modules/core/src/main/scala/sangria/execution/Executor.scala @@ -29,107 +29,116 @@ case class Executor[Ctx, Root]( operationName: Option[String] = None, variables: Input = emptyMapVars )(implicit um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] = { - val (violations, validationTiming) = - TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit)) + val scalarMiddleware = Middleware.composeFromScalarMiddleware(middleware, userContext) + val valueCollector = new ValueCollector[Ctx, Input]( + schema, + variables, + queryAst.sourceMapper, + deprecationTracker, + userContext, + exceptionHandler, + scalarMiddleware, + false)(um) - if (violations.nonEmpty) - Future.failed(ValidationError(violations, exceptionHandler)) - else { - val scalarMiddleware = Middleware.composeFromScalarMiddleware(middleware, userContext) - val valueCollector = new ValueCollector[Ctx, Input]( - schema, - variables, - queryAst.sourceMapper, - deprecationTracker, - userContext, - exceptionHandler, + val operationCtx = for { + operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) + unmarshalledVariables <- valueCollector.getVariableValues( + operation.variables, scalarMiddleware, - false)(um) - - val executionResult = for { - operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) - unmarshalledVariables <- valueCollector.getVariableValues( - operation.variables, - scalarMiddleware, - errorsLimit - ) - fieldCollector = new FieldCollector[Ctx, Root]( - schema, - queryAst, - unmarshalledVariables, - queryAst.sourceMapper, - valueCollector, - exceptionHandler) - tpe <- Executor.getOperationRootType( - schema, - exceptionHandler, - operation, - queryAst.sourceMapper) - fields <- fieldCollector.collectFields(ExecutionPath.empty, tpe, Vector(operation)) - } yield { - val preparedFields = fields.fields.flatMap { - case CollectedField(_, astField, Success(_)) => - val allFields = - tpe.getField(schema, astField.name).asInstanceOf[Vector[Field[Ctx, Root]]] - val field = allFields.head - val args = valueCollector.getFieldArgumentValues( - ExecutionPath.empty.add(astField, tpe), - Some(astField), - field.arguments, - astField.arguments, - unmarshalledVariables) + errorsLimit + ) + } yield (operation, unmarshalledVariables) - args.toOption.map(PreparedField(field, _)) - case _ => None - } + operationCtx match { + case Failure(error) => Future.failed(error) + case Success((operation, unmarshalledVariables)) => + val (violations, validationTiming) = + TimeMeasurement.measure( + queryValidator.validateQuery(schema, queryAst, unmarshalledVariables, errorsLimit)) - QueryReducerExecutor - .reduceQuery( - schema, - queryReducers, - exceptionHandler, - fieldCollector, - valueCollector, - unmarshalledVariables, - tpe, - fields, - userContext) - .map { case (newCtx, timing) => - new PreparedQuery[Ctx, Root, Input]( - queryAst, + if (violations.nonEmpty) + Future.failed(ValidationError(violations, exceptionHandler)) + else { + val executionResult = for { + tpe <- Executor.getOperationRootType( + schema, + exceptionHandler, operation, - tpe, - newCtx, - root, - preparedFields, - (c: Ctx, r: Root, m: ResultMarshaller, scheme: ExecutionScheme) => - executeOperation( + queryAst.sourceMapper) + fieldCollector = new FieldCollector[Ctx, Root]( + schema, + queryAst, + unmarshalledVariables, + queryAst.sourceMapper, + valueCollector, + exceptionHandler) + fields <- fieldCollector.collectFields(ExecutionPath.empty, tpe, Vector(operation)) + } yield { + val preparedFields = fields.fields.flatMap { + case CollectedField(_, astField, Success(_)) => + val allFields = + tpe.getField(schema, astField.name).asInstanceOf[Vector[Field[Ctx, Root]]] + val field = allFields.head + val args = valueCollector.getFieldArgumentValues( + ExecutionPath.empty.add(astField, tpe), + Some(astField), + field.arguments, + astField.arguments, + unmarshalledVariables) + + args.toOption.map(PreparedField(field, _)) + case _ => None + } + + QueryReducerExecutor + .reduceQuery( + schema, + queryReducers, + exceptionHandler, + fieldCollector, + valueCollector, + unmarshalledVariables, + tpe, + fields, + userContext) + .map { case (newCtx, timing) => + new PreparedQuery[Ctx, Root, Input]( queryAst, - operationName, - variables, - um, operation, - queryAst.sourceMapper, - valueCollector, - fieldCollector, - m, - unmarshalledVariables, tpe, - fields, - c, - r, - scheme, - validationTiming, - timing + newCtx, + root, + preparedFields, + (c: Ctx, r: Root, m: ResultMarshaller, scheme: ExecutionScheme) => + executeOperation( + queryAst, + operationName, + variables, + um, + operation, + queryAst.sourceMapper, + valueCollector, + fieldCollector, + m, + unmarshalledVariables, + tpe, + fields, + c, + r, + scheme, + validationTiming, + timing + ) ) - ) + } } - } - executionResult match { - case Success(future) => future - case Failure(error) => Future.failed(error) - } + executionResult match { + case Success(future) => future + case Failure(error) => Future.failed(error) + } + } + } } @@ -143,82 +152,92 @@ case class Executor[Ctx, Root]( marshaller: ResultMarshaller, um: InputUnmarshaller[Input], scheme: ExecutionScheme): scheme.Result[Ctx, marshaller.Node] = { - val (violations, validationTiming) = - TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit)) - if (violations.nonEmpty) - scheme.failed(ValidationError(violations, exceptionHandler)) - else { - val scalarMiddleware = Middleware.composeFromScalarMiddleware(middleware, userContext) - val valueCollector = new ValueCollector[Ctx, Input]( - schema, - variables, - queryAst.sourceMapper, - deprecationTracker, - userContext, - exceptionHandler, + val scalarMiddleware = Middleware.composeFromScalarMiddleware(middleware, userContext) + val valueCollector = new ValueCollector[Ctx, Input]( + schema, + variables, + queryAst.sourceMapper, + deprecationTracker, + userContext, + exceptionHandler, + scalarMiddleware, + false)(um) + + val operationCtx = for { + operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) + unmarshalledVariables <- valueCollector.getVariableValues( + operation.variables, scalarMiddleware, - false)(um) + errorsLimit + ) + } yield (operation, unmarshalledVariables) - val executionResult = for { - operation <- Executor.getOperation(exceptionHandler, queryAst, operationName) - unmarshalledVariables <- valueCollector.getVariableValues( - operation.variables, - scalarMiddleware, - errorsLimit - ) - fieldCollector = new FieldCollector[Ctx, Root]( - schema, - queryAst, - unmarshalledVariables, - queryAst.sourceMapper, - valueCollector, - exceptionHandler) - tpe <- Executor.getOperationRootType( - schema, - exceptionHandler, - operation, - queryAst.sourceMapper) - fields <- fieldCollector.collectFields(ExecutionPath.empty, tpe, Vector(operation)) - } yield { - val reduced = QueryReducerExecutor.reduceQuery( - schema, - queryReducers, - exceptionHandler, - fieldCollector, - valueCollector, - unmarshalledVariables, - tpe, - fields, - userContext) - scheme.flatMapFuture(reduced) { case (newCtx, timing) => - executeOperation( - queryAst, - operationName, - variables, - um, - operation, - queryAst.sourceMapper, - valueCollector, - fieldCollector, - marshaller, - unmarshalledVariables, - tpe, - fields, - newCtx, - root, - scheme, - validationTiming, - timing - ) - } - } + operationCtx match { + case Failure(error) => scheme.failed(error) + case Success((operation, unmarshalledVariables)) => + val (violations, validationTiming) = + TimeMeasurement.measure( + queryValidator.validateQuery(schema, queryAst, unmarshalledVariables, errorsLimit)) - executionResult match { - case Success(result) => result - case Failure(error) => scheme.failed(error) - } + if (violations.nonEmpty) + scheme.failed(ValidationError(violations, exceptionHandler)) + else { + val executionResult = for { + tpe <- Executor.getOperationRootType( + schema, + exceptionHandler, + operation, + queryAst.sourceMapper) + fieldCollector = new FieldCollector[Ctx, Root]( + schema, + queryAst, + unmarshalledVariables, + queryAst.sourceMapper, + valueCollector, + exceptionHandler) + fields <- fieldCollector.collectFields(ExecutionPath.empty, tpe, Vector(operation)) + } yield { + val reduced = QueryReducerExecutor.reduceQuery( + schema, + queryReducers, + exceptionHandler, + fieldCollector, + valueCollector, + unmarshalledVariables, + tpe, + fields, + userContext) + scheme.flatMapFuture(reduced) { case (newCtx, timing) => + executeOperation( + queryAst, + operationName, + variables, + um, + operation, + queryAst.sourceMapper, + valueCollector, + fieldCollector, + marshaller, + unmarshalledVariables, + tpe, + fields, + newCtx, + root, + scheme, + validationTiming, + timing + ) + } + } + + executionResult match { + case Success(result) => result + case Failure(error) => scheme.failed(error) + } + } } + } private def executeOperation[Input]( diff --git a/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala b/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala index a70ba15e..f92c9bbc 100644 --- a/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala +++ b/modules/core/src/main/scala/sangria/execution/InputDocumentMaterializer.scala @@ -28,16 +28,17 @@ case class InputDocumentMaterializer[Vars]( None, false)(iu) - val violations = QueryValidator.default.validateInputDocument(schema, document, inputType) + val variableDefinitions = inferVariableDefinitions(document, inputType) - if (violations.nonEmpty) - Failure(InputDocumentMaterializationError(violations, ExceptionHandler.empty)) - else { - val variableDefinitions = inferVariableDefinitions(document, inputType) + collector.getVariableValues(variableDefinitions, None) match { + case Failure(e) => Failure(e) + case Success(vars) => + val violations = + QueryValidator.default.validateInputDocument(schema, document, inputType, vars) - collector.getVariableValues(variableDefinitions, None) match { - case Failure(e) => Failure(e) - case Success(vars) => + if (violations.nonEmpty) + Failure(InputDocumentMaterializationError(violations, ExceptionHandler.empty)) + else { try Success(document.values.flatMap { value => collector.coercionHelper.coerceInputValue( @@ -56,7 +57,7 @@ case class InputDocumentMaterializer[Vars]( catch { case NonFatal(e) => Failure(e) } - } + } } } diff --git a/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala b/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala index 258ec369..d6d3dc82 100644 --- a/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala +++ b/modules/core/src/main/scala/sangria/execution/QueryReducerExecutor.scala @@ -22,7 +22,7 @@ object QueryReducerExecutor { middleware: List[Middleware[Ctx]] = Nil, errorsLimit: Option[Int] = None )(implicit executionContext: ExecutionContext): Future[(Ctx, TimeMeasurement)] = { - val violations = queryValidator.validateQuery(schema, queryAst, errorsLimit) + val violations = queryValidator.validateQuery(schema, queryAst, Map.empty, errorsLimit) if (violations.nonEmpty) Future.failed(ValidationError(violations, exceptionHandler)) diff --git a/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala b/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala index 795fe07e..7fba78be 100644 --- a/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala +++ b/modules/core/src/main/scala/sangria/execution/batch/BatchExecutor.scala @@ -101,7 +101,10 @@ object BatchExecutor { inferVariableDefinitions, exceptionHandler)) .flatMap { case res @ (updatedDocument, _) => - val violations = queryValidator.validateQuery(schema, updatedDocument, errorsLimit) + // we're not going to pass variables here, as we call validateQuery again on + // executeIndividual which has the unmarshalled variables at that point + val violations = + queryValidator.validateQuery(schema, updatedDocument, Map.empty, errorsLimit) if (violations.nonEmpty) Failure(ValidationError(violations, exceptionHandler)) else Success(res) diff --git a/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala b/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala index 7d711e68..c0549829 100644 --- a/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala +++ b/modules/core/src/main/scala/sangria/schema/ResolverBasedAstSchemaBuilder.scala @@ -63,7 +63,8 @@ class ResolverBasedAstSchemaBuilder[Ctx](val resolvers: Seq[AstSchemaResolver[Ct schema: ast.Document, validator: QueryValidator = ResolverBasedAstSchemaBuilder.validator, errorsLimit: Option[Int] = None): Vector[Violation] = - allowKnownDynamicDirectives(validator.validateQuery(validationSchema, schema, errorsLimit)) + allowKnownDynamicDirectives( + validator.validateQuery(validationSchema, schema, Map.empty, errorsLimit)) def validateSchemaWithException( schema: ast.Document, diff --git a/modules/core/src/main/scala/sangria/validation/QueryValidator.scala b/modules/core/src/main/scala/sangria/validation/QueryValidator.scala index 635e2250..7e776e5b 100644 --- a/modules/core/src/main/scala/sangria/validation/QueryValidator.scala +++ b/modules/core/src/main/scala/sangria/validation/QueryValidator.scala @@ -3,6 +3,7 @@ package sangria.validation import sangria.ast import sangria.ast.AstVisitorCommand._ import sangria.ast.{AstVisitor, AstVisitorCommand, SourceMapper} +import sangria.execution import sangria.renderer.SchemaRenderer import sangria.schema._ import sangria.validation.rules._ @@ -14,6 +15,7 @@ trait QueryValidator { def validateQuery( schema: Schema[_, _], queryAst: ast.Document, + variableValues: Map[String, execution.VariableValue], errorsLimit: Option[Int]): Vector[Violation] } @@ -56,6 +58,7 @@ object QueryValidator { def validateQuery( schema: Schema[_, _], queryAst: ast.Document, + variableValues: Map[String, execution.VariableValue], errorsLimit: Option[Int]): Vector[Violation] = Vector.empty } @@ -66,12 +69,15 @@ class RuleBasedQueryValidator(rules: List[ValidationRule]) extends QueryValidato def validateQuery( schema: Schema[_, _], queryAst: ast.Document, - errorsLimit: Option[Int]): Vector[Violation] = { + variables: Map[String, execution.VariableValue], + errorsLimit: Option[Int] + ): Vector[Violation] = { val ctx = new ValidationContext( schema, queryAst, queryAst.sourceMapper, new TypeInfo(schema), + variables, errorsLimit) validateUsingRules(queryAst, ctx, rules.map(_.visitor(ctx)), topLevel = true) @@ -82,9 +88,11 @@ class RuleBasedQueryValidator(rules: List[ValidationRule]) extends QueryValidato def validateInputDocument( schema: Schema[_, _], doc: ast.InputDocument, - inputTypeName: String): Vector[Violation] = + inputTypeName: String, + variables: Map[String, execution.VariableValue] + ): Vector[Violation] = schema.getInputType(ast.NamedType(inputTypeName)) match { - case Some(it) => validateInputDocument(schema, doc, it) + case Some(it) => validateInputDocument(schema, doc, it, variables) case None => throw new IllegalStateException( s"Can't find input type '$inputTypeName' in the schema. Known input types are: ${schema.inputTypes.keys.toVector.sorted @@ -94,10 +102,18 @@ class RuleBasedQueryValidator(rules: List[ValidationRule]) extends QueryValidato def validateInputDocument( schema: Schema[_, _], doc: ast.InputDocument, - inputType: InputType[_]): Vector[Violation] = { + inputType: InputType[_], + variables: Map[String, execution.VariableValue] + ): Vector[Violation] = { val typeInfo = new TypeInfo(schema, Some(inputType)) - val ctx = ValidationContext(schema, ast.Document.emptyStub, doc.sourceMapper, typeInfo) + val ctx = ValidationContext( + schema, + ast.Document.emptyStub, + doc.sourceMapper, + typeInfo, + variables + ) validateUsingRules(doc, ctx, rules.map(_.visitor(ctx)), topLevel = true) @@ -164,6 +180,7 @@ class ValidationContext( val doc: ast.Document, val sourceMapper: Option[SourceMapper], val typeInfo: TypeInfo, + val variables: Map[String, execution.VariableValue], errorsLimit: Option[Int]) { // Using mutable data-structures and mutability to minimize validation footprint @@ -194,9 +211,10 @@ object ValidationContext { schema: Schema[_, _], doc: ast.Document, sourceMapper: Option[SourceMapper], - typeInfo: TypeInfo + typeInfo: TypeInfo, + variables: Map[String, execution.VariableValue] ): ValidationContext = - new ValidationContext(schema, doc, sourceMapper, typeInfo, None) + new ValidationContext(schema, doc, sourceMapper, typeInfo, variables, None) @deprecated( "The validations are now implemented as a part of `ValuesOfCorrectType` validation.", diff --git a/modules/core/src/main/scala/sangria/validation/rules/ExactlyOneOfFieldGiven.scala b/modules/core/src/main/scala/sangria/validation/rules/ExactlyOneOfFieldGiven.scala index d3ac6921..faad0d37 100644 --- a/modules/core/src/main/scala/sangria/validation/rules/ExactlyOneOfFieldGiven.scala +++ b/modules/core/src/main/scala/sangria/validation/rules/ExactlyOneOfFieldGiven.scala @@ -1,37 +1,99 @@ package sangria.validation.rules import sangria.ast +import sangria.execution +import sangria.execution.Trinary.{Defined, NullWithDefault} import sangria.schema import sangria.ast.AstVisitorCommand import sangria.validation._ +import sangria.marshalling.CoercedScalaResultMarshaller /** For oneOf input objects, exactly one field should be non-null. */ class ExactlyOneOfFieldGiven extends ValidationRule { + val marshaller = CoercedScalaResultMarshaller.default + override def visitor(ctx: ValidationContext) = new AstValidatingVisitor { - override val onEnter: ValidationVisit = { case ast.ObjectValue(fields, _, pos) => - ctx.typeInfo.inputType match { - case Some(inputType) => - inputType.namedInputType match { - case schema.InputObjectType(name, _, _, directives, _) if directives.exists { d => - d.name == schema.OneOfDirective.name - } => - val nonNullFields = fields.filter { field => - field.value match { - case ast.NullValue(_, _) => false - case _ => true + private def getResolvedVariableValue( + name: String, + inputType: schema.InputType[_], + variableValues: Map[String, execution.VariableValue] + ): Option[Any] = { + val variableValue = ctx.variables.get(name) + + variableValue.map(_.resolve(marshaller, marshaller, inputType)) match { + case Some(Right(Defined(resolved))) => Some(resolved) + case Some(Right(NullWithDefault(resolved))) => Some(resolved) + case _ => None + } + } + + private def hasOneOfDirective(inputObject: schema.InputObjectType[_]) = + inputObject.astDirectives.exists(_.name == schema.OneOfDirective.name) + + private def visitNode( + inputType: Option[schema.InputType[_]], + node: Either[ast.ObjectValue, ast.VariableValue] + ) = + inputType.fold(AstVisitorCommand.RightContinue) { inputType => + inputType.namedInputType match { + case namedInputType: schema.InputObjectType[_] if hasOneOfDirective(namedInputType) => + val pos = node match { + case Left(node) => node.location + case Right(node) => node.location + } + + val nonNullFields = node match { + case Left(ast.ObjectValue(fields, _, _)) => + fields.filter { field => + field.value match { + case ast.NullValue(_, _) => false + case ast.VariableValue(name, _, _) => + val fieldInputType = namedInputType.fieldsByName + .get(field.name) + .map(_.fieldType) + + val variableValue = fieldInputType.flatMap { fieldInputType => + getResolvedVariableValue(name, fieldInputType, ctx.variables) + } + + variableValue.isDefined + case _ => true + } } - } + case Right(ast.VariableValue(name, _, _)) => + val variableValue = getResolvedVariableValue(name, namedInputType, ctx.variables) - nonNullFields.size match { - case 1 => AstVisitorCommand.RightContinue - case _ => - Left(Vector(NotExactlyOneOfField(name, ctx.sourceMapper, pos.toList))) - } + try + variableValue match { + case Some(resolved) => + val variableObj = resolved.asInstanceOf[Map[String, Any]] + namedInputType.fields.filter { field => + variableObj.get(field.name).fold(false)(_ != None) + } + case _ => Vector.empty + } + catch { + // could get this from asInstanceOf failing for unexpected variable type. + // other validation will cover this problem. + case _: Throwable => Vector.empty + } + } - case _ => AstVisitorCommand.RightContinue - } - case None => AstVisitorCommand.RightContinue + nonNullFields.size match { + case 1 => AstVisitorCommand.RightContinue + case _ => + Left( + Vector( + NotExactlyOneOfField(namedInputType.name, ctx.sourceMapper, pos.toList) + ) + ) + } + case _ => AstVisitorCommand.RightContinue + } } + override val onEnter: ValidationVisit = { + case node: ast.ObjectValue => visitNode(ctx.typeInfo.inputType, Left(node)) + case node: ast.VariableValue => visitNode(ctx.typeInfo.inputType, Right(node)) } } } diff --git a/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala b/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala index 30d7fc17..978e83c1 100644 --- a/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala +++ b/modules/core/src/test/scala/sangria/execution/InputDocumentMaterializerSpec.scala @@ -111,7 +111,7 @@ class InputDocumentMaterializerSpec extends AnyWordSpec with Matchers with Strin } """ - val errors = QueryValidator.default.validateInputDocument(schema, inp, "Config") + val errors = QueryValidator.default.validateInputDocument(schema, inp, "Config", Map.empty) assertViolations( errors, @@ -164,7 +164,7 @@ class InputDocumentMaterializerSpec extends AnyWordSpec with Matchers with Strin } """ - val errors = QueryValidator.default.validateInputDocument(schema, inp, "Config") + val errors = QueryValidator.default.validateInputDocument(schema, inp, "Config", Map.empty) assertViolations( errors, diff --git a/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala b/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala index 203db9ce..4cc424a3 100644 --- a/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala +++ b/modules/core/src/test/scala/sangria/starWars/StartWarsValidationSpec.scala @@ -30,7 +30,8 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, Map.empty, None) should be( + Symbol("empty")) } "Notes that non-existent fields are invalid" in { @@ -42,7 +43,11 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 + QueryValidator.default.validateQuery( + StarWarsSchema, + query, + Map.empty, + None) should have size 1 } "Requires fields on objects" in { @@ -52,7 +57,11 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 + QueryValidator.default.validateQuery( + StarWarsSchema, + query, + Map.empty, + None) should have size 1 } "Disallows fields on scalars" in { @@ -66,7 +75,11 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 + QueryValidator.default.validateQuery( + StarWarsSchema, + query, + Map.empty, + None) should have size 1 } "Disallows object fields on interfaces" in { @@ -79,7 +92,11 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should have size 1 + QueryValidator.default.validateQuery( + StarWarsSchema, + query, + Map.empty, + None) should have size 1 } "Allows object fields in fragments" in { @@ -96,7 +113,8 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, Map.empty, None) should be( + Symbol("empty")) } "Allows object fields in inline fragments" in { @@ -111,7 +129,8 @@ class StartWarsValidationSpec extends AnyWordSpec with Matchers with FutureResul } """) - QueryValidator.default.validateQuery(StarWarsSchema, query, None) should be(Symbol("empty")) + QueryValidator.default.validateQuery(StarWarsSchema, query, Map.empty, None) should be( + Symbol("empty")) } } } diff --git a/modules/core/src/test/scala/sangria/util/CatsSupport.scala b/modules/core/src/test/scala/sangria/util/CatsSupport.scala index 60d56f92..573894af 100644 --- a/modules/core/src/test/scala/sangria/util/CatsSupport.scala +++ b/modules/core/src/test/scala/sangria/util/CatsSupport.scala @@ -221,7 +221,7 @@ object CatsScenarioExecutor extends FutureResultSupport { case Validate(rules) => ValidationResult( new RuleBasedQueryValidator(rules.toList) - .validateQuery(`given`.schema, QueryParser.parse(`given`.query).get, None)) + .validateQuery(`given`.schema, QueryParser.parse(`given`.query).get, Map.empty, None)) case Execute(validate, value, vars, op) => val validator = if (validate) QueryValidator.default else QueryValidator.empty diff --git a/modules/core/src/test/scala/sangria/util/ValidationSupport.scala b/modules/core/src/test/scala/sangria/util/ValidationSupport.scala index 95db047d..6e8c289b 100644 --- a/modules/core/src/test/scala/sangria/util/ValidationSupport.scala +++ b/modules/core/src/test/scala/sangria/util/ValidationSupport.scala @@ -11,6 +11,8 @@ import org.scalatest.matchers.should.Matchers import sangria.ast.Document import sangria.util.tag.@@ import sangria.marshalling.FromInput.CoercedScalaResult +import sangria.execution.ExceptionHandler +import sangria.execution.ValueCollector trait ValidationSupport extends Matchers { type TestField = Field[Unit, Unit] @@ -247,6 +249,7 @@ trait ValidationSupport extends Matchers { val QueryRoot = ObjectType( "QueryRoot", List[TestField]( + Field("foo", OptionType(StringType), resolve = _ => None), Field( "human", OptionType(Human), @@ -362,10 +365,14 @@ trait ValidationSupport extends Matchers { s: Schema[_, _], rules: List[ValidationRule], query: String, - expectedErrors: Seq[(String, Seq[Pos])]) = { + expectedErrors: Seq[(String, Seq[Pos])], + vars: (String, String) = "" -> "" + ) = { val Success(doc) = QueryParser.parse(query) - assertViolations(validator(rules).validateQuery(s, doc, None), expectedErrors: _*) + val variables = getVariableValues(s, vars) + + assertViolations(validator(rules).validateQuery(s, doc, variables, None), expectedErrors: _*) } def expectInputInvalid( @@ -376,12 +383,47 @@ trait ValidationSupport extends Matchers { typeName: String) = { val Success(doc) = QueryParser.parseInputDocumentWithVariables(query) - assertViolations(validator(rules).validateInputDocument(s, doc, typeName), expectedErrors: _*) + assertViolations( + validator(rules).validateInputDocument(s, doc, typeName, Map.empty), + expectedErrors: _*) + } + + private def getVariableValues(s: Schema[_, _], vars: (String, String)) = { + import spray.json._ + import sangria.marshalling.sprayJson._ + + val valueCollector = new ValueCollector( + s, + (if (vars._2.nonEmpty) vars._2 else "{}").parseJson, + None, + None, + (), + ExceptionHandler.empty, + None, + true) + + valueCollector + .getVariableValues( + QueryParser + .parse(s"query Foo${if (vars._1.nonEmpty) "(" + vars._1 + ")" else ""} {foo}") + .get + .operations(Some("Foo")) + .variables, + None) + .get } - def expectValid(s: Schema[_, _], rules: List[ValidationRule], query: String) = { + def expectValid( + s: Schema[_, _], + rules: List[ValidationRule], + query: String, + vars: (String, String) = "" -> "" + ) = { val Success(doc) = QueryParser.parse(query) - val errors = validator(rules).validateQuery(s, doc, None) + + val variables = getVariableValues(s, vars) + + val errors = validator(rules).validateQuery(s, doc, variables, None) withClue(renderViolations(errors)) { errors should have size 0 @@ -396,15 +438,18 @@ trait ValidationSupport extends Matchers { val Success(doc) = QueryParser.parseInputDocumentWithVariables(query) withClue("Should validate") { - validator(rules).validateInputDocument(s, doc, typeName) should have size 0 + validator(rules).validateInputDocument(s, doc, typeName, Map.empty) should have size 0 } } def expectPassesRule(rule: ValidationRule, query: String) = expectValid(schema, rule :: Nil, query) - def expectPasses(query: String) = - expectValid(schema, defaultRule.get :: Nil, query) + def expectPasses( + query: String, + vars: (String, String) = "" -> "" + ) = + expectValid(schema, defaultRule.get :: Nil, query, vars) def expectInputPasses(typeName: String, query: String) = expectValidInput(schema, defaultRule.get :: Nil, query, typeName) @@ -419,12 +464,18 @@ trait ValidationSupport extends Matchers { query, expectedErrors.map { case (msg, pos) => msg -> pos.toList }) - def expectFails(query: String, expectedErrors: List[(String, Option[Pos])]) = + def expectFails( + query: String, + expectedErrors: List[(String, Option[Pos])], + vars: (String, String) = "" -> "" + ) = expectInvalid( schema, defaultRule.get :: Nil, query, - expectedErrors.map { case (msg, pos) => msg -> pos.toList }) + expectedErrors.map { case (msg, pos) => msg -> pos.toList }, + vars + ) def expectInputFails(typeName: String, query: String, expectedErrors: List[(String, List[Pos])]) = expectInputInvalid(schema, defaultRule.get :: Nil, query, expectedErrors, typeName) @@ -446,7 +497,7 @@ trait ValidationSupport extends Matchers { violationCheck: Violation => Unit): Unit = { val schema = Schema.buildFromAst(initialSchemaDoc) val Success(docUnderTest) = QueryParser.parse(sdlUnderTest) - val violations = validator(v.toList).validateQuery(schema, docUnderTest, None) + val violations = validator(v.toList).validateQuery(schema, docUnderTest, Map.empty, None) violations shouldNot be(empty) violations.size shouldBe 1 violationCheck(violations.head) @@ -465,7 +516,7 @@ trait ValidationSupport extends Matchers { v: Option[ValidationRule]): Unit = { val schema = Schema.buildFromAst(initialSchemaDoc) val Success(docUnderTest) = QueryParser.parse(sdlUnderTest) - val violations = validator(v.toList).validateQuery(schema, docUnderTest, None) + val violations = validator(v.toList).validateQuery(schema, docUnderTest, Map.empty, None) violations shouldBe empty } diff --git a/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala b/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala index d64e81ef..7e48b59b 100644 --- a/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala +++ b/modules/core/src/test/scala/sangria/validation/QueryValidatorSpec.scala @@ -42,7 +42,7 @@ class QueryValidatorSpec extends AnyWordSpec { "not limit number of errors returned if the limit is not provided" in { val Success(doc) = QueryParser.parse(invalidQuery) - val result = validator.validateQuery(schema, doc, None) + val result = validator.validateQuery(schema, doc, Map.empty, None) // 10 errors are expected because there are 5 input objects in the list with 2 missing fields each assertResult(10)(result.length) @@ -51,7 +51,7 @@ class QueryValidatorSpec extends AnyWordSpec { val errorsLimit = 5 val Success(doc) = QueryParser.parse(invalidQuery) - val result = validator.validateQuery(schema, doc, Some(errorsLimit)) + val result = validator.validateQuery(schema, doc, Map.empty, Some(errorsLimit)) assertResult(errorsLimit)(result.length) } diff --git a/modules/core/src/test/scala/sangria/validation/rules/ExactlyOneOfFieldGivenSpec.scala b/modules/core/src/test/scala/sangria/validation/rules/ExactlyOneOfFieldGivenSpec.scala index cbff6a1c..7fb506a6 100644 --- a/modules/core/src/test/scala/sangria/validation/rules/ExactlyOneOfFieldGivenSpec.scala +++ b/modules/core/src/test/scala/sangria/validation/rules/ExactlyOneOfFieldGivenSpec.scala @@ -8,7 +8,7 @@ class ExactlyOneOfFieldGivenSpec extends AnyWordSpec with ValidationSupport { override val defaultRule = Some(new ExactlyOneOfFieldGiven) "Validate: exactly oneOf field given" should { - "with exactly one non-null field given" in expectPasses(""" + "pass with exactly one non-null field given" in expectPasses(""" query OneOfQuery { oneOfQuery(input: { catName: "Gretel" @@ -20,7 +20,7 @@ class ExactlyOneOfFieldGivenSpec extends AnyWordSpec with ValidationSupport { } """) - "with exactly one null field given" in expectFails( + "fail with exactly one null field given" in expectFails( """ query OneOfQuery { oneOfQuery(input: { @@ -35,7 +35,7 @@ class ExactlyOneOfFieldGivenSpec extends AnyWordSpec with ValidationSupport { List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 31))) ) - "with no fields given" in expectFails( + "fail with no fields given" in expectFails( """ query OneOfQuery { oneOfQuery(input: {}) { @@ -48,7 +48,7 @@ class ExactlyOneOfFieldGivenSpec extends AnyWordSpec with ValidationSupport { List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 31))) ) - "with more than one non-null fields given" in expectFails( + "fail with more than one non-null fields given" in expectFails( """ query OneOfQuery { oneOfQuery(input: { @@ -66,5 +66,149 @@ class ExactlyOneOfFieldGivenSpec extends AnyWordSpec with ValidationSupport { """, List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 31))) ) + + "pass with a null variable and non-null arg given" in expectPasses( + """ + query OneOfQuery($catName: String) { + oneOfQuery(input: { + catName: $catName, + dogId: 123 + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + "$catName: String" -> """{"catName": null}""" + ) + + "fail with a non-null variable and non-null arg given" in expectFails( + """ + query OneOfQuery($catName: String) { + oneOfQuery(input: { + catName: $catName, + dogId: 123 + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$catName: String" -> """{"catName": "Gretel"}""" + ) + + "pass with a variable object with only one non-null value" in expectPasses( + """ + query OneOfQuery($input: OneOfInput!) { + oneOfQuery(input: $input) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + "$input: OneOfInput!" -> """{"input":{"catName": "Gretel", "dogId": null}}""" + ) + + "fail with a variable object with only null values" in expectFails( + """ + query OneOfQuery($input: OneOfInput!) { + oneOfQuery(input: $input) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$input: OneOfInput!" -> """{"input":{"catName": null}}""" + ) + + "fail with a variable object with more than one non-null values" in expectFails( + """ + query OneOfQuery($input: OneOfInput!) { + oneOfQuery(input: $input) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + "$input: OneOfInput!" -> """{"input":{"catName": "Gretel", "dogId": 123}}""" + ) + + "pass with a variable object with exactly one non-null values" in expectPasses( + """ + query OneOfQuery($input: OneOfInput!) { + oneOfQuery(input: $input) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + "$input: OneOfInput!" -> """{"input":{"dogId": 123}}""" + ) + + "pass with a variables with default value but only one resolved to non-null" in expectPasses( + """ + query OneOfQuery($catName: String = "Gretel", $dogId: Int) { + oneOfQuery(input: { + catName: $catName, + dogId: $dogId + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + """$catName: String = "Gretel", $dogId: Int""" -> """{}""" + ) + + "fail with a variable object with default value that causes there to be more than one non-null value" in expectFails( + """ + query OneOfQuery($catName: String = "Gretel", $dogId: Int) { + oneOfQuery(input: { + catName: $catName, + dogId: $dogId + }) { + ... on Cat { + name + } + ... on Dog { + name + } + } + } + """, + List("Exactly one key must be specified for oneOf type 'OneOfInput'." -> Some(Pos(3, 27))), + """$catName: String = "Gretel", $dogId: Int""" -> """{"dogId":123}""" + ) } }