Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace CtorDtor with simplified FromTo #1243

Merged
merged 1 commit into from
Feb 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 70 additions & 92 deletions core/src/main/scala/shapeless/generic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,19 @@ trait CaseClassMacros extends ReprTypes with CaseClassMacrosVersionSpecifics {
def nameOf(tpe: Type): Name =
tpe.typeSymbol.name

def mkHListValue(elems: List[Tree]): Tree =
elems.foldRight(q"_root_.shapeless.HNil": Tree) {
case (elem, acc) => q"_root_.shapeless.::($elem, $acc)"
def mkHListValue(elems: List[Tree]): Tree = {
val cons = objectRef[::.type]
elems.foldRight(objectRef[HNil.type]) {
case (elem, acc) => q"$cons($elem, $acc)"
}
}

def mkHListPattern(elems: List[TermName]): Tree = {
val cons = objectRef[::.type]
elems.foldRight(objectRef[HNil.type]) {
case (elem, acc) => pq"$cons($elem, $acc)"
}
}

/**
* Fold `items` into a type using `cons` as a type constructor.
Expand Down Expand Up @@ -838,110 +847,83 @@ trait CaseClassMacros extends ReprTypes with CaseClassMacrosVersionSpecifics {
}
}

trait CtorDtor {
def construct(args: List[Tree]): Tree
def binding: (Tree, List[Tree])
def reprBinding: (Tree, List[Tree])
}
case class FromTo(from: CaseDef, to: CaseDef)
object FromTo {
import c.internal.gen

object CtorDtor {
def apply(tpe: Type): CtorDtor = {
val sym = tpe.typeSymbol
val isCaseClass = sym.asClass.isCaseClass
def of(tpe: Type, reprTpt: Tree): FromTo = {
val wildcard = Ident(termNames.WILDCARD)
// like pq"_*" except that it does work
val repWCard = Star(wildcard)

val repWCard = Star(Ident(termNames.WILDCARD)) // like pq"_*" except that it does work
def narrow(tree: Tree, tpe: Type): Tree = tpe match {
case ConstantType(c) => q"$c.asInstanceOf[$tpe]"
case _ => tree
}

def narrow(tree: Tree, tpe: Type): Tree =
tpe match {
case ConstantType(c) =>
q"$c.asInstanceOf[$tpe]"
case _ =>
tree
}
def const(tree: Tree): CaseDef =
cq"_ => $tree"

def narrow1(tree: Tree, tpe: Type): Tree =
if(isVararg(tpe))
q"$tree: _*"
else
narrow(tree, tpe)

def mkCtorDtor0(elems0: List[(TermName, Type)]) = {
val elems = elems0.map { case (_, tpe) => (TermName(c.freshName("pat")), tpe) }
val pattern = pq"${companionRef(tpe)}(..${elems.map { case (binder, tpe) => if(isVararg(tpe)) pq"$binder @ $repWCard" else pq"$binder"}})"
val reprPattern =
elems.foldRight(q"_root_.shapeless.HNil": Tree) {
case ((bound, _), acc) => pq"_root_.shapeless.::($bound, $acc)"
}
val nonCaseParamLists: List[List[Tree]] = List.fill(numNonCaseParamLists(tpe))(Nil)
new CtorDtor {
def construct(args: List[Tree]): Tree = q"${companionRef(tpe)}[..${tpe.typeArgs}](...${args :: nonCaseParamLists})"
def binding: (Tree, List[Tree]) = (pattern, elems.map { case (binder, tpe) => narrow(q"$binder", tpe) })
def reprBinding: (Tree, List[Tree]) = (reprPattern, elems.map { case (binder, tpe) => narrow1(q"$binder", tpe) })
}
def from(fields: List[(TermName, Type)])(construct: List[Tree] => Tree): CaseDef = {
val (pats, args) = fields.map { case (field, tpe) =>
val pat = c.freshName(field)
(pat, if (isVararg(tpe)) q"$pat: _*" else narrow(q"$pat", tpe))
}.unzip
cq"${mkHListPattern(pats)} => ${construct(args)}"
}

def mkCtorDtor1(elems: List[(TermName, TermName, Type)], pattern: Tree, rhs: List[Tree]) = {
val reprPattern =
elems.foldRight(q"_root_.shapeless.HNil": Tree) {
case ((bound, _, _), acc) => pq"_root_.shapeless.::($bound, $acc)"
}
new CtorDtor {
def construct(args: List[Tree]): Tree = q"new $tpe(..$args)"
def binding: (Tree, List[Tree]) = (pattern, rhs)
def reprBinding: (Tree, List[Tree]) = (reprPattern, elems.map { case (binder, _, tpe) => narrow1(q"$binder", tpe) })
}
def to(pattern: Tree, args: List[Tree]): CaseDef =
cq"$pattern => ${mkHListValue(args)}.asInstanceOf[$reprTpt]"

def fromApply(fields: List[(TermName, Type)]): CaseDef = from(fields) { args =>
val nonCaseArgs = List.fill(numNonCaseParamLists(tpe))(List.empty[Tree])
q"${companionRef(tpe)}[..${tpe.typeArgs}](..$args)(...$nonCaseArgs)"
}

def fromConstructor(fields: List[(TermName, Type)]): CaseDef =
from(fields)(args => q"new $tpe(..$args)")

def toUnapply(fields: List[(TermName, Type)]): CaseDef = {
val (pats, args) = fields.map { case (field, tpe) =>
val pat = c.freshName(field)
(Bind(pat, if (isVararg(tpe)) repWCard else wildcard), narrow(Ident(pat), tpe))
}.unzip
to(pq"${companionRef(tpe)}(..$pats)", args)
}

def toGetters(fields: List[(TermName, Type)]): CaseDef = {
val pattern = c.freshName(TermName("x"))
to(pq"$pattern", fields.map { case (field, tpe) => narrow(q"$pattern.$field", tpe) })
}

lowerKind(tpe) match {
// case 1: Unit
case tpe if tpe =:= typeOf[Unit] =>
new CtorDtor {
def construct(args: List[Tree]): Tree = q"()"
def binding: (Tree, List[Tree]) = (pq"()", Nil)
def reprBinding: (Tree, List[Tree]) = (pq"_root_.shapeless.HNil", Nil)
}

FromTo(const(q"()"), const(objectRef[HNil.type]))
// case 2: singleton
case tpe if isCaseObjectLike(tpe.typeSymbol.asClass) =>
val singleton =
tpe match {
case SingleType(pre, sym) =>
c.internal.gen.mkAttributedRef(pre, sym)
case TypeRef(pre, sym, List()) if sym.isModule =>
c.internal.gen.mkAttributedRef(pre, sym.asModule)
case TypeRef(pre, sym, List()) if sym.isModuleClass =>
c.internal.gen.mkAttributedRef(pre, sym.asClass.module)
case _ =>
abort(s"Bad case object-like type $tpe")
}
new CtorDtor {
def construct(args: List[Tree]): Tree = q"$singleton: $tpe"
def binding: (Tree, List[Tree]) = (pq"_: $tpe", Nil)
def reprBinding: (Tree, List[Tree]) = (pq"_root_.shapeless.HNil", Nil)
val singleton = tpe match {
case SingleType(pre, sym) => gen.mkAttributedRef(pre, sym)
case TypeRef(pre, sym, Nil) if sym.isModule => gen.mkAttributedRef(pre, sym.asModule)
case TypeRef(pre, sym, Nil) if sym.isModuleClass => gen.mkAttributedRef(pre, sym.asClass.module)
case _ => abort(s"Bad case object-like type $tpe")
}

FromTo(const(singleton), const(objectRef[HNil.type]))
// case 3: case class
case tpe if isCaseClass => mkCtorDtor0(fieldsOf(tpe))

case tpe if tpe.typeSymbol.asClass.isCaseClass =>
val fields = fieldsOf(tpe)
FromTo(fromApply(fields), toUnapply(fields))
// case 4: exactly one matching public apply/unapply
case HasApplyUnapply(args) => mkCtorDtor0(args)

case HasApplyUnapply(args) =>
FromTo(fromApply(args), toUnapply(args))
// case 5: concrete, exactly one public constructor with matching public unapply
case HasCtorUnapply(args) =>
val elems = args.map { case (name, tpe) => (TermName(c.freshName("pat")), name, tpe) }
val pattern = pq"${companionRef(tpe)}(..${elems.map { case (binder, _, tpe) => if(isVararg(tpe)) pq"$binder @ $repWCard" else pq"$binder" }})"
val rhs = elems.map { case (binder, _, tpe) => narrow(q"$binder", tpe) }
mkCtorDtor1(elems, pattern, rhs)

FromTo(fromConstructor(args), toUnapply(args))
// case 6: concrete, exactly one public constructor with matching accessible fields
case HasUniqueCtor(args) =>
val elems = args.map { case (name, tpe) => (TermName(c.freshName("pat")), name, tpe) }
val binder = TermName(c.freshName("pat"))
val pattern = pq"$binder"
val rhs = elems.map { case (_, name, tpe) => narrow(q"$binder.$name", tpe) }
mkCtorDtor1(elems, pattern, rhs)

case _ => abort(s"Bad product type $tpe")
FromTo(fromConstructor(args), toGetters(args))
case _ =>
abort(s"Bad product type $tpe")
}
}
}
Expand All @@ -965,11 +947,7 @@ class GenericMacros(val c: whitebox.Context) extends CaseClassMacros {

def mkProductGeneric(tpe: Type): Tree = {
val repr = mkHListTpe(fieldsOf(tpe).map(_._2))
val ctorDtor = CtorDtor(tpe)
val (p, ts) = ctorDtor.binding
val to = cq"$p => ${mkHListValue(ts)}.asInstanceOf[$repr]"
val (rp, rts) = ctorDtor.reprBinding
val from = cq"$rp => ${ctorDtor.construct(rts)}"
val FromTo(from, to) = FromTo.of(tpe, TypeTree(repr))
q"$generic.instance[$tpe, $repr]({ case $to }, { case $from })"
}

Expand Down
15 changes: 5 additions & 10 deletions core/src/main/scala/shapeless/generic1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,12 @@ class Generic1Macros(val c: whitebox.Context) extends CaseClassMacros {
}

def mkProductGeneric1(tpe: Type, frTpe: Type): Tree = {
val ctorDtor = CtorDtor(tpe)
val (p, ts) = ctorDtor.binding
val to = cq"$p => ${mkHListValue(ts)}"
val (rp, rts) = ctorDtor.reprBinding
val from = cq"$rp => ${ctorDtor.construct(rts)}"
val name = TypeName(c.freshName("P"))
val reprTpt = reprTypTree1(tpe, name)
val reprName = TypeName(c.freshName("R"))

val tparam = c.freshName(TypeName("P"))
val reprName = c.freshName(TypeName("R"))
val reprTpt = reprTypTree1(tpe, tparam)
val FromTo(from, to) = FromTo.of(tpe, tq"$reprName[$AnyTpe]")
q"""
type $reprName[$name] = $reprTpt
type $reprName[$tparam] = $reprTpt
$generic1.unsafeInstance[$tpe, $frTpe, $reprName]({ case $to }, { case $from })
"""
}
Expand Down