Skip to content

Commit

Permalink
Adding AVLTreeDigest option with a specific Random obj
Browse files Browse the repository at this point in the history
The random element in TDigest can cause some unpredictability in certain use cases.
This commit adds a second constructor to `AVLTreeDigest`, which allows a specific random obj to be used.
If this constructor is used, then the Random object will be persisted, such that the random number generation
is consistent.

Tests have been added to verify that this option does not change the behaviour of the standard `AVLTreeDigest` constructor
  • Loading branch information
cedric-hansen committed Jan 4, 2022
1 parent 15a2de9 commit 2b8891f
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 22 deletions.
8 changes: 6 additions & 2 deletions benchmark/src/main/java/com/tdunning/Benchmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class Benchmark {
private Random gen = new Random();
private double[] data;

@Param({"merge", "tree"})
@Param({"merge", "tree", "seededTree"})
public String method;

@Param({"20", "50", "100", "200", "500"})
Expand All @@ -59,8 +59,12 @@ public void setup() {
}
if (method.equals("tree")) {
td = new AVLTreeDigest(compression);
} else {
} else if (method.equals("merge")){
td = new MergingDigest(500);
} else if (method.equals("seededTree")) {
td = new AVLTreeDigest(compression, gen);
} else {
throw new IllegalArgumentException("Method " + method + " is not supported");
}

// First values are very cheap to add, we are more interested in the steady state,
Expand Down
13 changes: 12 additions & 1 deletion benchmark/src/main/java/com/tdunning/TDigestBench.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ TDigest create(double compression) {
return new AVLTreeDigest(compression);
}

@Override
TDigest create() {
return create(20);
}
},
SEEDED_AVL_TREE {
@Override
TDigest create(double compression) {
return new AVLTreeDigest(compression, new Random());
}

@Override
TDigest create() {
return create(20);
Expand Down Expand Up @@ -106,7 +117,7 @@ AbstractDistribution create(Random random) {
@Param({"100", "300"})
double compression;

@Param({"MERGE", "AVL_TREE"})
@Param({"MERGE", "AVL_TREE", "SEEDED_AVL_TREE"})
TDigestFactory tdigestFactory;

@Param({"NORMAL", "GAMMA"})
Expand Down
113 changes: 96 additions & 17 deletions core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

package com.tdunning.math.stats;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -27,12 +32,19 @@
import static com.tdunning.math.stats.IntAVLTree.NIL;

public class AVLTreeDigest extends AbstractTDigest {
final Random gen = new Random();
private final Random rng;
private final double compression;
private AVLGroupTree summary;

private long count = 0; // package private for testing

/**
* If {@link rng} should be persisted
*/
private boolean persistRandomObject;

private final static int NUM_BYTES_FOR_RANDOM_OBJECT = 104;

/**
* A histogram structure that will record a sketch of a distribution.
*
Expand All @@ -45,6 +57,26 @@ public class AVLTreeDigest extends AbstractTDigest {
public AVLTreeDigest(double compression) {
this.compression = compression;
summary = new AVLGroupTree(false);
rng = new Random();
persistRandomObject = false;
}

/**
* Creates an AVL tree digest with a random object whose state will be maintained when serialized/deserialized.
* This has uses where the stream of random numbers should be consistent across restarts.
*
* @param compression How should accuracy be traded for size? A value of N here will give quantile errors
* almost always less than 3/N with considerably smaller errors expected for extreme
* quantiles. Conversely, you should expect to track about 5 N centroids for this
* accuracy.
* @param random The random object to use for this AVLTreeDigest
*/
@SuppressWarnings("WeakerAccess")
public AVLTreeDigest(double compression, Random random) {
this.compression = compression;
summary = new AVLGroupTree(false);
rng = random;
persistRandomObject = true;
}

@Override
Expand Down Expand Up @@ -128,7 +160,7 @@ public void add(double x, int w, List<Double> data) {
// what it does is sample uniformly from all clusters that have room
if (summary.count(neighbor) + w <= k) {
n++;
if (gen.nextDouble() < 1 / n) {
if (rng.nextDouble() < 1 / n) {
closest = neighbor;
}
}
Expand Down Expand Up @@ -500,7 +532,7 @@ public double compression() {
@Override
public int byteSize() {
compress();
return 32 + summary.size() * 12;
return 36 + NUM_BYTES_FOR_RANDOM_OBJECT + summary.size() * 12;
}

/**
Expand All @@ -527,7 +559,10 @@ public void asBytes(ByteBuffer buf) {
buf.putDouble(min);
buf.putDouble(max);
buf.putDouble((float) compression());
buf.putInt(persistRandomObject ? 1 : 0);
buf.put(serializeRandomObj(rng));
buf.putInt(summary.size());

for (Centroid centroid : summary) {
buf.putDouble(centroid.mean());
}
Expand All @@ -543,6 +578,8 @@ public void asSmallBytes(ByteBuffer buf) {
buf.putDouble(min);
buf.putDouble(max);
buf.putDouble(compression());
buf.putInt(persistRandomObject ? 1 : 0);
buf.put(serializeRandomObj(rng));
buf.putInt(summary.size());

double x = 0;
Expand All @@ -567,14 +604,21 @@ public void asSmallBytes(ByteBuffer buf) {
@SuppressWarnings("WeakerAccess")
public static AVLTreeDigest fromBytes(ByteBuffer buf) {
int encoding = buf.getInt();
double min = buf.getDouble();
double max = buf.getDouble();
double compression = buf.getDouble();
boolean persistRandomObj = buf.getInt() == 0 ? false : true;
byte [] randomObjBytes = new byte[NUM_BYTES_FOR_RANDOM_OBJECT];
buf.get(randomObjBytes);
Random rand = deserializeRandomObj(randomObjBytes);
AVLTreeDigest r = persistRandomObj ?
new AVLTreeDigest(compression, rand) :
new AVLTreeDigest(compression);
r.setMinMax(min, max);
int n = buf.getInt();
double[] means = new double[n];

if (encoding == VERBOSE_ENCODING) {
double min = buf.getDouble();
double max = buf.getDouble();
double compression = buf.getDouble();
AVLTreeDigest r = new AVLTreeDigest(compression);
r.setMinMax(min, max);
int n = buf.getInt();
double[] means = new double[n];
for (int i = 0; i < n; i++) {
means[i] = buf.getDouble();
}
Expand All @@ -583,13 +627,6 @@ public static AVLTreeDigest fromBytes(ByteBuffer buf) {
}
return r;
} else if (encoding == SMALL_ENCODING) {
double min = buf.getDouble();
double max = buf.getDouble();
double compression = buf.getDouble();
AVLTreeDigest r = new AVLTreeDigest(compression);
r.setMinMax(min, max);
int n = buf.getInt();
double[] means = new double[n];
double x = 0;
for (int i = 0; i < n; i++) {
double delta = buf.getFloat();
Expand All @@ -607,4 +644,46 @@ public static AVLTreeDigest fromBytes(ByteBuffer buf) {
}
}


private byte[] serializeRandomObj(Random r) {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try {
ObjectOutputStream oos = new ObjectOutputStream(bos);
oos.writeObject(r);
oos.flush();
byte [] data = bos.toByteArray();
bos.close();
oos.close();
return data;
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException("Cannot serialize random object");
}
}

private static Random deserializeRandomObj(byte [] bytes) {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
try {
ObjectInputStream ois = new ObjectInputStream(bais);
Random r = (Random)ois.readObject();
return r;
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException("Cannot deserialize random object");
} catch (ClassNotFoundException e) {
e.printStackTrace();
throw new RuntimeException("Unable to find Random class");
}
}

@Override
public boolean persistRandomValue() {
return persistRandomObject;
}

@Override
public Random getRandomNumberGenerator() {
return rng;
}

}
37 changes: 37 additions & 0 deletions core/src/main/java/com/tdunning/math/stats/TDigest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.List;
import java.util.Random;

/**
* Adaptive histogram based on something like streaming k-means crossed with Q-digest.
Expand Down Expand Up @@ -70,6 +71,22 @@ public static TDigest createAvlTreeDigest(double compression) {
return new AVLTreeDigest(compression);
}

/**
* Creates an AVLTreeDigest with a specific random seed.
*
* This behaves very similarly to the standard AVLTreeDigest, but with the added ability to start with a specific seed.
* This has uses with allowing historic tree values to remain unchanged
*
* @param compression The compression parameter. 100 is a common value for normal uses. 1000 is extremely large.
* The number of centroids retained will be a smallish (usually less than 10) multiple of this number.
* @param random The random object to user for this TDigest
* @return the AvlTreeDigest
*/
@SuppressWarnings("WeakerAccess")
public static TDigest createAvlTreeDigestWithSeed(double compression, Random random) {
return new AVLTreeDigest(compression, random);
}

/**
* Creates a TDigest of whichever type is the currently recommended type. MergingDigest is generally the best
* known implementation right now.
Expand Down Expand Up @@ -237,4 +254,24 @@ void setMinMax(double min, double max) {
this.min = min;
this.max = max;
}

/**
* In certain TDigest implementations, there are cases where a random object might be
* serialized. This flag indicates if the TDigest is persisting the random object
*
* @return true if the TDigest has a random object that will be serialized
*/
public boolean persistRandomValue() {
return false;
}

/**
* In certain TDigest implementations, there are cases where a random object might play a significant
* role. This method returns the Random instance being used to generate these numbers
*
* @return the random instance in the TDigest if one exists, null otherwise.
*/
public Random getRandomNumberGenerator() {
return null;
}
}
42 changes: 42 additions & 0 deletions core/src/test/java/com/tdunning/math/stats/AVLTreeDigestTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package com.tdunning.math.stats;

import org.junit.BeforeClass;
import org.junit.Test;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.Random;

public class AVLTreeDigestTest extends TDigestTest {
@BeforeClass
Expand Down Expand Up @@ -56,4 +59,43 @@ public void testSingletonInACrowd() {
public void singleSingleRange() {
// disabled for AVLTreeDigest for now
}

@Test
public void testRandomNumberGenerator() {
//The AVLTreeDigest constructor with a specific random Obj
//should generate predictable random numbers.
//Testing that the random obj is serialized/deserialized properly is done elsewhere.
//This test simply confirms that the `randomness` is consistent if random objects with specific seeds are used

int compression = 100;
int randomSeed = 42;
AVLTreeDigest seededTree1 = new AVLTreeDigest(compression, new Random(randomSeed));
AVLTreeDigest seededTree2 = new AVLTreeDigest(compression, new Random(randomSeed));
AVLTreeDigest unseededTree = new AVLTreeDigest(compression);

Random rng = new Random();

for (int i = 0; i < 100_100; i++) {
int value = rng.nextInt(100_100);
seededTree1.add(value);
seededTree2.add(value);
unseededTree.add(value);
}

//Check that the two seeded trees resulted in the same tree
//However, we cannot guarantee the unseeded tree is shares no similarity with the seeded ones
assertEquals(seededTree1.quantile(0.5), seededTree2.quantile(0.5), 0.0);
Iterator<Centroid> cx = seededTree2.centroids().iterator();
for (Centroid c1 : seededTree1.centroids()) {
Centroid c2 = cx.next();
assertEquals(c1.count(), c2.count());
assertEquals(c1.mean(), c2.mean(), 1e-10);
}

Long t1Val = seededTree1.getRandomNumberGenerator().nextLong();
Long t2Val = seededTree2.getRandomNumberGenerator().nextLong();
Long unseededVal = unseededTree.getRandomNumberGenerator().nextLong();
assertEquals(t1Val, t2Val);
assertNotSame(t1Val, unseededVal);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ public void testMerges() throws FileNotFoundException {
for (double compression : new double[]{50, 100, 200, 400}) {
MergingDigest digest1 = new MergingDigest(compression);
AVLTreeDigest digest2 = new AVLTreeDigest(compression);
AVLTreeDigest digest3 = new AVLTreeDigest(compression, new Random());
List<Double> data = new ArrayList<>();
Random gen = new Random();
for (int i = 0; i < n; i++) {
double x = gen.nextDouble();
data.add(x);
digest1.add(x);
digest2.add(x);
digest3.add(x);
}
Collections.sort(data);
List<Double> counts = new ArrayList<>();
Expand All @@ -73,6 +75,7 @@ public void testMerges() throws FileNotFoundException {
}
sizes.printf("%s, %d, %d, %.0f, %d\n", "merge", counts.size(), digest1.centroids().size(), compression, n);
sizes.printf("%s, %d, %d, %.0f, %d\n", "tree", counts.size(), digest2.centroids().size(), compression, n);
sizes.printf("%s, %d, %d, %.0f, %d\n", "tree with seed", counts.size(), digest3.centroids().size(), compression, n);
sizes.printf("%s, %d, %d, %.0f, %d\n", "ideal", counts.size(), counts.size(), compression, n);
soFar = 0;
for (Double count : counts) {
Expand All @@ -92,6 +95,12 @@ public void testMerges() throws FileNotFoundException {
soFar += c.count();
}
assertEquals(n, soFar, 0);
soFar = 0;
for (Centroid c : digest3.centroids()) {
out.printf("%s, %.0f, %d, %.3f, %d\n", "tree", compression, n, (soFar + c.count() / 2) / n, c.count());
soFar += c.count();
}
assertEquals(n, soFar, 0);
}
}
}
Expand Down
Loading

0 comments on commit 2b8891f

Please sign in to comment.