diff --git a/testware/integration-tests/src/test/java/org/opencypher/gremlin/queries/FunctionTest.java b/testware/integration-tests/src/test/java/org/opencypher/gremlin/queries/FunctionTest.java index ac8df281..5ed7914d 100644 --- a/testware/integration-tests/src/test/java/org/opencypher/gremlin/queries/FunctionTest.java +++ b/testware/integration-tests/src/test/java/org/opencypher/gremlin/queries/FunctionTest.java @@ -18,6 +18,7 @@ import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; +import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -298,6 +299,15 @@ public void sizeOfPatternExpressionInWhere() { assertThat(results) .extracting("n.name") .containsExactly("marko", "josh"); + } + + @Test + public void propertiesOnNode() { + List> results = submitAndGet( + "MATCH (n {name:'marko'}) RETURN properties(n) as r"); + assertThat(results) + .extracting("r") + .containsExactly(ImmutableMap.of("age", 29L, "name", "marko")); } } diff --git a/testware/integration-tests/src/test/java/org/opencypher/gremlin/queries/ReturnTest.java b/testware/integration-tests/src/test/java/org/opencypher/gremlin/queries/ReturnTest.java index 2267cf0b..1cac16c8 100644 --- a/testware/integration-tests/src/test/java/org/opencypher/gremlin/queries/ReturnTest.java +++ b/testware/integration-tests/src/test/java/org/opencypher/gremlin/queries/ReturnTest.java @@ -503,4 +503,15 @@ public void plusTest() throws Exception { .containsExactly(result); } } + + @Test + public void optionalProjection() throws Exception { + String cypher = "OPTIONAL MATCH (n:notExisting) WITH (n) as m RETURN m"; + + List> cypherResults = submitAndGet(cypher); + + assertThat(cypherResults) + .extracting("m") + .containsExactly((Object) null); + } } diff --git a/tinkerpop/cypher-gremlin-extensions/src/main/java/org/opencypher/gremlin/traversal/CustomFunctions.java b/tinkerpop/cypher-gremlin-extensions/src/main/java/org/opencypher/gremlin/traversal/CustomFunctions.java index 17225a8c..1f329e17 100644 --- a/tinkerpop/cypher-gremlin-extensions/src/main/java/org/opencypher/gremlin/traversal/CustomFunctions.java +++ b/tinkerpop/cypher-gremlin-extensions/src/main/java/org/opencypher/gremlin/traversal/CustomFunctions.java @@ -142,6 +142,11 @@ public static Function cypherToFloat() { public static Function cypherProperties() { return traverser -> { Object argument = traverser.get(); + + if (argument == Tokens.NULL) { + return Tokens.NULL; + } + if (argument instanceof Map) { return argument; } diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/TranslationWriter.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/TranslationWriter.scala index 2b23aa1f..9415cec9 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/TranslationWriter.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/TranslationWriter.scala @@ -96,18 +96,18 @@ sealed class TranslationWriter[T, P] private (translator: Translator[T, P], para .getOrElse(g.by(writeLocalSteps(traversal))) case Cap(sideEffectKey) => g.cap(sideEffectKey) - case ChooseC(choiceTraversal) => + case ChooseT(choiceTraversal, None, None) => g.choose(writeLocalSteps(choiceTraversal)) - case ChooseT(traversalPredicate, trueChoice, falseChoice) => - if (trueChoice.nonEmpty && falseChoice.nonEmpty) { - g.choose(writeLocalSteps(traversalPredicate), writeLocalSteps(trueChoice), writeLocalSteps(falseChoice)) - } - case ChooseP(predicate, trueChoice, falseChoice) => - if (trueChoice.nonEmpty && falseChoice.nonEmpty) { - g.choose(writePredicate(predicate), writeLocalSteps(trueChoice), writeLocalSteps(falseChoice)) - } else if (trueChoice.nonEmpty) { - g.choose(writePredicate(predicate), writeLocalSteps(trueChoice)) - } + case c @ ChooseT(_, None, Some(_)) => + throw new UnsupportedOperationException(s"Unsupported $c") + case c @ ChooseT(_, Some(_), None) => + throw new UnsupportedOperationException(s"Unsupported $c") + case ChooseT(traversalPredicate, Some(trueChoice), Some(falseChoice)) => + g.choose(writeLocalSteps(traversalPredicate), writeLocalSteps(trueChoice), writeLocalSteps(falseChoice)) + case ChooseP(predicate, trueChoice, None) => + g.choose(writePredicate(predicate), writeLocalSteps(trueChoice)) + case ChooseP(predicate, trueChoice, Some(falseChoice)) => + g.choose(writePredicate(predicate), writeLocalSteps(trueChoice), writeLocalSteps(falseChoice)) case Coalesce(coalesceTraversals @ _*) => g.coalesce(coalesceTraversals.map(writeLocalSteps): _*) case Constant(e) => diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/builder/IRGremlinSteps.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/builder/IRGremlinSteps.scala index 17fe3037..d19a09a3 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/builder/IRGremlinSteps.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/builder/IRGremlinSteps.scala @@ -103,7 +103,7 @@ class IRGremlinSteps extends GremlinSteps[Seq[GremlinStep], GremlinPredicate] { override def choose(choiceTraversal: GremlinSteps[Seq[GremlinStep], GremlinPredicate]) : GremlinSteps[Seq[GremlinStep], GremlinPredicate] = { - buf += ChooseC(choiceTraversal.current()) + buf += ChooseT(choiceTraversal.current()) this } @@ -112,7 +112,7 @@ class IRGremlinSteps extends GremlinSteps[Seq[GremlinStep], GremlinPredicate] { trueChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate], falseChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate]) : GremlinSteps[Seq[GremlinStep], GremlinPredicate] = { - buf += ChooseT(traversalPredicate.current(), trueChoice.current(), falseChoice.current()) + buf += ChooseT(traversalPredicate.current(), Some(trueChoice.current()), Some(falseChoice.current())) this } @@ -121,13 +121,13 @@ class IRGremlinSteps extends GremlinSteps[Seq[GremlinStep], GremlinPredicate] { trueChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate], falseChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate]) : GremlinSteps[Seq[GremlinStep], GremlinPredicate] = { - buf += ChooseP(predicate, trueChoice.current(), falseChoice.current()) + buf += ChooseP(predicate, trueChoice.current(), Some(falseChoice.current())) this } override def choose(predicate: GremlinPredicate, trueChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate]) : GremlinSteps[Seq[GremlinStep], GremlinPredicate] = { - buf += ChooseP(predicate, trueChoice.current(), Nil) + buf += ChooseP(predicate, trueChoice.current()) this } diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/model/GremlinStep.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/model/GremlinStep.scala index fc0dd1a7..2deb6484 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/model/GremlinStep.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/model/GremlinStep.scala @@ -69,7 +69,7 @@ case object Barrier extends GremlinStep case class BothE(edgeLabels: String*) extends GremlinStep -case class By(traversal: Seq[GremlinStep], order: Option[TraversalOrder]) extends GremlinStep { +case class By(traversal: Seq[GremlinStep], order: Option[TraversalOrder] = None) extends GremlinStep { override def mapTraversals(f: Seq[GremlinStep] => Seq[GremlinStep]): GremlinStep = { By(f(traversal), order) } @@ -81,38 +81,36 @@ case class By(traversal: Seq[GremlinStep], order: Option[TraversalOrder]) extend case class Cap(sideEffectKey: String) extends GremlinStep -case class ChooseC(choiceTraversal: Seq[GremlinStep]) extends GremlinStep { - override def mapTraversals(f: Seq[GremlinStep] => Seq[GremlinStep]): GremlinStep = { - ChooseC(f(choiceTraversal)) - } - - override def foldTraversals[R](z: R)(op: (R, Seq[GremlinStep]) => R): R = { - op(z, choiceTraversal) - } -} - case class ChooseT( traversalPredicate: Seq[GremlinStep], - trueChoice: Seq[GremlinStep], - falseChoice: Seq[GremlinStep] = Nil) + trueChoice: Option[Seq[GremlinStep]] = None, + falseChoice: Option[Seq[GremlinStep]] = None) extends GremlinStep { + override def mapTraversals(f: Seq[GremlinStep] => Seq[GremlinStep]): GremlinStep = { - ChooseT(f(traversalPredicate), f(trueChoice), f(falseChoice)) + ChooseT(f(traversalPredicate), trueChoice.map(f(_)), falseChoice.map(f(_))) } override def foldTraversals[R](z: R)(op: (R, Seq[GremlinStep]) => R): R = { - op(op(op(z, traversalPredicate), trueChoice), falseChoice) + val predicateFold = op(z, traversalPredicate) + val trueChoiceFold = trueChoice.map(op(predicateFold, _)).getOrElse(predicateFold) + falseChoice.map(op(trueChoiceFold, _)).getOrElse(trueChoiceFold) } } -case class ChooseP(predicate: GremlinPredicate, trueChoice: Seq[GremlinStep], falseChoice: Seq[GremlinStep] = Nil) +case class ChooseP( + predicate: GremlinPredicate, + trueChoice: Seq[GremlinStep], + falseChoice: Option[Seq[GremlinStep]] = None) extends GremlinStep { + override def mapTraversals(f: Seq[GremlinStep] => Seq[GremlinStep]): GremlinStep = { - ChooseP(predicate, f(trueChoice), f(falseChoice)) + ChooseP(predicate, f(trueChoice), falseChoice.map(f(_))) } override def foldTraversals[R](z: R)(op: (R, Seq[GremlinStep]) => R): R = { - op(op(z, trueChoice), falseChoice) + val trueChoiceFold = op(z, trueChoice) + falseChoice.map(op(trueChoiceFold, _)).getOrElse(trueChoiceFold) } } diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/CustomFunctionFallback.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/CustomFunctionFallback.scala index a644aed2..d6a177b6 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/CustomFunctionFallback.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/CustomFunctionFallback.scala @@ -36,7 +36,7 @@ object CustomFunctionFallback extends GremlinRewriter { Path :: From(text) :: rest case SelectC(values) :: MapF(function) :: rest if function.getName == cypherPlus().getName => - SelectC(values) :: Local(Unfold :: ChooseP(Neq(NULL), Sum :: Nil, Constant(NULL) :: Nil) :: Nil) :: rest + SelectC(values) :: Local(Unfold :: ChooseP(Neq(NULL), Sum :: Nil, None) :: Nil) :: rest case MapF(function) :: rest if function.getName == cypherSize().getName => CountS(local) :: rest diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/RemoveUselessNullChecks.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/RemoveUselessNullChecks.scala index 3f7779d5..d88e19d5 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/RemoveUselessNullChecks.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/RemoveUselessNullChecks.scala @@ -37,10 +37,10 @@ object RemoveUselessNullChecks extends GremlinRewriter { private def splitSegment(steps: Seq[GremlinStep]): (Seq[GremlinStep], Seq[GremlinStep]) = { val (segment, rest) = steps.span { - case By(SelectK(_) :: ChooseP(Neq(NULL), _, Constant(NULL) :: Nil) :: Nil, None) => false - case By(ChooseP(Neq(NULL), _, Constant(NULL) :: Nil) :: Nil, None) => false - case ChooseP(Neq(NULL), _, Constant(NULL) :: Nil) => false - case _ => true + case By(SelectK(_) :: ChooseP(Neq(NULL), _, None) :: Nil, None) => false + case By(ChooseP(Neq(NULL), _, None) :: Nil, None) => false + case ChooseP(Neq(NULL), _, None) => false + case _ => true } rest match { case head :: tail => (segment :+ head, tail) @@ -62,11 +62,11 @@ object RemoveUselessNullChecks extends GremlinRewriter { } val last = steps.last match { - case By(SelectK(key) :: ChooseP(Neq(NULL), t, Constant(NULL) :: Nil) :: Nil, None) => - By(SelectK(key) +: t, None) :: Nil - case By(ChooseP(Neq(NULL), t, Constant(NULL) :: Nil) :: Nil, None) => - By(t, None) :: Nil - case ChooseP(Neq(NULL), traversal, Constant(NULL) :: Nil) => + case By(SelectK(key) :: ChooseP(Neq(NULL), traversal, None) :: Nil, None) => + By(SelectK(key) +: traversal, None) :: Nil + case By(ChooseP(Neq(NULL), traversal, None) :: Nil, None) => + By(traversal, None) :: Nil + case ChooseP(Neq(NULL), traversal, None) => traversal case step => step :: Nil diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/SimplifyDelete.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/SimplifyDelete.scala index 1f56447e..3e54595c 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/SimplifyDelete.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/SimplifyDelete.scala @@ -26,7 +26,7 @@ object SimplifyDelete extends GremlinRewriter { def apply(steps: Seq[GremlinStep]): Seq[GremlinStep] = { val withChoose = foldTraversals(false)((acc, localSteps) => { acc || extract({ - case ChooseP(IsNode(), Aggregate(DELETE) :: Nil, Aggregate(DETACH_DELETE) :: Nil) :: _ => true + case ChooseP(IsNode(), Aggregate(DELETE) :: Nil, Some(Aggregate(DETACH_DELETE) :: Nil)) :: _ => true })(localSteps).contains(true) })(steps) diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/SimplifyPropertySetters.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/SimplifyPropertySetters.scala index 69251c1b..a4aa9efc 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/SimplifyPropertySetters.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/SimplifyPropertySetters.scala @@ -33,7 +33,7 @@ object SimplifyPropertySetters extends GremlinRewriter { PropertyV(key, value) :: rest case PropertyTC(cardinality, key, Constant(value) :: Nil) :: rest => PropertyVC(cardinality, key, value) :: rest - case ChooseT(_, PropertyV(key, value) :: Nil, drop) :: rest => + case ChooseT(_, Some(PropertyV(key, value) :: Nil), Some(drop)) :: rest => val empty = value match { case NULL => true case coll: util.Collection[_] => coll.isEmpty @@ -44,7 +44,7 @@ object SimplifyPropertySetters extends GremlinRewriter { } else { PropertyV(key, value) :: rest } - case step @ ChooseT(_, prop @ PropertyT(_, Project(_*) :: valueTail) :: Nil, _) :: rest + case step @ ChooseT(_, Some(prop @ PropertyT(_, Project(_*) :: valueTail) :: Nil), _) :: rest if valueTail.init.forall(_.isInstanceOf[By]) => valueTail.last match { case _: By | _: SelectC => prop ++ rest diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/walker/ExpressionWalker.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/walker/ExpressionWalker.scala index 412feb6c..548bb6c6 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/walker/ExpressionWalker.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/walker/ExpressionWalker.scala @@ -219,20 +219,20 @@ private class ExpressionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte case "abs" => traversals.head.math("abs(_)") case "coalesce" => __.coalesce(traversals.init.map(_.is(p.neq(NULL))) :+ traversals.last: _*) case "exists" => traversals.head.flatMap(anyMatch(__.is(p.neq(NULL)))) - case "head" => traversals.head.flatMap(notNull(emptyToNull(__.limit(Scope.local, 1), context), context)) + case "head" => traversals.head.flatMap(emptyToNull(__.limit(Scope.local, 1), context)) case "id" => traversals.head.flatMap(notNull(__.id(), context)) case "keys" if onEntity => traversals.head.properties().key().fold() case "keys" => traversals.head.select(Column.keys) case "labels" => traversals.head.label().is(p.neq(Vertex.DEFAULT_LABEL)).fold() case "length" => traversals.head.count(Scope.local).math("(_-1)/2") - case "last" => traversals.head.flatMap(notNull(emptyToNull(__.tail(Scope.local, 1), context), context)) + case "last" => traversals.head.flatMap(emptyToNull(__.tail(Scope.local, 1), context)) case "nodes" => traversals.head.flatMap(filterElements(args, includeNodes = true)) case "properties" => traversals.head.flatMap(properties(args)) case "range" => range(args) case "relationships" => traversals.head.flatMap(filterElements(args, includeRelationships = true)) case "size" => traversals.head.flatMap(size(args)) case "sqrt" => traversals.head.math("sqrt(_)") - case "tail" => traversals.head.flatMap(notNull(__.range(Scope.local, 1, -1), context)) + case "tail" => traversals.head.flatMap(__.range(Scope.local, 1, -1)) case "type" => traversals.head.flatMap(notNull(__.label().is(p.neq(Vertex.DEFAULT_LABEL)), context)) case "toboolean" => traversals.head.map(CustomFunction.cypherToBoolean()) case "tofloat" => traversals.head.map(CustomFunction.cypherToFloat()) @@ -440,21 +440,23 @@ private class ExpressionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte } private def properties(args: Seq[Expression]): GremlinSteps[T, P] = { - lazy val elementT = __.local( - __.properties() - .group() - .by(__.key()) - .by(__.map(__.value())) - ) + lazy val elementT = + notNull( + __.local( + __.properties() + .group() + .by(__.key()) + .by(__.map(__.value())) + ), + context) val typ = typeOf(args.head) - val traversal = typ match { + typ match { case _: NodeType => elementT case _: RelationshipType => elementT case _: MapType => __.identity() case _ => __.map(CustomFunction.cypherProperties()) } - notNull(traversal, context) } private val injectHardLimit = 10000 diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/walker/NodeUtils.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/walker/NodeUtils.scala index e566faa6..003824c2 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/walker/NodeUtils.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/walker/NodeUtils.scala @@ -125,7 +125,7 @@ object NodeUtils { def notNull[T, P](traversal: GremlinSteps[T, P], context: WalkerContext[T, P]): GremlinSteps[T, P] = { val g = context.dsl.steps() val p = context.dsl.predicates() - g.start().choose(p.neq(NULL), traversal, g.start().constant(NULL)) + g.start().choose(p.neq(NULL), traversal) } def emptyToNull[T, P](traversal: GremlinSteps[T, P], context: WalkerContext[T, P]): GremlinSteps[T, P] = { diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/walker/ProjectionWalker.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/walker/ProjectionWalker.scala index e0c4c6f4..8b525af1 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/walker/ProjectionWalker.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/walker/ProjectionWalker.scala @@ -269,11 +269,6 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte dependencyNames.distinct } - private def nullIfNull(g: GremlinSteps[T, P], trueChoice: GremlinSteps[T, P]): GremlinSteps[T, P] = { - val p = context.dsl.predicates() - g.choose(p.neq(NULL), trueChoice, g.start().constant(NULL)) - } - private def subTraversal(alias: String, expression: Expression): (ReturnFunctionType, GremlinSteps[T, P]) = { if (expression.containsAggregate) { aggregation(alias, expression) @@ -315,12 +310,13 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte .valueMap(true) .fold()) + def notNull(traversal: GremlinSteps[T, P]): GremlinSteps[T, P] = { + NodeUtils.notNull(traversal, context) + } + qualifiedType(expression) match { case (_: NodeType, _) => - nullIfNull( - subTraversal, - finalizeNode - ) + subTraversal.flatMap(notNull(finalizeNode)) case (_: ListType, _: NodeType) => __.flatMap(subTraversal) .unfold() @@ -328,28 +324,20 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte .flatMap(finalizeNode) .fold() case (_: RelationshipType, _) => - nullIfNull( - subTraversal, - finalizeRelationship - ) + subTraversal.flatMap(notNull(finalizeRelationship)) case (_: ListType, _: RelationshipType) => - nullIfNull( - subTraversal, - __.unfold() - .flatMap(finalizeRelationship) - .fold() - ) + subTraversal.flatMap( + notNull( + __.unfold() + .flatMap(finalizeRelationship) + .fold())) case (_: PathType, _) => - nullIfNull( - subTraversal, - finalizePath - ) + subTraversal.flatMap(notNull(finalizePath)) case (_: ListType, _: PathType) => - nullIfNull( - subTraversal, - __.flatMap(finalizePath) - .fold() - ) + subTraversal.flatMap( + notNull( + __.flatMap(finalizePath) + .fold())) case _ => subTraversal } diff --git a/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/CustomFunctionsFallbackTest.scala b/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/CustomFunctionsFallbackTest.scala index dee2b718..b058b531 100644 --- a/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/CustomFunctionsFallbackTest.scala +++ b/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/CustomFunctionsFallbackTest.scala @@ -59,7 +59,7 @@ class CustomFunctionsFallbackTest { .removes(__.select(Column.values).map(CustomFunction.cypherPlus())) .adds( __.select(Column.values) - .local(__.unfold().choose(P.neq(Tokens.NULL), __.sum(), __.start().constant(Tokens.NULL)))) + .local(__.unfold().choose(P.neq(Tokens.NULL), __.sum()))) } @Test diff --git a/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/RemoveUselessNullChecksTest.scala b/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/RemoveUselessNullChecksTest.scala index fb58162c..53ff3ddb 100644 --- a/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/RemoveUselessNullChecksTest.scala +++ b/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/RemoveUselessNullChecksTest.scala @@ -41,7 +41,7 @@ class RemoveUselessNullChecksTest { """.stripMargin)) .withFlavor(flavor) .rewritingWith(RemoveUselessNullChecks) - .removes(__.by(__.choose(P.neq(NULL), __.valueMap(true), __.constant(NULL)))) + .removes(__.by(__.choose(P.neq(NULL), __.valueMap(true)))) .adds(__.by(__.valueMap(true))) } @@ -53,9 +53,108 @@ class RemoveUselessNullChecksTest { """.stripMargin)) .withFlavor(flavor) .rewritingWith(RemoveUselessNullChecks) - .removes(__.by(__.select("n").choose(P.neq(NULL), __.valueMap(true), __.constant(NULL)))) - .removes(__.by(__.select("m").choose(P.neq(NULL), __.valueMap(true), __.constant(NULL)))) + .removes(__.by(__.select("n").choose(P.neq(NULL), __.valueMap(true)))) + .removes(__.by(__.select("m").choose(P.neq(NULL), __.valueMap(true)))) .adds(__.by(__.select("n").valueMap(true))) .adds(__.by(__.select("m").valueMap(true))) } + + @Test + def singleOptionalProjection(): Unit = { + assertThat(parse(""" + |OPTIONAL MATCH (n) + |RETURN n + """.stripMargin)) + .withFlavor(flavor) + .rewritingWith(RemoveUselessNullChecks) + .keeps(__.choose(P.neq(NULL), __.valueMap(true))) + } + + @Test + def functionInvocation(): Unit = { + assertThat(parse(""" + MATCH (n:notExising) WITH n AS n RETURN head(collect(n)) AS head + """.stripMargin)) + .withFlavor(flavor) + .rewritingWith(RemoveUselessNullChecks) + .keeps(__.choose(P.neq(NULL), __.valueMap(true))) + } + + @Test + def multipleOptionalProjections(): Unit = { + assertThat(parse(""" + |OPTIONAL MATCH (n)-->(m) + |RETURN n, m + """.stripMargin)) + .withFlavor(flavor) + .keeps(__.select("n").choose(P.neq(NULL), __.valueMap(true))) + .keeps(__.select("m").choose(P.neq(NULL), __.valueMap(true))) + } + + @Test + def optionalWithProjection(): Unit = { + assertThat(parse(""" + |OPTIONAL MATCH (n:notExisting) WITH (n) as m RETURN m + """.stripMargin)) + .withFlavor(flavor) + .contains(__.choose(P.neq(NULL), __.valueMap(true))) + } + + @Test + def create(): Unit = { + assertThat(parse(""" + |CREATE (n)-[r:knows]->(m) RETURN n, r, m + """.stripMargin)) + .withFlavor(flavor) + .rewritingWith(RemoveUselessNullChecks) + .removes(__.by(__.select("n").choose(P.neq(NULL), __.valueMap(true)))) + .removes( + __.by( + __.select("r") + .choose( + P.neq(" cypher.null"), + __.project(" cypher.element", " cypher.inv", " cypher.outv") + .by(__.valueMap(true)) + .by(__.inV().id()) + .by(__.outV().id())))) + .removes(__.by(__.select("m").choose(P.neq(NULL), __.valueMap(true)))) + .adds(__.by(__.select("n").valueMap(true))) + .adds( + __.by( + __.select("r") + .project(" cypher.element", " cypher.inv", " cypher.outv") + .by(__.valueMap(true)) + .by(__.inV().id()) + .by(__.outV().id()))) + .adds(__.by(__.select("m").valueMap(true))) + } + + @Test + def merge(): Unit = { + assertThat(parse(""" + |MERGE (n)-[r:knows]->(m) RETURN n, r, m + """.stripMargin)) + .withFlavor(flavor) + .rewritingWith(RemoveUselessNullChecks) + .removes(__.by(__.select("n").choose(P.neq(NULL), __.valueMap(true)))) + .removes( + __.by( + __.select("r") + .choose( + P.neq(" cypher.null"), + __.project(" cypher.element", " cypher.inv", " cypher.outv") + .by(__.valueMap(true)) + .by(__.inV().id()) + .by(__.outV().id())))) + .removes(__.by(__.select("m").choose(P.neq(NULL), __.valueMap(true)))) + .adds(__.by(__.select("n").valueMap(true))) + .adds( + __.by( + __.select("r") + .project(" cypher.element", " cypher.inv", " cypher.outv") + .by(__.valueMap(true)) + .by(__.inV().id()) + .by(__.outV().id()))) + .adds(__.by(__.select("m").valueMap(true))) + } }