diff --git a/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruption.java b/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruption.java index 76f68a7c0983d..0e7ac45c628e0 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruption.java +++ b/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruption.java @@ -21,11 +21,13 @@ import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.test.InternalTestCluster; -import java.util.HashSet; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Pattern; /** @@ -35,11 +37,12 @@ public class LongGCDisruption extends SingleNodeDisruption { private static final Pattern[] unsafeClasses = new Pattern[]{ // logging has shared JVM locks - we may suspend a thread and block other nodes from doing their thing - Pattern.compile("Logger") + Pattern.compile("logging\\.log4j") }; protected final String disruptedNode; private Set suspendedThreads; + private long stoppingTimeoutInMillis; public LongGCDisruption(Random random, String disruptedNode) { super(random); @@ -49,13 +52,66 @@ public LongGCDisruption(Random random, String disruptedNode) { @Override public synchronized void startDisrupting() { if (suspendedThreads == null) { - suspendedThreads = new HashSet<>(); - stopNodeThreads(disruptedNode, suspendedThreads); + boolean success = false; + try { + suspendedThreads = ConcurrentHashMap.newKeySet(); + + final String currentThreadNamme = Thread.currentThread().getName(); + assert currentThreadNamme.contains("[" + disruptedNode + "]") == false : + "current thread match pattern. thread name: " + currentThreadNamme + ", node: " + disruptedNode; + // we spawn a background thread to protect against deadlock which can happen + // if there are shared resources between caller thread and and suspended threads + // see unsafeClasses to how to avoid that + final AtomicReference stoppingError = new AtomicReference<>(); + final Thread stoppingThread = new Thread(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + stoppingError.set(e); + } + + @Override + protected void doRun() throws Exception { + while (stopNodeThreads(disruptedNode, suspendedThreads)) ; + } + }); + stoppingThread.setName(currentThreadNamme + "[LongGCDisruption][threadStopper]"); + stoppingThread.start(); + try { + stoppingThread.join(getStoppingTimeoutInMillis()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + if (stoppingError.get() != null) { + throw new RuntimeException("uknown error while stopping threads", stoppingError.get()); + } + if (stoppingThread.isAlive()) { + logger.warn("failed to stop node [{}]'s thread within [{}] millis. Stopping thread stack trace:\n {}" + , disruptedNode, getStoppingTimeoutInMillis(), stackTrace(stoppingThread)); + stoppingThread.interrupt(); // best effort; + throw new RuntimeException("stopping node threads took too long"); + } + success = true; + } finally { + if (success == false) { + // resume threads if failed + resumeThreads(suspendedThreads); + suspendedThreads = null; + } + } } else { throw new IllegalStateException("can't disrupt twice, call stopDisrupting() first"); } } + private String stackTrace(Thread thread) { + String result = ""; + for (StackTraceElement s : thread.getStackTrace()) { + result += "\tat " + s.getClassName() + "." + s.getMethodName() + + "(" + s.getFileName() + ":" + s.getLineNumber() + ")" + "\n"; + } + return result; + } + @Override public synchronized void stopDisrupting() { if (suspendedThreads != null) { @@ -77,6 +133,14 @@ public TimeValue expectedTimeToHeal() { @SuppressWarnings("deprecation") // stops/resumes threads intentionally @SuppressForbidden(reason = "stops/resumes threads intentionally") + /** + * resolves all threads belonging to given node and suspends them if their current stack trace + * is "safe". Threads are added to nodeThreads if suspended. + * + * returns true if some live threads were found. The caller is expected to call this method + * until no more "live" are found. + * + */ protected boolean stopNodeThreads(String node, Set nodeThreads) { Thread[] allThreads = null; while (allThreads == null) { @@ -86,7 +150,7 @@ protected boolean stopNodeThreads(String node, Set nodeThreads) { allThreads = null; } } - boolean stopped = false; + boolean liveThreadsFound = false; final String nodeThreadNamePart = "[" + node + "]"; for (Thread thread : allThreads) { if (thread == null) { @@ -95,7 +159,7 @@ protected boolean stopNodeThreads(String node, Set nodeThreads) { String name = thread.getName(); if (name.contains(nodeThreadNamePart)) { if (thread.isAlive() && nodeThreads.add(thread)) { - stopped = true; + liveThreadsFound = true; logger.trace("stopping thread [{}]", name); thread.suspend(); // double check the thread is not in a shared resource like logging. If so, let it go and come back.. @@ -103,7 +167,7 @@ protected boolean stopNodeThreads(String node, Set nodeThreads) { safe: for (StackTraceElement stackElement : thread.getStackTrace()) { String className = stackElement.getClassName(); - for (Pattern unsafePattern : unsafeClasses) { + for (Pattern unsafePattern : getUnsafeClasses()) { if (unsafePattern.matcher(className).find()) { safe = false; break safe; @@ -118,7 +182,17 @@ protected boolean stopNodeThreads(String node, Set nodeThreads) { } } } - return stopped; + return liveThreadsFound; + } + + // for testing + protected Pattern[] getUnsafeClasses() { + return unsafeClasses; + } + + // for testing + protected long getStoppingTimeoutInMillis() { + return 30 * 1000L; } @SuppressWarnings("deprecation") // stops/resumes threads intentionally diff --git a/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruptionTest.java b/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruptionTest.java new file mode 100644 index 0000000000000..c27185947aa5e --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/disruption/LongGCDisruptionTest.java @@ -0,0 +1,147 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.test.disruption; + +import org.elasticsearch.test.ESTestCase; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class LongGCDisruptionTest extends ESTestCase { + + static class LockedExecutor { + ReentrantLock lock = new ReentrantLock(); + + public void executeLocked(Runnable r) { + lock.lock(); + try { + r.run(); + } finally { + lock.unlock(); + } + } + } + + public void testBlockingTimeout() throws Exception { + final String nodeName = "test_node"; + LongGCDisruption disruption = new LongGCDisruption(random(), nodeName) { + @Override + protected Pattern[] getUnsafeClasses() { + return new Pattern[]{ + Pattern.compile("LockedExecutor") + }; + } + + @Override + protected long getStoppingTimeoutInMillis() { + return 100; + } + }; + final AtomicBoolean stop = new AtomicBoolean(); + final CountDownLatch underLock = new CountDownLatch(1); + final CountDownLatch pauseUnderLock = new CountDownLatch(1); + final LockedExecutor lockedExecutor = new LockedExecutor(); + final AtomicLong ops = new AtomicLong(); + try { + Thread[] threads = new Thread[10]; + for (int i = 0; i < 10; i++) { + // at least one locked and one none lock thread + final boolean lockedExec = (i < 9 && randomBoolean()) || i == 0; + threads[i] = new Thread(() -> { + while (stop.get() == false) { + if (lockedExec) { + lockedExecutor.executeLocked(() -> { + try { + underLock.countDown(); + ops.incrementAndGet(); + pauseUnderLock.await(); + } catch (InterruptedException e) { + + } + }); + } else { + ops.incrementAndGet(); + } + } + }); + threads[i].setName("[" + nodeName + "][" + i + "]"); + threads[i].start(); + } + // make sure some threads are under lock + underLock.await(); + RuntimeException e = expectThrows(RuntimeException.class, disruption::startDisrupting); + assertThat(e.getMessage(), containsString("stopping node threads took too long")); + } finally { + stop.set(true); + pauseUnderLock.countDown(); + } + } + + public void testNotBlockingUnsafeStackTraces() throws Exception { + final String nodeName = "test_node"; + LongGCDisruption disruption = new LongGCDisruption(random(), nodeName) { + @Override + protected Pattern[] getUnsafeClasses() { + return new Pattern[]{ + Pattern.compile("LockedExecutor") + }; + } + }; + final AtomicBoolean stop = new AtomicBoolean(); + final LockedExecutor lockedExecutor = new LockedExecutor(); + final AtomicLong ops = new AtomicLong(); + try { + Thread[] threads = new Thread[10]; + for (int i = 0; i < 10; i++) { + threads[i] = new Thread(() -> { + for (int iter = 0; stop.get() == false; iter++) { + if (iter % 2 == 0) { + lockedExecutor.executeLocked(() -> { + Thread.yield(); // give some chance to catch this stack trace + ops.incrementAndGet(); + }); + } else { + Thread.yield(); // give some chance to catch this stack trace + ops.incrementAndGet(); + } + } + }); + threads[i].setName("[" + nodeName + "][" + i + "]"); + threads[i].start(); + } + // make sure some threads are under lock + disruption.startDisrupting(); + long first = ops.get(); + assertThat(lockedExecutor.lock.isLocked(), equalTo(false)); // no threads should own the lock + Thread.sleep(100); + assertThat(ops.get(), equalTo(first)); + disruption.stopDisrupting(); + assertBusy(() -> assertThat(ops.get(), greaterThan(first))); + } finally { + stop.set(true); + } + } +}