diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD b/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD index 8d6e346b48266c..702cec5b2fa023 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/BUILD @@ -20,5 +20,6 @@ java_library( "//third_party:jsr305", "//third_party:rxjava3", "//third_party/grpc-java:grpc-jar", + "@maven//:io_netty_netty_transport_native_unix_common", ], ) diff --git a/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java index bb2dddb9d3a3f0..8398e73cc30aa8 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java +++ b/src/main/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactory.java @@ -16,7 +16,12 @@ import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.netty.channel.unix.Errors; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.functions.Action; @@ -118,7 +123,13 @@ public Single create() { .map( conn -> new SharedConnection( - conn, /* onClose= */ () -> tokenBucket.addToken(token)))); + conn, + /* onClose= */ () -> tokenBucket.addToken(token), + /* onFatalError= */ () -> { + synchronized (this) { + connectionAsyncSubject = null; + } + }))); } /** Returns current number of available connections. */ @@ -130,16 +141,33 @@ public int numAvailableConnections() { public static class SharedConnection implements Connection { private final Connection connection; private final Action onClose; + private final Runnable onFatalError; - public SharedConnection(Connection connection, Action onClose) { + public SharedConnection(Connection connection, Action onClose, Runnable onFatalError) { this.connection = connection; this.onClose = onClose; + this.onFatalError = onFatalError; } @Override public ClientCall call( MethodDescriptor method, CallOptions options) { - return connection.call(method, options); + return new SimpleForwardingClientCall<>(connection.call(method, options)) { + @Override + public void start(Listener responseListener, Metadata headers) { + super.start( + new SimpleForwardingClientCallListener<>(responseListener) { + @Override + public void onClose(Status status, Metadata trailers) { + if (isFatalError(status.getCause())) { + onFatalError.run(); + } + super.onClose(status, trailers); + } + }, + headers); + } + }; } @Override @@ -155,5 +183,11 @@ public void close() throws IOException { public Connection getUnderlyingConnection() { return connection; } + + private static boolean isFatalError(@Nullable Throwable t) { + // A low-level netty error indicates that the connection is fundamentally broken + // and should not be reused for retries. + return t instanceof Errors.NativeIoException; + } } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD b/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD index f8b88c67127d7b..5dd8863eeb9bb1 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD +++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/BUILD @@ -29,5 +29,7 @@ java_test( "//third_party:mockito", "//third_party:rxjava3", "//third_party:truth", + "@maven//:io_grpc_grpc_api", + "@maven//:io_netty_netty_transport_native_unix_common", ], ) diff --git a/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java index ad3f3c73f1999d..97840716602203 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/grpc/SharedConnectionFactoryTest.java @@ -14,6 +14,8 @@ package com.google.devtools.build.lib.remote.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -21,9 +23,19 @@ import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection; import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.netty.channel.unix.Errors; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.observers.TestObserver; import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayDeque; +import java.util.List; +import java.util.Queue; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -145,6 +157,76 @@ public void create_concurrentCreate_shareConnections() throws InterruptedExcepti verify(connectionFactory, times(1)).create(); } + @Test + public void create_belowMaxConcurrency_fatalErrorPreventsReuse() throws IOException { + Connection brokenConnection = + new Connection() { + @Override + public ClientCall call( + MethodDescriptor method, CallOptions options) { + var call = mock(ClientCall.class); + doAnswer( + invocationOnMock -> { + ((ClientCall.Listener) invocationOnMock.getArgument(0)) + .onClose( + Status.fromThrowable(mock(Errors.NativeIoException.class)), + new Metadata()); + return null; + }) + .when(call) + .start(any(), any()); + return call; + } + + @Override + public void close() {} + }; + Connection newConnection = mock(Connection.class); + Queue connectionsToCreate = + new ArrayDeque<>(List.of(brokenConnection, newConnection)); + when(connectionFactory.create()) + .thenAnswer(invocation -> Single.just(connectionsToCreate.remove())); + + SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2); + + TestObserver observer1 = factory.create().test(); + assertThat(factory.numAvailableConnections()).isEqualTo(1); + observer1 + .assertValue(conn -> conn.getUnderlyingConnection() == brokenConnection) + .assertComplete(); + + // Submit a call on the first connection and have it fail. + MethodDescriptor.Marshaller nullMarshaller = + new MethodDescriptor.Marshaller<>() { + @Override + public InputStream stream(byte[] bytes) { + return null; + } + + @Override + public byte[] parse(InputStream inputStream) { + return null; + } + }; + try (Connection firstConnection = observer1.values().getFirst()) { + var call = + firstConnection.call( + MethodDescriptor.newBuilder(nullMarshaller, nullMarshaller) + .setType(MethodDescriptor.MethodType.CLIENT_STREAMING) + .setFullMethodName("testMethod") + .build(), + CallOptions.DEFAULT); + ClientCall.Listener listener = new ClientCall.Listener<>() {}; + call.start(listener, new io.grpc.Metadata()); + listener.onClose(Status.fromThrowable(mock(Errors.NativeIoException.class)), new Metadata()); + } + + // Validate that the connection is not reused. + TestObserver observer2 = factory.create().test(); + observer2.assertValue(conn -> conn.getUnderlyingConnection() == newConnection).assertComplete(); + assertThat(factory.numAvailableConnections()).isEqualTo(1); + } + @Test public void create_afterLastFailed_success() { AtomicInteger times = new AtomicInteger(0);