Skip to content

Commit

Permalink
[ML] fixes inference timeout handling bug that throws unexpected Null…
Browse files Browse the repository at this point in the history
…PointerException (elastic#87533)

The timeout thread was being scheduled to be executed before all the requisite fields were populated. 

This would (usually only in extreme circumstances), would throw a null pointer exception that looked like:

```
2022-06-06T23:40:42,534][ERROR][o.e.b.ElasticsearchUncaughtExceptionHandler] [javaRestTest-2] uncaught exception in thread [elasticsearch[javaRestTest-2][ml_utility][T#4]]
java.lang.NullPointerException: Cannot invoke "org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager$ProcessContext.getTimeoutCount()" because "this.processContext" is null
	at org.elasticsearch.xpack.ml.inference.deployment.AbstractPyTorchAction.onTimeout(AbstractPyTorchAction.java:58) ~[?:?]
	at org.elasticsearch.common.util.concurrent.ThreadContext$ContextPreservingRunnable.run(ThreadContext.java:709) ~[elasticsearch-8.4.0-SNAPSHOT.jar:?]
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) ~[?:?]
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) ~[?:?]
	at java.lang.Thread.run(Thread.java:833) [?:?]
```

This PR fixes this bug and an additional (more minor) bug where inference call rejections were not accurately counted.

closes: elastic#87457
  • Loading branch information
benwtrent authored Jun 13, 2022
1 parent af9c0da commit f1b48de
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 35 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/87533.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 87533
summary: Fixes inference timeout handling bug that throws unexpected `NullPointerException`
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;

import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.core.Strings.format;

abstract class AbstractPyTorchAction<T> extends AbstractRunnable {
abstract class AbstractPyTorchAction<T> extends AbstractInitializableRunnable {

private final String modelId;
private final long requestId;
private final TimeValue timeout;
private final Scheduler.Cancellable timeoutHandler;
private Scheduler.Cancellable timeoutHandler;
private final DeploymentManager.ProcessContext processContext;
private final AtomicBoolean notified = new AtomicBoolean();

private final ActionListener<T> listener;
private final ThreadPool threadPool;

protected AbstractPyTorchAction(
String modelId,
Expand All @@ -41,16 +41,23 @@ protected AbstractPyTorchAction(
ThreadPool threadPool,
ActionListener<T> listener
) {
this.modelId = modelId;
this.modelId = ExceptionsHelper.requireNonNull(modelId, "modelId");
this.requestId = requestId;
this.timeout = timeout;
this.timeoutHandler = threadPool.schedule(
this::onTimeout,
ExceptionsHelper.requireNonNull(timeout, "timeout"),
MachineLearning.UTILITY_THREAD_POOL_NAME
);
this.processContext = processContext;
this.listener = listener;
this.timeout = ExceptionsHelper.requireNonNull(timeout, "timeout");
this.processContext = ExceptionsHelper.requireNonNull(processContext, "processContext");
this.listener = ExceptionsHelper.requireNonNull(listener, "listener");
this.threadPool = ExceptionsHelper.requireNonNull(threadPool, "threadPool");
}

/**
* Needs to be called after construction. This init starts the timeout handler and needs to be called before added to the executor for
* scheduled work.
*/
@Override
public final void init() {
if (this.timeoutHandler == null) {
this.timeoutHandler = threadPool.schedule(this::onTimeout, timeout, MachineLearning.UTILITY_THREAD_POOL_NAME);
}
}

void onTimeout() {
Expand All @@ -66,17 +73,31 @@ void onTimeout() {
}

void onSuccess(T result) {
timeoutHandler.cancel();
if (timeoutHandler != null) {
timeoutHandler.cancel();
} else {
assert false : "init() not called, timeout handler unexpectedly null";
}
if (notified.compareAndSet(false, true)) {
listener.onResponse(result);
return;
}
getLogger().debug("[{}] request [{}] received inference response but listener already notified", modelId, requestId);
}

@Override
public void onRejection(Exception e) {
super.onRejection(e);
processContext.getRejectedExecutionCount().incrementAndGet();
}

@Override
public void onFailure(Exception e) {
timeoutHandler.cancel();
if (timeoutHandler != null) {
timeoutHandler.cancel();
} else {
assert false : "init() not called, timeout handler unexpectedly null";
}
if (notified.compareAndSet(false, true)) {
processContext.getResultProcessor().ignoreResponseWithoutNotifying(String.valueOf(requestId));
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

package org.elasticsearch.xpack.ml.inference.pytorch;

import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;
import org.elasticsearch.xpack.ml.job.process.AbstractProcessWorkerExecutorService;

import java.util.concurrent.PriorityBlockingQueue;
Expand All @@ -35,7 +35,7 @@ public enum RequestPriority {
* A Runnable sorted first by RequestPriority then a tie breaker which in
* most cases will be the insertion order
*/
public static record OrderedRunnable(RequestPriority priority, long tieBreaker, Runnable runnable)
public record OrderedRunnable(RequestPriority priority, long tieBreaker, Runnable runnable)
implements
Comparable<OrderedRunnable>,
Runnable {
Expand Down Expand Up @@ -76,7 +76,8 @@ public PriorityProcessWorkerExecutorService(ThreadContext contextHolder, String
* @param priority Request priority
* @param tieBreaker For sorting requests of equal priority
*/
public synchronized void executeWithPriority(AbstractRunnable command, RequestPriority priority, long tieBreaker) {
public synchronized void executeWithPriority(AbstractInitializableRunnable command, RequestPriority priority, long tieBreaker) {
command.init();
if (isShutdown()) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException(processName + " worker service has shutdown", true);
command.onRejection(rejected);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.job.process;

import org.elasticsearch.common.util.concurrent.AbstractRunnable;

/**
* Abstract runnable that has an `init` function that can be called when passed to an executor
*/
public abstract class AbstractInitializableRunnable extends AbstractRunnable {
public abstract void init();
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ public ProcessWorkerExecutorService(ThreadContext contextHolder, String processN

@Override
public synchronized void execute(Runnable command) {
if (command instanceof AbstractInitializableRunnable initializableRunnable) {
initializableRunnable.init();
}
if (isShutdown()) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException(processName + " worker service has shutdown", true);
if (command instanceof AbstractRunnable runnable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void testRunNotCalledAfterNotified() {
tp,
listener
);

action.init();
action.onTimeout();
action.run();
verify(resultProcessor, times(1)).ignoreResponseWithoutNotifying("1");
Expand All @@ -84,6 +84,7 @@ public void testRunNotCalledAfterNotified() {
tp,
listener
);
action.init();

action.onFailure(new IllegalStateException());
action.run();
Expand Down Expand Up @@ -122,6 +123,7 @@ public void testDoRun() throws IOException {
tp,
listener
);
action.init();

action.run();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -31,9 +30,6 @@
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -74,6 +70,7 @@ public void testRejectedExecution() {
Long taskId = 1L;
when(task.getId()).thenReturn(taskId);
when(task.isStopped()).thenReturn(Boolean.FALSE);
when(task.getModelId()).thenReturn("test-rejected");

DeploymentManager deploymentManager = new DeploymentManager(
mock(Client.class),
Expand All @@ -82,9 +79,12 @@ public void testRejectedExecution() {
mock(PyTorchProcessFactory.class)
);

PriorityProcessWorkerExecutorService executorService = mock(PriorityProcessWorkerExecutorService.class);
doThrow(new EsRejectedExecutionException("mock executor rejection")).when(executorService)
.executeWithPriority(any(AbstractRunnable.class), any(), anyLong());
PriorityProcessWorkerExecutorService executorService = new PriorityProcessWorkerExecutorService(
tp.getThreadContext(),
"test reject",
10
);
executorService.shutdown();

AtomicInteger rejectedCount = new AtomicInteger();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void testInferListenerOnlyCalledOnce() {
tp,
listener
);

action.init();
action.onSuccess(new WarningInferenceResults("foo"));
for (int i = 0; i < 10; i++) {
action.onSuccess(new WarningInferenceResults("foo"));
Expand All @@ -95,7 +95,7 @@ public void testInferListenerOnlyCalledOnce() {
tp,
listener
);

action.init();
action.onTimeout();
for (int i = 0; i < 10; i++) {
action.onSuccess(new WarningInferenceResults("foo"));
Expand All @@ -116,7 +116,7 @@ public void testInferListenerOnlyCalledOnce() {
tp,
listener
);

action.init();
action.onFailure(new Exception("bar"));
for (int i = 0; i < 10; i++) {
action.onSuccess(new WarningInferenceResults("foo"));
Expand Down Expand Up @@ -146,7 +146,7 @@ public void testRunNotCalledAfterNotified() {
tp,
listener
);

action.init();
action.onTimeout();
action.run();
verify(resultProcessor, times(1)).ignoreResponseWithoutNotifying("1");
Expand All @@ -163,7 +163,7 @@ public void testRunNotCalledAfterNotified() {
tp,
listener
);

action.init();
action.onFailure(new IllegalStateException());
action.run();
verify(resultProcessor, never()).registerRequest(anyString(), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

package org.elasticsearch.xpack.ml.inference.pytorch;

import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.job.process.AbstractInitializableRunnable;
import org.junit.After;

import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -39,7 +39,12 @@ public void testQueueCapacityReached() {
var r3 = new RunOrderValidator(3, counter);
executor.executeWithPriority(r3, RequestPriority.NORMAL, 101L);

assertTrue(r1.initialized);
assertTrue(r2.initialized);
assertTrue(r3.initialized);
assertTrue(r3.hasBeenRejected);
assertFalse(r1.hasBeenRejected);
assertFalse(r2.hasBeenRejected);
}

public void testQueueCapacityReached_HighestPriority() {
Expand All @@ -56,8 +61,11 @@ public void testQueueCapacityReached_HighestPriority() {
var r5 = new RunOrderValidator(5, counter);
executor.executeWithPriority(r5, RequestPriority.NORMAL, 105L);

assertTrue(r3.initialized);
assertTrue(r3.hasBeenRejected);
assertTrue(highestPriorityAlwaysAccepted.initialized);
assertFalse(highestPriorityAlwaysAccepted.hasBeenRejected);
assertTrue(r5.initialized);
assertTrue(r5.hasBeenRejected);
}

Expand All @@ -78,6 +86,10 @@ public void testOrderedRunnables_NormalPriority() {

executor.start();

assertTrue(r1.initialized);
assertTrue(r2.initialized);
assertTrue(r3.initialized);

assertTrue(r1.hasBeenRun);
assertTrue(r2.hasBeenRun);
assertTrue(r3.hasBeenRun);
Expand Down Expand Up @@ -114,18 +126,24 @@ private PriorityProcessWorkerExecutorService createProcessWorkerExecutorService(
);
}

private static class RunOrderValidator extends AbstractRunnable {
private static class RunOrderValidator extends AbstractInitializableRunnable {

private boolean hasBeenRun = false;
private boolean hasBeenRejected = false;
private final int expectedOrder;
private final AtomicInteger counter;
private boolean initialized = false;

RunOrderValidator(int expectedOrder, AtomicInteger counter) {
this.expectedOrder = expectedOrder;
this.counter = counter;
}

@Override
public void init() {
initialized = true;
}

@Override
public void onRejection(Exception e) {
hasBeenRejected = true;
Expand All @@ -143,7 +161,7 @@ protected void doRun() {
}
}

private static class ShutdownExecutorRunnable extends AbstractRunnable {
private static class ShutdownExecutorRunnable extends AbstractInitializableRunnable {

PriorityProcessWorkerExecutorService executor;

Expand All @@ -162,5 +180,9 @@ protected void doRun() {
executor.shutdown();
}

@Override
public void init() {
// do nothing
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,19 @@ public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() {
threadPool.generic().execute(executor::start);
executor.shutdown();
AtomicBoolean rejected = new AtomicBoolean(false);
executor.execute(new AbstractRunnable() {
AtomicBoolean initialized = new AtomicBoolean(false);
executor.execute(new AbstractInitializableRunnable() {
@Override
public void onRejection(Exception e) {
assertThat(e, isA(EsRejectedExecutionException.class));
rejected.set(true);
}

@Override
public void init() {
initialized.set(true);
}

@Override
public void onFailure(Exception e) {
fail("onFailure should not be called after the worker is shutdown");
Expand All @@ -61,6 +67,7 @@ protected void doRun() throws Exception {
});

assertTrue(rejected.get());
assertTrue(initialized.get());
}

public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown() throws Exception {
Expand Down

0 comments on commit f1b48de

Please sign in to comment.