From 09c621e4cf5b968f4c6cdf905ab142d5961f9ddc Mon Sep 17 00:00:00 2001 From: chiwang Date: Tue, 23 Mar 2021 19:10:30 -0700 Subject: [PATCH] Remote: Fix a race that AsyncTaskCache#Execution could be reused after disposed which results in CancellationException("disposed") propagated to downstream. Also added a test case to verify the fix. PiperOrigin-RevId: 364699975 --- .../build/lib/remote/util/AsyncTaskCache.java | 140 +++++++++++------- .../lib/remote/util/AsyncTaskCacheTest.java | 61 +++++--- 2 files changed, 133 insertions(+), 68 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java index 7005364f284b8c..c3d2c25267f50d 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; @@ -54,7 +55,7 @@ public final class AsyncTaskCache { private final Map finished; @GuardedBy("lock") - private final Map inProgress; + private final Map> inProgress; public static AsyncTaskCache create() { return new AsyncTaskCache<>(); @@ -90,18 +91,22 @@ public Single executeIfNot(KeyT key, Single task) { return execute(key, task, false); } - private class Execution { + private static class Execution { + private final AtomicBoolean isTaskDisposed = new AtomicBoolean(false); private final Single task; private final AsyncSubject asyncSubject = AsyncSubject.create(); - private final AtomicInteger subscriberCount = new AtomicInteger(0); + private final AtomicInteger referenceCount = new AtomicInteger(0); private final AtomicReference taskDisposable = new AtomicReference<>(null); Execution(Single task) { this.task = task; } - public Single start() { - if (taskDisposable.get() == null) { + Single executeIfNot() { + checkState(!isTaskDisposed(), "disposed"); + + int subscribed = referenceCount.getAndIncrement(); + if (taskDisposable.get() == null && subscribed == 0) { task.subscribe( new SingleObserver() { @Override @@ -122,27 +127,39 @@ public void onError(@NonNull Throwable e) { }); } - return Single.fromObservable(asyncSubject) - .doOnSubscribe(d -> subscriberCount.incrementAndGet()) - .doOnDispose( - () -> { - if (subscriberCount.decrementAndGet() == 0) { - Disposable d = taskDisposable.get(); - if (d != null) { - d.dispose(); - } - asyncSubject.onError(new CancellationException("disposed")); - } - }); + return Single.fromObservable(asyncSubject); + } + + boolean isTaskTerminated() { + return asyncSubject.hasComplete() || asyncSubject.hasThrowable(); + } + + boolean isTaskDisposed() { + return isTaskDisposed.get(); + } + + void tryDisposeTask() { + checkState(!isTaskDisposed(), "disposed"); + checkState(!isTaskTerminated(), "terminated"); + + if (referenceCount.decrementAndGet() == 0) { + isTaskDisposed.set(true); + asyncSubject.onError(new CancellationException("disposed")); + + Disposable d = taskDisposable.get(); + if (d != null) { + d.dispose(); + } + } } } /** Returns count of subscribers for a task. */ public int getSubscriberCount(KeyT key) { synchronized (lock) { - Execution execution = inProgress.get(key); + Execution execution = inProgress.get(key); if (execution != null) { - return execution.subscriberCount.get(); + return execution.referenceCount.get(); } } @@ -158,49 +175,72 @@ public int getSubscriberCount(KeyT key) { * error if any. */ public Single execute(KeyT key, Single task, boolean force) { - return Single.defer( - () -> { + return Single.create( + emitter -> { synchronized (lock) { if (!force && finished.containsKey(key)) { - return Single.just(finished.get(key)); + emitter.onSuccess(finished.get(key)); + return; } finished.remove(key); - Execution execution = + Execution execution = inProgress.computeIfAbsent( key, - missingKey -> { + ignoredKey -> { AtomicInteger subscribeTimes = new AtomicInteger(0); - return new Execution( + return new Execution<>( Single.defer( - () -> { - int times = subscribeTimes.incrementAndGet(); - checkState(times == 1, "Subscribed more than once to the task"); - return task; - }) - .doOnSuccess( - value -> { - synchronized (lock) { - finished.put(key, value); - inProgress.remove(key); - } - }) - .doOnError( - error -> { - synchronized (lock) { - inProgress.remove(key); - } - }) - .doOnDispose( - () -> { - synchronized (lock) { - inProgress.remove(key); - } - })); + () -> { + int times = subscribeTimes.incrementAndGet(); + checkState(times == 1, "Subscribed more than once to the task"); + return task; + })); }); - return execution.start(); + execution + .executeIfNot() + .subscribe( + new SingleObserver() { + @Override + public void onSubscribe(@NonNull Disposable d) { + emitter.setCancellable( + () -> { + d.dispose(); + + if (!execution.isTaskTerminated()) { + synchronized (lock) { + execution.tryDisposeTask(); + if (execution.isTaskDisposed()) { + inProgress.remove(key); + } + } + } + }); + } + + @Override + public void onSuccess(@NonNull ValueT value) { + synchronized (lock) { + finished.put(key, value); + inProgress.remove(key); + } + + emitter.onSuccess(value); + } + + @Override + public void onError(@NonNull Throwable e) { + synchronized (lock) { + inProgress.remove(key); + } + + if (!emitter.isDisposed()) { + emitter.onError(e); + } + } + }); } }); } diff --git a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java index 9e2d641f1edc1b..8e3ee28b2cee30 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java @@ -14,16 +14,15 @@ package com.google.devtools.build.lib.remote.util; import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.core.SingleEmitter; import io.reactivex.rxjava3.observers.TestObserver; -import io.reactivex.rxjava3.plugins.RxJavaPlugins; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import org.junit.After; -import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -32,21 +31,7 @@ @RunWith(JUnit4.class) public class AsyncTaskCacheTest { - private final AtomicReference rxGlobalThrowable = new AtomicReference<>(null); - - @Before - public void setUp() { - RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set); - } - - @After - public void tearDown() throws Throwable { - // Make sure rxjava didn't receive global errors - Throwable t = rxGlobalThrowable.getAndSet(null); - if (t != null) { - throw t; - } - } + @Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule(); @Test public void execute_noSubscription_noExecution() { @@ -296,4 +281,44 @@ public void execute_multipleTasks_completeOne() { assertThat(cache.getInProgressTasks()).containsExactly("key2"); assertThat(cache.getFinishedTasks()).containsExactly("key1"); } + + @Test + public void execute_executeAndDisposeLoop_noErrors() throws InterruptedException { + AsyncTaskCache cache = AsyncTaskCache.create(); + Single task = Single.timer(1, SECONDS); + AtomicReference error = new AtomicReference<>(null); + AtomicInteger errorCount = new AtomicInteger(0); + int executionCount = 100; + Runnable runnable = + () -> { + try { + for (int i = 0; i < executionCount; ++i) { + TestObserver observer = cache.execute("key1", task, true).test(); + observer.assertNoErrors(); + observer.dispose(); + } + } catch (Throwable t) { + errorCount.incrementAndGet(); + error.set(t); + } + }; + int threadCount = 10; + Thread[] threads = new Thread[threadCount]; + for (int i = 0; i < threadCount; ++i) { + Thread thread = new Thread(runnable); + threads[i] = thread; + } + + for (Thread thread : threads) { + thread.start(); + } + for (Thread thread : threads) { + thread.join(); + } + + if (error.get() != null) { + throw new IllegalStateException( + String.format("%s/%s errors", errorCount.get(), threadCount), error.get()); + } + } }