Skip to content

Commit

Permalink
Simplify null guards (#186)
Browse files Browse the repository at this point in the history
Simplify null guards

- Omit unnecessary `_.constant(NULL)` in null guards:
  - `choose(p.neq(NULL), traversal, constant(NULL))` → `choose(p.neq(NULL), traversal)`
- Use `NodeUtils#notNull` in projection walker
- Skip null checks in gremlin steps that handle `  cypher.null`: `limit`, `tail`, `range`, `identity`
- More tests for nulls

Signed-off-by: Dwitry dwitry@users.noreply.github.com
  • Loading branch information
dwitry authored Oct 4, 2018
1 parent d870ab8 commit 8b5e5f3
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;

import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -298,6 +299,15 @@ public void sizeOfPatternExpressionInWhere() {
assertThat(results)
.extracting("n.name")
.containsExactly("marko", "josh");
}

@Test
public void propertiesOnNode() {
List<Map<String, Object>> results = submitAndGet(
"MATCH (n {name:'marko'}) RETURN properties(n) as r");

assertThat(results)
.extracting("r")
.containsExactly(ImmutableMap.of("age", 29L, "name", "marko"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -503,4 +503,15 @@ public void plusTest() throws Exception {
.containsExactly(result);
}
}

@Test
public void optionalProjection() throws Exception {
String cypher = "OPTIONAL MATCH (n:notExisting) WITH (n) as m RETURN m";

List<Map<String, Object>> cypherResults = submitAndGet(cypher);

assertThat(cypherResults)
.extracting("m")
.containsExactly((Object) null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ public static Function<Traverser, Object> cypherToFloat() {
public static Function<Traverser, Object> cypherProperties() {
return traverser -> {
Object argument = traverser.get();

if (argument == Tokens.NULL) {
return Tokens.NULL;
}

if (argument instanceof Map) {
return argument;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,18 @@ sealed class TranslationWriter[T, P] private (translator: Translator[T, P], para
.getOrElse(g.by(writeLocalSteps(traversal)))
case Cap(sideEffectKey) =>
g.cap(sideEffectKey)
case ChooseC(choiceTraversal) =>
case ChooseT(choiceTraversal, None, None) =>
g.choose(writeLocalSteps(choiceTraversal))
case ChooseT(traversalPredicate, trueChoice, falseChoice) =>
if (trueChoice.nonEmpty && falseChoice.nonEmpty) {
g.choose(writeLocalSteps(traversalPredicate), writeLocalSteps(trueChoice), writeLocalSteps(falseChoice))
}
case ChooseP(predicate, trueChoice, falseChoice) =>
if (trueChoice.nonEmpty && falseChoice.nonEmpty) {
g.choose(writePredicate(predicate), writeLocalSteps(trueChoice), writeLocalSteps(falseChoice))
} else if (trueChoice.nonEmpty) {
g.choose(writePredicate(predicate), writeLocalSteps(trueChoice))
}
case c @ ChooseT(_, None, Some(_)) =>
throw new UnsupportedOperationException(s"Unsupported $c")
case c @ ChooseT(_, Some(_), None) =>
throw new UnsupportedOperationException(s"Unsupported $c")
case ChooseT(traversalPredicate, Some(trueChoice), Some(falseChoice)) =>
g.choose(writeLocalSteps(traversalPredicate), writeLocalSteps(trueChoice), writeLocalSteps(falseChoice))
case ChooseP(predicate, trueChoice, None) =>
g.choose(writePredicate(predicate), writeLocalSteps(trueChoice))
case ChooseP(predicate, trueChoice, Some(falseChoice)) =>
g.choose(writePredicate(predicate), writeLocalSteps(trueChoice), writeLocalSteps(falseChoice))
case Coalesce(coalesceTraversals @ _*) =>
g.coalesce(coalesceTraversals.map(writeLocalSteps): _*)
case Constant(e) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class IRGremlinSteps extends GremlinSteps[Seq[GremlinStep], GremlinPredicate] {

override def choose(choiceTraversal: GremlinSteps[Seq[GremlinStep], GremlinPredicate])
: GremlinSteps[Seq[GremlinStep], GremlinPredicate] = {
buf += ChooseC(choiceTraversal.current())
buf += ChooseT(choiceTraversal.current())
this
}

Expand All @@ -112,7 +112,7 @@ class IRGremlinSteps extends GremlinSteps[Seq[GremlinStep], GremlinPredicate] {
trueChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate],
falseChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate])
: GremlinSteps[Seq[GremlinStep], GremlinPredicate] = {
buf += ChooseT(traversalPredicate.current(), trueChoice.current(), falseChoice.current())
buf += ChooseT(traversalPredicate.current(), Some(trueChoice.current()), Some(falseChoice.current()))
this
}

Expand All @@ -121,13 +121,13 @@ class IRGremlinSteps extends GremlinSteps[Seq[GremlinStep], GremlinPredicate] {
trueChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate],
falseChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate])
: GremlinSteps[Seq[GremlinStep], GremlinPredicate] = {
buf += ChooseP(predicate, trueChoice.current(), falseChoice.current())
buf += ChooseP(predicate, trueChoice.current(), Some(falseChoice.current()))
this
}

override def choose(predicate: GremlinPredicate, trueChoice: GremlinSteps[Seq[GremlinStep], GremlinPredicate])
: GremlinSteps[Seq[GremlinStep], GremlinPredicate] = {
buf += ChooseP(predicate, trueChoice.current(), Nil)
buf += ChooseP(predicate, trueChoice.current())
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case object Barrier extends GremlinStep

case class BothE(edgeLabels: String*) extends GremlinStep

case class By(traversal: Seq[GremlinStep], order: Option[TraversalOrder]) extends GremlinStep {
case class By(traversal: Seq[GremlinStep], order: Option[TraversalOrder] = None) extends GremlinStep {
override def mapTraversals(f: Seq[GremlinStep] => Seq[GremlinStep]): GremlinStep = {
By(f(traversal), order)
}
Expand All @@ -81,38 +81,36 @@ case class By(traversal: Seq[GremlinStep], order: Option[TraversalOrder]) extend

case class Cap(sideEffectKey: String) extends GremlinStep

case class ChooseC(choiceTraversal: Seq[GremlinStep]) extends GremlinStep {
override def mapTraversals(f: Seq[GremlinStep] => Seq[GremlinStep]): GremlinStep = {
ChooseC(f(choiceTraversal))
}

override def foldTraversals[R](z: R)(op: (R, Seq[GremlinStep]) => R): R = {
op(z, choiceTraversal)
}
}

case class ChooseT(
traversalPredicate: Seq[GremlinStep],
trueChoice: Seq[GremlinStep],
falseChoice: Seq[GremlinStep] = Nil)
trueChoice: Option[Seq[GremlinStep]] = None,
falseChoice: Option[Seq[GremlinStep]] = None)
extends GremlinStep {

override def mapTraversals(f: Seq[GremlinStep] => Seq[GremlinStep]): GremlinStep = {
ChooseT(f(traversalPredicate), f(trueChoice), f(falseChoice))
ChooseT(f(traversalPredicate), trueChoice.map(f(_)), falseChoice.map(f(_)))
}

override def foldTraversals[R](z: R)(op: (R, Seq[GremlinStep]) => R): R = {
op(op(op(z, traversalPredicate), trueChoice), falseChoice)
val predicateFold = op(z, traversalPredicate)
val trueChoiceFold = trueChoice.map(op(predicateFold, _)).getOrElse(predicateFold)
falseChoice.map(op(trueChoiceFold, _)).getOrElse(trueChoiceFold)
}
}

case class ChooseP(predicate: GremlinPredicate, trueChoice: Seq[GremlinStep], falseChoice: Seq[GremlinStep] = Nil)
case class ChooseP(
predicate: GremlinPredicate,
trueChoice: Seq[GremlinStep],
falseChoice: Option[Seq[GremlinStep]] = None)
extends GremlinStep {

override def mapTraversals(f: Seq[GremlinStep] => Seq[GremlinStep]): GremlinStep = {
ChooseP(predicate, f(trueChoice), f(falseChoice))
ChooseP(predicate, f(trueChoice), falseChoice.map(f(_)))
}

override def foldTraversals[R](z: R)(op: (R, Seq[GremlinStep]) => R): R = {
op(op(z, trueChoice), falseChoice)
val trueChoiceFold = op(z, trueChoice)
falseChoice.map(op(trueChoiceFold, _)).getOrElse(trueChoiceFold)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ object CustomFunctionFallback extends GremlinRewriter {
Path :: From(text) :: rest

case SelectC(values) :: MapF(function) :: rest if function.getName == cypherPlus().getName =>
SelectC(values) :: Local(Unfold :: ChooseP(Neq(NULL), Sum :: Nil, Constant(NULL) :: Nil) :: Nil) :: rest
SelectC(values) :: Local(Unfold :: ChooseP(Neq(NULL), Sum :: Nil, None) :: Nil) :: rest

case MapF(function) :: rest if function.getName == cypherSize().getName =>
CountS(local) :: rest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ object RemoveUselessNullChecks extends GremlinRewriter {

private def splitSegment(steps: Seq[GremlinStep]): (Seq[GremlinStep], Seq[GremlinStep]) = {
val (segment, rest) = steps.span {
case By(SelectK(_) :: ChooseP(Neq(NULL), _, Constant(NULL) :: Nil) :: Nil, None) => false
case By(ChooseP(Neq(NULL), _, Constant(NULL) :: Nil) :: Nil, None) => false
case ChooseP(Neq(NULL), _, Constant(NULL) :: Nil) => false
case _ => true
case By(SelectK(_) :: ChooseP(Neq(NULL), _, None) :: Nil, None) => false
case By(ChooseP(Neq(NULL), _, None) :: Nil, None) => false
case ChooseP(Neq(NULL), _, None) => false
case _ => true
}
rest match {
case head :: tail => (segment :+ head, tail)
Expand All @@ -62,11 +62,11 @@ object RemoveUselessNullChecks extends GremlinRewriter {
}

val last = steps.last match {
case By(SelectK(key) :: ChooseP(Neq(NULL), t, Constant(NULL) :: Nil) :: Nil, None) =>
By(SelectK(key) +: t, None) :: Nil
case By(ChooseP(Neq(NULL), t, Constant(NULL) :: Nil) :: Nil, None) =>
By(t, None) :: Nil
case ChooseP(Neq(NULL), traversal, Constant(NULL) :: Nil) =>
case By(SelectK(key) :: ChooseP(Neq(NULL), traversal, None) :: Nil, None) =>
By(SelectK(key) +: traversal, None) :: Nil
case By(ChooseP(Neq(NULL), traversal, None) :: Nil, None) =>
By(traversal, None) :: Nil
case ChooseP(Neq(NULL), traversal, None) =>
traversal
case step =>
step :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object SimplifyDelete extends GremlinRewriter {
def apply(steps: Seq[GremlinStep]): Seq[GremlinStep] = {
val withChoose = foldTraversals(false)((acc, localSteps) => {
acc || extract({
case ChooseP(IsNode(), Aggregate(DELETE) :: Nil, Aggregate(DETACH_DELETE) :: Nil) :: _ => true
case ChooseP(IsNode(), Aggregate(DELETE) :: Nil, Some(Aggregate(DETACH_DELETE) :: Nil)) :: _ => true
})(localSteps).contains(true)
})(steps)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ object SimplifyPropertySetters extends GremlinRewriter {
PropertyV(key, value) :: rest
case PropertyTC(cardinality, key, Constant(value) :: Nil) :: rest =>
PropertyVC(cardinality, key, value) :: rest
case ChooseT(_, PropertyV(key, value) :: Nil, drop) :: rest =>
case ChooseT(_, Some(PropertyV(key, value) :: Nil), Some(drop)) :: rest =>
val empty = value match {
case NULL => true
case coll: util.Collection[_] => coll.isEmpty
Expand All @@ -44,7 +44,7 @@ object SimplifyPropertySetters extends GremlinRewriter {
} else {
PropertyV(key, value) :: rest
}
case step @ ChooseT(_, prop @ PropertyT(_, Project(_*) :: valueTail) :: Nil, _) :: rest
case step @ ChooseT(_, Some(prop @ PropertyT(_, Project(_*) :: valueTail) :: Nil), _) :: rest
if valueTail.init.forall(_.isInstanceOf[By]) =>
valueTail.last match {
case _: By | _: SelectC => prop ++ rest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,20 @@ private class ExpressionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
case "abs" => traversals.head.math("abs(_)")
case "coalesce" => __.coalesce(traversals.init.map(_.is(p.neq(NULL))) :+ traversals.last: _*)
case "exists" => traversals.head.flatMap(anyMatch(__.is(p.neq(NULL))))
case "head" => traversals.head.flatMap(notNull(emptyToNull(__.limit(Scope.local, 1), context), context))
case "head" => traversals.head.flatMap(emptyToNull(__.limit(Scope.local, 1), context))
case "id" => traversals.head.flatMap(notNull(__.id(), context))
case "keys" if onEntity => traversals.head.properties().key().fold()
case "keys" => traversals.head.select(Column.keys)
case "labels" => traversals.head.label().is(p.neq(Vertex.DEFAULT_LABEL)).fold()
case "length" => traversals.head.count(Scope.local).math("(_-1)/2")
case "last" => traversals.head.flatMap(notNull(emptyToNull(__.tail(Scope.local, 1), context), context))
case "last" => traversals.head.flatMap(emptyToNull(__.tail(Scope.local, 1), context))
case "nodes" => traversals.head.flatMap(filterElements(args, includeNodes = true))
case "properties" => traversals.head.flatMap(properties(args))
case "range" => range(args)
case "relationships" => traversals.head.flatMap(filterElements(args, includeRelationships = true))
case "size" => traversals.head.flatMap(size(args))
case "sqrt" => traversals.head.math("sqrt(_)")
case "tail" => traversals.head.flatMap(notNull(__.range(Scope.local, 1, -1), context))
case "tail" => traversals.head.flatMap(__.range(Scope.local, 1, -1))
case "type" => traversals.head.flatMap(notNull(__.label().is(p.neq(Vertex.DEFAULT_LABEL)), context))
case "toboolean" => traversals.head.map(CustomFunction.cypherToBoolean())
case "tofloat" => traversals.head.map(CustomFunction.cypherToFloat())
Expand Down Expand Up @@ -440,21 +440,23 @@ private class ExpressionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
}

private def properties(args: Seq[Expression]): GremlinSteps[T, P] = {
lazy val elementT = __.local(
__.properties()
.group()
.by(__.key())
.by(__.map(__.value()))
)
lazy val elementT =
notNull(
__.local(
__.properties()
.group()
.by(__.key())
.by(__.map(__.value()))
),
context)

val typ = typeOf(args.head)
val traversal = typ match {
typ match {
case _: NodeType => elementT
case _: RelationshipType => elementT
case _: MapType => __.identity()
case _ => __.map(CustomFunction.cypherProperties())
}
notNull(traversal, context)
}

private val injectHardLimit = 10000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ object NodeUtils {
def notNull[T, P](traversal: GremlinSteps[T, P], context: WalkerContext[T, P]): GremlinSteps[T, P] = {
val g = context.dsl.steps()
val p = context.dsl.predicates()
g.start().choose(p.neq(NULL), traversal, g.start().constant(NULL))
g.start().choose(p.neq(NULL), traversal)
}

def emptyToNull[T, P](traversal: GremlinSteps[T, P], context: WalkerContext[T, P]): GremlinSteps[T, P] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,6 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
dependencyNames.distinct
}

private def nullIfNull(g: GremlinSteps[T, P], trueChoice: GremlinSteps[T, P]): GremlinSteps[T, P] = {
val p = context.dsl.predicates()
g.choose(p.neq(NULL), trueChoice, g.start().constant(NULL))
}

private def subTraversal(alias: String, expression: Expression): (ReturnFunctionType, GremlinSteps[T, P]) = {
if (expression.containsAggregate) {
aggregation(alias, expression)
Expand Down Expand Up @@ -315,41 +310,34 @@ private class ProjectionWalker[T, P](context: WalkerContext[T, P], g: GremlinSte
.valueMap(true)
.fold())

def notNull(traversal: GremlinSteps[T, P]): GremlinSteps[T, P] = {
NodeUtils.notNull(traversal, context)
}

qualifiedType(expression) match {
case (_: NodeType, _) =>
nullIfNull(
subTraversal,
finalizeNode
)
subTraversal.flatMap(notNull(finalizeNode))
case (_: ListType, _: NodeType) =>
__.flatMap(subTraversal)
.unfold()
.is(p.neq(NULL))
.flatMap(finalizeNode)
.fold()
case (_: RelationshipType, _) =>
nullIfNull(
subTraversal,
finalizeRelationship
)
subTraversal.flatMap(notNull(finalizeRelationship))
case (_: ListType, _: RelationshipType) =>
nullIfNull(
subTraversal,
__.unfold()
.flatMap(finalizeRelationship)
.fold()
)
subTraversal.flatMap(
notNull(
__.unfold()
.flatMap(finalizeRelationship)
.fold()))
case (_: PathType, _) =>
nullIfNull(
subTraversal,
finalizePath
)
subTraversal.flatMap(notNull(finalizePath))
case (_: ListType, _: PathType) =>
nullIfNull(
subTraversal,
__.flatMap(finalizePath)
.fold()
)
subTraversal.flatMap(
notNull(
__.flatMap(finalizePath)
.fold()))
case _ =>
subTraversal
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class CustomFunctionsFallbackTest {
.removes(__.select(Column.values).map(CustomFunction.cypherPlus()))
.adds(
__.select(Column.values)
.local(__.unfold().choose(P.neq(Tokens.NULL), __.sum(), __.start().constant(Tokens.NULL))))
.local(__.unfold().choose(P.neq(Tokens.NULL), __.sum())))
}

@Test
Expand Down
Loading

0 comments on commit 8b5e5f3

Please sign in to comment.