Skip to content

Commit

Permalink
Support optional toolchains with find_cpp_toolchain
Browse files Browse the repository at this point in the history
Rules that want to use optional toolchains but still support builds that
don't use C++ toolchain resolution should point their `_cc_toolchain`
attribute to the new
`@bazel_tools//tools/cpp:optional_current_cc_toolchain` target.
  • Loading branch information
fmeum committed Jan 25, 2023
1 parent 942ddda commit ff81792
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
import com.google.devtools.build.lib.analysis.RunfilesProvider;
import com.google.devtools.build.lib.analysis.TemplateVariableInfo;
import com.google.devtools.build.lib.analysis.platform.ToolchainInfo;
import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
import com.google.devtools.build.lib.collect.nestedset.Order;
import com.google.devtools.build.lib.packages.RuleClass;
import com.google.devtools.build.lib.packages.Type;
import javax.annotation.Nullable;

/** Implementation of the {@code cc_toolchain_alias} rule. */
Expand All @@ -50,6 +53,7 @@ public RuleClass build(RuleClass.Builder builder, RuleDefinitionEnvironment env)
.add(
attr(CcToolchain.CC_TOOLCHAIN_TYPE_ATTRIBUTE_NAME, NODEP_LABEL)
.value(CppRuleClasses.ccToolchainTypeAttribute(env)))
.add(attr("mandatory", Type.BOOLEAN).value(true))
.requiresConfigurationFragments(PlatformConfiguration.class)
.addToolchainTypes(CppRuleClasses.ccToolchainTypeRequirement(env))
.build();
Expand All @@ -72,7 +76,14 @@ public ConfiguredTarget create(RuleContext ruleContext)
throws InterruptedException, RuleErrorException, ActionConflictException {

CcToolchainProvider ccToolchainProvider =
CppHelper.getToolchainUsingDefaultCcToolchainAttribute(ruleContext);
CppHelper.getToolchainUsingDefaultCcToolchainAttribute(ruleContext,
ruleContext.attributes().get("mandatory", Type.BOOLEAN));
if (ccToolchainProvider == null) {
return new RuleConfiguredTargetBuilder(ruleContext)
.addProvider(RunfilesProvider.simple(Runfiles.EMPTY))
.setFilesToBuild(NestedSetBuilder.emptySet(Order.STABLE_ORDER))
.build();
}

TemplateVariableInfo templateVariableInfo =
CcToolchain.createMakeVariableProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,19 @@ public static NestedSet<Artifact> getGcovFilesIfNeeded(
@Nullable
public static CcToolchainProvider getToolchainUsingDefaultCcToolchainAttribute(
RuleContext ruleContext) throws RuleErrorException {
return getToolchainUsingDefaultCcToolchainAttribute(ruleContext, true);
}

@Nullable
public static CcToolchainProvider getToolchainUsingDefaultCcToolchainAttribute(
RuleContext ruleContext, boolean mandatory) throws RuleErrorException {
if (ruleContext.attributes().has(CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME)) {
return getToolchain(ruleContext, CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME);
return getToolchain(ruleContext, CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME, mandatory);
} else if (ruleContext
.attributes()
.has(CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME_FOR_STARLARK)) {
return getToolchain(
ruleContext, CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME_FOR_STARLARK);
ruleContext, CcToolchain.CC_TOOLCHAIN_DEFAULT_ATTRIBUTE_NAME_FOR_STARLARK, mandatory);
}
return null;
}
Expand Down Expand Up @@ -314,10 +320,12 @@ public static NestedSet<Artifact> getDefaultCcToolchainStaticRuntimeInputs(

/**
* Makes sure that the given info collection has a {@link CcToolchainProvider} (gives an error
* otherwise), and returns a reference to that {@link CcToolchainProvider}.
* otherwise), and returns a reference to that {@link CcToolchainProvider}. May return
* {@code null} if {@code mandatory} is {@code false}.
*/
public static CcToolchainProvider getToolchain(RuleContext ruleContext, String toolchainAttribute)
throws RuleErrorException {
@Nullable
public static CcToolchainProvider getToolchain(RuleContext ruleContext, String toolchainAttribute,
boolean mandatory) throws RuleErrorException {
if (!ruleContext.isAttrDefined(toolchainAttribute, LABEL)) {
throw ruleContext.throwWithRuleError(
String.format(
Expand All @@ -326,7 +334,7 @@ public static CcToolchainProvider getToolchain(RuleContext ruleContext, String t
toolchainAttribute));
}
TransitiveInfoCollection dep = ruleContext.getPrerequisite(toolchainAttribute);
return getToolchain(ruleContext, dep);
return getToolchain(ruleContext, dep, mandatory);
}

/**
Expand All @@ -336,15 +344,21 @@ public static CcToolchainProvider getToolchain(RuleContext ruleContext, String t
*/
public static CcToolchainProvider getToolchain(
RuleContext ruleContext, TransitiveInfoCollection dep) throws RuleErrorException {
return getToolchain(ruleContext, dep, true);
}

public static CcToolchainProvider getToolchain(
RuleContext ruleContext, TransitiveInfoCollection dep, boolean mandatory)
throws RuleErrorException {
Label toolchainType = getToolchainTypeFromRuleClass(ruleContext);
return getToolchain(ruleContext, dep, toolchainType);
return getToolchain(ruleContext, dep, toolchainType, mandatory);
}

public static CcToolchainProvider getToolchain(
RuleContext ruleContext, TransitiveInfoCollection dep, Label toolchainType)
RuleContext ruleContext, TransitiveInfoCollection dep, Label toolchainType, boolean mandatory)
throws RuleErrorException {
if (toolchainType != null && useToolchainResolution(ruleContext)) {
return getToolchainFromPlatformConstraints(ruleContext, toolchainType);
return getToolchainFromPlatformConstraints(ruleContext, toolchainType, mandatory);
}
return getToolchainFromLegacyToolchain(ruleContext, dep);
}
Expand All @@ -366,12 +380,15 @@ public static Label getToolchainTypeFromRuleClass(RuleContext ruleContext) {
}

private static CcToolchainProvider getToolchainFromPlatformConstraints(
RuleContext ruleContext, Label toolchainType) throws RuleErrorException {
RuleContext ruleContext, Label toolchainType, boolean mandatory) throws RuleErrorException {
ToolchainInfo toolchainInfo = ruleContext.getToolchainInfo(toolchainType);
if (toolchainInfo == null) {
throw ruleContext.throwWithRuleError(
"Unable to find a CC toolchain using toolchain resolution. Did you properly set"
+ " --platforms?");
if (mandatory) {
throw ruleContext.throwWithRuleError(
"Unable to find a CC toolchain using toolchain resolution. Did you properly set"
+ " --platforms?");
}
return null;
}
try {
return (CcToolchainProvider) toolchainInfo.getValue("cc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ private ConfiguredAspect buildAspect(

try {
CcToolchainProvider ccToolchain =
CppHelper.getToolchain(ruleContext, ":j2objc_cc_toolchain");
CppHelper.getToolchain(ruleContext, ":j2objc_cc_toolchain", /*mandatory*/true);
ImmutableList<String> extraCompileArgs =
j2objcCompileWithARC(ruleContext)
? ImmutableList.of("-fno-strict-overflow", "-fobjc-arc-exceptions")
Expand Down
12 changes: 9 additions & 3 deletions src/main/starlark/builtins_bzl/common/cc/cc_helper.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,22 @@ def _get_dynamic_library_for_runtime_or_none(library, linking_statically):

_CPP_TOOLCHAIN_TYPE = "@" + objc_semantics.get_repo() + "//tools/cpp:toolchain_type"

def _find_cpp_toolchain(ctx):
def _find_cpp_toolchain(ctx, *, mandatory = True):
"""
Finds the c++ toolchain.
If the c++ toolchain is in use, returns it. Otherwise, returns a c++
toolchain derived from legacy toolchain selection.
toolchain derived from legacy toolchain selection, constructed from
the CppConfiguration.
Args:
ctx: The rule context for which to find a toolchain.
mandatory: If this is set to False, this function will return None rather
than fail if no toolchain is found.
Returns:
A CcToolchainProvider.
A CcToolchainProvider, or None if the c++ toolchain is declared as
optional, mandatory is False and no toolchain has been found.
"""

# Check the incompatible flag for toolchain resolution.
Expand All @@ -220,6 +224,8 @@ def _find_cpp_toolchain(ctx):
fail("In order to use find_cpp_toolchain, you must include the '//tools/cpp:toolchain_type' in the toolchains argument to your rule.")
toolchain_info = ctx.toolchains[_CPP_TOOLCHAIN_TYPE]
if toolchain_info == None:
if not mandatory:
return None
# No cpp toolchain was found, so report an error.
fail("Unable to find a CC toolchain using toolchain resolution. Target: %s, Platform: %s, Exec platform: %s" %
(ctx.label, ctx.fragments.platform.platform, ctx.fragments.platform.host_platform))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ TemplateVariableInfo = _builtins.toplevel.platform_common.TemplateVariableInfo
ToolchainInfo = _builtins.toplevel.platform_common.ToolchainInfo

def _impl(ctx):
cc_toolchain = cc_helper.find_cpp_toolchain(ctx)
cc_toolchain = cc_helper.find_cpp_toolchain(ctx, mandatory = ctx.attr.mandatory)
if not cc_toolchain:
return []
make_variables = cc_toolchain.get_additional_make_variables()
cc_provider_make_variables = cc_helper.get_toolchain_global_make_variables(cc_toolchain)
template_variable_info = TemplateVariableInfo(make_variables | cc_provider_make_variables)
Expand All @@ -43,6 +45,7 @@ cc_toolchain_alias = rule(
implementation = _impl,
fragments = ["cpp", "platform"],
attrs = {
"mandatory": attr.bool(default = True),
"_cc_toolchain": attr.label(default = configuration_field(fragment = "cpp", name = "cc_toolchain"), providers = [CcToolchainInfo]),
"_cc_toolchain_type": attr.label(default = "@" + semantics.get_repo() + "//tools/cpp:toolchain_type"),
},
Expand Down
67 changes: 67 additions & 0 deletions src/test/shell/bazel/cc_integration_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1910,4 +1910,71 @@ EOF
expect_log "runtime error: index 10 out of bounds"
}

function setup_find_optional_cpp_toolchain() {
mkdir -p pkg

cat > pkg/BUILD <<'EOF'
load(":rules.bzl", "my_rule")
my_rule(
name = "my_rule",
)
platform(
name = "exotic_platform",
constraint_values = [
"@platforms//cpu:wasm64",
"@platforms//os:windows",
],
)
EOF

cat > pkg/rules.bzl <<'EOF'
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain")
def _my_rule_impl(ctx):
out = ctx.actions.declare_file(ctx.attr.name)
toolchain = find_cpp_toolchain(ctx, mandatory = False)
if toolchain:
ctx.actions.write(out, "Toolchain found")
else:
ctx.actions.write(out, "Toolchain not found")
return [DefaultInfo(files = depset([out]))]
my_rule = rule(
implementation = _my_rule_impl,
attrs = {
"_cc_toolchain": attr.label(
default = "@bazel_tools//tools/cpp:optional_current_cc_toolchain",
),
},
toolchains = use_cpp_toolchain(mandatory = False),
)
EOF
}

function test_find_optional_cpp_toolchain_present_without_toolchain_resolution() {
setup_find_optional_cpp_toolchain

bazel build //pkg:my_rule --noincompatible_enable_cc_toolchain_resolution \
&> "$TEST_log" || fail "Build failed"
assert_contains "Toolchain found" bazel-bin/pkg/my_rule
}

function test_find_optional_cpp_toolchain_present_with_toolchain_resolution() {
setup_find_optional_cpp_toolchain

bazel build //pkg:my_rule --incompatible_enable_cc_toolchain_resolution \
&> "$TEST_log" || fail "Build failed"
assert_contains "Toolchain found" bazel-bin/pkg/my_rule
}

function test_find_optional_cpp_toolchain_not_present_with_toolchain_resolution() {
setup_find_optional_cpp_toolchain

bazel build //pkg:my_rule --incompatible_enable_cc_toolchain_resolution \
--platforms=//pkg:exotic_platform &> "$TEST_log" || fail "Build failed"
assert_contains "Toolchain not found" bazel-bin/pkg/my_rule
}

run_suite "cc_integration_test"
7 changes: 2 additions & 5 deletions tools/cpp/BUILD.tools
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,9 @@ constraint_value(

cc_toolchain_alias(name = "current_cc_toolchain")

# In future versions of Bazel, this target will not fail if no C++ toolchain is
# available. Instead, it will not advertise the cc_common.CcToolchainInfo
# provider.
alias(
cc_toolchain_alias(
name = "optional_current_cc_toolchain",
actual = ":current_cc_toolchain",
mandatory = False,
)

cc_host_toolchain_alias(name = "current_cc_host_toolchain")
Expand Down
12 changes: 8 additions & 4 deletions tools/cpp/toolchain_utils.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ def find_cpp_toolchain(ctx, *, mandatory = True):
Args:
ctx: The rule context for which to find a toolchain.
mandatory: This is currently a no-op. In future releases of Bazel, if this
is set to False, this function will return None rather than fail if no
toolchain is found.
mandatory: If this is set to False, this function will return None rather
than fail if no toolchain is found. To use this parameter, the calling
rule should have a `_cc_toolchain` label attribute with default
`@bazel_tools//tools/cpp:optional_current_cc_toolchain`.
Returns:
A CcToolchainProvider.
A CcToolchainProvider, or None if the c++ toolchain is declared as
optional, mandatory is False and no toolchain has been found.
"""

# Check the incompatible flag for toolchain resolution.
Expand All @@ -43,6 +45,8 @@ def find_cpp_toolchain(ctx, *, mandatory = True):
fail("In order to use find_cpp_toolchain, you must include the '%s' in the toolchains argument to your rule." % CPP_TOOLCHAIN_TYPE)
toolchain_info = ctx.toolchains[CPP_TOOLCHAIN_TYPE]
if toolchain_info == None:
if not mandatory:
return None
# No cpp toolchain was found, so report an error.
fail("Unable to find a CC toolchain using toolchain resolution. Target: %s, Platform: %s, Exec platform: %s" %
(ctx.label, ctx.fragments.platform.platform, ctx.fragments.platform.host_platform))
Expand Down

0 comments on commit ff81792

Please sign in to comment.