Skip to content

Commit

Permalink
[WIP] Fix mapping of annotations containing defs
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel committed Apr 16, 2024
1 parent d148973 commit 3bd7df2
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 5 deletions.
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
case _ => argss
loop(tree, Nil)

/** All term arguments of an application in a single flattened list */
/** All type and term arguments of an application in a single flattened list */
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
case Apply(fn, args) => allArguments(fn) ::: args
case TypeApply(fn, _) => allArguments(fn)
case TypeApply(fn, args) => allArguments(fn) ::: args
case Block(_, expr) => allArguments(expr)
case _ => Nil
}
Expand Down
17 changes: 14 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import printing.{Showable, Printer}
import printing.Texts.Text

import scala.annotation.internal.sharable
import dotty.tools.dotc.ast.TreeTypeMap

object Annotations {

Expand All @@ -30,7 +31,7 @@ object Annotations {
def derivedAnnotation(tree: Tree)(using Context): Annotation =
if (tree eq this.tree) this else Annotation(tree)

/** All arguments to this annotation in a single flat list */
/** All type and term arguments to this annotation in a single flat list */
def arguments(using Context): List[Tree] = tpd.allArguments(tree)

def argument(i: Int)(using Context): Option[Tree] = {
Expand Down Expand Up @@ -62,10 +63,20 @@ object Annotations {
if tm.isRange(x) then x
else
val tp1 = tm(tree.tpe)
foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree)
foldOver(if tp1 == tree.tpe then x else tp1, tree)
val diff = findDiff(NoType, args)
if tm.isRange(diff) then EmptyAnnotation
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
else if diff.exists then
val ttm = new TreeTypeMap(typeMap = tm):
/*
final override def transformDefs[TT <: Tree](trees: List[TT])(using Context): (TreeTypeMap, List[TT]) =
val syms = localSyms(trees)
val ttmap = withMappedSyms(syms, mapSymbols(syms, this, mapAlways = true))
(ttmap, ttmap.transformSub(trees))
*/
final override def withMappedSyms(syms: List[Symbol]): TreeTypeMap =
withMappedSyms(syms, mapSymbols(syms, this, mapAlways = true))
derivedAnnotation(ttm.transform(tree))
else this

/** Does this annotation refer to a parameter of `tl`? */
Expand Down
7 changes: 7 additions & 0 deletions tests/pos/annot-17939.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class qualified[T](f: T => Boolean) extends annotation.StaticAnnotation

class Box[T](val x: T)
class Box2(val x: Int)

class A(a: String @qualified((x: Int) => Box(3).x == 3)) // crash
class A2(a2: String @qualified((x: Int) => Box2(3).x == 3)) // works
8 changes: 8 additions & 0 deletions tests/pos/annot-19846.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation

class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x))

@main def main =
val p = EqualPair(42, 42)
val y = p.y
println(42)
16 changes: 16 additions & 0 deletions tests/printing/annot-18064.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[[syntax trees at end of typer]] // tests/printing/annot-18064.scala
package <empty> {
class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() {
T
}
trait Tensor[T >: Nothing <: Any]() extends Object {
T
def add: Tensor[Tensor.this.T] @myAnnot[T]
}
class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[
TensorImpl.this.A] {
A
def add: Tensor[A] @myAnnot[A] = this
}
}

7 changes: 7 additions & 0 deletions tests/printing/annot-18064.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class myAnnot[T]() extends annotation.Annotation

trait Tensor[T]:
def add: Tensor[T] @myAnnot[T]()

class TensorImpl[A]() extends Tensor[A]:
def add /* : Tensor[A] @myAnnot[A] */ = this
33 changes: 33 additions & 0 deletions tests/printing/annot-19846b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala
package <empty> {
class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(),
annotation.StaticAnnotation {
private[this] val g: () => Int
}
final lazy module val Test: Test = new Test()
final module class Test() extends Object() { this: Test.type =>
val y: Int = ???
val z:
Int @lambdaAnnot(
{
def $anonfun(): Int = Test.y
closure($anonfun)
}
)
= f(Test.y)
}
final lazy module val annot-19846b$package: annot-19846b$package =
new annot-19846b$package()
final module class annot-19846b$package() extends Object() {
this: annot-19846b$package.type =>
def f(x: Int):
Int @lambdaAnnot(
{
def $anonfun(): Int = x
closure($anonfun)
}
)
= x
}
}

7 changes: 7 additions & 0 deletions tests/printing/annot-19846b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation

def f(x: Int): Int @lambdaAnnot(() => x) = x

object Test:
val y: Int = ???
val z /* : Int @lambdaAnnot(() => y) */ = f(y)

0 comments on commit 3bd7df2

Please sign in to comment.