From 878cc3e1919eb62c6948d492358b6ec2c0a17d39 Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 18 Apr 2024 14:13:04 +0200 Subject: [PATCH] Fix mapping and pickling of annotated types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Annotation.mapWith` maps an `Annotation` with a type map `tm`. Before actually applying `tm` to the annotation’s `tree`, it first checks if `tm` would result in any change by applying it to the types of the annotation’s arguments, and checking if the mapped types are different. This optimization had two problems: it didn’t include type parameters, and used `frozen_=:=` to compare types, which failed to detected some changes. This commit changes `Annotation.arguments` to also include type parameters, and, and changes `Annotation.MapWith` to use `==` to compare types instead of `frozen_=:=`. Furthermore, in case of changes, the symbol in the annotation's tree should be copied to make sure that the same symbol is not used for different trees. This commit achieves this by using a custom `TreeTypeMap` with an overridden `withMappedSyms` method where `Symbols.mapSymbols` is called with the argument `mapAlways = true`. Finally, positons of trees that appear inside `AnnotatedType` only were not pickled. This commit also fixes this. --- .../src/dotty/tools/dotc/ast/TreeInfo.scala | 4 +-- .../dotty/tools/dotc/core/Annotations.scala | 29 ++++++++-------- .../dotc/core/tasty/PositionPickler.scala | 4 +++ .../tools/dotc/core/tasty/TreePickler.scala | 7 ++++ .../tools/dotc/quoted/PickledQuotes.scala | 2 +- .../dotty/tools/dotc/transform/Pickler.scala | 2 +- tests/pos/annot-17939.scala | 7 ++++ tests/pos/annot-17939b.scala | 10 ++++++ tests/pos/annot-19846.scala | 8 +++++ tests/pos/annot-19846b.scala | 7 ++++ tests/pos/annot-5789.scala | 10 ++++++ tests/printing/annot-18064.check | 16 +++++++++ tests/printing/annot-18064.scala | 7 ++++ tests/printing/annot-19846b.check | 33 +++++++++++++++++++ tests/printing/annot-19846b.scala | 7 ++++ 15 files changed, 133 insertions(+), 20 deletions(-) create mode 100644 tests/pos/annot-17939.scala create mode 100644 tests/pos/annot-17939b.scala create mode 100644 tests/pos/annot-19846.scala create mode 100644 tests/pos/annot-19846b.scala create mode 100644 tests/pos/annot-5789.scala create mode 100644 tests/printing/annot-18064.check create mode 100644 tests/printing/annot-18064.scala create mode 100644 tests/printing/annot-19846b.check create mode 100644 tests/printing/annot-19846b.scala diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index a1bba544cc06..1ba86a66aae0 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -134,10 +134,10 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] => case _ => argss loop(tree, Nil) - /** All term arguments of an application in a single flattened list */ + /** All type and term arguments of an application in a single flattened list */ def allArguments(tree: Tree): List[Tree] = unsplice(tree) match { case Apply(fn, args) => allArguments(fn) ::: args - case TypeApply(fn, _) => allArguments(fn) + case TypeApply(fn, args) => allArguments(fn) ::: args case Block(_, expr) => allArguments(expr) case _ => Nil } diff --git a/compiler/src/dotty/tools/dotc/core/Annotations.scala b/compiler/src/dotty/tools/dotc/core/Annotations.scala index a5ef4c26eed1..9a3360e7b0ad 100644 --- a/compiler/src/dotty/tools/dotc/core/Annotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Annotations.scala @@ -3,8 +3,9 @@ package dotc package core import Symbols.*, Types.*, Contexts.*, Constants.*, Phases.* -import ast.tpd, tpd.* -import util.Spans.Span +import ast.{tpd, untpd, TreeTypeMap} +import tpd.* +import util.Spans.{Span, NoSpan} import printing.{Showable, Printer} import printing.Texts.Text @@ -30,7 +31,7 @@ object Annotations { def derivedAnnotation(tree: Tree)(using Context): Annotation = if (tree eq this.tree) this else Annotation(tree) - /** All arguments to this annotation in a single flat list */ + /** All type and term arguments to this annotation in a single flat list */ def arguments(using Context): List[Tree] = tpd.allArguments(tree) def argument(i: Int)(using Context): Option[Tree] = { @@ -54,19 +55,15 @@ object Annotations { * type, since ranges cannot be types of trees. */ def mapWith(tm: TypeMap)(using Context) = - val args = arguments - if args.isEmpty then this - else - val findDiff = new TreeAccumulator[Type]: - def apply(x: Type, tree: Tree)(using Context): Type = - if tm.isRange(x) then x - else - val tp1 = tm(tree.tpe) - foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree) - val diff = findDiff(NoType, args) - if tm.isRange(diff) then EmptyAnnotation - else if diff.exists then derivedAnnotation(tm.mapOver(tree)) - else this + val originalTree = tree + val mappedTree = tm.mapOver(originalTree) + if mappedTree neq originalTree then + val ttm = + new TreeTypeMap(typeMap = tm): + final override def withMappedSyms(syms: List[Symbol]): TreeTypeMap = + withMappedSyms(syms, mapSymbols(syms, this, mapAlways = true)) + derivedAnnotation(ttm.transform(tree)) + derivedAnnotation(tm.mapOver(tree)) /** Does this annotation refer to a parameter of `tl`? */ def refersToParamOf(tl: TermLambda)(using Context): Boolean = diff --git a/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala index 86076517021a..3d8080e72a29 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala @@ -33,6 +33,7 @@ object PositionPickler: pickler: TastyPickler, addrOfTree: TreeToAddr, treeAnnots: untpd.MemberDef => List[tpd.Tree], + typeAnnots: List[tpd.Tree], relativePathReference: String, source: SourceFile, roots: List[Tree], @@ -136,6 +137,9 @@ object PositionPickler: } for (root <- roots) traverse(root, NoSource) + + for annotTree <- typeAnnots do + traverse(annotTree, NoSource) end picklePositions end PositionPickler diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index 0a8669292a74..1cede78c96f4 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -40,6 +40,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { */ private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]() + /** A set of annotation trees appearing in annotated types. + */ + private val annotatedTypeTrees = mutable.ListBuffer[Tree]() + /** A map from member definitions to their doc comments, so that later * parallel comment pickling does not need to access symbols of trees (which * would involve accessing symbols of named types and possibly changing phases @@ -56,6 +60,8 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { val ts = annotTrees.lookup(tree) if ts == null then Nil else ts.toList + def typeAnnots: List[Tree] = annotatedTypeTrees.toList + def docString(tree: untpd.MemberDef): Option[Comment] = Option(docStrings.lookup(tree)) @@ -266,6 +272,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) { case tpe: AnnotatedType => writeByte(ANNOTATEDtype) withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) } + annotatedTypeTrees += tpe.annot.tree case tpe: AndType => writeByte(ANDtype) withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) } diff --git a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala index 8ebd1f6973f2..db40283076aa 100644 --- a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala +++ b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala @@ -224,7 +224,7 @@ object PickledQuotes { if tree.span.exists then val positionWarnings = new mutable.ListBuffer[Message]() val reference = ctx.settings.sourceroot.value - PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference, + PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference, ctx.compilationUnit.source, tree :: Nil, positionWarnings) positionWarnings.foreach(report.warning(_)) diff --git a/compiler/src/dotty/tools/dotc/transform/Pickler.scala b/compiler/src/dotty/tools/dotc/transform/Pickler.scala index 3a4212547d16..6841b9b686a8 100644 --- a/compiler/src/dotty/tools/dotc/transform/Pickler.scala +++ b/compiler/src/dotty/tools/dotc/transform/Pickler.scala @@ -143,7 +143,7 @@ class Pickler extends Phase { if tree.span.exists then val reference = ctx.settings.sourceroot.value PositionPickler.picklePositions( - pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference, + pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference, unit.source, tree :: Nil, positionWarnings, scratch.positionBuffer, scratch.pickledIndices) diff --git a/tests/pos/annot-17939.scala b/tests/pos/annot-17939.scala new file mode 100644 index 000000000000..2b3adf0ac1cc --- /dev/null +++ b/tests/pos/annot-17939.scala @@ -0,0 +1,7 @@ +class qualified[T](f: T => Boolean) extends annotation.StaticAnnotation + +class Box[T](val x: T) +class Box2(val x: Int) + +class A(a: String @qualified((x: Int) => Box(3).x == 3)) // crash +class A2(a2: String @qualified((x: Int) => Box2(3).x == 3)) // works diff --git a/tests/pos/annot-17939b.scala b/tests/pos/annot-17939b.scala new file mode 100644 index 000000000000..a48f4690d0b2 --- /dev/null +++ b/tests/pos/annot-17939b.scala @@ -0,0 +1,10 @@ +import scala.annotation.Annotation +class myRefined(f: ? => Boolean) extends Annotation + +def test(axes: Int) = true + +trait Tensor: + def mean(axes: Int): Int @myRefined(_ => test(axes)) + +class TensorImpl() extends Tensor: + def mean(axes: Int) = ??? diff --git a/tests/pos/annot-19846.scala b/tests/pos/annot-19846.scala new file mode 100644 index 000000000000..09c24a5cf3cf --- /dev/null +++ b/tests/pos/annot-19846.scala @@ -0,0 +1,8 @@ +class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation + +class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x)) + +@main def main = + val p = EqualPair(42, 42) + val y = p.y + println(42) diff --git a/tests/pos/annot-19846b.scala b/tests/pos/annot-19846b.scala new file mode 100644 index 000000000000..81f25065d980 --- /dev/null +++ b/tests/pos/annot-19846b.scala @@ -0,0 +1,7 @@ +class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation + +def f(x: Int): Int @qualified[Int](it => it == x) = ??? + +@main def main = + val z = f(42) + () diff --git a/tests/pos/annot-5789.scala b/tests/pos/annot-5789.scala new file mode 100644 index 000000000000..bdf4438c9d5d --- /dev/null +++ b/tests/pos/annot-5789.scala @@ -0,0 +1,10 @@ +class Annot[T] extends scala.annotation.Annotation + +class D[T](val f: Int@Annot[T]) + +object A{ + def main(a:Array[String]) = { + val c = new D[Int](1) + c.f + } +} diff --git a/tests/printing/annot-18064.check b/tests/printing/annot-18064.check new file mode 100644 index 000000000000..d93ddb95afee --- /dev/null +++ b/tests/printing/annot-18064.check @@ -0,0 +1,16 @@ +[[syntax trees at end of typer]] // tests/printing/annot-18064.scala +package { + class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() { + T + } + trait Tensor[T >: Nothing <: Any]() extends Object { + T + def add: Tensor[Tensor.this.T] @myAnnot[T] + } + class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[ + TensorImpl.this.A] { + A + def add: Tensor[A] @myAnnot[A] = this + } +} + diff --git a/tests/printing/annot-18064.scala b/tests/printing/annot-18064.scala new file mode 100644 index 000000000000..95554fd3a1b7 --- /dev/null +++ b/tests/printing/annot-18064.scala @@ -0,0 +1,7 @@ +class myAnnot[T]() extends annotation.Annotation + +trait Tensor[T]: + def add: Tensor[T] @myAnnot[T]() + +class TensorImpl[A]() extends Tensor[A]: + def add /* : Tensor[A] @myAnnot[A] */ = this diff --git a/tests/printing/annot-19846b.check b/tests/printing/annot-19846b.check new file mode 100644 index 000000000000..3f63a46c4286 --- /dev/null +++ b/tests/printing/annot-19846b.check @@ -0,0 +1,33 @@ +[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala +package { + class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(), + annotation.StaticAnnotation { + private[this] val g: () => Int + } + final lazy module val Test: Test = new Test() + final module class Test() extends Object() { this: Test.type => + val y: Int = ??? + val z: + Int @lambdaAnnot( + { + def $anonfun(): Int = Test.y + closure($anonfun) + } + ) + = f(Test.y) + } + final lazy module val annot-19846b$package: annot-19846b$package = + new annot-19846b$package() + final module class annot-19846b$package() extends Object() { + this: annot-19846b$package.type => + def f(x: Int): + Int @lambdaAnnot( + { + def $anonfun(): Int = x + closure($anonfun) + } + ) + = x + } +} + diff --git a/tests/printing/annot-19846b.scala b/tests/printing/annot-19846b.scala new file mode 100644 index 000000000000..951a3c8116ff --- /dev/null +++ b/tests/printing/annot-19846b.scala @@ -0,0 +1,7 @@ +class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation + +def f(x: Int): Int @lambdaAnnot(() => x) = x + +object Test: + val y: Int = ??? + val z /* : Int @lambdaAnnot(() => y) */ = f(y)