Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Lucene's PackedInt dependency from Cuckoo filter #74736

Merged
merged 8 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.backwards;

import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.List;

/**
* Test that index enough data to trigger the creation of Cuckoo filters.
*/
public class RareTermsIT extends ESRestTestCase {

private static final String index = "idx";

private int indexDocs(int numDocs, int id) throws Exception {
final Request request = new Request("POST", "/_bulk");
final StringBuilder builder = new StringBuilder();
for (int i = 0; i < numDocs; ++i) {
builder.append("{ \"index\" : { \"_index\" : \"" + index + "\", \"_id\": \"" + id++ + "\" } }\n");
builder.append("{\"str_value\" : \"s" + i + "\"}\n");
}
request.setJsonEntity(builder.toString());
assertOK(client().performRequest(request));
return id;
}

public void testSingleValuedString() throws Exception {
final Settings.Builder settings = Settings.builder()
.put(IndexMetadata.INDEX_NUMBER_OF_SHARDS_SETTING.getKey(), 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its worth leaving a comment about why you chose 2 here and 12000 - 17000 and 5 below. I know they were fairly carefully chosen to get us to trigger the cuckoo filter. Its probably worth writing that out so future me won't blindly screw this up.

Oh!!!!!!!!!! Can we add to the profile results whether or not we used the cuckoo filter? Like a count of the number of times we used it? That way we can assert in this test that we hit it. Without that assertion this test could bit rot fairly easily.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about the choice of the number of docs. Not sure about adding such implementations details into the profiler. I guess it will be nice if the agg reported the precision of the result if possible, then you will know if the egg never went into cuckoo filter if the result is fully precise.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about adding such implementations details into the profiler. I guess it will be nice if the agg reported the precision of the result if possible, then you will know if the egg never went into cuckoo filter if the result is fully precise.

That's kind of the point of the profiler - if you've ever wondered "did it do x or y" - you can just make it tell you. But, yeah, maybe it'd be nice to add a precise_results boolean or something then there isn't really a need for the profiler here at all. Either of those can wait for a follow up - but i think something to make sure the test can assert that it hit the non-precise case is important soon-ish.

.put(IndexMetadata.INDEX_NUMBER_OF_REPLICAS_SETTING.getKey(), 0);
createIndex(index, settings.build());
// We want to trigger the usage oif cuckoo filters that happen only when there are
// more than 10k distinct values in one shard.
final int numDocs = randomIntBetween(12000, 17000);
int id = 1;
// Index every value 5 times
for (int i = 0; i < 5; i++) {
id = indexDocs(numDocs, id);
refreshAllIndices();
}
// There are no rare terms that only appear in one document
assertNumRareTerms(1, 0);
// All terms have a cardinality lower than 10
assertNumRareTerms(10, numDocs);
}

private void assertNumRareTerms(int maxDocs, int rareTerms) throws IOException {
final Request request = new Request("POST", index + "/_search");
request.setJsonEntity(
"{\"aggs\" : {\"rareTerms\" : {\"rare_terms\" : {\"field\" : \"str_value.keyword\", \"max_doc_count\" : " + maxDocs + "}}}}"
);
final Response response = client().performRequest(request);
assertOK(response);
final Object o = XContentMapValues.extractValue("aggregations.rareTerms.buckets", responseAsMap(response));
assertThat(o, Matchers.instanceOf(List.class));
assertThat(((List<?>) o).size(), Matchers.equalTo(rareTerms));
}
}
264 changes: 240 additions & 24 deletions server/src/main/java/org/elasticsearch/common/util/CuckooFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.packed.PackedInts;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -56,7 +58,7 @@ public class CuckooFilter implements Writeable {
private static final int MAX_EVICTIONS = 500;
static final int EMPTY = 0;

private final PackedInts.Mutable data;
private final PackedArray data;
private final int numBuckets;
private final int bitsPerEntry;
private final int fingerprintMask;
Expand All @@ -82,7 +84,7 @@ public class CuckooFilter implements Writeable {
throw new IllegalArgumentException("Attempted to create [" + numBuckets * entriesPerBucket
+ "] entries which is > Integer.MAX_VALUE");
}
this.data = PackedInts.getMutable(numBuckets * entriesPerBucket, bitsPerEntry, PackedInts.COMPACT);
this.data = new PackedArray(numBuckets * entriesPerBucket, bitsPerEntry);

// puts the bits at the right side of the mask, e.g. `0000000000001111` for bitsPerEntry = 4
this.fingerprintMask = (0x80000000 >> (bitsPerEntry - 1)) >>> (Integer.SIZE - bitsPerEntry);
Expand All @@ -106,7 +108,7 @@ public class CuckooFilter implements Writeable {
+ "] entries which is > Integer.MAX_VALUE");
}
// TODO this is probably super slow, but just used for testing atm
this.data = PackedInts.getMutable(numBuckets * entriesPerBucket, bitsPerEntry, PackedInts.COMPACT);
this.data = new PackedArray(numBuckets * entriesPerBucket, bitsPerEntry);
for (int i = 0; i < other.data.size(); i++) {
data.set(i, other.data.get(i));
}
Expand All @@ -122,17 +124,26 @@ public class CuckooFilter implements Writeable {

this.fingerprintMask = (0x80000000 >> (bitsPerEntry - 1)) >>> (Integer.SIZE - bitsPerEntry);

data = (PackedInts.Mutable) PackedInts.getReader(new DataInput() {
@Override
public byte readByte() throws IOException {
return in.readByte();
}

@Override
public void readBytes(byte[] b, int offset, int len) throws IOException {
in.readBytes(b, offset, len);
if (in.getVersion().before(Version.V_8_0_0)) {
final PackedInts.Reader reader = PackedInts.getReader(new DataInput() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PackedInts.getReader() doesn't exist in lucene 9, so is the plan to remove this further in a follow-up or do we need to rework it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be removed once we backport the change to 7.x

@Override
public byte readByte() throws IOException {
return in.readByte();
}

@Override
public void readBytes(byte[] b, int offset, int len) throws IOException {
in.readBytes(b, offset, len);
}
});
// This is probably slow but it should only happen if we have a mixed clusters (e.g during upgrade).
data = new PackedArray(numBuckets * entriesPerBucket, bitsPerEntry);
for (int i = 0; i < reader.size(); i++) {
data.set(i, reader.get(i));
}
});
} else {
data = new PackedArray(in);
}
}

@Override
Expand All @@ -142,18 +153,26 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(entriesPerBucket);
out.writeVInt(count);
out.writeVInt(evictedFingerprint);

data.save(new DataOutput() {
@Override
public void writeByte(byte b) throws IOException {
out.writeByte(b);
}

@Override
public void writeBytes(byte[] b, int offset, int length) throws IOException {
out.writeBytes(b, offset, length);
if (out.getVersion().before(Version.V_8_0_0)) {
// This is probably slow but it should only happen if we have a mixed clusters (e.g during upgrade).
PackedInts.Mutable mutable = PackedInts.getMutable(numBuckets * entriesPerBucket, bitsPerEntry, PackedInts.COMPACT);
for (int i = 0; i < data.size(); i++) {
mutable.set(i, data.get(i));
}
});
mutable.save(new DataOutput() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, save() is no longer in the lucene 9 API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be removed once we backport the change to 7.x

@Override
public void writeByte(byte b) throws IOException {
out.writeByte(b);
}

@Override
public void writeBytes(byte[] b, int offset, int length) throws IOException {
out.writeBytes(b, offset, length);
}
});
} else {
data.save(out);
}
}

