diff --git a/src/main/java/io/vertx/core/impl/ContextBase.java b/src/main/java/io/vertx/core/impl/ContextBase.java index 35779d54ce6..afd2d45b769 100644 --- a/src/main/java/io/vertx/core/impl/ContextBase.java +++ b/src/main/java/io/vertx/core/impl/ContextBase.java @@ -23,7 +23,7 @@ */ class ContextBase extends AtomicReferenceArray { - private final int localsLength; + final int localsLength; ContextBase(int localsLength) { super(localsLength); diff --git a/src/main/java/io/vertx/core/impl/ContextImpl.java b/src/main/java/io/vertx/core/impl/ContextImpl.java index bd17689ae09..0a96fe853e6 100644 --- a/src/main/java/io/vertx/core/impl/ContextImpl.java +++ b/src/main/java/io/vertx/core/impl/ContextImpl.java @@ -44,7 +44,7 @@ static void setResultHandler(ContextInternal ctx, Future fut, Handler void setResultHandler(ContextInternal ctx, Future fut, Handler> LOCAL_MAP = new ContextLocalImpl<>(0); + ContextLocal> LOCAL_MAP = new ContextLocalImpl<>(0, ConcurrentHashMap::new); /** * @return the current context diff --git a/src/main/java/io/vertx/core/impl/ContextLocalImpl.java b/src/main/java/io/vertx/core/impl/ContextLocalImpl.java index 02671c9a58c..2eb30e3273d 100644 --- a/src/main/java/io/vertx/core/impl/ContextLocalImpl.java +++ b/src/main/java/io/vertx/core/impl/ContextLocalImpl.java @@ -12,18 +12,27 @@ import io.vertx.core.spi.context.storage.ContextLocal; +import java.util.function.Function; + /** * @author Julien Viet */ public class ContextLocalImpl implements ContextLocal { + public static ContextLocal create(Class type, Function duplicator) { + synchronized (LocalSeq.class) { + int idx = LocalSeq.locals.size(); + ContextLocal local = new ContextLocalImpl<>(idx, duplicator); + LocalSeq.locals.add(local); + return local; + } + } + final int index; + final Function duplicator; - public ContextLocalImpl(int index) { + public ContextLocalImpl(int index, Function duplicator) { this.index = index; - } - - public ContextLocalImpl() { - this.index = LocalSeq.next(); + this.duplicator = duplicator; } } diff --git a/src/main/java/io/vertx/core/impl/DuplicatedContext.java b/src/main/java/io/vertx/core/impl/DuplicatedContext.java index fa2b307926c..4e6f6088f44 100644 --- a/src/main/java/io/vertx/core/impl/DuplicatedContext.java +++ b/src/main/java/io/vertx/core/impl/DuplicatedContext.java @@ -180,7 +180,9 @@ public boolean isWorkerContext() { @Override public ContextInternal duplicate() { - return new DuplicatedContext(delegate); + DuplicatedContext duplicate = new DuplicatedContext(delegate); + delegate.owner().duplicate(this, duplicate); + return duplicate; } @Override diff --git a/src/main/java/io/vertx/core/impl/LocalSeq.java b/src/main/java/io/vertx/core/impl/LocalSeq.java index b9b4a7f0d8a..099039903e4 100644 --- a/src/main/java/io/vertx/core/impl/LocalSeq.java +++ b/src/main/java/io/vertx/core/impl/LocalSeq.java @@ -10,28 +10,35 @@ */ package io.vertx.core.impl; -import java.util.concurrent.atomic.AtomicInteger; +import io.vertx.core.spi.context.storage.ContextLocal; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; /** * @author Julien Viet */ class LocalSeq { - // 0 : reserved slot for local context map - private static final AtomicInteger seq = new AtomicInteger(1); + static final List> locals = new ArrayList<>(); + + static { + reset(); + } /** * Hook for testing purposes */ - static void reset() { - seq.set((1)); - } - - static int get() { - return seq.get(); + synchronized static void reset() { + // 0 : reserved slot for local context map + locals.clear(); + locals.add(ContextInternal.LOCAL_MAP); } - static int next() { - return seq.getAndIncrement(); + synchronized static ContextLocal[] get() { + return locals.toArray(new ContextLocal[0]); } } diff --git a/src/main/java/io/vertx/core/impl/VertxImpl.java b/src/main/java/io/vertx/core/impl/VertxImpl.java index f701650375a..90e34119bd4 100644 --- a/src/main/java/io/vertx/core/impl/VertxImpl.java +++ b/src/main/java/io/vertx/core/impl/VertxImpl.java @@ -36,6 +36,7 @@ import io.vertx.core.impl.btc.BlockedThreadChecker; import io.vertx.core.net.impl.NetClientBuilder; import io.vertx.core.impl.transports.JDKTransport; +import io.vertx.core.spi.context.storage.ContextLocal; import io.vertx.core.spi.file.FileResolver; import io.vertx.core.file.impl.FileSystemImpl; import io.vertx.core.file.impl.WindowsFileSystem; @@ -77,6 +78,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; import java.util.function.Supplier; /** @@ -134,7 +136,7 @@ private static ThreadFactory virtualThreadFactory() { private final FileResolver fileResolver; private final Map sharedHttpServers = new HashMap<>(); private final Map sharedNetServers = new HashMap<>(); - private final int contextLocalsLength; + private final ContextLocal[] contextLocals; final WorkerPool workerPool; final WorkerPool internalWorkerPool; final WorkerPool virtualThreaWorkerPool; @@ -191,7 +193,7 @@ private static ThreadFactory virtualThreadFactory() { ThreadFactory virtualThreadFactory = virtualThreadFactory(); - contextLocalsLength = LocalSeq.get(); + contextLocals = LocalSeq.get(); closeFuture = new CloseFuture(log); maxEventLoopExecTime = maxEventLoopExecuteTime; maxEventLoopExecTimeUnit = maxEventLoopExecuteTimeUnit; @@ -580,7 +582,7 @@ private ContextImpl createVirtualThreadContext(EventLoop eventLoop, CloseFuture private ContextImpl createContext(ThreadingModel threadingModel, EventLoop eventLoop, CloseFuture closeFuture, Deployment deployment, ClassLoader tccl, EventExecutor eventExecutor, WorkerPool wp) { - return new ContextImpl(this, contextLocalsLength, threadingModel, eventLoop, eventExecutor, internalWorkerPool, wp, deployment, closeFuture, disableTCCL ? null : tccl); + return new ContextImpl(this, contextLocals.length, threadingModel, eventLoop, eventExecutor, internalWorkerPool, wp, deployment, closeFuture, disableTCCL ? null : tccl); } @Override @@ -790,6 +792,17 @@ public synchronized void close(Handler> completionHandler) { }); } + void duplicate(ContextBase src, ContextBase dst) { + for (int i = 0;i < contextLocals.length;i++) { + ContextLocalImpl contextLocal = (ContextLocalImpl) contextLocals[i]; + Object local = src.get(i); + if (local != null) { + local = ((Function)contextLocal.duplicator).apply(local); + } + dst.set(i, local); + } + } + @Override public Future deployVerticle(String name, DeploymentOptions options) { if (options.isHa() && haManager() != null) { diff --git a/src/main/java/io/vertx/core/spi/context/storage/ContextLocal.java b/src/main/java/io/vertx/core/spi/context/storage/ContextLocal.java index 8831b74098f..0f6e6e16559 100644 --- a/src/main/java/io/vertx/core/spi/context/storage/ContextLocal.java +++ b/src/main/java/io/vertx/core/spi/context/storage/ContextLocal.java @@ -14,6 +14,7 @@ import io.vertx.core.impl.ContextInternal; import io.vertx.core.impl.ContextLocalImpl; +import java.util.function.Function; import java.util.function.Supplier; /** @@ -35,7 +36,16 @@ public interface ContextLocal { * @return the context local storage */ static ContextLocal registerLocal(Class type) { - return new ContextLocalImpl<>(); + return ContextLocalImpl.create(type, Function.identity()); + } + + /** + * Registers a context local storage. + * + * @return the context local storage + */ + static ContextLocal registerLocal(Class type, Function duplicator) { + return ContextLocalImpl.create(type, duplicator); } /** diff --git a/src/test/java/io/vertx/core/ContextTest.java b/src/test/java/io/vertx/core/ContextTest.java index a8ccfe69fb3..988586fe306 100644 --- a/src/test/java/io/vertx/core/ContextTest.java +++ b/src/test/java/io/vertx/core/ContextTest.java @@ -1238,4 +1238,19 @@ public void testInterruptTask(ContextInternal context, Consumer actor) assertTrue((System.currentTimeMillis() - now) < 2000); assertTrue(interrupted.get()); } + + @Test + public void testNestedDuplicate() { + ContextInternal ctx = ((ContextInternal) vertx.getOrCreateContext()).duplicate(); + ctx.putLocal("foo", "bar"); + Object expected = new Object(); + ctx.putLocal(contextLocal, AccessMode.CONCURRENT, expected); + ContextInternal duplicate = ctx.duplicate(); + assertEquals("bar", duplicate.getLocal("foo")); + assertEquals(expected, duplicate.getLocal(contextLocal)); + ctx.removeLocal("foo"); + ctx.removeLocal(contextLocal, AccessMode.CONCURRENT); + assertEquals("bar", duplicate.getLocal("foo")); + assertEquals(expected, duplicate.getLocal(contextLocal)); + } }