Skip to content

Commit

Permalink
Don't add canEqual and equals if users already defined them
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarchambault committed Jan 17, 2025
1 parent 54a87d8 commit 5fae974
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 31 deletions.
75 changes: 44 additions & 31 deletions src/main/scala/dataclass/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,28 @@ private[dataclass] class Macros(val c: Context) extends ImplTransformers {
tq"${t.name}"
}

def hasMethod(methodName: String): Boolean = {
def hasMethod(
methodName: String,
hasArgs: Boolean = false
): Boolean = {
def fromStats =
stats.exists {
case DefDef(_, nme, tparams, vparamss, _, _)
if nme.decodedName.toString == methodName && tparams.isEmpty && vparamss
.forall(_.isEmpty) =>
if nme.decodedName.toString == methodName &&
tparams.isEmpty &&
(
hasArgs && vparamss.map(_.length).sum == 1 ||
vparamss.forall(_.isEmpty)
) =>
true
case t @ ValDef(_, name, _, _)
if name.decodedName.toString == methodName =>
if !hasArgs && name.decodedName.toString == methodName =>
true
case _ =>
false
}

val fromFields =
val fromFields = !hasArgs &&
allParams.exists(_.name.decodedName.toString() == methodName)

fromFields || fromStats
Expand All @@ -56,6 +63,8 @@ private[dataclass] class Macros(val c: Context) extends ImplTransformers {
val hasToString = hasMethod("toString")
val hasHashCode = hasMethod("hashCode")
val hasTuple = hasMethod("tuple")
val hasCanEqual = hasMethod("canEqual", hasArgs = true)
val hasEquals = hasMethod("equals", hasArgs = true)

val namedArgs = paramss.map(_.map { p =>
q"${p.name}=this.${p.name}"
Expand Down Expand Up @@ -122,33 +131,37 @@ private[dataclass] class Macros(val c: Context) extends ImplTransformers {
tq"({type L[..$tparams0]=$WildcardType})#L"
}

val canEqualMethod = {
val hashCheck =
if (cachedHashCode) q"obj.hashCode == hashCode" else q"true"
Seq(
q"""
override def canEqual(obj: Any): _root_.scala.Boolean =
obj != null && obj.isInstanceOf[$tpname[..$wildcardedTparams]] && $hashCheck
"""
)
}
val canEqualMethod =
if (hasCanEqual) Nil
else {
val hashCheck =
if (cachedHashCode) q"obj.hashCode == hashCode" else q"true"
Seq(
q"""
override def canEqual(obj: Any): _root_.scala.Boolean =
obj != null && obj.isInstanceOf[$tpname[..$wildcardedTparams]] && $hashCheck
"""
)
}

val equalsMethod = {
val fldChecks = paramss.flatten
.map { param =>
q"this.${param.name} == other.${param.name}"
}
.foldLeft[Tree](q"true")((a, b) => q"$a && $b")
Seq(
q"""
override def equals(obj: Any): _root_.scala.Boolean =
this.eq(obj.asInstanceOf[AnyRef]) || canEqual(obj) && {
val other = obj.asInstanceOf[$tpname[..$wildcardedTparams]]
$fldChecks
}
"""
)
}
val equalsMethod =
if (hasEquals) Nil
else {
val fldChecks = paramss.flatten
.map { param =>
q"this.${param.name} == other.${param.name}"
}
.foldLeft[Tree](q"true")((a, b) => q"$a && $b")
Seq(
q"""
override def equals(obj: Any): _root_.scala.Boolean =
this.eq(obj.asInstanceOf[AnyRef]) || canEqual(obj) && {
val other = obj.asInstanceOf[$tpname[..$wildcardedTparams]]
$fldChecks
}
"""
)
}

val hashCodeMethod =
if (hasHashCode) Nil
Expand Down
73 changes: 73 additions & 0 deletions src/test/scala/dataclass/MoreFieldsTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,79 @@ object MoreFieldsTests extends TestSuite {
}
}

"has equals" - {
"def" - {
@data class Foo(password: String) {
override def equals(obj: Any): Boolean =
obj.isInstanceOf[Foo] && (
password == obj.asInstanceOf[Foo].password ||
obj.asInstanceOf[Foo].password == "special"
)
}

val foo = Foo("aa")
val other = Foo("bb")
val specialFoo = Foo("special")
val equalsCopy = foo == foo.withPassword("aa")
assert(equalsCopy)
val equalsOther = foo == other
assert(!equalsOther)
val equalsSpecial = foo == specialFoo
assert(equalsSpecial)
}

"two-arg" - {
@data class Foo(password: String) {
def equals(obj: Any, thing: Any): Boolean =
true
}

val foo = Foo("aa")
val other = Foo("bb")
val specialFoo = Foo("special")
val equalsCopy = foo == foo.withPassword("aa")
assert(equalsCopy)
val equalsOther = foo == other
assert(!equalsOther)
val equalsSpecial = foo == specialFoo
assert(!equalsSpecial)
}
}

"has canEqual" - {
"def" - {
@data class Foo(password: String) {
override def canEqual(obj: Any): Boolean =
obj.isInstanceOf[Foo] || obj.toString == "Foo"
}

val foo = Foo("aa")
val other = Foo("bb")
val canEqualOther = foo.canEqual(other)
assert(canEqualOther)
val canEqualRandomString = foo.canEqual("thing")
assert(!canEqualRandomString)
val canEqualSpecialString = foo.canEqual("Foo")
assert(canEqualSpecialString)
}

"two-arg" - {
@data class Foo(password: String) {
def canEqual(obj: Any, thing: Any): Boolean =
obj.isInstanceOf[Foo] || obj.toString == "Foo"
}

val foo = Foo("aa")
val other = Foo("bb")
val canEqualOther = foo.canEqual(other)
assert(canEqualOther)
val canEqualRandomString = foo.canEqual("thing")
assert(!canEqualRandomString)
val canEqualSpecialString = foo.canEqual("Foo")
assert(!canEqualSpecialString)
}
}

"override val with default" - {
class Repository {
def versionsCheckHasModule: Boolean = false
Expand Down

0 comments on commit 5fae974

Please sign in to comment.