Skip to content

Commit

Permalink
Experimental support for remote persistent workers
Browse files Browse the repository at this point in the history
Add a new --experimental_remote_mark_tool_inputs flag, which makes Bazel tag
tool inputs when executing actions remotely, and also adds a tools input key
to the platform proto sent as part of the remote execution request.

This allows a remote execution system to implement persistent workers, i.e.,
to keep worker processes around and reuse them for subsequent actions. In a
trivial example, this improves build performance by ~3x.

We use "persistentWorkerKey" for the platform property, with the value being
a hash of the tool inputs, and "bazel_tool_input" as the node property name,
with an empty string as value (this is just a boolean tag).

Implements bazelbuild#10091.

Change-Id: Iccb36081fee399855be7c487c2d4091cb36f8df3
  • Loading branch information
ulfjack committed Feb 16, 2021
1 parent 415c6cc commit a630961
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import build.bazel.remote.execution.v2.Platform;
import build.bazel.remote.execution.v2.Platform.Property;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Ordering;
import com.google.devtools.build.lib.actions.Spawn;
Expand All @@ -28,8 +30,11 @@
import com.google.protobuf.TextFormat;
import com.google.protobuf.TextFormat.ParseException;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -63,29 +68,42 @@ public static Platform buildPlatformProto(Map<String, String> executionPropertie
@Nullable
public static Platform getPlatformProto(Spawn spawn, @Nullable RemoteOptions remoteOptions)
throws UserExecException {
return getPlatformProto(spawn, remoteOptions, ImmutableMap.of());
}

@Nullable
public static Platform getPlatformProto(
Spawn spawn,
@Nullable RemoteOptions remoteOptions,
Map<String, String> additionalProperties)
throws UserExecException {
SortedMap<String, String> defaultExecProperties =
remoteOptions != null
? remoteOptions.getRemoteDefaultExecProperties()
: ImmutableSortedMap.of();

if (spawn.getExecutionPlatform() == null
&& spawn.getCombinedExecProperties().isEmpty()
&& defaultExecProperties.isEmpty()) {
&& defaultExecProperties.isEmpty()
&& additionalProperties.isEmpty()) {
return null;
}

Platform.Builder platformBuilder = Platform.newBuilder();

Map<String, String> properties;
if (!spawn.getCombinedExecProperties().isEmpty()) {
for (Map.Entry<String, String> entry : spawn.getCombinedExecProperties().entrySet()) {
platformBuilder.addPropertiesBuilder().setName(entry.getKey()).setValue(entry.getValue());
}
properties = spawn.getCombinedExecProperties();
} else if (spawn.getExecutionPlatform() != null
&& !Strings.isNullOrEmpty(spawn.getExecutionPlatform().remoteExecutionProperties())) {
// Try and get the platform info from the execution properties.
properties = new HashMap<>();
// Try and get the platform info from the execution properties. This is pretty inefficient; it
// would be better to store the parsed properties instead of the String text proto.
try {
Platform.Builder platformBuilder = Platform.newBuilder();
TextFormat.getParser()
.merge(spawn.getExecutionPlatform().remoteExecutionProperties(), platformBuilder);
for (Property property : platformBuilder.getPropertiesList()) {
properties.put(property.getName(), property.getValue());
}
} catch (ParseException e) {
String message =
String.format(
Expand All @@ -95,12 +113,23 @@ public static Platform getPlatformProto(Spawn spawn, @Nullable RemoteOptions rem
e, createFailureDetail(message, Code.INVALID_REMOTE_EXECUTION_PROPERTIES));
}
} else {
for (Map.Entry<String, String> property : defaultExecProperties.entrySet()) {
platformBuilder.addProperties(
Property.newBuilder().setName(property.getKey()).setValue(property.getValue()).build());
properties = defaultExecProperties;
}

if (!additionalProperties.isEmpty()) {
if (properties.isEmpty()) {
properties = additionalProperties;
} else {
// Merge the two maps.
properties = new HashMap<>(properties);
properties.putAll(additionalProperties);
}
}

Platform.Builder platformBuilder = Platform.newBuilder();
for (Map.Entry<String, String> entry : properties.entrySet()) {
platformBuilder.addPropertiesBuilder().setName(entry.getKey()).setValue(entry.getValue());
}
sortPlatformProperties(platformBuilder);
return platformBuilder.build();
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/google/devtools/build/lib/remote/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ java_library(
"//src/main/java/com/google/devtools/build/lib/sandbox",
"//src/main/java/com/google/devtools/build/lib/skyframe:mutable_supplier",
"//src/main/java/com/google/devtools/build/lib/skyframe:tree_artifact_value",
"//src/main/java/com/google/devtools/build/lib/util",
"//src/main/java/com/google/devtools/build/lib/util:abrupt_exit_exception",
"//src/main/java/com/google/devtools/build/lib/util:detailed_exit_code",
"//src/main/java/com/google/devtools/build/lib/util:exit_code",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@

package com.google.devtools.build.lib.remote;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.devtools.build.lib.profiler.ProfilerTask.REMOTE_DOWNLOAD;
import static com.google.devtools.build.lib.profiler.ProfilerTask.REMOTE_EXECUTION;
import static com.google.devtools.build.lib.profiler.ProfilerTask.UPLOAD_TIME;
import static com.google.devtools.build.lib.remote.util.Utils.createSpawnResult;
import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
import static com.google.devtools.build.lib.remote.util.Utils.getInMemoryOutputPath;
import static com.google.devtools.build.lib.remote.util.Utils.hasFilesToDownload;
import static com.google.devtools.build.lib.remote.util.Utils.shouldDownloadAllSpawnOutputs;

import build.bazel.remote.execution.v2.Action;
import build.bazel.remote.execution.v2.ActionResult;
import build.bazel.remote.execution.v2.Command;
Expand All @@ -46,16 +36,20 @@
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import com.google.devtools.build.lib.actions.ActionInput;
import com.google.devtools.build.lib.actions.ActionInputHelper;
import com.google.devtools.build.lib.actions.Artifact;
import com.google.devtools.build.lib.actions.CommandLines.ParamFileActionInput;
import com.google.devtools.build.lib.actions.ExecException;
import com.google.devtools.build.lib.actions.FileArtifactValue;
import com.google.devtools.build.lib.actions.MetadataProvider;
import com.google.devtools.build.lib.actions.Spawn;
import com.google.devtools.build.lib.actions.SpawnMetrics;
import com.google.devtools.build.lib.actions.SpawnResult;
import com.google.devtools.build.lib.actions.SpawnResult.Status;
import com.google.devtools.build.lib.actions.Spawns;
import com.google.devtools.build.lib.actions.cache.VirtualActionInput;
import com.google.devtools.build.lib.analysis.platform.PlatformUtils;
import com.google.devtools.build.lib.collect.nestedset.NestedSet;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.events.Reporter;
Expand Down Expand Up @@ -83,6 +77,7 @@
import com.google.devtools.build.lib.server.FailureDetails;
import com.google.devtools.build.lib.server.FailureDetails.FailureDetail;
import com.google.devtools.build.lib.util.ExitCode;
import com.google.devtools.build.lib.util.Fingerprint;
import com.google.devtools.build.lib.util.io.FileOutErr;
import com.google.devtools.build.lib.vfs.Path;
import com.google.devtools.build.lib.vfs.PathFragment;
Expand All @@ -93,6 +88,8 @@
import com.google.protobuf.util.Timestamps;
import io.grpc.Context;
import io.grpc.Status.Code;

import javax.annotation.Nullable;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
Expand All @@ -101,11 +98,22 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.devtools.build.lib.profiler.ProfilerTask.REMOTE_DOWNLOAD;
import static com.google.devtools.build.lib.profiler.ProfilerTask.REMOTE_EXECUTION;
import static com.google.devtools.build.lib.profiler.ProfilerTask.UPLOAD_TIME;
import static com.google.devtools.build.lib.remote.util.Utils.createSpawnResult;
import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
import static com.google.devtools.build.lib.remote.util.Utils.getInMemoryOutputPath;
import static com.google.devtools.build.lib.remote.util.Utils.hasFilesToDownload;
import static com.google.devtools.build.lib.remote.util.Utils.shouldDownloadAllSpawnOutputs;

/** A client for the remote execution service. */
@ThreadSafe
Expand Down Expand Up @@ -216,16 +224,33 @@ public SpawnResult exec(Spawn spawn, SpawnExecutionContext context)
context.report(ProgressStatus.SCHEDULING, getName());
RemoteOutputsMode remoteOutputsMode = remoteOptions.remoteOutputsMode;
SortedMap<PathFragment, ActionInput> inputMap = context.getInputMapping();
ToolSignature toolSignature =
remoteOptions.markToolInputs && Spawns.supportsWorkers(spawn)
? collectAndHashToolInputs(
spawn.getToolFiles(), context.getArtifactExpander(), context.getMetadataProvider())
: null;
final MerkleTree merkleTree =
MerkleTree.build(inputMap, context.getMetadataProvider(), execRoot, digestUtil);
MerkleTree.build(
inputMap,
toolSignature == null ? ImmutableSet.of() : toolSignature.toolInputs,
context.getMetadataProvider(),
execRoot,
digestUtil);
SpawnMetrics.Builder spawnMetrics =
SpawnMetrics.Builder.forRemoteExec()
.setInputBytes(merkleTree.getInputBytes())
.setInputFiles(merkleTree.getInputFiles());
maybeWriteParamFilesLocally(spawn);

// Get the remote platform properties.
Platform platform = PlatformUtils.getPlatformProto(spawn, remoteOptions);
Platform platform;
if (toolSignature != null) {
platform =
PlatformUtils.getPlatformProto(
spawn, remoteOptions, ImmutableMap.of("persistentWorkerKey", toolSignature.key));
} else {
platform = PlatformUtils.getPlatformProto(spawn, remoteOptions);
}

Command command =
buildCommand(
Expand Down Expand Up @@ -454,6 +479,28 @@ static void spawnMetricsAccounting(
}
}

@Nullable
private ToolSignature collectAndHashToolInputs(
NestedSet<? extends ActionInput> toolInputs,
Artifact.ArtifactExpander artifactExpander,
MetadataProvider metadataProvider)
throws IOException {
if (toolInputs.isEmpty()) {
return null;
}
List<ActionInput> toolInputsList =
ActionInputHelper.expandArtifacts(toolInputs, artifactExpander);
Fingerprint fingerprint = new Fingerprint();
for (ActionInput input : ActionInputHelper.expandArtifacts(toolInputs, artifactExpander)) {
fingerprint.addPath(input.getExecPath());
FileArtifactValue md = metadataProvider.getMetadata(input);
fingerprint.addBytes(md.getDigest());
}
return new ToolSignature(
fingerprint.hexDigestAndReset(),
toolInputsList.stream().map((input) -> input.getExecPath()).collect(Collectors.toSet()));
}

private SpawnResult downloadAndFinalizeSpawnResult(
RemoteActionExecutionContext remoteActionExecutionContext,
String actionId,
Expand Down Expand Up @@ -841,4 +888,18 @@ private static RemoteRetrier createExecuteRetrier(
return new ExecuteRetrier(
options.remoteMaxRetryAttempts, retryService, Retrier.ALLOW_ALL_CALLS);
}

/**
* A simple value class combining a hash of the tool inputs (and their digests) as well as a set
* of the relative paths of all tool inputs.
*/
private static final class ToolSignature {
private final String key;
private final Set<PathFragment> toolInputs;

private ToolSignature(String key, Set<PathFragment> toolInputs) {
this.key = key;
this.toolInputs = toolInputs;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,32 @@ static class FileNode extends Node {
private final ByteString data;
private final Digest digest;
private final boolean isExecutable;
private final boolean toolInput;

FileNode(String pathSegment, Path path, Digest digest, boolean isExecutable) {
this(pathSegment, path, digest, isExecutable, false);
}

FileNode(String pathSegment, Path path, Digest digest, boolean isExecutable, boolean toolInput) {
super(pathSegment);
this.path = Preconditions.checkNotNull(path, "path");
this.data = null;
this.digest = Preconditions.checkNotNull(digest, "digest");
this.isExecutable = isExecutable;
this.toolInput = toolInput;
}

FileNode(String pathSegment, ByteString data, Digest digest) {
this(pathSegment, data, digest, false);
}

FileNode(String pathSegment, ByteString data, Digest digest, boolean toolInput) {
super(pathSegment);
this.path = null;
this.data = Preconditions.checkNotNull(data, "data");
this.digest = Preconditions.checkNotNull(digest, "digest");
this.isExecutable = false;
this.toolInput = toolInput;
}

Digest getDigest() {
Expand All @@ -106,9 +117,13 @@ public boolean isExecutable() {
return isExecutable;
}

boolean isToolInput() {
return toolInput;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), path, data, digest, isExecutable);
return Objects.hash(super.hashCode(), path, data, digest, isExecutable, toolInput);
}

@Override
Expand All @@ -119,7 +134,8 @@ public boolean equals(Object o) {
&& Objects.equals(path, other.path)
&& Objects.equals(data, other.data)
&& Objects.equals(digest, other.digest)
&& isExecutable == other.isExecutable;
&& isExecutable == other.isExecutable
&& toolInput == other.toolInput;
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import build.bazel.remote.execution.v2.Digest;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.actions.ActionInput;
import com.google.devtools.build.lib.actions.ActionInputHelper;
import com.google.devtools.build.lib.actions.FileArtifactValue;
Expand All @@ -32,6 +33,7 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;

Expand Down Expand Up @@ -59,8 +61,18 @@ static DirectoryTree fromActionInputs(
Path execRoot,
DigestUtil digestUtil)
throws IOException {
return fromActionInputs(inputs, ImmutableSet.of(), metadataProvider, execRoot, digestUtil);
}

static DirectoryTree fromActionInputs(
SortedMap<PathFragment, ActionInput> inputs,
Set<PathFragment> toolInputs,
MetadataProvider metadataProvider,
Path execRoot,
DigestUtil digestUtil)
throws IOException {
Map<PathFragment, DirectoryNode> tree = new HashMap<>();
int numFiles = buildFromActionInputs(inputs, metadataProvider, execRoot, digestUtil, tree);
int numFiles = buildFromActionInputs(inputs, toolInputs, metadataProvider, execRoot, digestUtil, tree);
return new DirectoryTree(tree, numFiles);
}

Expand Down Expand Up @@ -116,6 +128,7 @@ private static int buildFromPaths(
*/
private static int buildFromActionInputs(
SortedMap<PathFragment, ActionInput> inputs,
Set<PathFragment> toolInputs,
MetadataProvider metadataProvider,
Path execRoot,
DigestUtil digestUtil,
Expand Down Expand Up @@ -149,14 +162,14 @@ private static int buildFromActionInputs(
isExecutable = inputPath.isExecutable();
}

currDir.addChild(new FileNode(path.getBaseName(), inputPath, d, isExecutable));
currDir.addChild(new FileNode(path.getBaseName(), inputPath, d, isExecutable, toolInputs.contains(path)));
return 1;

case DIRECTORY:
SortedMap<PathFragment, ActionInput> directoryInputs =
explodeDirectory(path, execRoot);
return buildFromActionInputs(
directoryInputs, metadataProvider, execRoot, digestUtil, tree);
directoryInputs, toolInputs, metadataProvider, execRoot, digestUtil, tree);

case SYMLINK:
throw new IllegalStateException(
Expand Down
Loading

0 comments on commit a630961

Please sign in to comment.