From a6c8818d29664b6d081dd0ad057f85302925271e Mon Sep 17 00:00:00 2001 From: ilist Date: Wed, 17 Nov 2021 05:38:00 -0800 Subject: [PATCH] Implement proto_common.create_proto_info in Starlark. There is some confusion in naming of the fields in ProtoInfo, for example: ``` ImmutableList directSources; // in Starlark .proto_sources (added) ImmutableList directProtoSources; // in Starlark .sources (was there before) ``` We need to keep old names in Starlark and create some new. The naming opportunity for the new Starlark names clashed with what was already used natively. PiperOrigin-RevId: 410502067 --- .../lib/rules/proto/BazelProtoCommon.java | 86 +++++++++---- .../lib/rules/proto/ProtoConfiguration.java | 10 ++ .../build/lib/rules/proto/ProtoInfo.java | 6 + .../build/lib/rules/proto/ProtoSource.java | 5 + .../lib/starlarkbuildapi/ProtoInfoApi.java | 3 + .../common/proto/proto_common.bzl | 118 ++++++++++++++++++ .../common/proto/proto_library.bzl | 5 +- 7 files changed, 206 insertions(+), 27 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommon.java b/src/main/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommon.java index 4eb0511f70e59e..5f6d1481fb8dba 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommon.java +++ b/src/main/java/com/google/devtools/build/lib/rules/proto/BazelProtoCommon.java @@ -14,16 +14,14 @@ package com.google.devtools.build.lib.rules.proto; -import com.google.devtools.build.lib.analysis.starlark.StarlarkRuleContext; -import com.google.devtools.build.lib.cmdline.Label; -import com.google.devtools.build.lib.packages.BazelModuleContext; +import com.google.devtools.build.lib.actions.Artifact; +import com.google.devtools.build.lib.collect.nestedset.Depset; import com.google.devtools.build.lib.starlarkbuildapi.proto.ProtoCommonApi; -import javax.annotation.Nullable; +import com.google.devtools.build.lib.vfs.PathFragment; import net.starlark.java.annot.Param; import net.starlark.java.annot.StarlarkMethod; import net.starlark.java.eval.EvalException; -import net.starlark.java.eval.Module; -import net.starlark.java.eval.Starlark; +import net.starlark.java.eval.StarlarkList; import net.starlark.java.eval.StarlarkThread; /** Protocol buffers support for Starlark. */ @@ -33,26 +31,66 @@ public class BazelProtoCommon implements ProtoCommonApi { protected BazelProtoCommon() {} @StarlarkMethod( - name = "create_proto_info", + name = "ProtoSource", documented = false, - parameters = {@Param(name = "ctx", doc = "The rule context")}, - useStarlarkThread = true, - allowReturnNones = true) - @Nullable - public ProtoInfo createProtoInfo(StarlarkRuleContext ruleContext, StarlarkThread thread) + parameters = { + @Param(name = "source_file", doc = "The proto file."), + @Param(name = "original_source_file", doc = "Original proto file."), + @Param(name = "proto_path", doc = "Path to proto file."), + }, + useStarlarkThread = true) + public ProtoSource protoSource( + Artifact sourceFile, Artifact originalSourceFile, String sourceRoot, StarlarkThread thread) throws EvalException { - Label label = - ((BazelModuleContext) Module.ofInnermostEnclosingStarlarkFunction(thread).getClientData()) - .label(); - if (!label.getPackageIdentifier().getRepository().toString().equals("@_builtins")) { - throw Starlark.errorf("Rule in '%s' cannot use private API", label.getPackageName()); - } + ProtoCommon.checkPrivateStarlarkificationAllowlist(thread); + return new ProtoSource(sourceFile, originalSourceFile, PathFragment.create(sourceRoot)); + } - return ProtoCommon.createProtoInfo( - ruleContext.getRuleContext(), - ruleContext - .getRuleContext() - .getFragment(ProtoConfiguration.class) - .generatedProtosInVirtualImports()); + @StarlarkMethod( + name = "ProtoInfo", + documented = false, + parameters = { + @Param(name = "direct_sources", doc = "Direct sources."), + @Param(name = "proto_path", doc = "Proto path."), + @Param(name = "transitive_sources", doc = "Transitive sources."), + @Param(name = "transitive_proto_sources", doc = "Transitive proto sources."), + @Param(name = "transitive_proto_path", doc = "Transitive proto path."), + @Param(name = "check_deps_sources", doc = "Check deps sources."), + @Param(name = "direct_descriptor_set", doc = "Direct descriptor set."), + @Param(name = "transitive_descriptor_set", doc = "Transitive descriptor sets."), + @Param(name = "exported_sources", doc = "Exported sources"), + @Param(name = "strict_importable_sources", doc = "Strict importable sources."), + @Param(name = "public_import_protos", doc = "Public import protos."), + }, + useStarlarkThread = true) + @SuppressWarnings("unchecked") + public ProtoInfo protoInfo( + StarlarkList directSources, + String directProtoSourceRoot, + Depset transitiveProtoSources, + Depset transitiveSources, + Depset transitiveProtoSourceRoots, + Depset strictImportableProtoSourcesForDependents, + Artifact directDescriptorSet, + Depset transitiveDescriptorSets, + Depset exportedSources, + Depset strictImportableSources, + Depset publicImportSources, + StarlarkThread thread) + throws EvalException { + ProtoCommon.checkPrivateStarlarkificationAllowlist(thread); + return new ProtoInfo( + ((StarlarkList) directSources).getImmutableList(), + PathFragment.create(directProtoSourceRoot), + Depset.cast(transitiveSources, ProtoSource.class, "transitive_sources"), + Depset.cast(transitiveProtoSources, Artifact.class, "transitive_proto_sources"), + Depset.cast(transitiveProtoSourceRoots, String.class, "transitive_proto_path"), + Depset.cast( + strictImportableProtoSourcesForDependents, Artifact.class, "check_deps_sources"), + directDescriptorSet, + Depset.cast(transitiveDescriptorSets, Artifact.class, "transitive_descriptor_set"), + Depset.cast(exportedSources, ProtoSource.class, "exported_sources"), + Depset.cast(strictImportableSources, ProtoSource.class, "strict_importable_sources"), + Depset.cast(publicImportSources, ProtoSource.class, "public_import_protos")); } } diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConfiguration.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConfiguration.java index 8b508f737cccb1..893bc777f1166f 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConfiguration.java +++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoConfiguration.java @@ -285,6 +285,16 @@ public boolean strictPublicImports() { return options.experimentalJavaProtoAddAllowedPublicImports; } + @StarlarkMethod( + name = "generated_protos_in_virtual_imports", + useStarlarkThread = true, + documented = false) + public boolean generatedProtosInVirtualImportsForStarlark(StarlarkThread thread) + throws EvalException { + ProtoCommon.checkPrivateStarlarkificationAllowlist(thread); + return generatedProtosInVirtualImports(); + } + public boolean generatedProtosInVirtualImports() { return options.generatedProtosInVirtualImports; } diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java index 742c8447288e66..ddfb2d3582bd8b 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java +++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoInfo.java @@ -209,6 +209,12 @@ public NestedSet getTransitiveDescriptorSets() { return transitiveDescriptorSets; } + @Override + public Depset getExportedSourcesForStarlark(StarlarkThread thread) throws EvalException { + ProtoCommon.checkPrivateStarlarkificationAllowlist(thread); + return Depset.of(ProtoSource.TYPE, getExportedSources()); + } + /** * Returns a set of {@code .proto} sources that may be imported by {@code proto_library} targets * directly depending on this {@code ProtoInfo}. diff --git a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java index c48de52fa4eddb..8ceb66e8ed6e11 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java +++ b/src/main/java/com/google/devtools/build/lib/rules/proto/ProtoSource.java @@ -82,4 +82,9 @@ public PathFragment getImportPath() { public String toString() { return "ProtoSource('" + getImportPath() + "')"; } + + @Override + public boolean isImmutable() { + return true; + } } diff --git a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java index ea7214f3251db3..b892f9162e64ae 100644 --- a/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java +++ b/src/main/java/com/google/devtools/build/lib/starlarkbuildapi/ProtoInfoApi.java @@ -113,4 +113,7 @@ interface ProtoInfoProviderApi extends ProviderApi { @StarlarkMethod(name = "transitive_proto_sources", documented = false, useStarlarkThread = true) Depset getTransitiveSourcesForStarlark(StarlarkThread thread) throws EvalException; + + @StarlarkMethod(name = "exported_sources", documented = false, useStarlarkThread = true) + Depset getExportedSourcesForStarlark(StarlarkThread thread) throws EvalException; } diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl index be24c4c463be16..3c3b49b6b0a29b 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_common.bzl @@ -17,8 +17,125 @@ Definition of proto_common module. """ load(":common/proto/proto_semantics.bzl", "semantics") +load(":common/paths.bzl", "paths") ProtoInfo = _builtins.toplevel.ProtoInfo +native_proto_common = _builtins.toplevel.proto_common + +def _join(*path): + return "/".join([p for p in path if p != ""]) + +def _create_proto_info(ctx): + srcs = ctx.files.srcs + deps = [dep[ProtoInfo] for dep in ctx.attr.deps] + exports = [dep[ProtoInfo] for dep in ctx.attr.exports] + + import_prefix = ctx.attr.import_prefix if hasattr(ctx.attr, "import_prefix") else "" + if not paths.is_normalized(import_prefix): + fail("should be normalized (without uplevel references or '.' path segments)", attr = "import_prefix") + + strip_import_prefix = ctx.attr.strip_import_prefix + if not paths.is_normalized(strip_import_prefix): + fail("should be normalized (without uplevel references or '.' path segments)", attr = "strip_import_prefix") + if strip_import_prefix.startswith("/"): + strip_import_prefix = strip_import_prefix[1:] + elif strip_import_prefix != "DO_NOT_STRIP": # Relative to current package + strip_import_prefix = _join(ctx.label.package, strip_import_prefix) + else: + strip_import_prefix = "" + + has_generated_sources = False + if ctx.fragments.proto.generated_protos_in_virtual_imports(): + has_generated_sources = any([not src.is_source for src in srcs]) + + direct_sources = [] + if import_prefix != "" or strip_import_prefix != "" or has_generated_sources: + # Use virtual source roots + if paths.is_absolute(import_prefix): + fail("should be a relative path", attr = "import_prefix") + + virtual_imports = _join("_virtual_imports", ctx.label.name) + if ctx.label.workspace_name == "" or ctx.label.workspace_root.startswith(".."): # siblingRepositoryLayout + proto_path = _join(ctx.genfiles_dir.path, ctx.label.package, virtual_imports) + else: + proto_path = _join(ctx.genfiles_dir.path, ctx.label.workspace_root, ctx.label.package, virtual_imports) + + for src in srcs: + if ctx.label.workspace_name == "": + repository_relative_path = src.short_path + else: + repository_relative_path = paths.relativize(src.short_path, "../" + ctx.label.workspace_name) + + if not repository_relative_path.startswith(strip_import_prefix): + fail(".proto file '%s' is not under the specified strip prefix '%s'" % + (src.short_path, strip_import_prefix)) + import_path = repository_relative_path[len(strip_import_prefix):] + + virtual_src = ctx.actions.declare_file(_join(virtual_imports, import_prefix, import_path)) + ctx.actions.symlink( + output = virtual_src, + target_file = src, + progress_message = "Symlinking virtual .proto sources for %{label}", + ) + direct_sources.append(native_proto_common.ProtoSource(virtual_src, src, proto_path)) + + else: + # No virtual source roots + proto_path = "." + for src in srcs: + direct_sources.append(native_proto_common.ProtoSource(src, src, ctx.label.workspace_root + src.root.path)) + + # Construct ProtoInfo + transitive_proto_sources = depset( + direct = direct_sources, + transitive = [dep.transitive_proto_sources() for dep in deps], + order = "preorder", + ) + transitive_sources = depset( + direct = [src.source_file() for src in direct_sources], + transitive = [dep.transitive_sources for dep in deps], + order = "preorder", + ) + transitive_proto_path = depset( + direct = [proto_path], + transitive = [dep.transitive_proto_path for dep in deps], + ) + if direct_sources: + check_deps_sources = depset(direct = [src.source_file() for src in direct_sources]) + else: + check_deps_sources = depset(transitive = [dep.check_deps_sources for dep in deps]) + + direct_descriptor_set = ctx.actions.declare_file(ctx.label.name + "-descriptor-set.proto.bin") + transitive_descriptor_sets = depset( + direct = [direct_descriptor_set], + transitive = [dep.transitive_descriptor_sets for dep in deps], + ) + + # Layering checks. + if direct_sources: + exported_sources = depset(direct = direct_sources) + strict_importable_sources = depset( + direct = direct_sources, + transitive = [dep.exported_sources() for dep in deps], + ) + else: + exported_sources = depset(transitive = [dep.exported_sources() for dep in deps]) + strict_importable_sources = depset() + public_import_protos = depset(transitive = [export.exported_sources() for export in exports]) + + return native_proto_common.ProtoInfo( + direct_sources, + proto_path, + transitive_sources, + transitive_proto_sources, + transitive_proto_path, + check_deps_sources, + direct_descriptor_set, + transitive_descriptor_sets, + exported_sources, + strict_importable_sources, + public_import_protos, + ) def _write_descriptor_set(ctx, proto_info): output = proto_info.direct_descriptor_set @@ -83,5 +200,6 @@ def _ExpandImportArgsFn(proto_source): return "-I%s=%s" % (proto_source.import_path(), proto_source.source_file().path) proto_common = struct( + create_proto_info = _create_proto_info, write_descriptor_set = _write_descriptor_set, ) diff --git a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl index 731a7a61fcf04d..de73546af4df55 100644 --- a/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl +++ b/src/main/starlark/builtins_bzl/common/proto/proto_library.bzl @@ -20,7 +20,6 @@ load(":common/proto/proto_semantics.bzl", "semantics") load(":common/proto/proto_common.bzl", "proto_common") ProtoInfo = _builtins.toplevel.ProtoInfo -native_proto_common = _builtins.toplevel.proto_common def _check_srcs_package(target_package, srcs): """Makes sure the given srcs live in the given package.""" @@ -33,7 +32,7 @@ def _proto_library_impl(ctx): _check_srcs_package(ctx.label.package, ctx.attr.srcs) - proto_info = native_proto_common.create_proto_info(ctx) + proto_info = proto_common.create_proto_info(ctx) proto_common.write_descriptor_set(ctx, proto_info) @@ -64,7 +63,7 @@ proto_library = rule( "exports": attr.label_list( providers = [ProtoInfo], ), - "strip_import_prefix": attr.string(), + "strip_import_prefix": attr.string(default = "DO_NOT_STRIP"), "data": attr.label_list( allow_files = True, flags = ["SKIP_CONSTRAINTS_OVERRIDE"],