diff --git a/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruption.java b/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruption.java index 76f68a7c0983d..944ddb9b05fff 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruption.java +++ b/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruption.java @@ -21,12 +21,16 @@ import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.test.InternalTestCluster; -import java.util.HashSet; +import java.util.Arrays; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Pattern; +import java.util.stream.Collectors; /** * Suspends all threads on the specified node in order to simulate a long gc. @@ -34,8 +38,8 @@ public class LongGCDisruption extends SingleNodeDisruption { private static final Pattern[] unsafeClasses = new Pattern[]{ - // logging has shared JVM locks - we may suspend a thread and block other nodes from doing their thing - Pattern.compile("Logger") + // logging has shared JVM locks - we may suspend a thread and block other nodes from doing their thing + Pattern.compile("logging\\.log4j") }; protected final String disruptedNode; @@ -49,13 +53,67 @@ public LongGCDisruption(Random random, String disruptedNode) { @Override public synchronized void startDisrupting() { if (suspendedThreads == null) { - suspendedThreads = new HashSet<>(); - stopNodeThreads(disruptedNode, suspendedThreads); + boolean success = false; + try { + suspendedThreads = ConcurrentHashMap.newKeySet(); + + final String currentThreadName = Thread.currentThread().getName(); + assert currentThreadName.contains("[" + disruptedNode + "]") == false : + "current thread match pattern. thread name: " + currentThreadName + ", node: " + disruptedNode; + // we spawn a background thread to protect against deadlock which can happen + // if there are shared resources between caller thread and and suspended threads + // see unsafeClasses to how to avoid that + final AtomicReference stoppingError = new AtomicReference<>(); + final Thread stoppingThread = new Thread(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + stoppingError.set(e); + } + + @Override + protected void doRun() throws Exception { + // keep trying to stop threads, until no new threads are discovered. + while (stopNodeThreads(disruptedNode, suspendedThreads)) { + if (Thread.interrupted()) { + return; + } + } + } + }); + stoppingThread.setName(currentThreadName + "[LongGCDisruption][threadStopper]"); + stoppingThread.start(); + try { + stoppingThread.join(getStoppingTimeoutInMillis()); + } catch (InterruptedException e) { + stoppingThread.interrupt(); // best effort to signal stopping + throw new RuntimeException(e); + } + if (stoppingError.get() != null) { + throw new RuntimeException("unknown error while stopping threads", stoppingError.get()); + } + if (stoppingThread.isAlive()) { + logger.warn("failed to stop node [{}]'s threads within [{}] millis. Stopping thread stack trace:\n {}" + , disruptedNode, getStoppingTimeoutInMillis(), stackTrace(stoppingThread)); + stoppingThread.interrupt(); // best effort; + throw new RuntimeException("stopping node threads took too long"); + } + success = true; + } finally { + if (success == false) { + // resume threads if failed + resumeThreads(suspendedThreads); + suspendedThreads = null; + } + } } else { throw new IllegalStateException("can't disrupt twice, call stopDisrupting() first"); } } + private String stackTrace(Thread thread) { + return Arrays.stream(thread.getStackTrace()).map(Object::toString).collect(Collectors.joining("\n")); + } + @Override public synchronized void stopDisrupting() { if (suspendedThreads != null) { @@ -75,6 +133,13 @@ public TimeValue expectedTimeToHeal() { return TimeValue.timeValueMillis(0); } + /** + * resolves all threads belonging to given node and suspends them if their current stack trace + * is "safe". Threads are added to nodeThreads if suspended. + * + * returns true if some live threads were found. The caller is expected to call this method + * until no more "live" are found. + */ @SuppressWarnings("deprecation") // stops/resumes threads intentionally @SuppressForbidden(reason = "stops/resumes threads intentionally") protected boolean stopNodeThreads(String node, Set nodeThreads) { @@ -86,7 +151,7 @@ protected boolean stopNodeThreads(String node, Set nodeThreads) { allThreads = null; } } - boolean stopped = false; + boolean liveThreadsFound = false; final String nodeThreadNamePart = "[" + node + "]"; for (Thread thread : allThreads) { if (thread == null) { @@ -95,7 +160,7 @@ protected boolean stopNodeThreads(String node, Set nodeThreads) { String name = thread.getName(); if (name.contains(nodeThreadNamePart)) { if (thread.isAlive() && nodeThreads.add(thread)) { - stopped = true; + liveThreadsFound = true; logger.trace("stopping thread [{}]", name); thread.suspend(); // double check the thread is not in a shared resource like logging. If so, let it go and come back.. @@ -103,7 +168,7 @@ protected boolean stopNodeThreads(String node, Set nodeThreads) { safe: for (StackTraceElement stackElement : thread.getStackTrace()) { String className = stackElement.getClassName(); - for (Pattern unsafePattern : unsafeClasses) { + for (Pattern unsafePattern : getUnsafeClasses()) { if (unsafePattern.matcher(className).find()) { safe = false; break safe; @@ -118,7 +183,17 @@ protected boolean stopNodeThreads(String node, Set nodeThreads) { } } } - return stopped; + return liveThreadsFound; + } + + // for testing + protected Pattern[] getUnsafeClasses() { + return unsafeClasses; + } + + // for testing + protected long getStoppingTimeoutInMillis() { + return TimeValue.timeValueSeconds(30).getMillis(); } @SuppressWarnings("deprecation") // stops/resumes threads intentionally diff --git a/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruptionTest.java b/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruptionTest.java new file mode 100644 index 0000000000000..38190444758b8 --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruptionTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.test.disruption; + +import org.elasticsearch.test.ESTestCase; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class LongGCDisruptionTest extends ESTestCase { + + static class LockedExecutor { + ReentrantLock lock = new ReentrantLock(); + + public void executeLocked(Runnable r) { + lock.lock(); + try { + r.run(); + } finally { + lock.unlock(); + } + } + } + + public void testBlockingTimeout() throws Exception { + final String nodeName = "test_node"; + LongGCDisruption disruption = new LongGCDisruption(random(), nodeName) { + @Override + protected Pattern[] getUnsafeClasses() { + return new Pattern[]{ + Pattern.compile(LockedExecutor.class.getSimpleName()) + }; + } + + @Override + protected long getStoppingTimeoutInMillis() { + return 100; + } + }; + final AtomicBoolean stop = new AtomicBoolean(); + final CountDownLatch underLock = new CountDownLatch(1); + final CountDownLatch pauseUnderLock = new CountDownLatch(1); + final LockedExecutor lockedExecutor = new LockedExecutor(); + final AtomicLong ops = new AtomicLong(); + try { + Thread[] threads = new Thread[10]; + for (int i = 0; i < 10; i++) { + // at least one locked and one none lock thread + final boolean lockedExec = (i < 9 && randomBoolean()) || i == 0; + threads[i] = new Thread(() -> { + while (stop.get() == false) { + if (lockedExec) { + lockedExecutor.executeLocked(() -> { + try { + underLock.countDown(); + ops.incrementAndGet(); + pauseUnderLock.await(); + } catch (InterruptedException e) { + + } + }); + } else { + ops.incrementAndGet(); + } + } + }); + threads[i].setName("[" + nodeName + "][" + i + "]"); + threads[i].start(); + } + // make sure some threads are under lock + underLock.await(); + RuntimeException e = expectThrows(RuntimeException.class, disruption::startDisrupting); + assertThat(e.getMessage(), containsString("stopping node threads took too long")); + } finally { + stop.set(true); + pauseUnderLock.countDown(); + } + } + + /** + * Checks that a GC disruption never blocks threads while they are doing something "unsafe" + * but does keep retrying until all threads can be safely paused + */ + public void testNotBlockingUnsafeStackTraces() throws Exception { + final String nodeName = "test_node"; + LongGCDisruption disruption = new LongGCDisruption(random(), nodeName) { + @Override + protected Pattern[] getUnsafeClasses() { + return new Pattern[]{ + Pattern.compile(LockedExecutor.class.getSimpleName()) + }; + } + }; + final AtomicBoolean stop = new AtomicBoolean(); + final LockedExecutor lockedExecutor = new LockedExecutor(); + final AtomicLong ops = new AtomicLong(); + try { + Thread[] threads = new Thread[10]; + for (int i = 0; i < 10; i++) { + threads[i] = new Thread(() -> { + for (int iter = 0; stop.get() == false; iter++) { + if (iter % 2 == 0) { + lockedExecutor.executeLocked(() -> { + Thread.yield(); // give some chance to catch this stack trace + ops.incrementAndGet(); + }); + } else { + Thread.yield(); // give some chance to catch this stack trace + ops.incrementAndGet(); + } + } + }); + threads[i].setName("[" + nodeName + "][" + i + "]"); + threads[i].start(); + } + // make sure some threads are under lock + disruption.startDisrupting(); + long first = ops.get(); + assertThat(lockedExecutor.lock.isLocked(), equalTo(false)); // no threads should own the lock + Thread.sleep(100); + assertThat(ops.get(), equalTo(first)); + disruption.stopDisrupting(); + assertBusy(() -> assertThat(ops.get(), greaterThan(first))); + } finally { + stop.set(true); + } + } +}