From 3cf5c5c2b10fdc0e8eab04f47b5825b78ff05e65 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sat, 17 Feb 2024 14:37:55 +0800 Subject: [PATCH] implementation and tests for unrolling non-first parameter list --- build.sc | 1 + unroll/plugin/src-2/UnrollPhaseScala2.scala | 53 ++++++++++++------- unroll/plugin/src-3/UnrollPhaseScala3.scala | 37 +++++++++---- .../secondParameterList/v1/src/Unrolled.scala | 5 ++ .../v1/test/src/UnrollTestMain.scala | 23 ++++++++ .../secondParameterList/v2/src/Unrolled.scala | 7 +++ .../v2/test/src/UnrollTestMain.scala | 25 +++++++++ .../secondParameterList/v3/src/Unrolled.scala | 10 ++++ .../src-js/UnrollTestPlatformSpecific.scala | 5 ++ .../src-jvm/UnrollTestPlatformSpecific.scala | 30 +++++++++++ .../UnrollTestPlatformSpecific.scala | 5 ++ .../v3/test/src/UnrollTestMain.scala | 28 ++++++++++ 12 files changed, 200 insertions(+), 29 deletions(-) create mode 100644 unroll/tests/secondParameterList/v1/src/Unrolled.scala create mode 100644 unroll/tests/secondParameterList/v1/test/src/UnrollTestMain.scala create mode 100644 unroll/tests/secondParameterList/v2/src/Unrolled.scala create mode 100644 unroll/tests/secondParameterList/v2/test/src/UnrollTestMain.scala create mode 100644 unroll/tests/secondParameterList/v3/src/Unrolled.scala create mode 100644 unroll/tests/secondParameterList/v3/test/src-js/UnrollTestPlatformSpecific.scala create mode 100644 unroll/tests/secondParameterList/v3/test/src-jvm/UnrollTestPlatformSpecific.scala create mode 100644 unroll/tests/secondParameterList/v3/test/src-native/UnrollTestPlatformSpecific.scala create mode 100644 unroll/tests/secondParameterList/v3/test/src/UnrollTestMain.scala diff --git a/build.sc b/build.sc index dc7a805..5091256 100644 --- a/build.sc +++ b/build.sc @@ -67,6 +67,7 @@ trait UnrollModule extends Cross.Module[String]{ "primaryConstructor", "secondaryConstructor", "caseclass", + "secondParameterList", // "abstractTraitMethod", // "abstractClassMethod" ) diff --git a/unroll/plugin/src-2/UnrollPhaseScala2.scala b/unroll/plugin/src-2/UnrollPhaseScala2.scala index 3051370..df1430b 100644 --- a/unroll/plugin/src-2/UnrollPhaseScala2.scala +++ b/unroll/plugin/src-2/UnrollPhaseScala2.scala @@ -60,16 +60,17 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT } } - val forwarderMethodType = defdef.symbol.tpe match{ + def forwarderMethodType0(t: Type, n: Int): Type = t match{ case MethodType(originalParams, result) => val forwarderParams = originalParams.map(symbolReplacements) - MethodType(forwarderParams.take(paramIndex), result) + if (n == annotatedParamListIndex) MethodType(forwarderParams.take(paramIndex), result) + else MethodType(forwarderParams, forwarderMethodType0(result, n + 1)) - case PolyType(tparams, MethodType(originalParams, result)) => - val forwarderParams = originalParams.map(symbolReplacements) - PolyType(tparams, MethodType(forwarderParams.take(paramIndex), result)) + case PolyType(tparams, res) => PolyType(tparams, forwarderMethodType0(res, n)) } + val forwarderMethodType = forwarderMethodType0(defdef.symbol.tpe, 0) + forwarderDefSymbol.setInfo(forwarderMethodType) val newParamLists = paramLists @@ -80,25 +81,15 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT } .map(_.map(copyValDef)) - val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map{n => - val mangledName = defdef.name.toString + "$default$" + (n + 1) - - val defaultOwner = - if (defdef.symbol.isConstructor) implDef.symbol.companionModule - else implDef.symbol - val defaultMember = defaultOwner.tpe.member(TermName(scala.reflect.NameTransformer.encode(mangledName))) - Ident(mangledName).setSymbol(defaultMember).set(defaultMember) - } + val defaultOffset = paramLists + .iterator + .take(annotatedParamListIndex) + .map(_.size) + .sum val forwardedValueParams = newParamLists(annotatedParamListIndex).map(p => Ident(p.name).set(p.symbol)) - val forwarderThis = This(defdef.symbol.owner).set(defdef.symbol.owner) - - val forwarderInner = - if (defdef.symbol.isConstructor) Super(forwarderThis, typeNames.EMPTY).set(defdef.symbol.owner) - else forwarderThis - val nestedForwarderMethodTypes = Seq .iterate(defdef.symbol.tpe, defdef.vparamss.length + 1){ case MethodType(args, res) => res @@ -106,11 +97,33 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT } .drop(1) + val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map{n => + val mangledName = defdef.name.toString + "$default$" + (defaultOffset + n + 1) + + val defaultOwner = + if (defdef.symbol.isConstructor) implDef.symbol.companionModule + else implDef.symbol + + val defaultMember = defaultOwner.tpe.member(TermName(scala.reflect.NameTransformer.encode(mangledName))) + newParamLists.take(annotatedParamListIndex).map(_.map( p => Ident(p.name).set(p.symbol))) + .zip(nestedForwarderMethodTypes) + .foldLeft(Ident(mangledName).setSymbol(defaultMember).set(defaultMember).set(defaultMember): Tree) { + case (lhs, (ps, methodType)) => Apply(fun = lhs, args = ps).setType(methodType) + } + + } + val forwarderCallArgs = newParamLists.zipWithIndex.map{case (v, i) => if (i == annotatedParamListIndex) forwardedValueParams ++ defaultCalls else v.map( p => Ident(p.name).set(p.symbol)) } + val forwarderThis = This(defdef.symbol.owner).set(defdef.symbol.owner) + + val forwarderInner = + if (defdef.symbol.isConstructor) Super(forwarderThis, typeNames.EMPTY).set(defdef.symbol.owner) + else forwarderThis + val forwarderCall0 = forwarderCallArgs .zip(nestedForwarderMethodTypes) .foldLeft(Select(forwarderInner, defdef.name).set(defdef.symbol): Tree){ diff --git a/unroll/plugin/src-3/UnrollPhaseScala3.scala b/unroll/plugin/src-3/UnrollPhaseScala3.scala index 0234304..a7836bd 100644 --- a/unroll/plugin/src-3/UnrollPhaseScala3.scala +++ b/unroll/plugin/src-3/UnrollPhaseScala3.scala @@ -40,6 +40,7 @@ class UnrollPhaseScala3() extends PluginPhase { ) } + def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef]) def generateSingleForwarder(defdef: DefDef, prevMethodType: Type, paramIndex: Int, @@ -48,14 +49,16 @@ class UnrollPhaseScala3() extends PluginPhase { isCaseApply: Boolean) (using Context) = { - def truncateMethodType0(tpe: Type): Type = { + def truncateMethodType0(tpe: Type, n: Int): Type = { tpe match{ - case pt: PolyType => PolyType(pt.paramNames, pt.paramInfos, truncateMethodType0(pt.resType)) - case mt: MethodType => MethodType(mt.paramInfos.take(paramIndex), mt.resType) + case pt: PolyType => PolyType(pt.paramNames, pt.paramInfos, truncateMethodType0(pt.resType, n + 1)) + case mt: MethodType => + if (n == annotatedParamListIndex) MethodType(mt.paramInfos.take(paramIndex), mt.resType) + else MethodType(mt.paramInfos, truncateMethodType0(mt.resType, n + 1)) } } - val truncatedMethodType = truncateMethodType0(prevMethodType) + val truncatedMethodType = truncateMethodType0(prevMethodType, 0) val forwarderDefSymbol = Symbols.newSymbol( defdef.symbol.owner, defdef.name, @@ -66,22 +69,38 @@ class UnrollPhaseScala3() extends PluginPhase { val newParamLists: List[ParamClause] = paramLists.zipWithIndex.map{ case (ps, i) => if (i == annotatedParamListIndex) ps.take(paramIndex).map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol)) else { - if (ps.headOption.exists(_.isInstanceOf[TypeDef])) ps.map(p => copyParam2(p.asInstanceOf[TypeDef], forwarderDefSymbol)) + if (isTypeClause(ps)) ps.map(p => copyParam2(p.asInstanceOf[TypeDef], forwarderDefSymbol)) else ps.map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol)) } } + val defaultOffset = paramLists + .iterator + .take(annotatedParamListIndex) + .filter(!isTypeClause(_)) + .map(_.size) + .sum + val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map(n => - if (defdef.symbol.isConstructor) { + val inner = if (defdef.symbol.isConstructor) { ref(defdef.symbol.owner.companionModule) - .select(DefaultGetterName(defdef.name, n)) + .select(DefaultGetterName(defdef.name, n + defaultOffset)) } else if (isCaseApply) { ref(defdef.symbol.owner.companionModule) - .select(DefaultGetterName(termName(""), n)) + .select(DefaultGetterName(termName(""), n + defaultOffset)) } else { This(defdef.symbol.owner.asClass) - .select(DefaultGetterName(defdef.name, n)) + .select(DefaultGetterName(defdef.name, n + defaultOffset)) } + + newParamLists + .take(annotatedParamListIndex) + .map(_.map(p => ref(p.symbol))) + .foldLeft[Tree](inner){ + case (lhs: Tree, newParams) => + if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams) + else Apply(lhs, newParams) + } ) val forwarderInner: Tree = This(defdef.symbol.owner.asClass).select(defdef.symbol) diff --git a/unroll/tests/secondParameterList/v1/src/Unrolled.scala b/unroll/tests/secondParameterList/v1/src/Unrolled.scala new file mode 100644 index 0000000..f9ddac2 --- /dev/null +++ b/unroll/tests/secondParameterList/v1/src/Unrolled.scala @@ -0,0 +1,5 @@ +package unroll + +class Unrolled{ + def foo(f: String => String)(s: String) = f(s) +} diff --git a/unroll/tests/secondParameterList/v1/test/src/UnrollTestMain.scala b/unroll/tests/secondParameterList/v1/test/src/UnrollTestMain.scala new file mode 100644 index 0000000..fe9ac8f --- /dev/null +++ b/unroll/tests/secondParameterList/v1/test/src/UnrollTestMain.scala @@ -0,0 +1,23 @@ +package unroll + +import unroll.TestUtils.logAssertStartsWith + +object UnrollTestMain{ + def main(args: Array[String]): Unit = { + logAssertStartsWith(new Unrolled().foo(identity)("cow"), "cow") + } +} + + + + + + + + + + + + + + diff --git a/unroll/tests/secondParameterList/v2/src/Unrolled.scala b/unroll/tests/secondParameterList/v2/src/Unrolled.scala new file mode 100644 index 0000000..81d481d --- /dev/null +++ b/unroll/tests/secondParameterList/v2/src/Unrolled.scala @@ -0,0 +1,7 @@ +package unroll + +import scala.annotation.unroll + +class Unrolled{ + def foo(f: String => String)(s: String, @unroll n: Int = 1, b: Boolean = true) = f(s + n + b) +} diff --git a/unroll/tests/secondParameterList/v2/test/src/UnrollTestMain.scala b/unroll/tests/secondParameterList/v2/test/src/UnrollTestMain.scala new file mode 100644 index 0000000..4dad1be --- /dev/null +++ b/unroll/tests/secondParameterList/v2/test/src/UnrollTestMain.scala @@ -0,0 +1,25 @@ +package unroll + +import unroll.TestUtils.logAssertStartsWith + +object UnrollTestMain{ + def main(args: Array[String]): Unit = { + logAssertStartsWith(new Unrolled().foo(identity)("cow"), "cow1true") + logAssertStartsWith(new Unrolled().foo(identity)("cow", 2), "cow2true") + logAssertStartsWith(new Unrolled().foo(identity)("cow", 2, false), "cow2false") + } +} + + + + + + + + + + + + + + diff --git a/unroll/tests/secondParameterList/v3/src/Unrolled.scala b/unroll/tests/secondParameterList/v3/src/Unrolled.scala new file mode 100644 index 0000000..215148e --- /dev/null +++ b/unroll/tests/secondParameterList/v3/src/Unrolled.scala @@ -0,0 +1,10 @@ +package unroll + +import scala.annotation.unroll + +class Unrolled{ + def foo(f: String => String)(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0) = f(s + n + b + l) +} + + + diff --git a/unroll/tests/secondParameterList/v3/test/src-js/UnrollTestPlatformSpecific.scala b/unroll/tests/secondParameterList/v3/test/src-js/UnrollTestPlatformSpecific.scala new file mode 100644 index 0000000..84abada --- /dev/null +++ b/unroll/tests/secondParameterList/v3/test/src-js/UnrollTestPlatformSpecific.scala @@ -0,0 +1,5 @@ +package unroll + +object UnrollTestPlatformSpecific{ + def apply() = {} +} \ No newline at end of file diff --git a/unroll/tests/secondParameterList/v3/test/src-jvm/UnrollTestPlatformSpecific.scala b/unroll/tests/secondParameterList/v3/test/src-jvm/UnrollTestPlatformSpecific.scala new file mode 100644 index 0000000..692de0e --- /dev/null +++ b/unroll/tests/secondParameterList/v3/test/src-jvm/UnrollTestPlatformSpecific.scala @@ -0,0 +1,30 @@ +package unroll + +object UnrollTestPlatformSpecific{ + def apply() = { + val instance = new Unrolled() + val cls = classOf[Unrolled] + + assert( + cls.getMethod("foo", classOf[String => String], classOf[String]) + .invoke(instance, identity[String](_), "hello") == + "hello1true0" + ) + + assert( + scala.util.Try(cls.getMethod("foo", classOf[String => String], classOf[String], classOf[Int])).isFailure + ) + assert( + cls.getMethod("foo", classOf[String => String], classOf[String], classOf[Int], classOf[Boolean]) + .invoke(instance, identity[String](_), "hello", 2: Integer, java.lang.Boolean.FALSE) == + "hello2false0" + ) + assert( + cls.getMethod("foo", classOf[String => String], classOf[String], classOf[Int], classOf[Boolean], classOf[Long]) + .invoke(instance, identity[String](_), "hello", 2: Integer, java.lang.Boolean.FALSE, 3: Integer) == + "hello2false3" + ) + + cls.getMethods.filter(_.getName.contains("foo")).foreach(println) + } +} \ No newline at end of file diff --git a/unroll/tests/secondParameterList/v3/test/src-native/UnrollTestPlatformSpecific.scala b/unroll/tests/secondParameterList/v3/test/src-native/UnrollTestPlatformSpecific.scala new file mode 100644 index 0000000..84abada --- /dev/null +++ b/unroll/tests/secondParameterList/v3/test/src-native/UnrollTestPlatformSpecific.scala @@ -0,0 +1,5 @@ +package unroll + +object UnrollTestPlatformSpecific{ + def apply() = {} +} \ No newline at end of file diff --git a/unroll/tests/secondParameterList/v3/test/src/UnrollTestMain.scala b/unroll/tests/secondParameterList/v3/test/src/UnrollTestMain.scala new file mode 100644 index 0000000..12bf0e1 --- /dev/null +++ b/unroll/tests/secondParameterList/v3/test/src/UnrollTestMain.scala @@ -0,0 +1,28 @@ +package unroll + +import unroll.TestUtils.logAssertStartsWith + +object UnrollTestMain{ + def main(args: Array[String]): Unit = { + UnrollTestPlatformSpecific() + + logAssertStartsWith(new Unrolled().foo(identity)("cow"), "cow1true0") + logAssertStartsWith(new Unrolled().foo(identity)("cow", 2), "cow2true0") + logAssertStartsWith(new Unrolled().foo(identity)("cow", 2, false), "cow2false0") + logAssertStartsWith(new Unrolled().foo(identity)("cow", 2, false, 3), "cow2false3") + } +} + + + + + + + + + + + + + +