diff --git a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala index 4080c26fd84..f4b6a561a90 100644 --- a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala +++ b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala @@ -163,6 +163,13 @@ class SourceAnalysisTraverser extends Traverser { }) } + def visitMods(mods: List[Mod]): Unit = { + mods.foreach({ + case Mod.Annot(init) => apply(init) // rely on `Init` extraction in main parsing match code + case _ => () + }) + } + override def apply(tree: Tree): Unit = tree match { case Pkg(ref, stats) => { val name = extractName(ref) @@ -170,36 +177,42 @@ class SourceAnalysisTraverser extends Traverser { withNamePart(name, () => super.apply(stats)) } - case Pkg.Object(_mods, nameNode, templ) => { + case Pkg.Object(mods, nameNode, templ) => { + visitMods(mods) val name = extractName(nameNode) recordScope(name) visitTemplate(templ, name) } - case Defn.Class(_mods, nameNode, _tparams, _ctor, templ) => { + case Defn.Class(mods, nameNode, _tparams, _ctor, templ) => { + visitMods(mods) val name = extractName(nameNode) recordProvidedName(name, sawClass = true) visitTemplate(templ, name) } - case Defn.Trait(_mods, nameNode, _tparams, _ctor, templ) => { + case Defn.Trait(mods, nameNode, _tparams, _ctor, templ) => { + visitMods(mods) val name = extractName(nameNode) recordProvidedName(name, sawTrait = true) visitTemplate(templ, name) } - case Defn.Object(_mods, nameNode, templ) => { + case Defn.Object(mods, nameNode, templ) => { + visitMods(mods) val name = extractName(nameNode) recordProvidedName(name, sawObject = true) visitTemplate(templ, name) } - case Defn.Type(_mods, nameNode, _tparams, _body) => { + case Defn.Type(mods, nameNode, _tparams, _body) => { + visitMods(mods) val name = extractName(nameNode) recordProvidedName(name) } - case Defn.Val(_mods, pats, decltpe, rhs) => { + case Defn.Val(mods, pats, decltpe, rhs) => { + visitMods(mods) pats.headOption.foreach(pat => { val name = extractName(pat) recordProvidedName(name) @@ -210,7 +223,8 @@ class SourceAnalysisTraverser extends Traverser { super.apply(rhs) } - case Defn.Var(_mods, pats, decltpe, rhs) => { + case Defn.Var(mods, pats, decltpe, rhs) => { + visitMods(mods) pats.headOption.foreach(pat => { val name = extractName(pat) recordProvidedName(name) @@ -221,7 +235,8 @@ class SourceAnalysisTraverser extends Traverser { super.apply(rhs) } - case Defn.Def(_mods, nameNode, _tparams, params, decltpe, body) => { + case Defn.Def(mods, nameNode, _tparams, params, decltpe, body) => { + visitMods(mods) val name = extractName(nameNode) recordProvidedName(name) @@ -263,19 +278,22 @@ class SourceAnalysisTraverser extends Traverser { extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_)) } - case Term.Param(_mods, _name, decltpe, _default) => { + case Term.Param(mods, _name, decltpe, _default) => { + visitMods(mods) decltpe.foreach(tpe => { extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_)) }) } - case Ctor.Primary(_mods, _name, params_list) => { + case Ctor.Primary(mods, _name, params_list) => { + visitMods(mods) params_list.foreach(params => { params.foreach(param => apply(param)) }) } - case Ctor.Secondary(_mods, _name, params_list, init, stats) => { + case Ctor.Secondary(mods, _name, params_list, init, stats) => { + visitMods(mods) params_list.foreach(params => { params.foreach(param => apply(param)) }) diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py index ea048187e5c..b8929147731 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py @@ -388,3 +388,39 @@ def test_package_object(rule_runner: RuleRunner) -> None: ), ) assert sorted(analysis.provided_symbols) == ["foo.bar.Hello"] + + +def test_extract_annotations(rule_runner: RuleRunner) -> None: + analysis = _analyze( + rule_runner, + textwrap.dedent( + """ + package foo + + @objectAnnotation("hello") + object Object { + @deprecated + def foo(arg: String @argAnnotation("foo")): Unit = {} + } + + @classAnnotation("world") + class Class { + @valAnnotation val foo = 3 + @varAnnotation var bar = 4 + } + + @traitAnnotation + trait Trait {} + """ + ), + ) + assert sorted(analysis.fully_qualified_consumed_symbols()) == [ + "foo.String", + "foo.Unit", + "foo.classAnnotation", + "foo.deprecated", + "foo.objectAnnotation", + "foo.traitAnnotation", + "foo.valAnnotation", + "foo.varAnnotation", + ]