Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change that breaks BWC in CuckooFilter #73585

Closed
wants to merge 16 commits into from
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.backwards;

import org.elasticsearch.client.Request;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.rest.ESRestTestCase;

public class RareTermsIT extends ESRestTestCase {

private int indexDocs(int numDocs, int id) throws Exception {
final Request request = new Request("POST", "/_bulk");
final StringBuilder builder = new StringBuilder();
for (int i = 0; i < numDocs; ++i) {
builder.append("{ \"index\" : { \"_index\" : \"idx\", \"_id\": \"" + id++ + "\" } }\n");
builder.append("{\"str_value\" : \"s" + i + "\"}\n");
}
request.setJsonEntity(builder.toString());
assertOK(client().performRequest(request));
return id;
}

public void testSingleValuedString() throws Exception {
final Settings.Builder settings = Settings.builder()
.put(IndexMetadata.INDEX_NUMBER_OF_SHARDS_SETTING.getKey(), 2)
.put(IndexMetadata.INDEX_NUMBER_OF_REPLICAS_SETTING.getKey(), 0);
final String index = "idx";
createIndex(index, settings.build());

final int numDocs = 15000;
int id = 1;
for (int i = 0; i < 5; i++) {
id = indexDocs(numDocs, id);
refreshAllIndices();
}

final Request request = new Request("POST", "idx/_search");
request.setJsonEntity("{\"size\": 0,\"aggs\":{\"rareTerms\":{\"rare_terms\" : {\"field\": \"str_value.keyword\"}}}}");
assertOK(client().performRequest(request));
}
}
269 changes: 245 additions & 24 deletions server/src/main/java/org/elasticsearch/common/util/CuckooFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.packed.PackedInts;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -56,7 +58,7 @@ public class CuckooFilter implements Writeable {
private static final int MAX_EVICTIONS = 500;
static final int EMPTY = 0;

private final PackedInts.Mutable data;
private final Mutable data;
private final int numBuckets;
private final int bitsPerEntry;
private final int fingerprintMask;
Expand All @@ -82,7 +84,7 @@ public class CuckooFilter implements Writeable {
throw new IllegalArgumentException("Attempted to create [" + numBuckets * entriesPerBucket
+ "] entries which is > Integer.MAX_VALUE");
}
this.data = PackedInts.getMutable(numBuckets * entriesPerBucket, bitsPerEntry, PackedInts.COMPACT);
this.data = new Mutable(numBuckets * entriesPerBucket, bitsPerEntry);

// puts the bits at the right side of the mask, e.g. `0000000000001111` for bitsPerEntry = 4
this.fingerprintMask = (0x80000000 >> (bitsPerEntry - 1)) >>> (Integer.SIZE - bitsPerEntry);
Expand All @@ -106,7 +108,7 @@ public class CuckooFilter implements Writeable {
+ "] entries which is > Integer.MAX_VALUE");
}
// TODO this is probably super slow, but just used for testing atm
this.data = PackedInts.getMutable(numBuckets * entriesPerBucket, bitsPerEntry, PackedInts.COMPACT);
this.data = new Mutable(numBuckets * entriesPerBucket, bitsPerEntry);
for (int i = 0; i < other.data.size(); i++) {
data.set(i, other.data.get(i));
}
Expand All @@ -122,17 +124,26 @@ public class CuckooFilter implements Writeable {

this.fingerprintMask = (0x80000000 >> (bitsPerEntry - 1)) >>> (Integer.SIZE - bitsPerEntry);

data = (PackedInts.Mutable) PackedInts.getReader(new DataInput() {
@Override
public byte readByte() throws IOException {
return in.readByte();
}

@Override
public void readBytes(byte[] b, int offset, int len) throws IOException {
in.readBytes(b, offset, len);
if (in.getVersion().before(Version.V_8_0_0)) {
final PackedInts.Reader reader = PackedInts.getReader(new DataInput() {
@Override
public byte readByte() throws IOException {
return in.readByte();
}

@Override
public void readBytes(byte[] b, int offset, int len) throws IOException {
in.readBytes(b, offset, len);
}
});
// This is pretty inefficient but it should only happen if we have a mixed clusters (e.g during upgrade).
data = new Mutable(numBuckets * entriesPerBucket, bitsPerEntry);
for (int i = 0; i < count; i++) {
data.set(i, reader.get(i));
}
});
} else {
data = new Mutable(in);
}
}

@Override
Expand All @@ -142,18 +153,26 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(entriesPerBucket);
out.writeVInt(count);
out.writeVInt(evictedFingerprint);

data.save(new DataOutput() {
@Override
public void writeByte(byte b) throws IOException {
out.writeByte(b);
}

@Override
public void writeBytes(byte[] b, int offset, int length) throws IOException {
out.writeBytes(b, offset, length);
if (out.getVersion().before(Version.V_8_0_0)) {
// This is pretty inefficient but it should only happen if we have a mixed clusters (e.g during upgrade).
PackedInts.Mutable mutable = PackedInts.getMutable(numBuckets * entriesPerBucket, bitsPerEntry, PackedInts.COMPACT);
for (int i = 0; i < count; i++) {
mutable.set(i, data.get(i));
}
});
mutable.save(new DataOutput() {
@Override
public void writeByte(byte b) throws IOException {
out.writeByte(b);
}

@Override
public void writeBytes(byte[] b, int offset, int length) throws IOException {
out.writeBytes(b, offset, length);
}
});
} else {
data.save(out);
}
}

/**
Expand Down Expand Up @@ -507,4 +526,206 @@ public boolean equals(Object other) {
&& Objects.equals(this.count, that.count)
&& Objects.equals(this.evictedFingerprint, that.evictedFingerprint);
}

// version if Lucene's Packed64 class that can be read / write to Elasticsearch streams.
private static class Mutable {
private static final int BLOCK_SIZE = 64; // 32 = int, 64 = long
private static final int BLOCK_BITS = 6; // The #bits representing BLOCK_SIZE
private static final int MOD_MASK = BLOCK_SIZE - 1; // x % BLOCK_SIZE

/**
* Values are stores contiguously in the blocks array.
*/
private final long[] blocks;
/**
* A right-aligned mask of width BitsPerValue used by {@link #get(int)}.
*/
private final long maskRight;
/**
* Optimization: Saves one lookup in {@link #get(int)}.
*/
private final int bpvMinusBlockSize;

private final int bitsPerValue;
private final int valueCount;

Mutable(int valueCount, int bitsPerValue) {
this.bitsPerValue = bitsPerValue;
this.valueCount = valueCount;
final PackedInts.Format format = PackedInts.Format.PACKED;
final int longCount = format.longCount(PackedInts.VERSION_CURRENT, valueCount, bitsPerValue);
this.blocks = new long[longCount];
maskRight = ~0L << (BLOCK_SIZE-bitsPerValue) >>> (BLOCK_SIZE-bitsPerValue);
bpvMinusBlockSize = bitsPerValue - BLOCK_SIZE;
}

Mutable(StreamInput in)
throws IOException {
this.bitsPerValue = in.readVInt();
this.valueCount = in.readVInt();
this.blocks = new long[in.readVInt()];
for (int i = 0; i < blocks.length; ++i) {
blocks[i] = in.readLong();
}
maskRight = ~0L << (BLOCK_SIZE - bitsPerValue) >>> (BLOCK_SIZE - bitsPerValue);
bpvMinusBlockSize = bitsPerValue - BLOCK_SIZE;
}

public void save(StreamOutput out) throws IOException {
assert valueCount != -1;
out.writeVInt(bitsPerValue);
out.writeVInt(valueCount);
out.writeVInt(blocks.length);
for (int i = 0; i < blocks.length; ++i) {
out.writeLong(blocks[i]);
}
}

public int size() {
return valueCount;
}

public long get(final int index) {
// The abstract index in a bit stream
final long majorBitPos = (long)index * bitsPerValue;
// The index in the backing long-array
final int elementPos = (int)(majorBitPos >>> BLOCK_BITS);
// The number of value-bits in the second long
final long endBits = (majorBitPos & MOD_MASK) + bpvMinusBlockSize;

if (endBits <= 0) { // Single block
return (blocks[elementPos] >>> -endBits) & maskRight;
}
// Two blocks
return ((blocks[elementPos] << endBits)
| (blocks[elementPos+1] >>> (BLOCK_SIZE - endBits)))
& maskRight;
}

public int get(int index, long[] arr, int off, int len) {
assert len > 0 : "len must be > 0 (got " + len + ")";
assert index >= 0 && index < valueCount;
len = Math.min(len, valueCount - index);
assert off + len <= arr.length;

final int originalIndex = index;
final PackedInts.Decoder decoder = PackedInts.getDecoder(PackedInts.Format.PACKED, PackedInts.VERSION_CURRENT, bitsPerValue);
// go to the next block where the value does not span across two blocks
final int offsetInBlocks = index % decoder.longValueCount();
if (offsetInBlocks != 0) {
for (int i = offsetInBlocks; i < decoder.longValueCount() && len > 0; ++i) {
arr[off++] = get(index++);
--len;
}
if (len == 0) {
return index - originalIndex;
}
}

// bulk get
assert index % decoder.longValueCount() == 0;
int blockIndex = (int) (((long) index * bitsPerValue) >>> BLOCK_BITS);
assert (((long)index * bitsPerValue) & MOD_MASK) == 0;
final int iterations = len / decoder.longValueCount();
decoder.decode(blocks, blockIndex, arr, off, iterations);
final int gotValues = iterations * decoder.longValueCount();
index += gotValues;
len -= gotValues;
assert len >= 0;

if (index > originalIndex) {
// stay at the block boundary
return index - originalIndex;
} else {
// no progress so far => already at a block boundary but no full block to get
assert index == originalIndex;
assert len > 0 : "len must be > 0 (got " + len + ")";
assert index >= 0 && index < size();
assert off + len <= arr.length;

final int gets = Math.min(size() - index, len);
for (int i = index, o = off, end = index + gets; i < end; ++i, ++o) {
arr[o] = get(i);
}
return gets;
}
}

public void set(final int index, final long value) {
// The abstract index in a contiguous bit stream
final long majorBitPos = (long)index * bitsPerValue;
// The index in the backing long-array
final int elementPos = (int)(majorBitPos >>> BLOCK_BITS); // / BLOCK_SIZE
// The number of value-bits in the second long
final long endBits = (majorBitPos & MOD_MASK) + bpvMinusBlockSize;

if (endBits <= 0) { // Single block
blocks[elementPos] = blocks[elementPos] & ~(maskRight << -endBits)
| (value << -endBits);
return;
}
// Two blocks
blocks[elementPos] = blocks[elementPos] & ~(maskRight >>> endBits)
| (value >>> endBits);
blocks[elementPos+1] = blocks[elementPos+1] & (~0L >>> endBits)
| (value << (BLOCK_SIZE - endBits));
}

public int set(int index, long[] arr, int off, int len) {
assert len > 0 : "len must be > 0 (got " + len + ")";
assert index >= 0 && index < valueCount;
len = Math.min(len, valueCount - index);
assert off + len <= arr.length;

final int originalIndex = index;
final PackedInts.Encoder encoder = PackedInts.getEncoder(PackedInts.Format.PACKED, PackedInts.VERSION_CURRENT, bitsPerValue);

// go to the next block where the value does not span across two blocks
final int offsetInBlocks = index % encoder.longValueCount();
if (offsetInBlocks != 0) {
for (int i = offsetInBlocks; i < encoder.longValueCount() && len > 0; ++i) {
set(index++, arr[off++]);
--len;
}
if (len == 0) {
return index - originalIndex;
}
}

// bulk set
assert index % encoder.longValueCount() == 0;
int blockIndex = (int) (((long) index * bitsPerValue) >>> BLOCK_BITS);
assert (((long)index * bitsPerValue) & MOD_MASK) == 0;
final int iterations = len / encoder.longValueCount();
encoder.encode(arr, off, blocks, blockIndex, iterations);
final int setValues = iterations * encoder.longValueCount();
index += setValues;
len -= setValues;
assert len >= 0;

if (index > originalIndex) {
// stay at the block boundary
return index - originalIndex;
} else {
// no progress so far => already at a block boundary but no full block to get
assert index == originalIndex;
len = Math.min(len, size() - index);
assert off + len <= arr.length;

for (int i = index, o = off, end = index + len; i < end; ++i, ++o) {
set(i, arr[o]);
}
return len;
}
}

public long ramBytesUsed() {
return RamUsageEstimator.alignObjectSize(
RamUsageEstimator.NUM_BYTES_OBJECT_HEADER
+ 3 * Integer.BYTES // bpvMinusBlockSize,valueCount,bitsPerValue
+ Long.BYTES // maskRight
+ RamUsageEstimator.NUM_BYTES_OBJECT_REF) // blocks ref
+ RamUsageEstimator.sizeOf(blocks);
}
}
}