Skip to content

Commit

Permalink
[ML] Persist data counts and datafeed timing stats asynchronously (el…
Browse files Browse the repository at this point in the history
…astic#93000)

When an anomaly detection job runs, the majority of results
originate from the C++ autodetect process, so can be persisted
in bulk. However, there are two types of results, namely data
counts and datafeed timing stats, that are generated wholly
within the ML Java code and where there are serious downsides
to batching them up with the output of the C++ process. (If
we batched them and the C++ process stopped generating results
then the input side stats would also stall, so it is better
that the input side stats are written independently.)

The approach used in this PR is to write data counts and
datafeed timing stats asynchronously _except_ at certain key
points, like job flush and close, and datafeed stop. At these
key points the latest stats _are_ persisted synchronously, like
before. When large amounts of data are being processed the code
will generate updated stats documents faster than they can be
indexed. The approach taken here is to skip persistence of the
newer document if persistence of the previous document is still
in progress. This can lead to the stats being slightly out of
date while a job is running. However, at key points like flush
and close the data counts will be up-to-date, and the datafeed
timing stats will get written at least once per datafeed
`frequency`, so should not be more out-of-date than that.
  • Loading branch information
droberts195 authored Jan 23, 2023
1 parent 603dbfa commit 69914bf
Show file tree
Hide file tree
Showing 14 changed files with 275 additions and 91 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/93000.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 93000
summary: Persist data counts and datafeed timing stats asynchronously
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ public void testGetDataCountsModelSizeAndTimingStatsWithSomeDocs() throws Except
storedDataCounts.incrementInputBytes(1L);
storedDataCounts.incrementMissingFieldCount(1L);
JobDataCountsPersister jobDataCountsPersister = new JobDataCountsPersister(client(), resultsPersisterService, auditor);
jobDataCountsPersister.persistDataCounts(job.getId(), storedDataCounts);
jobDataCountsPersister.persistDataCounts(job.getId(), storedDataCounts, true);
jobResultsPersister.commitWrites(job.getId(), JobResultsPersister.CommitType.RESULTS);

setOrThrow.get();
Expand Down Expand Up @@ -1046,9 +1046,9 @@ private void indexScheduledEvents(List<ScheduledEvent> events) throws IOExceptio
}
}

