From 8d0ce5e4b15a8817a1b8fb8259361b048bf23c13 Mon Sep 17 00:00:00 2001 From: ryanstull Date: Thu, 28 Mar 2019 00:48:51 -0400 Subject: [PATCH] Support for custom default in absent case --- README.md | 18 ++- .../com/ryanstull/nullsafe/package.scala | 114 +++++++++++++++++- .../scala/com/ryanstull/nullsafe/Tests.scala | 60 ++++++++- 3 files changed, 188 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0de3e73..767e3a3 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Key: ✔️ = Good, ⚠️ = Problematic, ⛔ = Bad Add the dependency: -`libraryDependencies += "com.ryanstull" %% "scalanullsafe" % "1.0.3"` +`libraryDependencies += "com.ryanstull" %% "scalanullsafe" % "1.1.0"` Example use: @@ -113,6 +113,17 @@ All of the above work for method invocation as well as property access, and the `?(someObj.methodA().field1.twoArgMethod("test",1).otherField)` will be translated properly. + +For the `?` macro, you can also provide a custom default instead of null, by passing it in as the second +parameter. For example + +``` +case class Person(name: String) + +val person: Person = null + +assert(?(person.name,"") == "") +``` ## Performance @@ -148,4 +159,7 @@ equivalent performance. * Using the `?` macro on an expression whose type is `<: AnyVal`, will result in returning the corresponding java wrapper type. For example `?(a.getInt)` will return `java.lang.Integer` instead of `Int` because the return type for this macro must -be nullable. The conversions are the default ones defined in `scala.Predef` \ No newline at end of file +be nullable. The conversions are the default ones defined in `scala.Predef` + +* If you're having trouble with resolving the correct method when using the `?` macro with a default arg, try explicitly +specifying the type of the default \ No newline at end of file diff --git a/src/main/scala/com/ryanstull/nullsafe/package.scala b/src/main/scala/com/ryanstull/nullsafe/package.scala index 0928401..6a3de77 100644 --- a/src/main/scala/com/ryanstull/nullsafe/package.scala +++ b/src/main/scala/com/ryanstull/nullsafe/package.scala @@ -21,6 +21,17 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Byte): java.lang.Byte = macro qMarkImpl[java.lang.Byte] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Byte, default: java.lang.Byte): java.lang.Byte = macro qMarkImplDefault[java.lang.Byte] + /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` * and adds explicit null-checks to avoid that. @@ -30,6 +41,17 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Short): java.lang.Short = macro qMarkImpl[java.lang.Short] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Short, default: java.lang.Short): java.lang.Short = macro qMarkImplDefault[java.lang.Short] + /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` * and adds explicit null-checks to avoid that. @@ -39,6 +61,17 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Char): java.lang.Character = macro qMarkImpl[java.lang.Character] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Char, default: java.lang.Character): java.lang.Character = macro qMarkImplDefault[java.lang.Character] + /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` * and adds explicit null-checks to avoid that. @@ -48,6 +81,17 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Int): java.lang.Integer = macro qMarkImpl[java.lang.Integer] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Int, default: java.lang.Integer): java.lang.Integer = macro qMarkImplDefault[java.lang.Integer] + /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` * and adds explicit null-checks to avoid that. @@ -57,6 +101,17 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Long): java.lang.Long = macro qMarkImpl[java.lang.Long] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Long, default: java.lang.Long): java.lang.Long = macro qMarkImplDefault[java.lang.Long] + /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` * and adds explicit null-checks to avoid that. @@ -66,6 +121,16 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Float): java.lang.Float = macro qMarkImpl[java.lang.Float] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Float, default: java.lang.Float): java.lang.Float = macro qMarkImplDefault[java.lang.Float] /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` * and adds explicit null-checks to avoid that. @@ -75,6 +140,17 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Double): java.lang.Double = macro qMarkImpl[java.lang.Double] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Double, default: java.lang.Double): java.lang.Double = macro qMarkImplDefault[java.lang.Double] + /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` * and adds explicit null-checks to avoid that. @@ -84,7 +160,16 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Boolean): java.lang.Boolean = macro qMarkImpl[java.lang.Boolean] - + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Boolean, default: java.lang.Boolean): java.lang.Boolean = macro qMarkImplDefault[java.lang.Boolean] /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` * and adds explicit null-checks to avoid that. @@ -95,6 +180,16 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?[A <: AnyRef](expr: A): A = macro qMarkImpl[A] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?[A <: AnyRef](expr: A, default: A): A = macro qMarkImplDefault[A] /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` @@ -105,6 +200,16 @@ package object nullsafe { * access on `null`, returns null instead. */ def ?(expr: Unit): Unit = macro qMarkUnitImpl[Unit] + /** + * Translates an expression that could cause a NullPointerException due to method/field access on `null` + * and adds explicit null-checks to avoid that. + * + * @param expr Some expression that might cause a NullPointerExpression due to method/field access on `null` + * @param default Custom default value to return in the case where NPE would have occurred. + * @return The value of the expression. If there would have been a NullPointerException due to method/field + * access on `null`, returns `default` instead. + */ + def ?(expr: Unit, default: Unit): Unit = macro qMarkImplDefault[Unit] /** * Translates an expression that could cause a NullPointerException due to method/field access on `null` @@ -130,6 +235,7 @@ package object nullsafe { //Putting the implementations in an object to avoid namespace pollution. private[this] object MacroImplementations { + def qMarkImpl[A : c.WeakTypeTag](c: blackbox.Context)(expr: c.Expr[A]): c.Expr[A] = { import c.universe._ @@ -138,6 +244,12 @@ package object nullsafe { c.Expr(result) } + def qMarkImplDefault[A : c.WeakTypeTag](c: blackbox.Context)(expr: c.Expr[A], default: c.Expr[A]): c.Expr[A] = { + val tree = expr.tree + val result = rewriteToNullSafe(c)(tree)(default.tree,a => a) + c.Expr(result) + } + def qMarkUnitImpl[A : c.WeakTypeTag](c: blackbox.Context)(expr: c.Expr[A]): c.Expr[A] = { import c.universe._ diff --git a/src/test/scala/com/ryanstull/nullsafe/Tests.scala b/src/test/scala/com/ryanstull/nullsafe/Tests.scala index ebe516f..7035833 100644 --- a/src/test/scala/com/ryanstull/nullsafe/Tests.scala +++ b/src/test/scala/com/ryanstull/nullsafe/Tests.scala @@ -173,11 +173,69 @@ class Tests extends FlatSpec { assert(!notNull(a.b.c.d.e)) } + + "Using custom default" should "work for absent case" in { + val a = A(B(null)) + + assert(?(a.b.c.d.e.s,"") == "") + } + + "Using custom default" should "work for present case" in { + val a = A(B(C(D(E("Hello"))))) + + assert(?(a.b.c.d.e.s,"") == "Hello") + } + + "Using custom default" should "work for absent case with Int" in { + val a = A(B(null)) + + assert(?(a.b.c.d.getInt,new Integer(3)) == 3) + } + + "Using custom default" should "work for present case with Int" in { + val a = A(B(C(D(E("Hello"))))) + + assert(?(a.b.c.d.getInt,new Integer(3)) == 0) + } + + "Using custom default" should "work for absent case with Boolean" in { + val a = A(B(null)) + + assert(?(a.b.c.d.e.getBool,true) == true) + } + + "Using custom default" should "work for present case with Boolean" in { + val a = A(B(C(D(E("Hello"))))) + + assert(?(a.b.c.d.e.getBool,true) == false) + } + + "Using custom default" should "work for absent case with Double" in { + val a = A(B(null)) + + assert(?(a.b.c.d.e.getDouble,3.0) == 3.0) + } + + "Using custom default" should "work for present case with Double" in { + val a = A(B(C(D(E("Hello"))))) + + assert(?(a.b.c.d.e.getDouble,3.0) == 0.0) + } + + "Using custom default" should "work for absent case with Unit" in { + val a = A(B(null)) + + ?(a.b.c.d.e.s.notify(),println("Absent")) + } + } //Example of deeply nested domain object object Tests { - case class E(s: String) + case class E(s: String){ + def getBool: Boolean = false + def getDouble: Double = 0.0 + } case class D(e: E){ def getE: E = E(null) def getInt: Int = 0