From 1bf147c0d9ce73b519bcca5a973c72db87f97707 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Sat, 19 Dec 2020 21:12:32 -0800 Subject: [PATCH 1/5] add AD task cache --- .../ad/AnomalyDetectorPlugin.java | 3 +- .../ad/MemoryTracker.java | 3 +- .../exception/LimitExceededException.java | 9 + .../ad/settings/AnomalyDetectorSettings.java | 14 + .../ad/task/ADBatchTaskCache.java | 131 ++++++++ .../ad/task/ADTaskCacheManager.java | 306 ++++++++++++++++++ .../ADStatsNodesTransportAction.java | 1 + .../ad/TestHelpers.java | 53 ++- .../ad/task/ADTaskCacheManagerTests.java | 173 ++++++++++ 9 files changed, 682 insertions(+), 11 deletions(-) create mode 100644 src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADBatchTaskCache.java create mode 100644 src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java create mode 100644 src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java index bc4e9fda..62a4d503 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/AnomalyDetectorPlugin.java @@ -570,7 +570,8 @@ public List> getSettings() { AnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, AnomalyDetectorSettings.MAX_PRIMARY_SHARDS, AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, - AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND + AnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, + AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE ); return unmodifiableList(Stream.concat(enabledSetting.stream(), systemSetting.stream()).collect(Collectors.toList())); } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/MemoryTracker.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/MemoryTracker.java index dfd167c9..3c906117 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/MemoryTracker.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/MemoryTracker.java @@ -38,7 +38,8 @@ public class MemoryTracker { public enum Origin { SINGLE_ENTITY_DETECTOR, - MULTI_ENTITY_DETECTOR + MULTI_ENTITY_DETECTOR, + HISTORICAL_SINGLE_ENTITY_DETECTOR, } // memory tracker for total consumption of bytes diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/common/exception/LimitExceededException.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/common/exception/LimitExceededException.java index 1ee0f28e..038133df 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/common/exception/LimitExceededException.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/common/exception/LimitExceededException.java @@ -30,6 +30,15 @@ public LimitExceededException(String anomalyDetectorId, String message) { super(anomalyDetectorId, message, true); } + /** + * Constructor with error message. + * + * @param message explanation for the limit + */ + public LimitExceededException(String message) { + super(null, message, true); + } + /** * Constructor with an anomaly detector ID and an explanation, and a flag for stopping. * diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java index 2c0e7fbd..7077b3a6 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/settings/AnomalyDetectorSettings.java @@ -321,4 +321,18 @@ private AnomalyDetectorSettings() {} Setting.Property.NodeScope, Setting.Property.Dynamic ); + + // Maximum number of batch tasks running on one node. + // TODO: performance test and tune the setting. + public static final Setting MAX_BATCH_TASK_PER_NODE = Setting + .intSetting( + "opendistro.anomaly_detection.max_batch_task_per_node", + 2, + 1, + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static int THRESHOLD_MODEL_TRAINING_SIZE = 1000; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADBatchTaskCache.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADBatchTaskCache.java new file mode 100644 index 00000000..bf6a1629 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADBatchTaskCache.java @@ -0,0 +1,131 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.task; + +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.NUM_TREES; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.THRESHOLD_MODEL_TRAINING_SIZE; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.TIME_DECAY; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +import com.amazon.opendistroforelasticsearch.ad.ml.HybridThresholdingModel; +import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingModel; +import com.amazon.opendistroforelasticsearch.ad.model.ADTask; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; +import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; +import com.amazon.randomcutforest.RandomCutForest; + +/** + * AD batch task cache which will hold RCF, threshold model, shingle and training data. + */ +public class ADBatchTaskCache { + private final String detectorId; + private RandomCutForest rcfModel; + private ThresholdingModel thresholdModel; + private boolean thresholdModelTrained; + private Deque>> shingle; + private List thresholdModelTrainingData; + private AtomicBoolean cancelled = new AtomicBoolean(false); + private AtomicLong cacheMemorySize = new AtomicLong(0); + private String cancelReason; + private String cancelledBy; + + protected ADBatchTaskCache(ADTask adTask) { + this.detectorId = adTask.getDetectorId(); + + AnomalyDetector detector = adTask.getDetector(); + rcfModel = RandomCutForest + .builder() + .dimensions(detector.getShingleSize() * detector.getEnabledFeatureIds().size()) + .numberOfTrees(NUM_TREES) + .lambda(TIME_DECAY) + .sampleSize(NUM_SAMPLES_PER_TREE) + .outputAfter(NUM_MIN_SAMPLES) + .parallelExecutionEnabled(false) + .build(); + + this.thresholdModel = new HybridThresholdingModel( + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.THRESHOLD_MAX_RANK_ERROR, + AnomalyDetectorSettings.THRESHOLD_MAX_SCORE, + AnomalyDetectorSettings.THRESHOLD_NUM_LOGNORMAL_QUANTILES, + AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES, + AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES + ); + this.thresholdModelTrainingData = new ArrayList<>(THRESHOLD_MODEL_TRAINING_SIZE); + this.thresholdModelTrained = false; + this.shingle = new ArrayDeque<>(detector.getShingleSize()); + } + + protected String getDetectorId() { + return detectorId; + } + + protected RandomCutForest getRcfModel() { + return rcfModel; + } + + protected Deque>> getShingle() { + return shingle; + } + + protected ThresholdingModel getThresholdModel() { + return thresholdModel; + } + + protected void setThresholdModelTrained(boolean thresholdModelTrained) { + this.thresholdModelTrained = thresholdModelTrained; + } + + protected boolean isThresholdModelTrained() { + return thresholdModelTrained; + } + + protected List getThresholdModelTrainingData() { + return thresholdModelTrainingData; + } + + protected AtomicLong getCacheMemorySize() { + return cacheMemorySize; + } + + protected boolean isCancelled() { + return cancelled.get(); + } + + protected String getCancelReason() { + return cancelReason; + } + + protected String getCancelledBy() { + return cancelledBy; + } + + protected void cancel(String reason, String userName) { + this.cancelled.compareAndSet(false, true); + this.cancelReason = reason; + this.cancelledBy = userName; + } +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java new file mode 100644 index 00000000..bf0802e4 --- /dev/null +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java @@ -0,0 +1,306 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.task; + +import static com.amazon.opendistroforelasticsearch.ad.MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.NUM_TREES; +import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.THRESHOLD_MODEL_TRAINING_SIZE; + +import java.util.Deque; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; + +import com.amazon.opendistroforelasticsearch.ad.MemoryTracker; +import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; +import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingModel; +import com.amazon.opendistroforelasticsearch.ad.model.ADTask; +import com.amazon.randomcutforest.RandomCutForest; + +public class ADTaskCacheManager { + + private final Map taskCaches; + private volatile Integer maxAdBatchTaskPerNode; + private final MemoryTracker memoryTracker; + + /** + * Constructor to create AD task cache manager. + * + * @param settings ES settings + * @param clusterService ES cluster service + * @param memoryTracker AD memory tracker + */ + public ADTaskCacheManager(Settings settings, ClusterService clusterService, MemoryTracker memoryTracker) { + this.maxAdBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); + taskCaches = new ConcurrentHashMap<>(); + this.memoryTracker = memoryTracker; + } + + /** + * Put AD task into cache. + * + * @param adTask AD task + */ + public synchronized void put(ADTask adTask) { + String taskId = adTask.getTaskId(); + if (contains(taskId)) { + throw new IllegalArgumentException("AD task is already running"); + } + if (containsTaskOfDetector(adTask.getDetectorId())) { + throw new IllegalArgumentException("There is one task executing for detector"); + } + checkRunningTaskLimit(); + long neededCacheSize = calculateADTaskCacheSize(adTask); + if (!memoryTracker.canAllocate(neededCacheSize)) { + throw new LimitExceededException("No enough memory to run detector"); + } + memoryTracker.consumeMemory(neededCacheSize, false, HISTORICAL_SINGLE_ENTITY_DETECTOR); + ADBatchTaskCache taskCache = new ADBatchTaskCache(adTask); + taskCache.getCacheMemorySize().set(neededCacheSize); + taskCaches.put(taskId, taskCache); + } + + /** + * check if current running batch task on current node exceeds + * max running task limitation. + */ + public void checkRunningTaskLimit() { + if (size() >= maxAdBatchTaskPerNode) { + String error = "Can't run more than " + maxAdBatchTaskPerNode + " historical detectors per data node"; + throw new LimitExceededException(error); + } + } + + /** + * Get task RCF model. + * If task doesn't exist in cache, will throw {@link java.lang.IllegalArgumentException}. + * + * @param taskId AD task id + * @return RCF model + */ + public RandomCutForest getRcfModel(String taskId) { + return getBatchTaskCache(taskId).getRcfModel(); + } + + /** + * Get task threshold model. + * If task doesn't exist in cache, will throw {@link java.lang.IllegalArgumentException}. + * + * @param taskId AD task id + * @return threshold model + */ + public ThresholdingModel getThresholdModel(String taskId) { + return getBatchTaskCache(taskId).getThresholdModel(); + } + + /** + * Get threshold model training data. + * If task doesn't exist in cache, will throw {@link java.lang.IllegalArgumentException}. + * + * @param taskId AD task id + * @return threshold model training data + */ + public List getThresholdModelTrainingData(String taskId) { + return getBatchTaskCache(taskId).getThresholdModelTrainingData(); + } + + /** + * Threshold model trained or not. + * If task doesn't exist in cache, will throw {@link java.lang.IllegalArgumentException}. + * + * @param taskId AD task id + * @return true if threshold model trained; otherwise, return false + */ + public boolean isThresholdModelTrained(String taskId) { + return getBatchTaskCache(taskId).isThresholdModelTrained(); + } + + /** + * Set threshold model trained or not. + * + * @param taskId task id + * @param trained threshold model trained or not + */ + public void setThresholdModelTrained(String taskId, boolean trained) { + ADBatchTaskCache taskCache = getBatchTaskCache(taskId); + taskCache.setThresholdModelTrained(trained); + if (trained) { + int size = taskCache.getThresholdModelTrainingData().size(); + long cacheSize = trainingDataMemorySize(size); + taskCache.getThresholdModelTrainingData().clear(); + taskCache.getCacheMemorySize().getAndAdd(-cacheSize); + memoryTracker.releaseMemory(cacheSize, false, HISTORICAL_SINGLE_ENTITY_DETECTOR); + } + } + + /** + * Get shingle data. + * + * @param taskId AD task id + * @return shingle data + */ + public Deque>> getShingle(String taskId) { + return getBatchTaskCache(taskId).getShingle(); + } + + /** + * Check if task exists in cache. + * + * @param taskId task id + * @return true if task exists in cache; otherwise, return false. + */ + public boolean contains(String taskId) { + return taskCaches.containsKey(taskId); + } + + /** + * Check if there is task in cache for detector. + * + * @param detectorId detector id + * @return true if there is task in cache; otherwise return false + */ + public boolean containsTaskOfDetector(String detectorId) { + return taskCaches.values().stream().filter(v -> Objects.equals(detectorId, v.getDetectorId())).findAny().isPresent(); + } + + /** + * Get batch task cache. If task doesn't exist in cache, will throw + * {@link java.lang.IllegalArgumentException} + * + * @param taskId task id + * @return AD batch task cache + */ + private ADBatchTaskCache getBatchTaskCache(String taskId) { + if (!contains(taskId)) { + throw new IllegalArgumentException("AD task not in cache"); + } + return taskCaches.get(taskId); + } + + /** + * Calculate AD task cache memory usage. + * + * @param adTask AD task + * @return how many bytes will consume + */ + private long calculateADTaskCacheSize(ADTask adTask) { + return memoryTracker.estimateModelSize(adTask.getDetector(), NUM_TREES) + trainingDataMemorySize(THRESHOLD_MODEL_TRAINING_SIZE) + + shingleMemorySize(adTask.getDetector().getShingleSize()); + } + + /** + * Remove task from cache. + * + * @param taskId AD task id + */ + public void remove(String taskId) { + if (contains(taskId)) { + memoryTracker.releaseMemory(getBatchTaskCache(taskId).getCacheMemorySize().get(), false, HISTORICAL_SINGLE_ENTITY_DETECTOR); + taskCaches.remove(taskId); + } + } + + /** + * Cancel AD task. + * + * @param taskId AD task id + * @param reason why need to cancel task + * @param userName user name + */ + public void cancel(String taskId, String reason, String userName) { + getBatchTaskCache(taskId).cancel(reason, userName); + } + + /** + * Task is cancelled or not. + * + * @param taskId AD task id + * @return true if task is cancelled; otherwise return false + */ + public boolean isCancelled(String taskId) { + ADBatchTaskCache taskCache = getBatchTaskCache(taskId); + return taskCache.isCancelled(); + } + + /** + * Get why task cancelled. + * + * @param taskId AD task id + * @return task cancellation reason + */ + public String getCancelReason(String taskId) { + return getBatchTaskCache(taskId).getCancelReason(); + } + + /** + * Get task cancelled by which user. + * + * @param taskId AD task id + * @return user name + */ + public String getCancelledBy(String taskId) { + return getBatchTaskCache(taskId).getCancelledBy(); + } + + /** + * Get current task count in cache. + * + * @return task count + */ + public int size() { + return taskCaches.size(); + } + + /** + * Clear all tasks. + */ + public void clear() { + taskCaches.clear(); + } + + /** + * Estimate max memory usage of model training data. + * The training data is double and will cache in {@link java.util.ArrayList}. + * Check {@link ADBatchTaskCache#getThresholdModelTrainingData()} + * + * @param size training data point count + * @return how many bytes will consume + */ + public long trainingDataMemorySize(int size) { + return 24 * size; + } + + /** + * Estimate max memory usage of shingle data. + * Based on the test, one shingle data point consumes about 96 bytes. + * The shingle data is stored in {@link java.util.Deque} + * Check {@link ADBatchTaskCache#getShingle()} + * + * @param shingleSize shingle data point count + * @return how many bytes will consume + */ + public long shingleMemorySize(int shingleSize) { + return 96 * shingleSize; + } + +} diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportAction.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportAction.java index bec6b026..5abd6183 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportAction.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/transport/ADStatsNodesTransportAction.java @@ -51,6 +51,7 @@ public class ADStatsNodesTransportAction extends * @param transportService TransportService * @param actionFilters Action Filters * @param adStats ADStats object + * @param jvmService ES JVM Service */ @Inject public ADStatsNodesTransportAction( diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java index c084aa08..f844a61b 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/TestHelpers.java @@ -20,6 +20,7 @@ import static org.elasticsearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; +import static org.elasticsearch.test.ESTestCase.randomBoolean; import static org.elasticsearch.test.ESTestCase.randomDouble; import static org.elasticsearch.test.ESTestCase.randomInt; import static org.elasticsearch.test.ESTestCase.randomIntBetween; @@ -415,14 +416,7 @@ public static Feature randomFeature() { } public static Feature randomFeature(String featureName, String aggregationName) { - AggregationBuilder testAggregation = null; - try { - testAggregation = randomAggregation(aggregationName); - } catch (IOException e) { - logger.error("Fail to generate test aggregation"); - throw new RuntimeException(); - } - return new Feature(randomAlphaOfLength(5), featureName, ESRestTestCase.randomBoolean(), testAggregation); + return randomFeature(featureName, aggregationName, randomBoolean()); } public static Feature randomFeature(boolean enabled) { @@ -739,10 +733,51 @@ public static Map>> create return mappings; } + public static ADTask randomAdTask() throws IOException { + return randomAdTask( + randomAlphaOfLength(5), + ADTaskState.RUNNING, + Instant.now().truncatedTo(ChronoUnit.SECONDS), + randomAlphaOfLength(5), + true + ); + } + + public static ADTask randomAdTask( + String taskId, + ADTaskState state, + Instant executionEndTime, + String stoppedBy, + String detectorId, + AnomalyDetector detector + ) { + executionEndTime = executionEndTime == null ? null : executionEndTime.truncatedTo(ChronoUnit.SECONDS); + ADTask task = ADTask + .builder() + .taskId(taskId) + .taskType(ADTaskType.HISTORICAL.name()) + .detectorId(detectorId) + .detector(detector) + .state(state.name()) + .taskProgress(0.5f) + .initProgress(1.0f) + .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) + .executionStartTime(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(100, ChronoUnit.MINUTES)) + .executionEndTime(executionEndTime) + .isLatest(true) + .error(randomAlphaOfLength(5)) + .checkpointId(randomAlphaOfLength(5)) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .startedBy(randomAlphaOfLength(5)) + .stoppedBy(stoppedBy) + .build(); + return task; + } + public static ADTask randomAdTask(String taskId, ADTaskState state, Instant executionEndTime, String stoppedBy, boolean withDetector) throws IOException { AnomalyDetector detector = withDetector - ? randomAnomalyDetector(ImmutableMap.of(), Instant.now().truncatedTo(ChronoUnit.SECONDS)) + ? randomAnomalyDetector(ImmutableMap.of(), Instant.now().truncatedTo(ChronoUnit.SECONDS), true) : null; executionEndTime = executionEndTime == null ? null : executionEndTime.truncatedTo(ChronoUnit.SECONDS); ADTask task = ADTask diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java new file mode 100644 index 00000000..99ee8595 --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java @@ -0,0 +1,173 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.ad.task; + +import static com.amazon.opendistroforelasticsearch.ad.MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import com.amazon.opendistroforelasticsearch.ad.MemoryTracker; +import com.amazon.opendistroforelasticsearch.ad.TestHelpers; +import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; +import com.amazon.opendistroforelasticsearch.ad.model.ADTask; +import com.amazon.opendistroforelasticsearch.ad.model.ADTaskState; +import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; +import com.google.common.collect.ImmutableList; + +public class ADTaskCacheManagerTests extends ESTestCase { + private MemoryTracker memoryTracker; + private ADTaskCacheManager adTaskCacheManager; + private ClusterService clusterService; + private Settings settings; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + settings = Settings.builder().put(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.getKey(), 2).build(); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + memoryTracker = mock(MemoryTracker.class); + adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + adTaskCacheManager.clear(); + } + + public void testPutTask() throws IOException { + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + ADTask adTask = TestHelpers.randomAdTask(); + adTaskCacheManager.put(adTask); + assertEquals(1, adTaskCacheManager.size()); + assertTrue(adTaskCacheManager.contains(adTask.getTaskId())); + assertTrue(adTaskCacheManager.containsTaskOfDetector(adTask.getDetectorId())); + assertNotNull(adTaskCacheManager.getRcfModel(adTask.getTaskId())); + assertNotNull(adTaskCacheManager.getShingle(adTask.getTaskId())); + assertNotNull(adTaskCacheManager.getThresholdModel(adTask.getTaskId())); + assertNotNull(adTaskCacheManager.getThresholdModelTrainingData(adTask.getTaskId())); + assertFalse(adTaskCacheManager.isThresholdModelTrained(adTask.getTaskId())); + adTaskCacheManager.remove(adTask.getTaskId()); + assertEquals(0, adTaskCacheManager.size()); + } + + public void testPutDuplicateTask() throws IOException { + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + ADTask adTask1 = TestHelpers.randomAdTask(); + adTaskCacheManager.put(adTask1); + assertEquals(1, adTaskCacheManager.size()); + IllegalArgumentException e1 = expectThrows(IllegalArgumentException.class, () -> adTaskCacheManager.put(adTask1)); + assertEquals("AD task is already running", e1.getMessage()); + + ADTask adTask2 = TestHelpers + .randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + adTask1.getExecutionEndTime(), + adTask1.getStoppedBy(), + adTask1.getDetectorId(), + adTask1.getDetector() + ); + IllegalArgumentException e2 = expectThrows(IllegalArgumentException.class, () -> adTaskCacheManager.put(adTask2)); + assertEquals("There is one task executing for detector", e2.getMessage()); + } + + public void testPutTaskWithMemoryExceedLimit() { + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + LimitExceededException exception = expectThrows( + LimitExceededException.class, + () -> adTaskCacheManager.put(TestHelpers.randomAdTask()) + ); + assertEquals("No enough memory to run detector", exception.getMessage()); + } + + public void testThresholdModelTrained() throws IOException { + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + ADTask adTask = TestHelpers.randomAdTask(); + adTaskCacheManager.put(adTask); + assertEquals(1, adTaskCacheManager.size()); + adTaskCacheManager.getThresholdModelTrainingData(adTask.getTaskId()).addAll(ImmutableList.of(randomDouble(), randomDouble())); + int size = adTaskCacheManager.getThresholdModelTrainingData(adTask.getTaskId()).size(); + long cacheSize = adTaskCacheManager.trainingDataMemorySize(size); + adTaskCacheManager.setThresholdModelTrained(adTask.getTaskId(), false); + verify(memoryTracker, never()).releaseMemory(anyLong(), anyBoolean(), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); + adTaskCacheManager.setThresholdModelTrained(adTask.getTaskId(), true); + verify(memoryTracker, times(1)).releaseMemory(eq(cacheSize), eq(false), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); + } + + public void testCancel() throws IOException { + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + ADTask adTask = TestHelpers.randomAdTask(); + adTaskCacheManager.put(adTask); + assertEquals(1, adTaskCacheManager.size()); + assertEquals(false, adTaskCacheManager.isCancelled(adTask.getTaskId())); + String cancelReason = randomAlphaOfLength(10); + String userName = randomAlphaOfLength(5); + adTaskCacheManager.cancel(adTask.getTaskId(), cancelReason, userName); + assertEquals(true, adTaskCacheManager.isCancelled(adTask.getTaskId())); + assertEquals(cancelReason, adTaskCacheManager.getCancelReason(adTask.getTaskId())); + assertEquals(userName, adTaskCacheManager.getCancelledBy(adTask.getTaskId())); + } + + public void testTaskNotExist() { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> adTaskCacheManager.getRcfModel(randomAlphaOfLength(5)) + ); + assertEquals("AD task not in cache", e.getMessage()); + } + + public void testRemoveTaskWhichNotExist() { + adTaskCacheManager.remove(randomAlphaOfLength(5)); + verify(memoryTracker, never()).releaseMemory(anyLong(), anyBoolean(), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); + } + + public void testExceedRunningTaskLimit() throws IOException { + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + adTaskCacheManager.put(TestHelpers.randomAdTask()); + adTaskCacheManager.put(TestHelpers.randomAdTask()); + assertEquals(2, adTaskCacheManager.size()); + LimitExceededException e = expectThrows(LimitExceededException.class, () -> adTaskCacheManager.put(TestHelpers.randomAdTask())); + assertEquals("Can't run more than 2 historical detectors per data node", e.getMessage()); + } +} From aaced0b3feceec73cc6b0de316ba9c48045a8456 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 21 Dec 2020 14:06:10 -0800 Subject: [PATCH 2/5] add java doc for exception --- .../ad/task/ADTaskCacheManager.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java index bf0802e4..39e134c7 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java @@ -58,6 +58,9 @@ public ADTaskCacheManager(Settings settings, ClusterService clusterService, Memo /** * Put AD task into cache. + * If AD task is already in cache, will throw {@link IllegalArgumentException} + * If there is one AD task in cache for detector, will throw {@link IllegalArgumentException} + * If there is no enough memory for this AD task, will throw {@link LimitExceededException} * * @param adTask AD task */ @@ -83,6 +86,8 @@ public synchronized void put(ADTask adTask) { /** * check if current running batch task on current node exceeds * max running task limitation. + * If executing task count exceeds limitation, will throw + * {@link LimitExceededException} */ public void checkRunningTaskLimit() { if (size() >= maxAdBatchTaskPerNode) { From de29b8a937ef9125bc2c66dad02059b04622822e Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 21 Dec 2020 14:23:26 -0800 Subject: [PATCH 3/5] change to reserved memory --- .../ad/task/ADTaskCacheManager.java | 4 ++-- .../ad/task/ADTaskCacheManagerTests.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java index 39e134c7..10fb0bdf 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java @@ -154,7 +154,7 @@ public void setThresholdModelTrained(String taskId, boolean trained) { long cacheSize = trainingDataMemorySize(size); taskCache.getThresholdModelTrainingData().clear(); taskCache.getCacheMemorySize().getAndAdd(-cacheSize); - memoryTracker.releaseMemory(cacheSize, false, HISTORICAL_SINGLE_ENTITY_DETECTOR); + memoryTracker.releaseMemory(cacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR); } } @@ -220,7 +220,7 @@ private long calculateADTaskCacheSize(ADTask adTask) { */ public void remove(String taskId) { if (contains(taskId)) { - memoryTracker.releaseMemory(getBatchTaskCache(taskId).getCacheMemorySize().get(), false, HISTORICAL_SINGLE_ENTITY_DETECTOR); + memoryTracker.releaseMemory(getBatchTaskCache(taskId).getCacheMemorySize().get(), true, HISTORICAL_SINGLE_ENTITY_DETECTOR); taskCaches.remove(taskId); } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java index 99ee8595..1990dfda 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java @@ -132,7 +132,7 @@ public void testThresholdModelTrained() throws IOException { adTaskCacheManager.setThresholdModelTrained(adTask.getTaskId(), false); verify(memoryTracker, never()).releaseMemory(anyLong(), anyBoolean(), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); adTaskCacheManager.setThresholdModelTrained(adTask.getTaskId(), true); - verify(memoryTracker, times(1)).releaseMemory(eq(cacheSize), eq(false), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); + verify(memoryTracker, times(1)).releaseMemory(eq(cacheSize), eq(true), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); } public void testCancel() throws IOException { From d48182d29dee44991ba9c6a0afd3b8b3e0aa7fb9 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 22 Dec 2020 12:48:30 -0800 Subject: [PATCH 4/5] fix shingle memory calculation;store threshold model training data in double array --- .../ad/task/ADBatchTaskCache.java | 19 ++++++--- .../ad/task/ADTaskCacheManager.java | 42 ++++++++++++------- .../ad/task/ADTaskCacheManagerTests.java | 4 +- 3 files changed, 43 insertions(+), 22 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADBatchTaskCache.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADBatchTaskCache.java index bf6a1629..a301370a 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADBatchTaskCache.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADBatchTaskCache.java @@ -22,12 +22,11 @@ import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.TIME_DECAY; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.Deque; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import com.amazon.opendistroforelasticsearch.ad.ml.HybridThresholdingModel; @@ -46,7 +45,8 @@ public class ADBatchTaskCache { private ThresholdingModel thresholdModel; private boolean thresholdModelTrained; private Deque>> shingle; - private List thresholdModelTrainingData; + private AtomicInteger thresholdModelTrainingDataSize = new AtomicInteger(0); + private double[] thresholdModelTrainingData; private AtomicBoolean cancelled = new AtomicBoolean(false); private AtomicLong cacheMemorySize = new AtomicLong(0); private String cancelReason; @@ -74,7 +74,7 @@ protected ADBatchTaskCache(ADTask adTask) { AnomalyDetectorSettings.THRESHOLD_DOWNSAMPLES, AnomalyDetectorSettings.THRESHOLD_MAX_SAMPLES ); - this.thresholdModelTrainingData = new ArrayList<>(THRESHOLD_MODEL_TRAINING_SIZE); + this.thresholdModelTrainingData = new double[THRESHOLD_MODEL_TRAINING_SIZE]; this.thresholdModelTrained = false; this.shingle = new ArrayDeque<>(detector.getShingleSize()); } @@ -103,10 +103,19 @@ protected boolean isThresholdModelTrained() { return thresholdModelTrained; } - protected List getThresholdModelTrainingData() { + protected double[] getThresholdModelTrainingData() { return thresholdModelTrainingData; } + protected void clearTrainingData() { + this.thresholdModelTrainingData = null; + this.thresholdModelTrainingDataSize.set(0); + } + + public AtomicInteger getThresholdModelTrainingDataSize() { + return thresholdModelTrainingDataSize; + } + protected AtomicLong getCacheMemorySize() { return cacheMemorySize; } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java index 10fb0bdf..73e45de8 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java @@ -21,11 +21,11 @@ import static com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings.THRESHOLD_MODEL_TRAINING_SIZE; import java.util.Deque; -import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; @@ -34,6 +34,7 @@ import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException; import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingModel; import com.amazon.opendistroforelasticsearch.ad.model.ADTask; +import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector; import com.amazon.randomcutforest.RandomCutForest; public class ADTaskCacheManager { @@ -125,10 +126,19 @@ public ThresholdingModel getThresholdModel(String taskId) { * @param taskId AD task id * @return threshold model training data */ - public List getThresholdModelTrainingData(String taskId) { + public double[] getThresholdModelTrainingData(String taskId) { return getBatchTaskCache(taskId).getThresholdModelTrainingData(); } + public int addThresholdModelTrainingData(String taskId, double... data) { + ADBatchTaskCache taskCache = getBatchTaskCache(taskId); + double[] thresholdModelTrainingData = taskCache.getThresholdModelTrainingData(); + AtomicInteger size = taskCache.getThresholdModelTrainingDataSize(); + int dataPointsAdded = Math.min(data.length, THRESHOLD_MODEL_TRAINING_SIZE - size.get()); + System.arraycopy(data, 0, thresholdModelTrainingData, size.get(), dataPointsAdded); + return size.addAndGet(dataPointsAdded); + } + /** * Threshold model trained or not. * If task doesn't exist in cache, will throw {@link java.lang.IllegalArgumentException}. @@ -146,13 +156,13 @@ public boolean isThresholdModelTrained(String taskId) { * @param taskId task id * @param trained threshold model trained or not */ - public void setThresholdModelTrained(String taskId, boolean trained) { + protected void setThresholdModelTrained(String taskId, boolean trained) { ADBatchTaskCache taskCache = getBatchTaskCache(taskId); taskCache.setThresholdModelTrained(trained); if (trained) { - int size = taskCache.getThresholdModelTrainingData().size(); + int size = taskCache.getThresholdModelTrainingDataSize().get(); long cacheSize = trainingDataMemorySize(size); - taskCache.getThresholdModelTrainingData().clear(); + taskCache.clearTrainingData(); taskCache.getCacheMemorySize().getAndAdd(-cacheSize); memoryTracker.releaseMemory(cacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR); } @@ -209,8 +219,9 @@ private ADBatchTaskCache getBatchTaskCache(String taskId) { * @return how many bytes will consume */ private long calculateADTaskCacheSize(ADTask adTask) { - return memoryTracker.estimateModelSize(adTask.getDetector(), NUM_TREES) + trainingDataMemorySize(THRESHOLD_MODEL_TRAINING_SIZE) - + shingleMemorySize(adTask.getDetector().getShingleSize()); + AnomalyDetector detector = adTask.getDetector(); + return memoryTracker.estimateModelSize(detector, NUM_TREES) + trainingDataMemorySize(THRESHOLD_MODEL_TRAINING_SIZE) + + shingleMemorySize(detector.getShingleSize(), detector.getEnabledFeatureIds().size()); } /** @@ -285,27 +296,30 @@ public void clear() { /** * Estimate max memory usage of model training data. - * The training data is double and will cache in {@link java.util.ArrayList}. - * Check {@link ADBatchTaskCache#getThresholdModelTrainingData()} + * The training data is double and will cache in double array. + * One double consumes 8 bytes. * * @param size training data point count * @return how many bytes will consume */ public long trainingDataMemorySize(int size) { - return 24 * size; + return 8 * size; } /** * Estimate max memory usage of shingle data. - * Based on the test, one shingle data point consumes about 96 bytes. - * The shingle data is stored in {@link java.util.Deque} + * One feature aggregated data point(double) consumes 8 bytes. + * The shingle data is stored in {@link java.util.Deque}. From testing, + * other parts except feature data consume 80 bytes. + * * Check {@link ADBatchTaskCache#getShingle()} * * @param shingleSize shingle data point count + * @param enabledFeatureSize enabled feature count * @return how many bytes will consume */ - public long shingleMemorySize(int shingleSize) { - return 96 * shingleSize; + public long shingleMemorySize(int shingleSize, int enabledFeatureSize) { + return (80 + 8 * enabledFeatureSize) * shingleSize; } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java index 1990dfda..93f144f4 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java @@ -43,7 +43,6 @@ import com.amazon.opendistroforelasticsearch.ad.model.ADTask; import com.amazon.opendistroforelasticsearch.ad.model.ADTaskState; import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings; -import com.google.common.collect.ImmutableList; public class ADTaskCacheManagerTests extends ESTestCase { private MemoryTracker memoryTracker; @@ -126,8 +125,7 @@ public void testThresholdModelTrained() throws IOException { ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.put(adTask); assertEquals(1, adTaskCacheManager.size()); - adTaskCacheManager.getThresholdModelTrainingData(adTask.getTaskId()).addAll(ImmutableList.of(randomDouble(), randomDouble())); - int size = adTaskCacheManager.getThresholdModelTrainingData(adTask.getTaskId()).size(); + int size = adTaskCacheManager.addThresholdModelTrainingData(adTask.getTaskId(), randomDouble(), randomDouble()); long cacheSize = adTaskCacheManager.trainingDataMemorySize(size); adTaskCacheManager.setThresholdModelTrained(adTask.getTaskId(), false); verify(memoryTracker, never()).releaseMemory(anyLong(), anyBoolean(), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); From 9088012ca0e00fad15a5fb48b15080395720290b Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 22 Dec 2020 21:41:55 -0800 Subject: [PATCH 5/5] address comments --- .../ad/task/ADTaskCacheManager.java | 12 ++++++++---- .../ad/task/ADTaskCacheManagerTests.java | 13 +++++++------ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java index 73e45de8..8b203b68 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.java @@ -42,6 +42,7 @@ public class ADTaskCacheManager { private final Map taskCaches; private volatile Integer maxAdBatchTaskPerNode; private final MemoryTracker memoryTracker; + private final int numberSize = 8; /** * Constructor to create AD task cache manager. @@ -75,10 +76,10 @@ public synchronized void put(ADTask adTask) { } checkRunningTaskLimit(); long neededCacheSize = calculateADTaskCacheSize(adTask); - if (!memoryTracker.canAllocate(neededCacheSize)) { + if (!memoryTracker.canAllocateReserved(adTask.getDetectorId(), neededCacheSize)) { throw new LimitExceededException("No enough memory to run detector"); } - memoryTracker.consumeMemory(neededCacheSize, false, HISTORICAL_SINGLE_ENTITY_DETECTOR); + memoryTracker.consumeMemory(neededCacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR); ADBatchTaskCache taskCache = new ADBatchTaskCache(adTask); taskCache.getCacheMemorySize().set(neededCacheSize); taskCaches.put(taskId, taskCache); @@ -201,6 +202,9 @@ public boolean containsTaskOfDetector(String detectorId) { /** * Get batch task cache. If task doesn't exist in cache, will throw * {@link java.lang.IllegalArgumentException} + * We throw exception rather than return {@code Optional.empty} or null + * here, so don't need to check task existence by writing duplicate null + * checking code. All AD task exceptions will be handled in AD task manager. * * @param taskId task id * @return AD batch task cache @@ -303,7 +307,7 @@ public void clear() { * @return how many bytes will consume */ public long trainingDataMemorySize(int size) { - return 8 * size; + return numberSize * size; } /** @@ -319,7 +323,7 @@ public long trainingDataMemorySize(int size) { * @return how many bytes will consume */ public long shingleMemorySize(int shingleSize, int enabledFeatureSize) { - return (80 + 8 * enabledFeatureSize) * shingleSize; + return (80 + numberSize * enabledFeatureSize) * shingleSize; } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java index 93f144f4..1a7db42d 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManagerTests.java @@ -18,6 +18,7 @@ import static com.amazon.opendistroforelasticsearch.ad.MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -75,7 +76,7 @@ public void tearDown() throws Exception { } public void testPutTask() throws IOException { - when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.put(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -91,7 +92,7 @@ public void testPutTask() throws IOException { } public void testPutDuplicateTask() throws IOException { - when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); ADTask adTask1 = TestHelpers.randomAdTask(); adTaskCacheManager.put(adTask1); assertEquals(1, adTaskCacheManager.size()); @@ -112,7 +113,7 @@ public void testPutDuplicateTask() throws IOException { } public void testPutTaskWithMemoryExceedLimit() { - when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(false); LimitExceededException exception = expectThrows( LimitExceededException.class, () -> adTaskCacheManager.put(TestHelpers.randomAdTask()) @@ -121,7 +122,7 @@ public void testPutTaskWithMemoryExceedLimit() { } public void testThresholdModelTrained() throws IOException { - when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.put(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -134,7 +135,7 @@ public void testThresholdModelTrained() throws IOException { } public void testCancel() throws IOException { - when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); ADTask adTask = TestHelpers.randomAdTask(); adTaskCacheManager.put(adTask); assertEquals(1, adTaskCacheManager.size()); @@ -161,7 +162,7 @@ public void testRemoveTaskWhichNotExist() { } public void testExceedRunningTaskLimit() throws IOException { - when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + when(memoryTracker.canAllocateReserved(anyString(), anyLong())).thenReturn(true); adTaskCacheManager.put(TestHelpers.randomAdTask()); adTaskCacheManager.put(TestHelpers.randomAdTask()); assertEquals(2, adTaskCacheManager.size());