Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Extract a new class for entity frequenty tracking (#389)
Browse files Browse the repository at this point in the history
* Extrac a new class for entity frequenty tracking

HC detectors use a 1-pass algorithm for estimating heavy hitters in a stream. Our method maintains a time-decayed count for each entity, which allows us to compare the frequencies of entities from different detectors in the stream. To reuse the code in historical detectors, I created a new class PriorityTracker and moved all related logic there.  When an entity is hit, the caller can call PriorityTracker.updatePriority to update the entity's priority.  The callers can find the most frequently occurring entities in the stream using PriorityTracker.getTopNEntities.

This PR also adds tests for NodeStateManager.

Testing done:
1. manually tested basic workflow of HC detectors still works.
2. added new tests for PriorityTracker.
  • Loading branch information
kaituo authored Feb 25, 2021
1 parent 6a3e1dd commit 0515a35
Show file tree
Hide file tree
Showing 9 changed files with 532 additions and 170 deletions.
1 change: 0 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ List<String> jacocoExclusions = [
'com.amazon.opendistroforelasticsearch.ad.transport.SearchAnomalyDetectorInfoTransportAction*',

// TODO: hc caused coverage to drop
'com.amazon.opendistroforelasticsearch.ad.NodeStateManager',
'com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices',
'com.amazon.opendistroforelasticsearch.ad.transport.handler.MultiEntityResultHandler',
'com.amazon.opendistroforelasticsearch.ad.util.ThrowingSupplierWrapper',
Expand Down
Binary file added docs/entity-priority.pdf
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,21 @@
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.Comparator;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.Collectors;

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import com.amazon.opendistroforelasticsearch.ad.ExpiringState;
import com.amazon.opendistroforelasticsearch.ad.MaintenanceState;
import com.amazon.opendistroforelasticsearch.ad.MemoryTracker;
import com.amazon.opendistroforelasticsearch.ad.MemoryTracker.Origin;
import com.amazon.opendistroforelasticsearch.ad.annotation.Generated;
import com.amazon.opendistroforelasticsearch.ad.ml.CheckpointDao;
import com.amazon.opendistroforelasticsearch.ad.ml.EntityModel;
import com.amazon.opendistroforelasticsearch.ad.ml.ModelState;
Expand All @@ -64,90 +59,22 @@
public class CacheBuffer implements ExpiringState, MaintenanceState {
private static final Logger LOG = LogManager.getLogger(CacheBuffer.class);

static class PriorityNode {
private String key;
private float priority;

PriorityNode(String key, float priority) {
this.priority = priority;
this.key = key;
}

@Generated
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
if (obj instanceof PriorityNode) {
PriorityNode other = (PriorityNode) obj;

EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(key, other.key);
return equalsBuilder.isEquals();
}
return false;
}

@Generated
@Override
public int hashCode() {
return new HashCodeBuilder().append(key).toHashCode();
}

@Generated
@Override
public String toString() {
ToStringBuilder builder = new ToStringBuilder(this);
builder.append("key", key);
builder.append("priority", priority);
return builder.toString();
}
}

static class PriorityNodeComparator implements Comparator<PriorityNode> {

@Override
public int compare(PriorityNode priority, PriorityNode priority2) {
int equality = priority.key.compareTo(priority2.key);
if (equality == 0) {
// this is consistent with PriorityNode's equals method
return 0;
}
// if not equal, first check priority
int cmp = Float.compare(priority.priority, priority2.priority);
if (cmp == 0) {
// if priority is equal, use lexicographical order of key
cmp = equality;
}
return cmp;
}
}
// max entities to track per detector
private final int MAX_TRACKING_ENTITIES = 1000000;

private final int minimumCapacity;
// key -> Priority node
private final ConcurrentHashMap<String, PriorityNode> key2Priority;
private final ConcurrentSkipListSet<PriorityNode> priorityList;
// key -> value
private final ConcurrentHashMap<String, ModelState<EntityModel>> items;
// when detector is created.  Can be reset.  Unit: seconds
private long landmarkSecs;
// length of seconds in one interval.  Used to compute elapsed periods
// since the detector has been enabled.
private long intervalSecs;
// memory consumption per entity
private final long memoryConsumptionPerEntity;
private final MemoryTracker memoryTracker;
private final Clock clock;
private final CheckpointDao checkpointDao;
private final Duration modelTtl;
private final String detectorId;
private Instant lastUsedTime;
private final int DECAY_CONSTANT;
private final long reservedBytes;
private final PriorityTracker priorityTracker;
private final Clock clock;

public CacheBuffer(
int minimumCapacity,
Expand All @@ -163,20 +90,20 @@ public CacheBuffer(
throw new IllegalArgumentException("minimum capacity should be larger than 0");
}
this.minimumCapacity = minimumCapacity;
this.key2Priority = new ConcurrentHashMap<>();
this.priorityList = new ConcurrentSkipListSet<>(new PriorityNodeComparator());

this.items = new ConcurrentHashMap<>();
this.landmarkSecs = clock.instant().getEpochSecond();
this.intervalSecs = intervalSecs;

this.memoryConsumptionPerEntity = memoryConsumptionPerEntity;
this.memoryTracker = memoryTracker;
this.clock = clock;

this.checkpointDao = checkpointDao;
this.modelTtl = modelTtl;
this.detectorId = detectorId;
this.lastUsedTime = clock.instant();
this.DECAY_CONSTANT = 3;

this.reservedBytes = memoryConsumptionPerEntity * minimumCapacity;
this.clock = clock;
this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES);
}

/**
Expand All @@ -186,50 +113,13 @@ public CacheBuffer(
* @param entityModelId model Id
*/
private void update(String entityModelId) {
PriorityNode node = key2Priority.computeIfAbsent(entityModelId, k -> new PriorityNode(entityModelId, 0f));
// reposition this node
this.priorityList.remove(node);
node.priority = getUpdatedPriority(node.priority);
this.priorityList.add(node);
priorityTracker.updatePriority(entityModelId);

Instant now = clock.instant();
items.get(entityModelId).setLastUsedTime(now);
lastUsedTime = now;
}

public float getUpdatedPriority(float oldPriority) {
long increment = computeWeightedCountIncrement();
// if overflowed, we take the short cut from now on
oldPriority += Math.log(1 + Math.exp(increment - oldPriority));
// if overflow happens, using \log(g(t_k-L)) instead.
if (oldPriority == Float.POSITIVE_INFINITY) {
oldPriority = increment;
}
return oldPriority;
}

/**
* Compute periods relative to landmark and the weighted count increment using 0.125n.
* Multiply by 0.125 is implemented using right shift for efficiency.
* @return the weighted count increment used in the priority update step.
*/
private long computeWeightedCountIncrement() {
long periods = (clock.instant().getEpochSecond() - landmarkSecs) / intervalSecs;
return periods >> DECAY_CONSTANT;
}

/**
* Compute the weighted total count by considering landmark
* \log(C)=\log(\sum_{i=1}^{n} (g(t_i-L)/g(t-L)))=\log(\sum_{i=1}^{n} (g(t_i-L))-\log(g(t-L))
* @return the minimum priority entity's ID and priority
*/
public Entry<String, Float> getMinimumPriority() {
PriorityNode smallest = priorityList.first();
long periods = (clock.instant().getEpochSecond() - landmarkSecs) / intervalSecs;
float detectorWeight = periods >> DECAY_CONSTANT;
return new SimpleImmutableEntry<>(smallest.key, smallest.priority - detectorWeight);
}

/**
* Insert the model state associated with a model Id to the cache
* @param entityModelId the model Id
Expand Down Expand Up @@ -257,9 +147,7 @@ public void put(String entityModelId, ModelState<EntityModel> value) {
private void put(String entityModelId, ModelState<EntityModel> value, float priority) {
ModelState<EntityModel> contentNode = items.get(entityModelId);
if (contentNode == null) {
PriorityNode node = new PriorityNode(entityModelId, priority);
key2Priority.put(entityModelId, node);
priorityList.add(node);
priorityTracker.addPriority(entityModelId, priority);
items.put(entityModelId, value);
Instant now = clock.instant();
value.setLastUsedTime(now);
Expand Down Expand Up @@ -319,9 +207,9 @@ public ModelState<EntityModel> remove() {
// The removed one loses references and soon GC will collect it.
// We have memory tracking correction to fix incorrect memory usage record.
// put: not a problem as it is unlikely we are removing and putting the same thing
PriorityNode smallest = priorityList.first();
if (smallest != null) {
return remove(smallest.key);
Optional<String> key = priorityTracker.getMinimumPriorityEntityId();
if (key.isPresent()) {
return remove(key.get());
}
return null;
}
Expand All @@ -334,12 +222,11 @@ public ModelState<EntityModel> remove() {
* is no associated ModelState for the key
*/
public ModelState<EntityModel> remove(String keyToRemove) {
// remove if the key matches; priority does not matter
priorityList.remove(new PriorityNode(keyToRemove, 0));
priorityTracker.removePriority(keyToRemove);

// if shared cache is empty, we are using reserved memory
boolean reserved = sharedCacheEmpty();

key2Priority.remove(keyToRemove);
ModelState<EntityModel> valueRemoved = items.remove(keyToRemove);

if (valueRemoved != null) {
Expand Down Expand Up @@ -382,15 +269,17 @@ public long getMemoryConsumptionPerEntity() {

/**
*
* If the cache is not full, check if some other items can replace internal entities.
* If the cache is not full, check if some other items can replace internal entities
* within the same detector.
*
* @param priority another entity's priority
* @return whether one entity can be replaced by another entity with a certain priority
*/
public boolean canReplace(float priority) {
public boolean canReplaceWithinDetector(float priority) {
if (items.isEmpty()) {
return false;
}
Entry<String, Float> minPriorityItem = getMinimumPriority();
Entry<String, Float> minPriorityItem = priorityTracker.getMinimumPriority();
return minPriorityItem != null && priority > minPriorityItem.getValue();
}

Expand All @@ -415,15 +304,6 @@ public void maintenance() {
ModelState<EntityModel> modelState = entry.getValue();
Instant now = clock.instant();

// we can have ConcurrentModificationException when serializing
// and updating rcf model at the same time. To prevent this,
// we need to have a deep copy of models or have a lock. Both
// options are costly.
// As we are gonna retry serializing either when the entity is
// evicted out of cache or during the next maintenance period,
// don't do anything when the exception happens.
checkpointDao.write(modelState, entityModelId);

if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) {
// race conditions can happen between the put and one of the following operations:
// remove: not a problem as all of the data structures are concurrent.
Expand All @@ -433,7 +313,17 @@ public void maintenance() {
// We have memory tracking correction to fix incorrect memory usage record.
// put: not a problem as we are unlikely to maintain an entry that's not
// already in the cache
// remove method saves checkpoint as well
remove(entityModelId);
} else {
// we can have ConcurrentModificationException when serializing
// and updating rcf model at the same time. To prevent this,
// we need to have a deep copy of models or have a lock. Both
// options are costly.
// As we are gonna retry serializing either when the entity is
// evicted out of cache or during the next maintenance period,
// don't do anything when the exception happens.
checkpointDao.write(modelState, entityModelId);
}
} catch (Exception e) {
LOG.warn("Failed to finish maintenance for model id " + entityModelId, e);
Expand Down Expand Up @@ -471,14 +361,6 @@ public long getLastUsedTime(String entityModelId) {
return -1;
}

/**
*
* @return Get the model of highest priority entity
*/
public Optional<String> getHighestPriorityEntityModelId() {
return Optional.of(priorityList).map(list -> list.last()).map(node -> node.key);
}

/**
*
* @param entityModelId entity Id
Expand All @@ -501,8 +383,7 @@ public void clear() {
memoryTracker.releaseMemory(getBytesInSharedCache(), false, Origin.MULTI_ENTITY_DETECTOR);
}
items.clear();
key2Priority.clear();
priorityList.clear();
priorityTracker.clearPriority();
}

/**
Expand Down Expand Up @@ -561,4 +442,8 @@ public String getDetectorId() {
public List<ModelState<?>> getAllModels() {
return items.values().stream().collect(Collectors.toList());
}

public PriorityTracker getPriorityTracker() {
return priorityTracker;
}
}
Loading

0 comments on commit 0515a35

Please sign in to comment.