From 031a92538f38c5e90a79d9323c08699ed03f7d28 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Sun, 13 Feb 2022 11:58:43 +0100 Subject: [PATCH] Replace CtorDtor with simplified FromTo `FromTo` decouples and simplifies the strategy to generate the `from` and `to` functions. This will let us more easily add other strategies in the future and refine the choice of which strategy to use in which case. --- core/src/main/scala/shapeless/generic.scala | 162 ++++++++----------- core/src/main/scala/shapeless/generic1.scala | 15 +- 2 files changed, 75 insertions(+), 102 deletions(-) diff --git a/core/src/main/scala/shapeless/generic.scala b/core/src/main/scala/shapeless/generic.scala index ed076fdb8..36a6cc64e 100644 --- a/core/src/main/scala/shapeless/generic.scala +++ b/core/src/main/scala/shapeless/generic.scala @@ -473,10 +473,19 @@ trait CaseClassMacros extends ReprTypes with CaseClassMacrosVersionSpecifics { def nameOf(tpe: Type): Name = tpe.typeSymbol.name - def mkHListValue(elems: List[Tree]): Tree = - elems.foldRight(q"_root_.shapeless.HNil": Tree) { - case (elem, acc) => q"_root_.shapeless.::($elem, $acc)" + def mkHListValue(elems: List[Tree]): Tree = { + val cons = objectRef[::.type] + elems.foldRight(objectRef[HNil.type]) { + case (elem, acc) => q"$cons($elem, $acc)" } + } + + def mkHListPattern(elems: List[TermName]): Tree = { + val cons = objectRef[::.type] + elems.foldRight(objectRef[HNil.type]) { + case (elem, acc) => pq"$cons($elem, $acc)" + } + } /** * Fold `items` into a type using `cons` as a type constructor. @@ -838,110 +847,83 @@ trait CaseClassMacros extends ReprTypes with CaseClassMacrosVersionSpecifics { } } - trait CtorDtor { - def construct(args: List[Tree]): Tree - def binding: (Tree, List[Tree]) - def reprBinding: (Tree, List[Tree]) - } + case class FromTo(from: CaseDef, to: CaseDef) + object FromTo { + import c.internal.gen - object CtorDtor { - def apply(tpe: Type): CtorDtor = { - val sym = tpe.typeSymbol - val isCaseClass = sym.asClass.isCaseClass + def of(tpe: Type, reprTpt: Tree): FromTo = { + val wildcard = Ident(termNames.WILDCARD) + // like pq"_*" except that it does work + val repWCard = Star(wildcard) - val repWCard = Star(Ident(termNames.WILDCARD)) // like pq"_*" except that it does work + def narrow(tree: Tree, tpe: Type): Tree = tpe match { + case ConstantType(c) => q"$c.asInstanceOf[$tpe]" + case _ => tree + } - def narrow(tree: Tree, tpe: Type): Tree = - tpe match { - case ConstantType(c) => - q"$c.asInstanceOf[$tpe]" - case _ => - tree - } + def const(tree: Tree): CaseDef = + cq"_ => $tree" - def narrow1(tree: Tree, tpe: Type): Tree = - if(isVararg(tpe)) - q"$tree: _*" - else - narrow(tree, tpe) - - def mkCtorDtor0(elems0: List[(TermName, Type)]) = { - val elems = elems0.map { case (_, tpe) => (TermName(c.freshName("pat")), tpe) } - val pattern = pq"${companionRef(tpe)}(..${elems.map { case (binder, tpe) => if(isVararg(tpe)) pq"$binder @ $repWCard" else pq"$binder"}})" - val reprPattern = - elems.foldRight(q"_root_.shapeless.HNil": Tree) { - case ((bound, _), acc) => pq"_root_.shapeless.::($bound, $acc)" - } - val nonCaseParamLists: List[List[Tree]] = List.fill(numNonCaseParamLists(tpe))(Nil) - new CtorDtor { - def construct(args: List[Tree]): Tree = q"${companionRef(tpe)}[..${tpe.typeArgs}](...${args :: nonCaseParamLists})" - def binding: (Tree, List[Tree]) = (pattern, elems.map { case (binder, tpe) => narrow(q"$binder", tpe) }) - def reprBinding: (Tree, List[Tree]) = (reprPattern, elems.map { case (binder, tpe) => narrow1(q"$binder", tpe) }) - } + def from(fields: List[(TermName, Type)])(construct: List[Tree] => Tree): CaseDef = { + val (pats, args) = fields.map { case (field, tpe) => + val pat = c.freshName(field) + (pat, if (isVararg(tpe)) q"$pat: _*" else narrow(q"$pat", tpe)) + }.unzip + cq"${mkHListPattern(pats)} => ${construct(args)}" } - def mkCtorDtor1(elems: List[(TermName, TermName, Type)], pattern: Tree, rhs: List[Tree]) = { - val reprPattern = - elems.foldRight(q"_root_.shapeless.HNil": Tree) { - case ((bound, _, _), acc) => pq"_root_.shapeless.::($bound, $acc)" - } - new CtorDtor { - def construct(args: List[Tree]): Tree = q"new $tpe(..$args)" - def binding: (Tree, List[Tree]) = (pattern, rhs) - def reprBinding: (Tree, List[Tree]) = (reprPattern, elems.map { case (binder, _, tpe) => narrow1(q"$binder", tpe) }) - } + def to(pattern: Tree, args: List[Tree]): CaseDef = + cq"$pattern => ${mkHListValue(args)}.asInstanceOf[$reprTpt]" + + def fromApply(fields: List[(TermName, Type)]): CaseDef = from(fields) { args => + val nonCaseArgs = List.fill(numNonCaseParamLists(tpe))(List.empty[Tree]) + q"${companionRef(tpe)}[..${tpe.typeArgs}](..$args)(...$nonCaseArgs)" + } + + def fromConstructor(fields: List[(TermName, Type)]): CaseDef = + from(fields)(args => q"new $tpe(..$args)") + + def toUnapply(fields: List[(TermName, Type)]): CaseDef = { + val (pats, args) = fields.map { case (field, tpe) => + val pat = c.freshName(field) + (Bind(pat, if (isVararg(tpe)) repWCard else wildcard), narrow(Ident(pat), tpe)) + }.unzip + to(pq"${companionRef(tpe)}(..$pats)", args) + } + + def toGetters(fields: List[(TermName, Type)]): CaseDef = { + val pattern = c.freshName(TermName("x")) + to(pq"$pattern", fields.map { case (field, tpe) => narrow(q"$pattern.$field", tpe) }) } lowerKind(tpe) match { // case 1: Unit case tpe if tpe =:= typeOf[Unit] => - new CtorDtor { - def construct(args: List[Tree]): Tree = q"()" - def binding: (Tree, List[Tree]) = (pq"()", Nil) - def reprBinding: (Tree, List[Tree]) = (pq"_root_.shapeless.HNil", Nil) - } - + FromTo(const(q"()"), const(objectRef[HNil.type])) // case 2: singleton case tpe if isCaseObjectLike(tpe.typeSymbol.asClass) => - val singleton = - tpe match { - case SingleType(pre, sym) => - c.internal.gen.mkAttributedRef(pre, sym) - case TypeRef(pre, sym, List()) if sym.isModule => - c.internal.gen.mkAttributedRef(pre, sym.asModule) - case TypeRef(pre, sym, List()) if sym.isModuleClass => - c.internal.gen.mkAttributedRef(pre, sym.asClass.module) - case _ => - abort(s"Bad case object-like type $tpe") - } - new CtorDtor { - def construct(args: List[Tree]): Tree = q"$singleton: $tpe" - def binding: (Tree, List[Tree]) = (pq"_: $tpe", Nil) - def reprBinding: (Tree, List[Tree]) = (pq"_root_.shapeless.HNil", Nil) + val singleton = tpe match { + case SingleType(pre, sym) => gen.mkAttributedRef(pre, sym) + case TypeRef(pre, sym, Nil) if sym.isModule => gen.mkAttributedRef(pre, sym.asModule) + case TypeRef(pre, sym, Nil) if sym.isModuleClass => gen.mkAttributedRef(pre, sym.asClass.module) + case _ => abort(s"Bad case object-like type $tpe") } - + FromTo(const(singleton), const(objectRef[HNil.type])) // case 3: case class - case tpe if isCaseClass => mkCtorDtor0(fieldsOf(tpe)) - + case tpe if tpe.typeSymbol.asClass.isCaseClass => + val fields = fieldsOf(tpe) + FromTo(fromApply(fields), toUnapply(fields)) // case 4: exactly one matching public apply/unapply - case HasApplyUnapply(args) => mkCtorDtor0(args) - + case HasApplyUnapply(args) => + FromTo(fromApply(args), toUnapply(args)) // case 5: concrete, exactly one public constructor with matching public unapply case HasCtorUnapply(args) => - val elems = args.map { case (name, tpe) => (TermName(c.freshName("pat")), name, tpe) } - val pattern = pq"${companionRef(tpe)}(..${elems.map { case (binder, _, tpe) => if(isVararg(tpe)) pq"$binder @ $repWCard" else pq"$binder" }})" - val rhs = elems.map { case (binder, _, tpe) => narrow(q"$binder", tpe) } - mkCtorDtor1(elems, pattern, rhs) - + FromTo(fromConstructor(args), toUnapply(args)) // case 6: concrete, exactly one public constructor with matching accessible fields case HasUniqueCtor(args) => - val elems = args.map { case (name, tpe) => (TermName(c.freshName("pat")), name, tpe) } - val binder = TermName(c.freshName("pat")) - val pattern = pq"$binder" - val rhs = elems.map { case (_, name, tpe) => narrow(q"$binder.$name", tpe) } - mkCtorDtor1(elems, pattern, rhs) - - case _ => abort(s"Bad product type $tpe") + FromTo(fromConstructor(args), toGetters(args)) + case _ => + abort(s"Bad product type $tpe") } } } @@ -965,11 +947,7 @@ class GenericMacros(val c: whitebox.Context) extends CaseClassMacros { def mkProductGeneric(tpe: Type): Tree = { val repr = mkHListTpe(fieldsOf(tpe).map(_._2)) - val ctorDtor = CtorDtor(tpe) - val (p, ts) = ctorDtor.binding - val to = cq"$p => ${mkHListValue(ts)}.asInstanceOf[$repr]" - val (rp, rts) = ctorDtor.reprBinding - val from = cq"$rp => ${ctorDtor.construct(rts)}" + val FromTo(from, to) = FromTo.of(tpe, TypeTree(repr)) q"$generic.instance[$tpe, $repr]({ case $to }, { case $from })" } diff --git a/core/src/main/scala/shapeless/generic1.scala b/core/src/main/scala/shapeless/generic1.scala index cca6be38e..015df351f 100644 --- a/core/src/main/scala/shapeless/generic1.scala +++ b/core/src/main/scala/shapeless/generic1.scala @@ -217,17 +217,12 @@ class Generic1Macros(val c: whitebox.Context) extends CaseClassMacros { } def mkProductGeneric1(tpe: Type, frTpe: Type): Tree = { - val ctorDtor = CtorDtor(tpe) - val (p, ts) = ctorDtor.binding - val to = cq"$p => ${mkHListValue(ts)}" - val (rp, rts) = ctorDtor.reprBinding - val from = cq"$rp => ${ctorDtor.construct(rts)}" - val name = TypeName(c.freshName("P")) - val reprTpt = reprTypTree1(tpe, name) - val reprName = TypeName(c.freshName("R")) - + val tparam = c.freshName(TypeName("P")) + val reprName = c.freshName(TypeName("R")) + val reprTpt = reprTypTree1(tpe, tparam) + val FromTo(from, to) = FromTo.of(tpe, tq"$reprName[$AnyTpe]") q""" - type $reprName[$name] = $reprTpt + type $reprName[$tparam] = $reprTpt $generic1.unsafeInstance[$tpe, $frTpe, $reprName]({ case $to }, { case $from }) """ }