diff --git a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala index 87a471daa3d..e4f8ef21366 100644 --- a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala +++ b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala @@ -17,6 +17,7 @@ case class Analysis( providedSymbolsEncoded: Vector[String], importsByScope: HashMap[String, ArrayBuffer[AnImport]], consumedSymbolsByScope: HashMap[String, HashSet[String]], + scopes: Vector[String], ) case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boolean) @@ -28,6 +29,7 @@ class SourceAnalysisTraverser extends Traverser { val providedSymbolsByScope = HashMap[String, HashMap[String, ProvidedSymbol]]() val importsByScope = HashMap[String, ArrayBuffer[AnImport]]() val consumedSymbolsByScope = HashMap[String, HashSet[String]]() + val scopes = HashSet[String]() // Extract a qualified name from a tree. def extractName(tree: Tree): String = { @@ -134,6 +136,11 @@ class SourceAnalysisTraverser extends Traverser { consumedSymbolsByScope(fullPackageName).add(name) } + def recordScope(name: String): Unit = { + val scopeName = (nameParts.toVector ++ Vector(name)).mkString(".") + scopes.add(scopeName) + } + def visitTemplate(templ: Template, name: String): Unit = { templ.inits.foreach(init => apply(init)) withNamePart(name, () => { @@ -144,7 +151,9 @@ class SourceAnalysisTraverser extends Traverser { override def apply(tree: Tree): Unit = tree match { case Pkg(ref, stats) => { - withNamePart(extractName(ref), () => super.apply(stats)) + val name = extractName(ref) + recordScope(name) + withNamePart(name, () => super.apply(stats)) } case Defn.Class(_mods, nameNode, _tparams, _ctor, templ) => { @@ -283,6 +292,7 @@ class SourceAnalysisTraverser extends Traverser { providedSymbolsEncoded = gatherEncodedProvidedSymbols(), importsByScope = importsByScope, consumedSymbolsByScope = consumedSymbolsByScope, + scopes = scopes.toVector, ) } } diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser.py b/src/python/pants/backend/scala/dependency_inference/scala_parser.py index bd78860d904..7ed556b5b37 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser.py @@ -115,6 +115,7 @@ class ScalaSourceDependencyAnalysis: provided_symbols_encoded: FrozenOrderedSet[str] imports_by_scope: FrozenDict[str, tuple[ScalaImport, ...]] consumed_symbols_by_scope: FrozenDict[str, FrozenOrderedSet[str]] + scopes: FrozenOrderedSet[str] def all_imports(self) -> Iterator[str]: # TODO: This might also be an import relative to its scope. @@ -180,6 +181,7 @@ def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis: for key, values in d["consumedSymbolsByScope"].items() } ), + scopes=FrozenOrderedSet(d["scopes"]), ) def to_debug_json_dict(self) -> dict[str, Any]: @@ -193,6 +195,7 @@ def to_debug_json_dict(self) -> dict[str, Any]: "consumed_symbols_by_scope": { k: sorted(list(v)) for k, v in self.consumed_symbols_by_scope.items() }, + "scopes": list(self.scopes), } diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py index b9d2b85f636..150e1fd3861 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py @@ -305,3 +305,49 @@ def this(bar: SomeTypeInSecondaryConstructor) { "org.pantsbuild.example.TupleTypeArg1", "org.pantsbuild.example.TupleTypeArg2", } + + +def test_extract_package_scopes(rule_runner: RuleRunner) -> None: + rule_runner.write_files( + { + "BUILD": textwrap.dedent( + """ + scala_source( + name="source", + source="Source.scala", + ) + """ + ), + "Source.scala": textwrap.dedent( + """ + package outer + package more.than.one.part.at.once + package inner + """ + ), + } + ) + + target = rule_runner.get_target(address=Address("", target_name="source")) + + source_files = rule_runner.request( + SourceFiles, + [ + SourceFilesRequest( + (target.get(SourcesField),), + for_sources_types=(ScalaSourceField,), + enable_codegen=True, + ) + ], + ) + + analysis = rule_runner.request( + ScalaSourceDependencyAnalysis, + [source_files], + ) + + assert sorted(analysis.scopes) == [ + "outer", + "outer.more.than.one.part.at.once", + "outer.more.than.one.part.at.once.inner", + ]