diff --git a/src/main/java/org/springframework/data/redis/cache/DefaultRedisCacheWriter.java b/src/main/java/org/springframework/data/redis/cache/DefaultRedisCacheWriter.java index b91733f444..205487fa1b 100644 --- a/src/main/java/org/springframework/data/redis/cache/DefaultRedisCacheWriter.java +++ b/src/main/java/org/springframework/data/redis/cache/DefaultRedisCacheWriter.java @@ -32,6 +32,7 @@ import org.springframework.data.redis.connection.ReactiveStringCommands; import org.springframework.data.redis.connection.RedisConnection; import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.connection.RedisStringCommands; import org.springframework.data.redis.connection.RedisStringCommands.SetOption; import org.springframework.data.redis.core.types.Expiration; import org.springframework.data.redis.util.ByteUtils; @@ -219,8 +220,10 @@ public byte[] putIfAbsent(String name, byte[] key, byte[] value, @Nullable Durat return execute(name, connection -> { + boolean wasLocked = false; if (isLockingCacheWriter()) { doLock(name, key, value, connection); + wasLocked = true; } try { @@ -242,7 +245,7 @@ public byte[] putIfAbsent(String name, byte[] key, byte[] value, @Nullable Durat return connection.stringCommands().get(key); } finally { - if (isLockingCacheWriter()) { + if (isLockingCacheWriter() && wasLocked) { doUnlock(name, connection); } } @@ -319,15 +322,17 @@ void lock(String name) { execute(name, connection -> doLock(name, name, null, connection)); } - @Nullable - protected Boolean doLock(String name, Object contextualKey, @Nullable Object contextualValue, - RedisConnection connection) { + boolean doLock(String name, Object contextualKey, @Nullable Object contextualValue, RedisConnection connection) { + RedisStringCommands commands = connection.stringCommands(); Expiration expiration = Expiration.from(this.lockTtl.getTimeToLive(contextualKey, contextualValue)); + byte[] cacheLockKey = createCacheLockKey(name); - while (!ObjectUtils.nullSafeEquals(connection.stringCommands().set(createCacheLockKey(name), new byte[0], expiration, SetOption.SET_IF_ABSENT),true)) { + while (!ObjectUtils.nullSafeEquals(commands.set(cacheLockKey, new byte[0], expiration, SetOption.SET_IF_ABSENT), + true)) { checkAndPotentiallyWaitUntilUnlocked(name, connection); } + return true; } @@ -341,7 +346,7 @@ void unlock(String name) { } @Nullable - private Long doUnlock(String name, RedisConnection connection) { + Long doUnlock(String name, RedisConnection connection) { return connection.keyCommands().del(createCacheLockKey(name)); } @@ -489,8 +494,7 @@ public CompletableFuture retrieve(String name, byte[] key, @Nullable Dur Mono cacheLockCheck = isLockingCacheWriter() ? waitForLock(connection, name) : Mono.empty(); ReactiveStringCommands stringCommands = connection.stringCommands(); - Mono get = shouldExpireWithin(ttl) - ? stringCommands.getEx(wrappedKey, Expiration.from(ttl)) + Mono get = shouldExpireWithin(ttl) ? stringCommands.getEx(wrappedKey, Expiration.from(ttl)) : stringCommands.get(wrappedKey); return cacheLockCheck.then(get).map(ByteUtils::getBytes).toFuture(); @@ -502,8 +506,7 @@ public CompletableFuture store(String name, byte[] key, byte[] value, @Nul return doWithConnection(connection -> { - Mono mono = isLockingCacheWriter() - ? doStoreWithLocking(name, key, value, ttl, connection) + Mono mono = isLockingCacheWriter() ? doStoreWithLocking(name, key, value, ttl, connection) : doStore(key, value, ttl, connection); return mono.then().toFuture(); @@ -531,7 +534,6 @@ private Mono doStore(byte[] cacheKey, byte[] value, @Nullable Duration } } - private Mono doLock(String name, Object contextualKey, @Nullable Object contextualValue, ReactiveRedisConnection connection) { diff --git a/src/test/java/org/springframework/data/redis/cache/DefaultRedisCacheWriterTests.java b/src/test/java/org/springframework/data/redis/cache/DefaultRedisCacheWriterTests.java index c44a9e85af..1f07ea6110 100644 --- a/src/test/java/org/springframework/data/redis/cache/DefaultRedisCacheWriterTests.java +++ b/src/test/java/org/springframework/data/redis/cache/DefaultRedisCacheWriterTests.java @@ -21,14 +21,19 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.ArrayList; import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; + import org.springframework.data.redis.connection.RedisConnection; import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.data.redis.connection.RedisStringCommands.SetOption; @@ -421,43 +426,73 @@ void noOpStatisticsCollectorReturnsEmptyStatsInstance() { assertThat(stats.getPuts()).isZero(); } - @ParameterizedRedisTest + @ParameterizedRedisTest // GH-1686 void doLockShouldGetLock() throws InterruptedException { int threadCount = 3; CountDownLatch beforeWrite = new CountDownLatch(threadCount); CountDownLatch afterWrite = new CountDownLatch(threadCount); + AtomicLong concurrency = new AtomicLong(); - DefaultRedisCacheWriter cw = new DefaultRedisCacheWriter(connectionFactory, Duration.ofMillis(50), - BatchStrategies.keys()){ - @Nullable - protected Boolean doLock(String name, Object contextualKey, @Nullable Object contextualValue, - RedisConnection connection) { - Boolean doLock = super.doLock(name, contextualKey, contextualValue, connection); - assertThat(doLock).isTrue(); + DefaultRedisCacheWriter cw = new DefaultRedisCacheWriter(connectionFactory, Duration.ofMillis(10), + BatchStrategies.keys()) { + + boolean doLock(String name, Object contextualKey, @Nullable Object contextualValue, RedisConnection connection) { + + boolean doLock = super.doLock(name, contextualKey, contextualValue, connection); + + // any concurrent access (aka not waiting until the lock is acquired) will result in a concurrency greater 1 + assertThat(concurrency.incrementAndGet()).isOne(); return doLock; } + + @Nullable + @Override + Long doUnlock(String name, RedisConnection connection) { + try { + return super.doUnlock(name, connection); + } finally { + concurrency.decrementAndGet(); + } + } }; cw.lock(CACHE_NAME); + // introduce concurrency + List> completions = new ArrayList<>(); for (int i = 0; i < threadCount; i++) { + + CompletableFuture completion = new CompletableFuture<>(); + completions.add(completion); + Thread th = new Thread(() -> { beforeWrite.countDown(); - cw.putIfAbsent(CACHE_NAME, binaryCacheKey, binaryCacheValue, Duration.ZERO); + try { + cw.putIfAbsent(CACHE_NAME, binaryCacheKey, binaryCacheValue, Duration.ZERO); + completion.complete(null); + } catch (Throwable e) { + completion.completeExceptionally(e); + } afterWrite.countDown(); }); th.start(); } - beforeWrite.await(); - - Thread.sleep(200); + assertThat(beforeWrite.await(5, TimeUnit.SECONDS)).isTrue(); + Thread.sleep(100); cw.unlock(CACHE_NAME); - afterWrite.await(); + assertThat(afterWrite.await(5, TimeUnit.SECONDS)).isTrue(); + for (CompletableFuture completion : completions) { + assertThat(completion).isCompleted().isCompletedWithValue(null); + } + + doWithConnection(conn -> { + assertThat(conn.exists("default-redis-cache-writer-tests~lock".getBytes())).isFalse(); + }); } private void doWithConnection(Consumer callback) {