private void indexDataCounts(DataCounts counts, String jobId) {
private void indexDataCounts(DataCounts counts, String jobId) throws InterruptedException {
JobDataCountsPersister persister = new JobDataCountsPersister(client(), resultsPersisterService, auditor);
persister.persistDataCounts(jobId, counts);
persister.persistDataCounts(jobId, counts, true);
}

private void indexFilters(List<MlFilter> filters) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private void previewDatafeed(
job,
xContentRegistry,
// Fake DatafeedTimingStatsReporter that does not have access to results index
new DatafeedTimingStatsReporter(new DatafeedTimingStats(datafeedConfig.getJobId()), (ts, refreshPolicy) -> {}),
new DatafeedTimingStatsReporter(new DatafeedTimingStats(datafeedConfig.getJobId()), (ts, refreshPolicy, listener1) -> {}),
listener.delegateFailure(
(l, dataExtractorFactory) -> isDateNanos(
previewDatafeedConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ private void createDataExtractor(
job,
xContentRegistry,
// Fake DatafeedTimingStatsReporter that does not have access to results index
new DatafeedTimingStatsReporter(new DatafeedTimingStats(job.getId()), (ts, refreshPolicy) -> {}),
new DatafeedTimingStatsReporter(new DatafeedTimingStats(job.getId()), (ts, refreshPolicy, listener1) -> {}),
ActionListener.wrap(
unused -> persistentTasksService.sendStartRequest(
MlTasks.datafeedTaskId(params.getDatafeedId()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,47 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedTimingStats;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;

import java.util.Objects;
import java.util.concurrent.CountDownLatch;

/**
* {@link DatafeedTimingStatsReporter} class handles the logic of persisting {@link DatafeedTimingStats} if they changed significantly
* since the last time they were persisted.
*
* {@link DatafeedTimingStatsReporter} class handles the logic of persisting {@link DatafeedTimingStats} for a single job
* if they changed significantly since the last time they were persisted.
* <p>
* This class is not thread-safe.
*/
public class DatafeedTimingStatsReporter {

private static final Logger LOGGER = LogManager.getLogger(DatafeedTimingStatsReporter.class);
private static final Logger logger = LogManager.getLogger(DatafeedTimingStatsReporter.class);

/** Interface used for persisting current timing stats to the results index. */
@FunctionalInterface
public interface DatafeedTimingStatsPersister {
/** Does nothing by default. This behavior is useful when creating fake {@link DatafeedTimingStatsReporter} objects. */
void persistDatafeedTimingStats(DatafeedTimingStats timingStats, WriteRequest.RefreshPolicy refreshPolicy);
void persistDatafeedTimingStats(
DatafeedTimingStats timingStats,
WriteRequest.RefreshPolicy refreshPolicy,
ActionListener<BulkResponse> listener
);
}

/** Persisted timing stats. May be stale. */
private DatafeedTimingStats persistedTimingStats;
private volatile DatafeedTimingStats persistedTimingStats;
/** Current timing stats. */
private volatile DatafeedTimingStats currentTimingStats;
private final DatafeedTimingStats currentTimingStats;
/** Object used to persist current timing stats. */
private final DatafeedTimingStatsPersister persister;
/** Whether or not timing stats will be persisted by the persister object. */
private volatile boolean allowedPersisting;
/** Records whether a persist is currently in progress. */
private CountDownLatch persistInProgressLatch;

public DatafeedTimingStatsReporter(DatafeedTimingStats timingStats, DatafeedTimingStatsPersister persister) {
Objects.requireNonNull(timingStats);
Expand Down Expand Up @@ -81,9 +90,11 @@ public void reportDataCounts(DataCounts dataCounts) {

/** Finishes reporting of timing stats. Makes timing stats persisted immediately. */
public void finishReporting() {
// Don't flush if current timing stats are identical to the persisted ones
if (currentTimingStats.equals(persistedTimingStats) == false) {
flush(WriteRequest.RefreshPolicy.IMMEDIATE);
try {
flush(WriteRequest.RefreshPolicy.IMMEDIATE, true);
} catch (InterruptedException e) {
logger.warn("[{}] interrupted while finishing reporting of datafeed timing stats", currentTimingStats.getJobId());
Thread.currentThread().interrupt();
}
}

Expand All @@ -94,19 +105,50 @@ public void disallowPersisting() {

private void flushIfDifferSignificantly() {
if (differSignificantly(currentTimingStats, persistedTimingStats)) {
flush(WriteRequest.RefreshPolicy.NONE);
try {
flush(WriteRequest.RefreshPolicy.NONE, false);
} catch (InterruptedException e) {
assert false : "This should never happen when flush is called with mustWait set to false";
Thread.currentThread().interrupt();
}
}
}

private void flush(WriteRequest.RefreshPolicy refreshPolicy) {
persistedTimingStats = new DatafeedTimingStats(currentTimingStats);
if (allowedPersisting) {
try {
persister.persistDatafeedTimingStats(persistedTimingStats, refreshPolicy);
} catch (Exception ex) {
// Since persisting datafeed timing stats is not critical, we just log a warning here.
LOGGER.warn(() -> "[" + currentTimingStats.getJobId() + "] failed to report datafeed timing stats", ex);
}
private synchronized void flush(WriteRequest.RefreshPolicy refreshPolicy, boolean mustWait) throws InterruptedException {
String jobId = currentTimingStats.getJobId();
if (allowedPersisting && mustWait && persistInProgressLatch != null) {
persistInProgressLatch.await();
persistInProgressLatch = null;
}
// Don't persist if:
// 1. Persistence is disallowed
// 2. There is already a persist in progress
// 3. Current timing stats are identical to the persisted ones
if (allowedPersisting == false) {
logger.trace("[{}] not persisting datafeed timing stats as persistence is disallowed", jobId);
return;
}
if (persistInProgressLatch != null && persistInProgressLatch.getCount() > 0) {
logger.trace("[{}] not persisting datafeed timing stats as the previous persist is still in progress", jobId);
return;
}
if (currentTimingStats.equals(persistedTimingStats)) {
logger.trace("[{}] not persisting datafeed timing stats as they are identical to latest already persisted", jobId);
return;
}
final CountDownLatch thisPersistLatch = new CountDownLatch(1);
final DatafeedTimingStats thisPersistTimingStats = new DatafeedTimingStats(currentTimingStats);
persistInProgressLatch = thisPersistLatch;
persister.persistDatafeedTimingStats(thisPersistTimingStats, refreshPolicy, ActionListener.wrap(r -> {
persistedTimingStats = thisPersistTimingStats;
thisPersistLatch.countDown();
}, e -> {
thisPersistLatch.countDown();
// Since persisting datafeed timing stats is not critical, we just log a warning here.
logger.warn(() -> "[" + jobId + "] failed to report datafeed timing stats", e);
}));
if (mustWait) {
thisPersistLatch.await();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ public void revertSnapshot(
CheckedConsumer<Boolean, Exception> updateHandler = response -> {
if (response) {
ModelSizeStats revertedModelSizeStats = new ModelSizeStats.Builder(modelSizeStats).setLogTime(new Date()).build();
jobResultsPersister.persistModelSizeStats(
jobResultsPersister.persistModelSizeStatsWithoutRetries(
revertedModelSizeStats,
WriteRequest.RefreshPolicy.IMMEDIATE,
ActionListener.wrap(modelSizeStatsResponseHandler, actionListener::onFailure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@

import java.io.IOException;
import java.time.Instant;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;

import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

/**
* Update a job's dataCounts
* i.e. the number of processed records, fields etc.
* Updates job data counts, i.e. the number of processed records, fields etc.
* One instance of this class handles updates for all jobs.
*/
public class JobDataCountsPersister {

Expand All @@ -40,6 +43,8 @@ public class JobDataCountsPersister {
private final Client client;
private final AnomalyDetectionAuditor auditor;

private final Map<String, CountDownLatch> ongoingPersists = new ConcurrentHashMap<>();

public JobDataCountsPersister(Client client, ResultsPersisterService resultsPersisterService, AnomalyDetectionAuditor auditor) {
this.resultsPersisterService = resultsPersisterService;
this.client = client;
Expand All @@ -52,12 +57,27 @@ private static XContentBuilder serialiseCounts(DataCounts counts) throws IOExcep
}

/**
* Update the job's data counts stats and figures.
* NOTE: This call is synchronous and pauses the calling thread.
* @param jobId Job to update
* @param counts The counts
* Update a job's data counts stats and figures.
* If the previous call for the same job is still in progress
* @param jobId Job to update.
* @param counts The counts.
* @param mustWait Whether to wait for the counts to be persisted.
* This will involve waiting for the supplied counts
* and also potentially the previous counts to be
* persisted if that previous persist is still ongoing.
* @return <code>true</code> if the counts were sent for persistence, or <code>false</code>
* if the previous persist was still in progress.
*/
public void persistDataCounts(String jobId, DataCounts counts) {
public boolean persistDataCounts(String jobId, DataCounts counts, boolean mustWait) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
CountDownLatch previousLatch = ongoingPersists.putIfAbsent(jobId, latch);
while (previousLatch != null) {
if (mustWait == false) {
return false;
}
previousLatch.await();
previousLatch = ongoingPersists.putIfAbsent(jobId, latch);
}
counts.setLogTime(Instant.now());
try {
resultsPersisterService.indexWithRetry(
Expand All @@ -69,21 +89,29 @@ public void persistDataCounts(String jobId, DataCounts counts) {
DataCounts.documentId(jobId),
true,
() -> true,
retryMessage -> logger.debug("[{}] Job data_counts {}", jobId, retryMessage)
retryMessage -> logger.debug("[{}] Job data_counts {}", jobId, retryMessage),
ActionListener.wrap(r -> ongoingPersists.remove(jobId).countDown(), e -> {
ongoingPersists.remove(jobId).countDown();
logger.error(() -> "[" + jobId + "] Failed persisting data_counts stats", e);
auditor.error(jobId, "Failed persisting data_counts stats: " + e.getMessage());
})
);
} catch (IOException ioe) {
logger.error(() -> "[" + jobId + "] Failed writing data_counts stats", ioe);
} catch (Exception ex) {
logger.error(() -> "[" + jobId + "] Failed persisting data_counts stats", ex);
auditor.error(jobId, "Failed persisting data_counts stats: " + ex.getMessage());
} catch (IOException e) {
// An exception caught here basically means toXContent() failed, which should never happen
logger.error(() -> "[" + jobId + "] Failed writing data_counts stats", e);
return false;
}
if (mustWait) {
latch.await();
}
return true;
}

/**
* The same as {@link JobDataCountsPersister#persistDataCounts(String, DataCounts)} but done Asynchronously.
*
* Very similar to {@link JobDataCountsPersister#persistDataCounts(String, DataCounts, boolean)}.
* <p>
* Two differences are:
* - The listener is notified on persistence failure
* - The caller is notified on persistence failure
* - If the persistence fails, it is not automatically retried
* @param jobId Job to update
* @param counts The counts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -350,7 +351,7 @@ public void persistQuantiles(Quantiles quantiles, WriteRequest.RefreshPolicy ref

Persistable persistable = new Persistable(indexOrAlias, quantiles.getJobId(), quantiles, quantilesDocId);
persistable.setRefreshPolicy(refreshPolicy);
persistable.persist(listener, AnomalyDetectorsIndex.jobStateIndexWriteAlias().equals(indexOrAlias));
persistable.persistWithoutRetries(listener, AnomalyDetectorsIndex.jobStateIndexWriteAlias().equals(indexOrAlias));
}, listener::onFailure);

// Step 1: Search for existing quantiles document in .ml-state*
Expand Down Expand Up @@ -410,7 +411,7 @@ public void persistModelSizeStats(ModelSizeStats modelSizeStats, Supplier<Boolea
/**
* Persist the memory usage data
*/
public void persistModelSizeStats(
public void persistModelSizeStatsWithoutRetries(
ModelSizeStats modelSizeStats,
WriteRequest.RefreshPolicy refreshPolicy,
ActionListener<IndexResponse> listener
Expand All @@ -424,7 +425,7 @@ public void persistModelSizeStats(
modelSizeStats.getId()
);
persistable.setRefreshPolicy(refreshPolicy);
persistable.persist(listener, true);
persistable.persistWithoutRetries(listener, true);
}

/**
Expand Down Expand Up @@ -486,8 +487,13 @@ public void commitWrites(String jobId, Set<CommitType> commitTypes) {
*
* @param timingStats datafeed timing stats to persist
* @param refreshPolicy refresh policy to apply
* @param listener listener for response or error
*/
public BulkResponse persistDatafeedTimingStats(DatafeedTimingStats timingStats, WriteRequest.RefreshPolicy refreshPolicy) {
public void persistDatafeedTimingStats(
DatafeedTimingStats timingStats,
WriteRequest.RefreshPolicy refreshPolicy,
ActionListener<BulkResponse> listener
) {
String jobId = timingStats.getJobId();
logger.trace("[{}] Persisting datafeed timing stats", jobId);
Persistable persistable = new Persistable(
Expand All @@ -498,7 +504,7 @@ public BulkResponse persistDatafeedTimingStats(DatafeedTimingStats timingStats,
DatafeedTimingStats.documentId(timingStats.getJobId())
);
persistable.setRefreshPolicy(refreshPolicy);
return persistable.persist(() -> true, true);
persistable.persist(() -> true, true, listener);
}

private static XContentBuilder toXContentBuilder(ToXContent obj, ToXContent.Params params) throws IOException {
Expand Down Expand Up @@ -534,9 +540,15 @@ void setRefreshPolicy(WriteRequest.RefreshPolicy refreshPolicy) {
}

BulkResponse persist(Supplier<Boolean> shouldRetry, boolean requireAlias) {
final PlainActionFuture<BulkResponse> getResponseFuture = PlainActionFuture.newFuture();
persist(shouldRetry, requireAlias, getResponseFuture);
return getResponseFuture.actionGet();
}

void persist(Supplier<Boolean> shouldRetry, boolean requireAlias, ActionListener<BulkResponse> listener) {
logCall();
try {
return resultsPersisterService.indexWithRetry(
resultsPersisterService.indexWithRetry(
jobId,
indexName,
object,
Expand All @@ -545,20 +557,23 @@ BulkResponse persist(Supplier<Boolean> shouldRetry, boolean requireAlias) {
id,
requireAlias,
shouldRetry,
retryMessage -> logger.debug("[{}] {} {}", jobId, id, retryMessage)
retryMessage -> logger.debug("[{}] {} {}", jobId, id, retryMessage),
listener
);
} catch (IOException e) {
logger.error(() -> format("[%s] Error writing [%s]", jobId, (id == null) ? "auto-generated ID" : id), e);
IndexResponse.Builder notCreatedResponse = new IndexResponse.Builder();
notCreatedResponse.setResult(Result.NOOP);
return new BulkResponse(
new BulkItemResponse[] { BulkItemResponse.success(0, DocWriteRequest.OpType.INDEX, notCreatedResponse.build()) },
0
listener.onResponse(
new BulkResponse(
new BulkItemResponse[] { BulkItemResponse.success(0, DocWriteRequest.OpType.INDEX, notCreatedResponse.build()) },
0
)
);
}
}

void persist(ActionListener<IndexResponse> listener, boolean requireAlias) {
void persistWithoutRetries(ActionListener<IndexResponse> listener, boolean requireAlias) {
logCall();

try (XContentBuilder content = toXContentBuilder(object, params)) {
Expand All @@ -585,5 +600,4 @@ private void logCall() {
}
}
}

}
Loading

0 comments on commit 69914bf

Please sign in to comment.