From 5fae974ad788a8a8decfcf7562025205b48fb721 Mon Sep 17 00:00:00 2001 From: Alex Archambault Date: Fri, 17 Jan 2025 14:15:14 +0100 Subject: [PATCH] Don't add canEqual and equals if users already defined them --- src/main/scala/dataclass/Macros.scala | 75 +++++++++++-------- .../scala/dataclass/MoreFieldsTests.scala | 73 ++++++++++++++++++ 2 files changed, 117 insertions(+), 31 deletions(-) diff --git a/src/main/scala/dataclass/Macros.scala b/src/main/scala/dataclass/Macros.scala index fd40469..9cbc4c6 100644 --- a/src/main/scala/dataclass/Macros.scala +++ b/src/main/scala/dataclass/Macros.scala @@ -33,21 +33,28 @@ private[dataclass] class Macros(val c: Context) extends ImplTransformers { tq"${t.name}" } - def hasMethod(methodName: String): Boolean = { + def hasMethod( + methodName: String, + hasArgs: Boolean = false + ): Boolean = { def fromStats = stats.exists { case DefDef(_, nme, tparams, vparamss, _, _) - if nme.decodedName.toString == methodName && tparams.isEmpty && vparamss - .forall(_.isEmpty) => + if nme.decodedName.toString == methodName && + tparams.isEmpty && + ( + hasArgs && vparamss.map(_.length).sum == 1 || + vparamss.forall(_.isEmpty) + ) => true case t @ ValDef(_, name, _, _) - if name.decodedName.toString == methodName => + if !hasArgs && name.decodedName.toString == methodName => true case _ => false } - val fromFields = + val fromFields = !hasArgs && allParams.exists(_.name.decodedName.toString() == methodName) fromFields || fromStats @@ -56,6 +63,8 @@ private[dataclass] class Macros(val c: Context) extends ImplTransformers { val hasToString = hasMethod("toString") val hasHashCode = hasMethod("hashCode") val hasTuple = hasMethod("tuple") + val hasCanEqual = hasMethod("canEqual", hasArgs = true) + val hasEquals = hasMethod("equals", hasArgs = true) val namedArgs = paramss.map(_.map { p => q"${p.name}=this.${p.name}" @@ -122,33 +131,37 @@ private[dataclass] class Macros(val c: Context) extends ImplTransformers { tq"({type L[..$tparams0]=$WildcardType})#L" } - val canEqualMethod = { - val hashCheck = - if (cachedHashCode) q"obj.hashCode == hashCode" else q"true" - Seq( - q""" - override def canEqual(obj: Any): _root_.scala.Boolean = - obj != null && obj.isInstanceOf[$tpname[..$wildcardedTparams]] && $hashCheck - """ - ) - } + val canEqualMethod = + if (hasCanEqual) Nil + else { + val hashCheck = + if (cachedHashCode) q"obj.hashCode == hashCode" else q"true" + Seq( + q""" + override def canEqual(obj: Any): _root_.scala.Boolean = + obj != null && obj.isInstanceOf[$tpname[..$wildcardedTparams]] && $hashCheck + """ + ) + } - val equalsMethod = { - val fldChecks = paramss.flatten - .map { param => - q"this.${param.name} == other.${param.name}" - } - .foldLeft[Tree](q"true")((a, b) => q"$a && $b") - Seq( - q""" - override def equals(obj: Any): _root_.scala.Boolean = - this.eq(obj.asInstanceOf[AnyRef]) || canEqual(obj) && { - val other = obj.asInstanceOf[$tpname[..$wildcardedTparams]] - $fldChecks - } - """ - ) - } + val equalsMethod = + if (hasEquals) Nil + else { + val fldChecks = paramss.flatten + .map { param => + q"this.${param.name} == other.${param.name}" + } + .foldLeft[Tree](q"true")((a, b) => q"$a && $b") + Seq( + q""" + override def equals(obj: Any): _root_.scala.Boolean = + this.eq(obj.asInstanceOf[AnyRef]) || canEqual(obj) && { + val other = obj.asInstanceOf[$tpname[..$wildcardedTparams]] + $fldChecks + } + """ + ) + } val hashCodeMethod = if (hasHashCode) Nil diff --git a/src/test/scala/dataclass/MoreFieldsTests.scala b/src/test/scala/dataclass/MoreFieldsTests.scala index 3cc9f24..5bde8ac 100644 --- a/src/test/scala/dataclass/MoreFieldsTests.scala +++ b/src/test/scala/dataclass/MoreFieldsTests.scala @@ -259,6 +259,79 @@ object MoreFieldsTests extends TestSuite { } } + "has equals" - { + "def" - { + @data class Foo(password: String) { + override def equals(obj: Any): Boolean = + obj.isInstanceOf[Foo] && ( + password == obj.asInstanceOf[Foo].password || + obj.asInstanceOf[Foo].password == "special" + ) + } + + val foo = Foo("aa") + val other = Foo("bb") + val specialFoo = Foo("special") + val equalsCopy = foo == foo.withPassword("aa") + assert(equalsCopy) + val equalsOther = foo == other + assert(!equalsOther) + val equalsSpecial = foo == specialFoo + assert(equalsSpecial) + } + + "two-arg" - { + @data class Foo(password: String) { + def equals(obj: Any, thing: Any): Boolean = + true + } + + val foo = Foo("aa") + val other = Foo("bb") + val specialFoo = Foo("special") + val equalsCopy = foo == foo.withPassword("aa") + assert(equalsCopy) + val equalsOther = foo == other + assert(!equalsOther) + val equalsSpecial = foo == specialFoo + assert(!equalsSpecial) + } + } + + "has canEqual" - { + "def" - { + @data class Foo(password: String) { + override def canEqual(obj: Any): Boolean = + obj.isInstanceOf[Foo] || obj.toString == "Foo" + } + + val foo = Foo("aa") + val other = Foo("bb") + val canEqualOther = foo.canEqual(other) + assert(canEqualOther) + val canEqualRandomString = foo.canEqual("thing") + assert(!canEqualRandomString) + val canEqualSpecialString = foo.canEqual("Foo") + assert(canEqualSpecialString) + } + + "two-arg" - { + @data class Foo(password: String) { + def canEqual(obj: Any, thing: Any): Boolean = + obj.isInstanceOf[Foo] || obj.toString == "Foo" + } + + val foo = Foo("aa") + val other = Foo("bb") + val canEqualOther = foo.canEqual(other) + assert(canEqualOther) + val canEqualRandomString = foo.canEqual("thing") + assert(!canEqualRandomString) + val canEqualSpecialString = foo.canEqual("Foo") + assert(!canEqualSpecialString) + } + } + "override val with default" - { class Repository { def versionsCheckHasModule: Boolean = false