/**
Expand Down Expand Up @@ -507,4 +526,201 @@ public boolean equals(Object other) {
&& Objects.equals(this.count, that.count)
&& Objects.equals(this.evictedFingerprint, that.evictedFingerprint);
}

/**
* Forked from Lucene's Packed64 class. The main difference is that this version
* can be read from / write to Elasticsearch streams.
*/
private static class PackedArray {
private static final int BLOCK_SIZE = 64; // 32 = int, 64 = long
private static final int BLOCK_BITS = 6; // The #bits representing BLOCK_SIZE
private static final int MOD_MASK = BLOCK_SIZE - 1; // x % BLOCK_SIZE

/**
* Values are stores contiguously in the blocks array.
*/
private final long[] blocks;
/**
* A right-aligned mask of width BitsPerValue used by {@link #get(int)}.
*/
private final long maskRight;
/**
* Optimization: Saves one lookup in {@link #get(int)}.
*/
private final int bpvMinusBlockSize;

private final int bitsPerValue;
private final int valueCount;

PackedArray(int valueCount, int bitsPerValue) {
this.bitsPerValue = bitsPerValue;
this.valueCount = valueCount;
final int longCount = PackedInts.Format.PACKED.longCount(PackedInts.VERSION_CURRENT, valueCount, bitsPerValue);
this.blocks = new long[longCount];
maskRight = ~0L << (BLOCK_SIZE-bitsPerValue) >>> (BLOCK_SIZE-bitsPerValue);
bpvMinusBlockSize = bitsPerValue - BLOCK_SIZE;
}

PackedArray(StreamInput in)
throws IOException {
this.bitsPerValue = in.readVInt();
this.valueCount = in.readVInt();
this.blocks = in.readLongArray();
maskRight = ~0L << (BLOCK_SIZE - bitsPerValue) >>> (BLOCK_SIZE - bitsPerValue);
bpvMinusBlockSize = bitsPerValue - BLOCK_SIZE;
}

public void save(StreamOutput out) throws IOException {
out.writeVInt(bitsPerValue);
out.writeVInt(valueCount);
out.writeLongArray(blocks);
}

public int size() {
return valueCount;
}

public long get(final int index) {
// The abstract index in a bit stream
final long majorBitPos = (long)index * bitsPerValue;
// The index in the backing long-array
final int elementPos = (int)(majorBitPos >>> BLOCK_BITS);
// The number of value-bits in the second long
final long endBits = (majorBitPos & MOD_MASK) + bpvMinusBlockSize;

if (endBits <= 0) { // Single block
return (blocks[elementPos] >>> -endBits) & maskRight;
}
// Two blocks
return ((blocks[elementPos] << endBits)
| (blocks[elementPos+1] >>> (BLOCK_SIZE - endBits)))
& maskRight;
}

public int get(int index, long[] arr, int off, int len) {
assert len > 0 : "len must be > 0 (got " + len + ")";
assert index >= 0 && index < valueCount;
len = Math.min(len, valueCount - index);
assert off + len <= arr.length;

final int originalIndex = index;
final PackedInts.Decoder decoder = PackedInts.getDecoder(PackedInts.Format.PACKED, PackedInts.VERSION_CURRENT, bitsPerValue);
// go to the next block where the value does not span across two blocks
final int offsetInBlocks = index % decoder.longValueCount();
if (offsetInBlocks != 0) {
for (int i = offsetInBlocks; i < decoder.longValueCount() && len > 0; ++i) {
arr[off++] = get(index++);
--len;
}
if (len == 0) {
return index - originalIndex;
}
}

// bulk get
assert index % decoder.longValueCount() == 0;
int blockIndex = (int) (((long) index * bitsPerValue) >>> BLOCK_BITS);
assert (((long)index * bitsPerValue) & MOD_MASK) == 0;
final int iterations = len / decoder.longValueCount();
decoder.decode(blocks, blockIndex, arr, off, iterations);
final int gotValues = iterations * decoder.longValueCount();
index += gotValues;
len -= gotValues;
assert len >= 0;

if (index > originalIndex) {
// stay at the block boundary
return index - originalIndex;
} else {
// no progress so far => already at a block boundary but no full block to get
assert index == originalIndex;
assert len > 0 : "len must be > 0 (got " + len + ")";
assert index >= 0 && index < size();
assert off + len <= arr.length;

final int gets = Math.min(size() - index, len);
for (int i = index, o = off, end = index + gets; i < end; ++i, ++o) {
arr[o] = get(i);
}
return gets;
}
}

public void set(final int index, final long value) {
// The abstract index in a contiguous bit stream
final long majorBitPos = (long)index * bitsPerValue;
// The index in the backing long-array
final int elementPos = (int)(majorBitPos >>> BLOCK_BITS); // / BLOCK_SIZE
// The number of value-bits in the second long
final long endBits = (majorBitPos & MOD_MASK) + bpvMinusBlockSize;

if (endBits <= 0) { // Single block
blocks[elementPos] = blocks[elementPos] & ~(maskRight << -endBits)
| (value << -endBits);
return;
}
// Two blocks
blocks[elementPos] = blocks[elementPos] & ~(maskRight >>> endBits)
| (value >>> endBits);
blocks[elementPos+1] = blocks[elementPos+1] & (~0L >>> endBits)
| (value << (BLOCK_SIZE - endBits));
}

public int set(int index, long[] arr, int off, int len) {
assert len > 0 : "len must be > 0 (got " + len + ")";
assert index >= 0 && index < valueCount;
len = Math.min(len, valueCount - index);
assert off + len <= arr.length;

final int originalIndex = index;
final PackedInts.Encoder encoder = PackedInts.getEncoder(PackedInts.Format.PACKED, PackedInts.VERSION_CURRENT, bitsPerValue);

// go to the next block where the value does not span across two blocks
final int offsetInBlocks = index % encoder.longValueCount();
if (offsetInBlocks != 0) {
for (int i = offsetInBlocks; i < encoder.longValueCount() && len > 0; ++i) {
set(index++, arr[off++]);
--len;
}
if (len == 0) {
return index - originalIndex;
}
}

// bulk set
assert index % encoder.longValueCount() == 0;
int blockIndex = (int) (((long) index * bitsPerValue) >>> BLOCK_BITS);
assert (((long)index * bitsPerValue) & MOD_MASK) == 0;
final int iterations = len / encoder.longValueCount();
encoder.encode(arr, off, blocks, blockIndex, iterations);
final int setValues = iterations * encoder.longValueCount();
index += setValues;
len -= setValues;
assert len >= 0;

if (index > originalIndex) {
// stay at the block boundary
return index - originalIndex;
} else {
// no progress so far => already at a block boundary but no full block to get
assert index == originalIndex;
len = Math.min(len, size() - index);
assert off + len <= arr.length;

for (int i = index, o = off, end = index + len; i < end; ++i, ++o) {
set(i, arr[o]);
}
return len;
}
}

public long ramBytesUsed() {
return RamUsageEstimator.alignObjectSize(
RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ 3 * Integer.BYTES // bpvMinusBlockSize,valueCount,bitsPerValue
+ Long.BYTES // maskRight
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF) // blocks ref
+ RamUsageEstimator.sizeOf(blocks);
}
}
}