diff --git a/dev-tools/scripts/releaseWizard.py b/dev-tools/scripts/releaseWizard.py
index d599095619d4..3814ae38a789 100755
--- a/dev-tools/scripts/releaseWizard.py
+++ b/dev-tools/scripts/releaseWizard.py
@@ -49,6 +49,7 @@
from collections import OrderedDict
from datetime import datetime
from datetime import timedelta
+from datetime import timezone
try:
import holidays
@@ -99,7 +100,7 @@ def expand_jinja(text, vars=None):
'state': state,
'gpg_key' : state.get_gpg_key(),
'gradle_cmd' : 'gradlew.bat' if is_windows() else './gradlew',
- 'epoch': unix_time_millis(datetime.utcnow()),
+ 'epoch': unix_time_millis(datetime.now(tz=timezone.utc)),
'get_next_version': state.get_next_version(),
'current_git_rev': state.get_current_git_rev(),
'keys_downloaded': keys_downloaded(),
@@ -199,7 +200,7 @@ def check_prerequisites(todo=None):
return True
-epoch = datetime.utcfromtimestamp(0)
+epoch = datetime.fromtimestamp(timestamp=0, tz=timezone.utc)
def unix_time_millis(dt):
@@ -279,7 +280,7 @@ def __init__(self, config_path, release_version, script_version):
self.latest_version = None
self.previous_rcs = {}
self.rc_number = 1
- self.start_date = unix_time_millis(datetime.utcnow())
+ self.start_date = unix_time_millis(datetime.now(tz=timezone.utc))
self.script_branch = run("git rev-parse --abbrev-ref HEAD").strip()
self.mirrored_versions = None
try:
@@ -741,7 +742,7 @@ def get_vars(self):
def set_done(self, is_done):
if is_done:
- self.state['done_date'] = unix_time_millis(datetime.utcnow())
+ self.state['done_date'] = unix_time_millis(datetime.now(tz=timezone.utc))
if self.persist_vars:
for k in self.persist_vars:
self.state[k] = self.get_vars()[k]
@@ -935,7 +936,7 @@ def expand_multiline(cmd_txt, indent=0):
def unix_to_datetime(unix_stamp):
- return datetime.utcfromtimestamp(unix_stamp / 1000)
+ return datetime.fromtimestamp(timestamp=unix_stamp / 1000, tz=timezone.utc)
def generate_asciidoc():
@@ -949,7 +950,7 @@ def generate_asciidoc():
fh.write("= Lucene Release %s\n\n" % state.release_version)
fh.write("(_Generated by releaseWizard.py v%s at %s_)\n\n"
- % (getScriptVersion(), datetime.utcnow().strftime("%Y-%m-%d %H:%M UTC")))
+ % (getScriptVersion(), datetime.now(tz=timezone.utc).strftime("%Y-%m-%d %H:%M UTC")))
fh.write(":numbered:\n\n")
fh.write("%s\n\n" % template('help'))
for group in state.todo_groups:
@@ -1839,9 +1840,9 @@ def create_ical(todo): # pylint: disable=unused-argument
return True
-today = datetime.utcnow().date()
+today = datetime.now(tz=timezone.utc).date()
sundays = {(today + timedelta(days=x)): 'Sunday' for x in range(10) if (today + timedelta(days=x)).weekday() == 6}
-y = datetime.utcnow().year
+y = datetime.now(tz=timezone.utc).year
years = [y, y+1]
non_working = holidays.CA(years=years) + holidays.US(years=years) + holidays.UK(years=years) \
+ holidays.DE(years=years) + holidays.NO(years=years) + holidays.IN(years=years) + holidays.RU(years=years)
@@ -1849,7 +1850,7 @@ def create_ical(todo): # pylint: disable=unused-argument
def vote_close_72h_date():
# Voting open at least 72 hours according to ASF policy
- return datetime.utcnow() + timedelta(hours=73)
+ return datetime.now(tz=timezone.utc) + timedelta(hours=73)
def vote_close_72h_holidays():
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 786c511b1de5..31ddd83d907e 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -11,7 +11,7 @@ API Changes
New Features
---------------------
-(No changes)
+* GITHUB#14097: Binary partitioning merge policy over float-valued vector field. (Mike Sokolov)
Improvements
---------------------
@@ -41,6 +41,9 @@ API Changes
* GITHUB#14069: Added DocIdSetIterator#intoBitSet API to let implementations
optimize loading doc IDs into a bit set. (Adrien Grand)
+* GITHUB#14134: Added Bits#applyMask API to help apply live docs as a mask on a
+ bit set of matches. (Adrien Grand)
+
New Features
---------------------
(No changes)
@@ -50,6 +53,9 @@ Improvements
* GITHUB#14079: Hunspell Dictionary now supports an option to tolerate REP rule count mismatches.
(Robert Muir)
+* GITHUB#13984: Add HNSW graph checks and stats to CheckIndex
+
+* GITHUB#14113: Remove unnecessary ByteArrayDataInput allocations from `Lucene90DocValuesProducer$TermsDict.decompressBlock`. (Ankit Jain)
Optimizations
---------------------
@@ -61,11 +67,20 @@ Optimizations
Bug Fixes
---------------------
-(No changes)
+
+* GITHUB#14109: prefetch may select the wrong memory segment for
+ multi-segment slices. (Chris Hegarty)
+
+* GITHUB#14123: SortingCodecReader NPE when segment has no (points, vectors, etc...) (Mike Sokolov)
Other
---------------------
-(No changes)
+
+* GITHUB#14081: Fix urls describing why NIOFS is not recommended for Windows (Marcel Yeonghyeon Ko)
+
+* GITHUB#14116 Use CDL to block threads to avoid flaky tests. (Ao Li)
+
+* GITHUB#14091: Cover all DataType. (Lu Xugang)
======================= Lucene 10.1.0 =======================
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
index 015fad7490ce..698ccff49f7e 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -488,6 +488,11 @@ public int entryNode() {
throw new UnsupportedOperationException();
}
+ @Override
+ public int maxConn() {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public NodesIterator getNodesOnLevel(int level) {
throw new UnsupportedOperationException();
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
index 845987c2957c..835293946dd9 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java
@@ -198,6 +198,11 @@ public int entryNode() {
throw new UnsupportedOperationException();
}
+ @Override
+ public int maxConn() {
+ return maxConn;
+ }
+
@Override
public NodesIterator getNodesOnLevel(int level) {
throw new UnsupportedOperationException();
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
index e71fa66719f8..80a71196c442 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java
@@ -548,6 +548,11 @@ public int entryNode() {
return entryNode;
}
+ @Override
+ public int maxConn() {
+ return (int) bytesForConns / Integer.BYTES - 1;
+ }
+
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
index e762e016bbff..3ebbeee8d8ae 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91OnHeapHnswGraph.java
@@ -160,6 +160,11 @@ public int entryNode() {
return entryNode;
}
+ @Override
+ public int maxConn() {
+ return maxConn;
+ }
+
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
index 034967efbaab..213b3fff4f26 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java
@@ -459,6 +459,11 @@ public int entryNode() {
return entryNode;
}
+ @Override
+ public int maxConn() {
+ return (int) bytesForConns / Integer.BYTES - 1;
+ }
+
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
index 1ad2e3023642..e386c2126bea 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java
@@ -448,6 +448,7 @@ private static final class OffHeapHnswGraph extends HnswGraph {
final int size;
final long bytesForConns;
final long bytesForConns0;
+ final int maxConn;
int arcCount;
int arcUpTo;
@@ -463,6 +464,7 @@ private static final class OffHeapHnswGraph extends HnswGraph {
this.bytesForConns = Math.multiplyExact(Math.addExact(entry.M, 1L), Integer.BYTES);
this.bytesForConns0 =
Math.multiplyExact(Math.addExact(Math.multiplyExact(entry.M, 2L), 1), Integer.BYTES);
+ maxConn = entry.M;
}
@Override
@@ -501,6 +503,11 @@ public int numLevels() {
return numLevels;
}
+ @Override
+ public int maxConn() {
+ return maxConn;
+ }
+
@Override
public int entryNode() {
return entryNode;
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java
index b5859daf9f2f..99fc82bfc54d 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java
@@ -537,6 +537,11 @@ public int entryNode() throws IOException {
return entryNode;
}
+ @Override
+ public int maxConn() {
+ return currentNeighborsBuffer.length / 2;
+ }
+
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
index 01698da79893..ec7eb048c8f4 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java
@@ -334,6 +334,11 @@ public int numLevels() {
return graph.numLevels();
}
+ @Override
+ public int maxConn() {
+ return graph.maxConn();
+ }
+
@Override
public int entryNode() {
throw new UnsupportedOperationException("Not supported on a mock graph");
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java
index c855d8f5e073..55fb3e8ea3ec 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java
@@ -366,6 +366,11 @@ public int entryNode() {
throw new UnsupportedOperationException("Not supported on a mock graph");
}
+ @Override
+ public int maxConn() {
+ throw new UnsupportedOperationException("Not supported on a mock graph");
+ }
+
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
index 54d070fa5995..1a3a5319efa2 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java
@@ -75,7 +75,7 @@ protected KnnVectorsReader() {}
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
*
The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
- * FieldInfo}. The return value is never {@code null}.
+ * FieldInfo}.
*
* @param field the vector field to search
* @param target the vector-valued query
@@ -103,7 +103,7 @@ public abstract void search(
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
*
*
The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
- * FieldInfo}. The return value is never {@code null}.
+ * FieldInfo}.
*
* @param field the vector field to search
* @param target the vector-valued query
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101PostingsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101PostingsReader.java
index 94ae2c9ef23a..c72da4abc3e4 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101PostingsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene101/Lucene101PostingsReader.java
@@ -53,7 +53,6 @@
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitUtil;
-import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
@@ -967,16 +966,13 @@ public int advance(int target) throws IOException {
}
@Override
- public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
- throws IOException {
+ public void intoBitSet(int upTo, FixedBitSet bitSet, int offset) throws IOException {
if (doc >= upTo) {
return;
}
// Handle the current doc separately, it may be on the previous docBuffer.
- if (acceptDocs == null || acceptDocs.get(doc)) {
- bitSet.set(doc - offset);
- }
+ bitSet.set(doc - offset);
for (; ; ) {
if (docBufferUpto == BLOCK_SIZE) {
@@ -990,7 +986,7 @@ public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset
int start = docBufferUpto;
int end = computeBufferEndBoundary(upTo);
if (end != 0) {
- bufferIntoBitSet(start, end, acceptDocs, bitSet, offset);
+ bufferIntoBitSet(start, end, bitSet, offset);
doc = docBuffer[end - 1];
}
docBufferUpto = end;
@@ -1004,51 +1000,28 @@ public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset
break;
case UNARY:
{
- if (acceptDocs == null) {
- int sourceFrom;
- if (docBufferUpto == 0) {
- // start from beginning
- sourceFrom = 0;
- } else {
- // start after the current doc
- sourceFrom = doc - docBitSetBase + 1;
- }
-
- int destFrom = docBitSetBase - offset + sourceFrom;
-
- assert level0LastDocID != NO_MORE_DOCS;
- int sourceTo = Math.min(upTo, level0LastDocID + 1) - docBitSetBase;
-
- if (sourceTo > sourceFrom) {
- FixedBitSet.orRange(
- docBitSet, sourceFrom, bitSet, destFrom, sourceTo - sourceFrom);
- }
- if (docBitSetBase + sourceTo <= level0LastDocID) {
- // We stopped before the end of the current bit set, which means that we're done.
- // Set the current doc before returning.
- advance(docBitSetBase + sourceTo);
- return;
- }
+ int sourceFrom;
+ if (docBufferUpto == 0) {
+ // start from beginning
+ sourceFrom = 0;
} else {
- // default impl, slow-ish
- long[] bits = docBitSet.getBits();
- for (int i = 0; i < bits.length; ++i) {
- long word = bits[i];
- while (word != 0) {
- int ntz = Long.numberOfTrailingZeros(word);
- int doc = docBitSetBase + ((i << 6) | ntz);
- if (doc >= this.doc) {
- if (doc >= upTo) {
- advance(doc); // this sets docBufferUpto as a side-effect
- return;
- }
- if (acceptDocs == null || acceptDocs.get(doc)) {
- bitSet.set(doc - offset);
- }
- }
- word ^= 1L << ntz;
- }
- }
+ // start after the current doc
+ sourceFrom = doc - docBitSetBase + 1;
+ }
+
+ int destFrom = docBitSetBase - offset + sourceFrom;
+
+ assert level0LastDocID != NO_MORE_DOCS;
+ int sourceTo = Math.min(upTo, level0LastDocID + 1) - docBitSetBase;
+
+ if (sourceTo > sourceFrom) {
+ FixedBitSet.orRange(docBitSet, sourceFrom, bitSet, destFrom, sourceTo - sourceFrom);
+ }
+ if (docBitSetBase + sourceTo <= level0LastDocID) {
+ // We stopped before the end of the current bit set, which means that we're done.
+ // Set the current doc before returning.
+ advance(docBitSetBase + sourceTo);
+ return;
}
doc = level0LastDocID;
docBufferUpto = BLOCK_SIZE;
@@ -1068,15 +1041,12 @@ private int computeBufferEndBoundary(int upTo) {
}
}
- private void bufferIntoBitSet(
- int start, int end, Bits acceptDocs, FixedBitSet bitSet, int offset) throws IOException {
- // acceptDocs#get (if backed by FixedBitSet), bitSet#set and `doc - offset` get
- // auto-vectorized
+ private void bufferIntoBitSet(int start, int end, FixedBitSet bitSet, int offset)
+ throws IOException {
+ // bitSet#set and `doc - offset` get auto-vectorized
for (int i = start; i < end; ++i) {
int doc = docBuffer[i];
- if (acceptDocs == null || acceptDocs.get(doc)) {
- bitSet.set(doc - offset);
- }
+ bitSet.set(doc - offset);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90DocValuesProducer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90DocValuesProducer.java
index 11e83b3f03c1..80dffb7b9708 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90DocValuesProducer.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90DocValuesProducer.java
@@ -1122,10 +1122,9 @@ private class TermsDict extends BaseTermsEnum {
final LongValues indexAddresses;
final RandomAccessInput indexBytes;
final BytesRef term;
+ final BytesRef blockBuffer;
+ final ByteArrayDataInput blockInput;
long ord = -1;
-
- BytesRef blockBuffer = null;
- ByteArrayDataInput blockInput = null;
long currentCompressedBlockStart = -1;
long currentCompressedBlockEnd = -1;
@@ -1149,6 +1148,7 @@ private class TermsDict extends BaseTermsEnum {
// add 7 padding bytes can help decompression run faster.
int bufferSize = entry.maxBlockLength + entry.maxTermLength + LZ4_DECOMPRESSOR_PADDING;
blockBuffer = new BytesRef(new byte[bufferSize], 0, bufferSize);
+ blockInput = new ByteArrayDataInput();
}
@Override
@@ -1324,8 +1324,7 @@ private void decompressBlock() throws IOException {
}
// Reset the buffer.
- blockInput =
- new ByteArrayDataInput(blockBuffer.bytes, blockBuffer.offset, blockBuffer.length);
+ blockInput.reset(blockBuffer.bytes, blockBuffer.offset, blockBuffer.length);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java
index ed6388b53cb7..b29f9da5b410 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java
@@ -491,7 +491,8 @@ public void seek(int level, int targetOrd) throws IOException {
level == 0
? targetOrd
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
- assert targetIndex >= 0;
+ assert targetIndex >= 0
+ : "seek level=" + level + " target=" + targetOrd + " not found: " + targetIndex;
// unsafe; no bounds checking
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
arcCount = dataIn.readVInt();
@@ -526,6 +527,11 @@ public int numLevels() throws IOException {
return numLevels;
}
+ @Override
+ public int maxConn() {
+ return currentNeighborsBuffer.length >> 1;
+ }
+
@Override
public int entryNode() throws IOException {
return entryNode;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
index 0f4e8196d52d..a587449e2e7b 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
@@ -93,7 +93,6 @@ public Lucene99HnswVectorsWriter(
this.numMergeWorkers = numMergeWorkers;
this.mergeExec = mergeExec;
segmentWriteState = state;
-
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
@@ -293,6 +292,11 @@ public int numLevels() {
return graph.numLevels();
}
+ @Override
+ public int maxConn() {
+ return graph.maxConn();
+ }
+
@Override
public int entryNode() {
throw new UnsupportedOperationException("Not supported on a mock graph");
@@ -448,11 +452,12 @@ private void writeMeta(
meta.writeVLong(vectorIndexLength);
meta.writeVInt(field.getVectorDimension());
meta.writeInt(count);
- meta.writeVInt(M);
// write graph nodes on each level
if (graph == null) {
+ meta.writeVInt(M);
meta.writeVInt(0);
} else {
+ meta.writeVInt(graph.maxConn());
meta.writeVInt(graph.numLevels());
long valueCount = 0;
for (int level = 0; level < graph.numLevels(); level++) {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
index bc18b231e74f..ee9cb5cd5f52 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java
@@ -28,6 +28,7 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
@@ -41,6 +42,7 @@
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.hnsw.HnswGraph;
/**
* Enables per field numeric vector support.
@@ -189,7 +191,7 @@ public long ramBytesUsed() {
}
/** VectorReader that can wrap multiple delegate readers, selected by field. */
- public static class FieldsReader extends KnnVectorsReader {
+ public static class FieldsReader extends KnnVectorsReader implements HnswGraphProvider {
private final IntObjectHashMap fields = new IntObjectHashMap<>();
private final FieldInfos fieldInfos;
@@ -322,6 +324,17 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
reader.search(field, target, knnCollector, acceptDocs);
}
+ @Override
+ public HnswGraph getGraph(String field) throws IOException {
+ final FieldInfo info = fieldInfos.fieldInfo(field);
+ KnnVectorsReader knnVectorsReader = fields.get(info.number);
+ if (knnVectorsReader instanceof HnswGraphProvider) {
+ return ((HnswGraphProvider) knnVectorsReader).getGraph(field);
+ } else {
+ return null;
+ }
+ }
+
@Override
public void close() throws IOException {
List readers = new ArrayList<>(fields.size());
diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
index d957af01d0a2..b3a5e4dc5d11 100644
--- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
+++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java
@@ -26,8 +26,11 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.text.NumberFormat;
+import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
@@ -52,12 +55,14 @@
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.DocumentStoredFieldVisitor;
import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus;
import org.apache.lucene.index.PointValues.IntersectVisitor;
import org.apache.lucene.index.PointValues.Relation;
+import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.KnnCollector;
@@ -74,6 +79,7 @@
import org.apache.lucene.store.Lock;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
+import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
@@ -91,6 +97,7 @@
import org.apache.lucene.util.automaton.ByteRunAutomaton;
import org.apache.lucene.util.automaton.CompiledAutomaton;
import org.apache.lucene.util.automaton.Operations;
+import org.apache.lucene.util.hnsw.HnswGraph;
/**
* Basic tool and API to check the health of an index and write a new segments file that removes
@@ -249,6 +256,9 @@ public static class SegmentInfoStatus {
/** Status of vectors */
public VectorValuesStatus vectorValuesStatus;
+ /** Status of HNSW graph */
+ public HnswGraphsStatus hnswGraphsStatus;
+
/** Status of soft deletes */
public SoftDeletesStatus softDeletesStatus;
@@ -406,6 +416,32 @@ public static final class VectorValuesStatus {
public Throwable error;
}
+ /** Status from testing a single HNSW graph */
+ public static final class HnswGraphStatus {
+
+ HnswGraphStatus() {}
+
+ /** Number of nodes at each level */
+ public List numNodesAtLevel;
+
+ /** Connectedness at each level represented as a fraction */
+ public List connectednessAtLevel;
+ }
+
+ /** Status from testing all HNSW graphs */
+ public static final class HnswGraphsStatus {
+
+ HnswGraphsStatus() {
+ this.hnswGraphsStatusByField = new HashMap<>();
+ }
+
+ /** Status of the HNSW graph keyed with field name */
+ public Map hnswGraphsStatusByField;
+
+ /** Exception thrown during term index test (null on success) */
+ public Throwable error;
+ }
+
/** Status from testing index sort */
public static final class IndexSortStatus {
IndexSortStatus() {}
@@ -1085,6 +1121,9 @@ private Status.SegmentInfoStatus testSegment(
// Test FloatVectorValues and ByteVectorValues
segInfoStat.vectorValuesStatus = testVectors(reader, infoStream, failFast);
+ // Test HNSW graph
+ segInfoStat.hnswGraphsStatus = testHnswGraphs(reader, infoStream, failFast);
+
// Test Index Sort
if (indexSort != null) {
segInfoStat.indexSortStatus = testSort(reader, indexSort, infoStream, failFast);
@@ -2746,6 +2785,196 @@ public static Status.VectorValuesStatus testVectors(
return status;
}
+ /** Test the HNSW graph. */
+ public static Status.HnswGraphsStatus testHnswGraphs(
+ CodecReader reader, PrintStream infoStream, boolean failFast) throws IOException {
+ if (infoStream != null) {
+ infoStream.print(" test: hnsw graphs.........");
+ }
+ long startNS = System.nanoTime();
+ Status.HnswGraphsStatus status = new Status.HnswGraphsStatus();
+ KnnVectorsReader vectorsReader = reader.getVectorReader();
+ FieldInfos fieldInfos = reader.getFieldInfos();
+
+ try {
+ if (fieldInfos.hasVectorValues()) {
+ for (FieldInfo fieldInfo : fieldInfos) {
+ if (fieldInfo.hasVectorValues()) {
+ KnnVectorsReader fieldReader = getFieldReaderForName(vectorsReader, fieldInfo.name);
+ if (fieldReader instanceof HnswGraphProvider graphProvider) {
+ HnswGraph hnswGraph = graphProvider.getGraph(fieldInfo.name);
+ testHnswGraph(hnswGraph, fieldInfo.name, status);
+ }
+ }
+ }
+ }
+ msg(
+ infoStream,
+ String.format(
+ Locale.ROOT,
+ "OK [%d fields] [took %.3f sec]",
+ status.hnswGraphsStatusByField.size(),
+ nsToSec(System.nanoTime() - startNS)));
+ printHnswInfo(infoStream, status.hnswGraphsStatusByField);
+ } catch (Exception e) {
+ if (failFast) {
+ throw IOUtils.rethrowAlways(e);
+ }
+ msg(infoStream, "ERROR: " + e);
+ status.error = e;
+ if (infoStream != null) {
+ e.printStackTrace(infoStream);
+ }
+ }
+
+ return status;
+ }
+
+ private static KnnVectorsReader getFieldReaderForName(
+ KnnVectorsReader vectorsReader, String fieldName) {
+ if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
+ return fieldsReader.getFieldReader(fieldName);
+ } else {
+ return vectorsReader;
+ }
+ }
+
+ private static void printHnswInfo(
+ PrintStream infoStream, Map fieldsStatus) {
+ for (Map.Entry entry : fieldsStatus.entrySet()) {
+ String fieldName = entry.getKey();
+ CheckIndex.Status.HnswGraphStatus status = entry.getValue();
+ msg(infoStream, " hnsw field name: " + fieldName);
+
+ int numLevels = Math.min(status.numNodesAtLevel.size(), status.connectednessAtLevel.size());
+ for (int level = numLevels - 1; level >= 0; level--) {
+ int numNodes = status.numNodesAtLevel.get(level);
+ String connectedness = status.connectednessAtLevel.get(level);
+ msg(
+ infoStream,
+ String.format(
+ Locale.ROOT,
+ " level %d: %d nodes, %s connected",
+ level,
+ numNodes,
+ connectedness));
+ }
+ }
+ }
+
+ private static void testHnswGraph(
+ HnswGraph hnswGraph, String fieldName, Status.HnswGraphsStatus status)
+ throws IOException, CheckIndexException {
+ if (hnswGraph != null) {
+ status.hnswGraphsStatusByField.put(fieldName, new Status.HnswGraphStatus());
+ final int numLevels = hnswGraph.numLevels();
+ status.hnswGraphsStatusByField.get(fieldName).numNodesAtLevel =
+ new ArrayList<>(Collections.nCopies(numLevels, null));
+ status.hnswGraphsStatusByField.get(fieldName).connectednessAtLevel =
+ new ArrayList<>(Collections.nCopies(numLevels, null));
+ // Perform checks on each level of the HNSW graph
+ for (int level = numLevels - 1; level >= 0; level--) {
+ // Collect BitSet of all nodes on this level
+ BitSet nodesOnThisLevel = new FixedBitSet(hnswGraph.size());
+ HnswGraph.NodesIterator nodesIterator = hnswGraph.getNodesOnLevel(level);
+ while (nodesIterator.hasNext()) {
+ nodesOnThisLevel.set(nodesIterator.nextInt());
+ }
+
+ nodesIterator = hnswGraph.getNodesOnLevel(level);
+ // Perform checks on each node on the level
+ while (nodesIterator.hasNext()) {
+ int node = nodesIterator.nextInt();
+ if (node < 0 || node > hnswGraph.size() - 1) {
+ throw new CheckIndexException(
+ "Field \""
+ + fieldName
+ + "\" has node: "
+ + node
+ + " not in the expected range [0, "
+ + (hnswGraph.size() - 1)
+ + "]");
+ }
+
+ // Perform checks on the node's neighbors
+ hnswGraph.seek(level, node);
+ int nbr, lastNeighbor = -1, firstNeighbor = -1;
+ while ((nbr = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) {
+ if (!nodesOnThisLevel.get(nbr)) {
+ throw new CheckIndexException(
+ "Field \""
+ + fieldName
+ + "\" has node: "
+ + node
+ + " with a neighbor "
+ + nbr
+ + " which is not on its level ("
+ + level
+ + ")");
+ }
+ if (firstNeighbor == -1) {
+ firstNeighbor = nbr;
+ }
+ if (nbr < lastNeighbor) {
+ throw new CheckIndexException(
+ "Field \""
+ + fieldName
+ + "\" has neighbors out of order for node "
+ + node
+ + ": "
+ + nbr
+ + "<"
+ + lastNeighbor
+ + " 1st="
+ + firstNeighbor);
+ } else if (nbr == lastNeighbor) {
+ throw new CheckIndexException(
+ "Field \""
+ + fieldName
+ + "\" has repeated neighbors of node "
+ + node
+ + " with value "
+ + nbr);
+ }
+ lastNeighbor = nbr;
+ }
+ }
+ int numNodesOnLayer = nodesIterator.size();
+ status.hnswGraphsStatusByField.get(fieldName).numNodesAtLevel.set(level, numNodesOnLayer);
+
+ // Evaluate connectedness at this level by measuring the number of nodes reachable from the
+ // entry point
+ IntIntHashMap connectedNodes = getConnectedNodesOnLevel(hnswGraph, numNodesOnLayer, level);
+ status
+ .hnswGraphsStatusByField
+ .get(fieldName)
+ .connectednessAtLevel
+ .set(level, connectedNodes.size() + "/" + numNodesOnLayer);
+ }
+ }
+ }
+
+ private static IntIntHashMap getConnectedNodesOnLevel(
+ HnswGraph hnswGraph, int numNodesOnLayer, int level) throws IOException {
+ IntIntHashMap connectedNodes = new IntIntHashMap(numNodesOnLayer);
+ int entryPoint = hnswGraph.entryNode();
+ Deque stack = new ArrayDeque<>();
+ stack.push(entryPoint);
+ while (!stack.isEmpty()) {
+ int node = stack.pop();
+ if (connectedNodes.containsKey(node)) {
+ continue;
+ }
+ connectedNodes.put(node, 1);
+ hnswGraph.seek(level, node);
+ int friendOrd;
+ while ((friendOrd = hnswGraph.nextNeighbor()) != NO_MORE_DOCS) {
+ stack.push(friendOrd);
+ }
+ }
+ return connectedNodes;
+ }
+
private static boolean vectorsReaderSupportsSearch(CodecReader codecReader, String fieldName) {
KnnVectorsReader vectorsReader = codecReader.getVectorReader();
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader) {
diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexSorter.java b/lucene/core/src/java/org/apache/lucene/index/IndexSorter.java
index 6db379b9b227..9fa2484af39f 100644
--- a/lucene/core/src/java/org/apache/lucene/index/IndexSorter.java
+++ b/lucene/core/src/java/org/apache/lucene/index/IndexSorter.java
@@ -32,11 +32,16 @@
/**
* Handles how documents should be sorted in an index, both within a segment and between segments.
*
- * Implementers must provide the following methods: {@link #getDocComparator(LeafReader,int)} -
- * an object that determines how documents within a segment are to be sorted {@link
- * #getComparableProviders(List)} - an array of objects that return a sortable long value per
- * document and segment {@link #getProviderName()} - the SPI-registered name of a {@link
- * SortFieldProvider} to serialize the sort
+ *
Implementers must provide the following methods:
+ *
+ *
+ * - {@link #getDocComparator(LeafReader,int)} - an object that determines how documents within
+ * a segment are to be sorted
+ *
- {@link #getComparableProviders(List)} - an array of objects that return a sortable long
+ * value per document and segment
+ *
- {@link #getProviderName()} - the SPI-registered name of a {@link SortFieldProvider} to
+ * serialize the sort
+ *
*
* The companion {@link SortFieldProvider} should be registered with SPI via {@code
* META-INF/services}
diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
index daec0c197d6a..ab9964026ad8 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
@@ -314,6 +314,7 @@ private static class SortingFloatVectorValues extends FloatVectorValues {
SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.delegate = delegate;
+ assert delegate != null;
// SortingValuesIterator consumes the iterator and records the docs and ord mapping
iteratorSupplier = iteratorSupplier(delegate, sortMap);
}
@@ -446,6 +447,9 @@ private SortingCodecReader(
@Override
public FieldsProducer getPostingsReader() {
FieldsProducer postingsReader = in.getPostingsReader();
+ if (postingsReader == null) {
+ return null;
+ }
return new FieldsProducer() {
@Override
public void close() throws IOException {
@@ -481,6 +485,9 @@ public int size() {
@Override
public StoredFieldsReader getFieldsReader() {
StoredFieldsReader delegate = in.getFieldsReader();
+ if (delegate == null) {
+ return null;
+ }
return newStoredFieldsReader(delegate);
}
@@ -526,6 +533,9 @@ public Bits getLiveDocs() {
@Override
public PointsReader getPointsReader() {
final PointsReader delegate = in.getPointsReader();
+ if (delegate == null) {
+ return null;
+ }
return new PointsReader() {
@Override
public void checkIntegrity() throws IOException {
@@ -551,6 +561,9 @@ public void close() throws IOException {
@Override
public KnnVectorsReader getVectorReader() {
KnnVectorsReader delegate = in.getVectorReader();
+ if (delegate == null) {
+ return null;
+ }
return new KnnVectorsReader() {
@Override
public void checkIntegrity() throws IOException {
@@ -587,6 +600,9 @@ public void close() throws IOException {
@Override
public NormsProducer getNormsReader() {
final NormsProducer delegate = in.getNormsReader();
+ if (delegate == null) {
+ return null;
+ }
return new NormsProducer() {
@Override
public NumericDocValues getNorms(FieldInfo field) throws IOException {
@@ -609,6 +625,9 @@ public void close() throws IOException {
@Override
public DocValuesProducer getDocValuesReader() {
final DocValuesProducer delegate = in.getDocValuesReader();
+ if (delegate == null) {
+ return null;
+ }
return new DocValuesProducer() {
@Override
public NumericDocValues getNumeric(FieldInfo field) throws IOException {
@@ -710,6 +729,9 @@ public TermVectorsReader getTermVectorsReader() {
}
private TermVectorsReader newTermVectorsReader(TermVectorsReader delegate) {
+ if (delegate == null) {
+ return null;
+ }
return new TermVectorsReader() {
@Override
public void prefetch(int doc) throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java
index 5d7dfaf8b832..a6599a57fd25 100644
--- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java
@@ -164,37 +164,6 @@ public long cost() {
return cost;
}
- private void scoreDisiWrapperIntoBitSet(DisiWrapper w, Bits acceptDocs, int min, int max)
- throws IOException {
- boolean needsScores = BooleanScorer.this.needsScores;
- FixedBitSet matching = BooleanScorer.this.matching;
- Bucket[] buckets = BooleanScorer.this.buckets;
-
- DocIdSetIterator it = w.iterator;
- Scorable scorer = w.scorable;
- int doc = w.doc;
- if (doc < min) {
- doc = it.advance(min);
- }
- if (buckets == null) {
- it.intoBitSet(acceptDocs, max, matching, doc & ~MASK);
- } else {
- for (; doc < max; doc = it.nextDoc()) {
- if (acceptDocs == null || acceptDocs.get(doc)) {
- final int i = doc & MASK;
- matching.set(i);
- final Bucket bucket = buckets[i];
- bucket.freq++;
- if (needsScores) {
- bucket.score += scorer.score();
- }
- }
- }
- }
-
- w.doc = it.docID();
- }
-
private void scoreWindowIntoBitSetAndReplay(
LeafCollector collector,
Bits acceptDocs,
@@ -207,7 +176,35 @@ private void scoreWindowIntoBitSetAndReplay(
for (int i = 0; i < numScorers; ++i) {
final DisiWrapper w = scorers[i];
assert w.doc < max;
- scoreDisiWrapperIntoBitSet(w, acceptDocs, min, max);
+
+ DocIdSetIterator it = w.iterator;
+ int doc = w.doc;
+ if (doc < min) {
+ doc = it.advance(min);
+ }
+ if (buckets == null) {
+ // This doesn't apply live docs, so we'll need to apply them later
+ it.intoBitSet(max, matching, base);
+ } else {
+ for (; doc < max; doc = it.nextDoc()) {
+ if (acceptDocs == null || acceptDocs.get(doc)) {
+ final int d = doc & MASK;
+ matching.set(d);
+ final Bucket bucket = buckets[d];
+ bucket.freq++;
+ if (needsScores) {
+ bucket.score += w.scorable.score();
+ }
+ }
+ }
+ }
+
+ w.doc = it.docID();
+ }
+
+ if (buckets == null && acceptDocs != null) {
+ // In this case, live docs have not been applied yet.
+ acceptDocs.applyMask(matching, base);
}
docIdStreamView.base = base;
diff --git a/lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java
index 2acf04ba501b..121687245248 100644
--- a/lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java
@@ -105,7 +105,11 @@ private void scoreWindowUsingBitSet(
assert clauseWindowMatches.scanIsEmpty();
int offset = lead.docID();
- lead.intoBitSet(acceptDocs, max, windowMatches, offset);
+ lead.intoBitSet(max, windowMatches, offset);
+ if (acceptDocs != null) {
+ // Apply live docs.
+ acceptDocs.applyMask(windowMatches, offset);
+ }
int upTo = 0;
for (;
@@ -116,9 +120,7 @@ private void scoreWindowUsingBitSet(
if (other.docID() < offset) {
other.advance(offset);
}
- // No need to apply acceptDocs on other clauses since we already applied live docs on the
- // leading clause.
- other.intoBitSet(null, max, clauseWindowMatches, offset);
+ other.intoBitSet(max, clauseWindowMatches, offset);
windowMatches.and(clauseWindowMatches);
clauseWindowMatches.clear();
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java
index 6ab57c7b180c..cedababbce6b 100644
--- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java
+++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java
@@ -21,7 +21,6 @@
import java.util.Collection;
import java.util.Comparator;
import org.apache.lucene.util.ArrayUtil;
-import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/**
@@ -150,17 +149,16 @@ public int advance(int target) throws IOException {
}
@Override
- public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
- throws IOException {
+ public void intoBitSet(int upTo, FixedBitSet bitSet, int offset) throws IOException {
while (leadTop.doc < upTo) {
- leadTop.approximation.intoBitSet(acceptDocs, upTo, bitSet, offset);
+ leadTop.approximation.intoBitSet(upTo, bitSet, offset);
leadTop.doc = leadTop.approximation.docID();
leadTop = leadIterators.updateTop();
}
minOtherDoc = Integer.MAX_VALUE;
for (DisiWrapper w : otherIterators) {
- w.approximation.intoBitSet(acceptDocs, upTo, bitSet, offset);
+ w.approximation.intoBitSet(upTo, bitSet, offset);
w.doc = w.approximation.docID();
minOtherDoc = Math.min(minOtherDoc, w.doc);
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/DocIdSetIterator.java b/lucene/core/src/java/org/apache/lucene/search/DocIdSetIterator.java
index ee30f627a56b..e0bee1da2314 100644
--- a/lucene/core/src/java/org/apache/lucene/search/DocIdSetIterator.java
+++ b/lucene/core/src/java/org/apache/lucene/search/DocIdSetIterator.java
@@ -17,7 +17,6 @@
package org.apache.lucene.search;
import java.io.IOException;
-import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/**
@@ -220,9 +219,7 @@ protected final int slowAdvance(int target) throws IOException {
*
*
* for (int doc = docID(); doc < upTo; doc = nextDoc()) {
- * if (acceptDocs == null || acceptDocs.get(doc)) {
- * bitSet.set(doc - offset);
- * }
+ * bitSet.set(doc - offset);
* }
*
*
@@ -233,13 +230,10 @@ protected final int slowAdvance(int target) throws IOException {
*
* @lucene.internal
*/
- public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
- throws IOException {
+ public void intoBitSet(int upTo, FixedBitSet bitSet, int offset) throws IOException {
assert offset <= docID();
for (int doc = docID(); doc < upTo; doc = nextDoc()) {
- if (acceptDocs == null || acceptDocs.get(doc)) {
- bitSet.set(doc - offset);
- }
+ bitSet.set(doc - offset);
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/store/FSDirectory.java b/lucene/core/src/java/org/apache/lucene/store/FSDirectory.java
index 413e22c45ae8..0a49cba05e49 100644
--- a/lucene/core/src/java/org/apache/lucene/store/FSDirectory.java
+++ b/lucene/core/src/java/org/apache/lucene/store/FSDirectory.java
@@ -60,7 +60,7 @@
* post.
* {@link NIOFSDirectory} uses java.nio's FileChannel's positional io when reading to avoid
* synchronization when reading from the same file. Unfortunately, due to a Windows-only Sun JRE bug this is a
+ * href="https://bugs.java.com/bugdatabase/view_bug?bug_id=6265734">Sun JRE bug this is a
* poor choice for Windows, but on all other platforms this is the preferred choice.
* Applications using {@link Thread#interrupt()} or {@link Future#cancel(boolean)} should use
* {@code RAFDirectory} instead, which is provided in the {@code misc} module. See {@link
diff --git a/lucene/core/src/java/org/apache/lucene/store/NIOFSDirectory.java b/lucene/core/src/java/org/apache/lucene/store/NIOFSDirectory.java
index 246f48082cfe..c9c92db91b40 100644
--- a/lucene/core/src/java/org/apache/lucene/store/NIOFSDirectory.java
+++ b/lucene/core/src/java/org/apache/lucene/store/NIOFSDirectory.java
@@ -36,7 +36,7 @@
* NOTE: NIOFSDirectory is not recommended on Windows because of a bug in how
* FileChannel.read is implemented in Sun's JRE. Inside of the implementation the position is
* apparently synchronized. See here for details.
+ * href="https://bugs.java.com/bugdatabase/view_bug?bug_id=6265734">here for details.
*
*
NOTE: Accessing this class either directly or indirectly from a thread while it's
* interrupted can close the underlying file descriptor immediately if at the same time the thread
diff --git a/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java b/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java
index 22b87f91d100..ff74c107b13c 100644
--- a/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java
+++ b/lucene/core/src/java/org/apache/lucene/util/BitSetIterator.java
@@ -99,18 +99,13 @@ public long cost() {
}
@Override
- public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
- throws IOException {
- // TODO: Can we also optimize the case when acceptDocs is not null?
- if (acceptDocs == null
- && upTo > doc
- && offset < bits.length()
- && bits instanceof FixedBitSet fixedBits) {
- upTo = Math.min(upTo, fixedBits.length());
+ public void intoBitSet(int upTo, FixedBitSet bitSet, int offset) throws IOException {
+ upTo = Math.min(upTo, bits.length());
+ if (upTo > doc && bits instanceof FixedBitSet fixedBits) {
FixedBitSet.orRange(fixedBits, doc, bitSet, doc - offset, upTo - doc);
advance(upTo); // set the current doc
} else {
- super.intoBitSet(acceptDocs, upTo, bitSet, offset);
+ super.intoBitSet(upTo, bitSet, offset);
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/Bits.java b/lucene/core/src/java/org/apache/lucene/util/Bits.java
index dd42ad4b1973..61757a1a34e4 100644
--- a/lucene/core/src/java/org/apache/lucene/util/Bits.java
+++ b/lucene/core/src/java/org/apache/lucene/util/Bits.java
@@ -16,6 +16,8 @@
*/
package org.apache.lucene.util;
+import org.apache.lucene.search.DocIdSetIterator;
+
/**
* Interface for Bitset-like structures.
*
@@ -34,6 +36,32 @@ public interface Bits {
/** Returns the number of bits in this set */
int length();
+ /**
+ * Apply this {@code Bits} instance to the given {@link FixedBitSet}, which starts at the given
+ * {@code offset}.
+ *
+ *
This should behave the same way as the default implementation, which does the following:
+ *
+ *
+ * for (int i = bitSet.nextSetBit(0);
+ * i != DocIdSetIterator.NO_MORE_DOCS;
+ * i = i + 1 >= bitSet.length() ? DocIdSetIterator.NO_MORE_DOCS : bitSet.nextSetBit(i + 1)) {
+ * if (get(offset + i) == false) {
+ * bitSet.clear(i);
+ * }
+ * }
+ *
+ */
+ default void applyMask(FixedBitSet bitSet, int offset) {
+ for (int i = bitSet.nextSetBit(0);
+ i != DocIdSetIterator.NO_MORE_DOCS;
+ i = i + 1 >= bitSet.length() ? DocIdSetIterator.NO_MORE_DOCS : bitSet.nextSetBit(i + 1)) {
+ if (get(offset + i) == false) {
+ bitSet.clear(i);
+ }
+ }
+ }
+
Bits[] EMPTY_ARRAY = new Bits[0];
/** Bits impl of the specified length with all bits set. */
diff --git a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
index bf3b8a2d2689..1b6954d2eb66 100644
--- a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
+++ b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
@@ -347,7 +347,7 @@ public void or(DocIdSetIterator iter) throws IOException {
} else {
checkUnpositioned(iter);
iter.nextDoc();
- iter.intoBitSet(null, DocIdSetIterator.NO_MORE_DOCS, this, 0);
+ iter.intoBitSet(DocIdSetIterator.NO_MORE_DOCS, this, 0);
}
}
@@ -396,7 +396,8 @@ public static void orRange(
// First, align `destFrom` with a word start, ie. a multiple of Long.SIZE (64)
if ((destFrom & 0x3F) != 0) {
int numBitsNeeded = Math.min(-destFrom & 0x3F, length);
- destBits[destFrom >> 6] |= readNBits(sourceBits, sourceFrom, numBitsNeeded) << destFrom;
+ long bits = readNBits(sourceBits, sourceFrom, numBitsNeeded) << destFrom;
+ destBits[destFrom >> 6] |= bits;
sourceFrom += numBitsNeeded;
destFrom += numBitsNeeded;
@@ -434,7 +435,74 @@ public static void orRange(
// Finally handle tail bits
if (length > 0) {
- destBits[destFrom >> 6] |= readNBits(sourceBits, sourceFrom, length);
+ long bits = readNBits(sourceBits, sourceFrom, length);
+ destBits[destFrom >> 6] |= bits;
+ }
+ }
+
+ /**
+ * And {@code length} bits starting at {@code sourceFrom} from {@code source} into {@code dest}
+ * starting at {@code destFrom}.
+ */
+ public static void andRange(
+ FixedBitSet source, int sourceFrom, FixedBitSet dest, int destFrom, int length) {
+ assert length >= 0 : length;
+ Objects.checkFromIndexSize(sourceFrom, length, source.length());
+ Objects.checkFromIndexSize(destFrom, length, dest.length());
+
+ if (length == 0) {
+ return;
+ }
+
+ long[] sourceBits = source.getBits();
+ long[] destBits = dest.getBits();
+
+ // First, align `destFrom` with a word start, ie. a multiple of Long.SIZE (64)
+ if ((destFrom & 0x3F) != 0) {
+ int numBitsNeeded = Math.min(-destFrom & 0x3F, length);
+ long bits = readNBits(sourceBits, sourceFrom, numBitsNeeded) << destFrom;
+ bits |= ~(((1L << numBitsNeeded) - 1) << destFrom);
+ destBits[destFrom >> 6] &= bits;
+
+ sourceFrom += numBitsNeeded;
+ destFrom += numBitsNeeded;
+ length -= numBitsNeeded;
+ }
+
+ if (length == 0) {
+ return;
+ }
+
+ assert (destFrom & 0x3F) == 0;
+
+ // Now AND at the word level
+ int numFullWords = length >> 6;
+ int sourceWordFrom = sourceFrom >> 6;
+ int destWordFrom = destFrom >> 6;
+
+ // Note: these two for loops auto-vectorize
+ if ((sourceFrom & 0x3F) == 0) {
+ // sourceFrom and destFrom are both aligned with a long[]
+ for (int i = 0; i < numFullWords; ++i) {
+ destBits[destWordFrom + i] &= sourceBits[sourceWordFrom + i];
+ }
+ } else {
+ for (int i = 0; i < numFullWords; ++i) {
+ destBits[destWordFrom + i] &=
+ (sourceBits[sourceWordFrom + i] >>> sourceFrom)
+ | (sourceBits[sourceWordFrom + i + 1] << -sourceFrom);
+ }
+ }
+
+ sourceFrom += numFullWords << 6;
+ destFrom += numFullWords << 6;
+ length -= numFullWords << 6;
+
+ // Finally handle tail bits
+ if (length > 0) {
+ long bits = readNBits(sourceBits, sourceFrom, length);
+ bits |= (~0L << length);
+ destBits[destFrom >> 6] &= bits;
}
}
@@ -730,4 +798,18 @@ public static FixedBitSet copyOf(Bits bits) {
public Bits asReadOnlyBits() {
return new FixedBits(bits, numBits);
}
+
+ @Override
+ public void applyMask(FixedBitSet bitSet, int offset) {
+ // Note: Some scorers don't track maxDoc and may thus call this method with an offset that is
+ // beyond bitSet.length()
+ int length = Math.min(bitSet.length(), length() - offset);
+ if (length >= 0) {
+ andRange(this, offset, bitSet, 0, length);
+ }
+ if (length < bitSet.length()
+ && bitSet.nextSetBit(Math.max(0, length)) != DocIdSetIterator.NO_MORE_DOCS) {
+ throw new IllegalArgumentException("Some bits are set beyond the end of live docs");
+ }
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java
index b4688d097302..7f44bcf96aa0 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java
@@ -60,11 +60,10 @@ protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxO
} else {
initializedNodes = new FixedBitSet(maxOrd);
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes);
- graph =
- InitializedHnswGraphBuilder.initGraph(M, initializerGraph, oldToNewOrdinalMap, maxOrd);
+ graph = InitializedHnswGraphBuilder.initGraph(initializerGraph, oldToNewOrdinalMap, maxOrd);
}
}
return new HnswConcurrentMergeBuilder(
- taskExecutor, numWorker, scorerSupplier, M, beamWidth, graph, initializedNodes);
+ taskExecutor, numWorker, scorerSupplier, beamWidth, graph, initializedNodes);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java
index c23f56bcdc62..d2e81addc5d4 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswConcurrentMergeBuilder.java
@@ -49,7 +49,6 @@ public HnswConcurrentMergeBuilder(
TaskExecutor taskExecutor,
int numWorker,
RandomVectorScorerSupplier scorerSupplier,
- int M,
int beamWidth,
OnHeapHnswGraph hnsw,
BitSet initializedNodes)
@@ -62,7 +61,6 @@ public HnswConcurrentMergeBuilder(
workers[i] =
new ConcurrentMergeWorker(
scorerSupplier.copy(),
- M,
beamWidth,
HnswGraphBuilder.randSeed,
hnsw,
@@ -149,7 +147,6 @@ private static final class ConcurrentMergeWorker extends HnswGraphBuilder {
private ConcurrentMergeWorker(
RandomVectorScorerSupplier scorerSupplier,
- int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw,
@@ -159,7 +156,6 @@ private ConcurrentMergeWorker(
throws IOException {
super(
scorerSupplier,
- M,
beamWidth,
seed,
hnsw,
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
index 0c38c4e2ff78..9326ace61f7f 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java
@@ -84,6 +84,9 @@ public int maxNodeId() {
/** Returns the number of levels of the graph */
public abstract int numLevels() throws IOException;
+ /** returns M, the maximum number of connections for a node. */
+ public abstract int maxConn() throws IOException;
+
/** Returns graph's entry point on the top level * */
public abstract int entryNode() throws IOException;
@@ -118,6 +121,11 @@ public int numLevels() {
return 0;
}
+ @Override
+ public int maxConn() {
+ return 0;
+ }
+
@Override
public int entryNode() {
return 0;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index 57e7e43d3d76..d3cef0bc6d10 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -100,19 +100,14 @@ public static HnswGraphBuilder create(
protected HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier, int M, int beamWidth, long seed, int graphSize)
throws IOException {
- this(scorerSupplier, M, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
+ this(scorerSupplier, beamWidth, seed, new OnHeapHnswGraph(M, graphSize));
}
protected HnswGraphBuilder(
- RandomVectorScorerSupplier scorerSupplier,
- int M,
- int beamWidth,
- long seed,
- OnHeapHnswGraph hnsw)
+ RandomVectorScorerSupplier scorerSupplier, int beamWidth, long seed, OnHeapHnswGraph hnsw)
throws IOException {
this(
scorerSupplier,
- M,
beamWidth,
seed,
hnsw,
@@ -125,8 +120,6 @@ protected HnswGraphBuilder(
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
*
* @param scorerSupplier a supplier to create vector scorer from ordinals.
- * @param M – graph fanout parameter used to calculate the maximum number of connections a node
- * can have – M on upper layers, and M * 2 on the lowest level.
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param seed the seed for a random number generator used during graph construction. Provide this
* to ensure repeatable construction.
@@ -134,20 +127,19 @@ protected HnswGraphBuilder(
*/
protected HnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier,
- int M,
int beamWidth,
long seed,
OnHeapHnswGraph hnsw,
HnswLock hnswLock,
HnswGraphSearcher graphSearcher)
throws IOException {
- if (M <= 0) {
+ if (hnsw.maxConn() <= 0) {
throw new IllegalArgumentException("M (max connections) must be positive");
}
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
}
- this.M = M;
+ this.M = hnsw.maxConn();
this.scorerSupplier =
Objects.requireNonNull(scorerSupplier, "scorer supplier must not be null");
// normalization factor for level generation; currently not configurable
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java
index 7028a81808d6..d0d398be2a78 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswUtil.java
@@ -23,6 +23,7 @@
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
+import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.CodecReader;
@@ -233,13 +234,16 @@ private static int nextClearBit(FixedBitSet bits, int index) {
public static boolean graphIsRooted(IndexReader reader, String vectorField) throws IOException {
for (LeafReaderContext ctx : reader.leaves()) {
CodecReader codecReader = (CodecReader) FilterLeafReader.unwrap(ctx.reader());
- HnswGraph graph =
- ((HnswGraphProvider)
- ((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader())
- .getFieldReader(vectorField))
- .getGraph(vectorField);
- if (isRooted(graph) == false) {
- return false;
+ KnnVectorsReader vectorsReader =
+ ((PerFieldKnnVectorsFormat.FieldsReader) codecReader.getVectorReader())
+ .getFieldReader(vectorField);
+ if (vectorsReader instanceof HnswGraphProvider) {
+ HnswGraph graph = ((HnswGraphProvider) vectorsReader).getGraph(vectorField);
+ if (isRooted(graph) == false) {
+ return false;
+ }
+ } else {
+ throw new IllegalArgumentException("not a graph: " + vectorsReader);
}
}
return true;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java
index c480d53360cb..bcdb0cab04a9 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java
@@ -21,7 +21,6 @@
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
-import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
@@ -63,33 +62,30 @@ public IncrementalHnswGraphMerger(
/**
* Adds a reader to the graph merger if it meets the following criteria: 1. Does not contain any
- * deleted docs 2. Is a HnswGraphProvider/PerFieldKnnVectorReader 3. Has the most docs of any
- * previous reader that met the above criteria
+ * deleted docs 2. Is a HnswGraphProvider 3. Has the most docs of any previous reader that met the
+ * above criteria
*/
@Override
public IncrementalHnswGraphMerger addReader(
KnnVectorsReader reader, MergeState.DocMap docMap, Bits liveDocs) throws IOException {
- KnnVectorsReader currKnnVectorsReader = reader;
- if (reader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
- currKnnVectorsReader = candidateReader.getFieldReader(fieldInfo.name);
+ if (hasDeletes(liveDocs) || !(reader instanceof HnswGraphProvider)) {
+ return this;
}
-
- if (!(currKnnVectorsReader instanceof HnswGraphProvider) || !noDeletes(liveDocs)) {
+ HnswGraph graph = ((HnswGraphProvider) reader).getGraph(fieldInfo.name);
+ if (graph == null || graph.size() == 0) {
return this;
}
-
int candidateVectorCount = 0;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
- ByteVectorValues byteVectorValues =
- currKnnVectorsReader.getByteVectorValues(fieldInfo.name);
+ ByteVectorValues byteVectorValues = reader.getByteVectorValues(fieldInfo.name);
if (byteVectorValues == null) {
return this;
}
candidateVectorCount = byteVectorValues.size();
}
case FLOAT32 -> {
- FloatVectorValues vectorValues = currKnnVectorsReader.getFloatVectorValues(fieldInfo.name);
+ FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldInfo.name);
if (vectorValues == null) {
return this;
}
@@ -97,7 +93,7 @@ public IncrementalHnswGraphMerger addReader(
}
}
if (candidateVectorCount > initGraphSize) {
- initReader = currKnnVectorsReader;
+ initReader = reader;
initDocMap = docMap;
initGraphSize = candidateVectorCount;
}
@@ -130,7 +126,6 @@ protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxO
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes);
return InitializedHnswGraphBuilder.fromGraph(
scorerSupplier,
- M,
beamWidth,
HnswGraphBuilder.randSeed,
initializerGraph,
@@ -194,16 +189,16 @@ protected final int[] getNewOrdMapping(
return oldToNewOrdinalMap;
}
- private static boolean noDeletes(Bits liveDocs) {
+ private static boolean hasDeletes(Bits liveDocs) {
if (liveDocs == null) {
- return true;
+ return false;
}
for (int i = 0; i < liveDocs.length(); i++) {
if (!liveDocs.get(i)) {
- return false;
+ return true;
}
}
- return true;
+ return false;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java
index 179d243233df..7dff036ddde4 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/InitializedHnswGraphBuilder.java
@@ -34,7 +34,6 @@ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
* Create a new HnswGraphBuilder that is initialized with the provided HnswGraph.
*
* @param scorerSupplier the scorer to use for vectors
- * @param M the number of connections to keep per node
* @param beamWidth the number of nodes to explore in the search
* @param seed the seed for the random number generator
* @param initializerGraph the graph to initialize the new graph builder
@@ -47,7 +46,6 @@ public final class InitializedHnswGraphBuilder extends HnswGraphBuilder {
*/
public static InitializedHnswGraphBuilder fromGraph(
RandomVectorScorerSupplier scorerSupplier,
- int M,
int beamWidth,
long seed,
HnswGraph initializerGraph,
@@ -57,17 +55,15 @@ public static InitializedHnswGraphBuilder fromGraph(
throws IOException {
return new InitializedHnswGraphBuilder(
scorerSupplier,
- M,
beamWidth,
seed,
- initGraph(M, initializerGraph, newOrdMap, totalNumberOfVectors),
+ initGraph(initializerGraph, newOrdMap, totalNumberOfVectors),
initializedNodes);
}
public static OnHeapHnswGraph initGraph(
- int M, HnswGraph initializerGraph, int[] newOrdMap, int totalNumberOfVectors)
- throws IOException {
- OnHeapHnswGraph hnsw = new OnHeapHnswGraph(M, totalNumberOfVectors);
+ HnswGraph initializerGraph, int[] newOrdMap, int totalNumberOfVectors) throws IOException {
+ OnHeapHnswGraph hnsw = new OnHeapHnswGraph(initializerGraph.maxConn(), totalNumberOfVectors);
for (int level = initializerGraph.numLevels() - 1; level >= 0; level--) {
HnswGraph.NodesIterator it = initializerGraph.getNodesOnLevel(level);
while (it.hasNext()) {
@@ -93,13 +89,12 @@ public static OnHeapHnswGraph initGraph(
public InitializedHnswGraphBuilder(
RandomVectorScorerSupplier scorerSupplier,
- int M,
int beamWidth,
long seed,
OnHeapHnswGraph initializedGraph,
BitSet initializedNodes)
throws IOException {
- super(scorerSupplier, M, beamWidth, seed, initializedGraph);
+ super(scorerSupplier, beamWidth, seed, initializedGraph);
this.initializedNodes = initializedNodes;
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index a79bcd17d91f..61f301527879 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -92,7 +92,14 @@ public final class OnHeapHnswGraph extends HnswGraph implements Accountable {
public NeighborArray getNeighbors(int level, int node) {
assert node < graph.length;
assert level < graph[node].length
- : "level=" + level + ", node has only " + graph[node].length + " levels";
+ : "level="
+ + level
+ + ", node "
+ + node
+ + " has only "
+ + graph[node].length
+ + " levels for graph "
+ + this;
assert graph[node][level] != null : "node=" + node + ", level=" + level;
return graph[node][level];
}
@@ -181,6 +188,11 @@ public int numLevels() {
return entryNode.get().level + 1;
}
+ @Override
+ public int maxConn() {
+ return nsize - 1;
+ }
+
/**
* Returns the graph's current entry node on the top level shown as ordinals of the nodes on 0th
* level
diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java
index 2424b53645bd..74594be5ec99 100644
--- a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java
+++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java
@@ -337,8 +337,6 @@ public void prefetch(long offset, long length) throws IOException {
ensureOpen();
- Objects.checkFromIndexSize(offset, length, length());
-
if (BitUtil.isZeroOrPowerOfTwo(consecutivePrefetchHitCount++) == false) {
// We've had enough consecutive hits on the page cache that this number is neither zero nor a
// power of two. There is a good chance that a good chunk of this index input is cached in
@@ -381,8 +379,6 @@ void advise(long offset, long length, IOConsumer advice) throws I
ensureOpen();
- Objects.checkFromIndexSize(offset, length, length());
-
final NativeAccess nativeAccess = NATIVE_ACCESS.get();
try {
@@ -818,6 +814,12 @@ public MemorySegment segmentSliceOrNull(long pos, long len) throws IOException {
throw handlePositionalIOOBE(e, "segmentSliceOrNull", pos);
}
}
+
+ @Override
+ public void prefetch(long offset, long length) throws IOException {
+ Objects.checkFromIndexSize(offset, length, this.length);
+ super.prefetch(offset, length);
+ }
}
/** This class adds offset support to MemorySegmentIndexInput, which is needed for slices. */
@@ -903,5 +905,11 @@ public MemorySegment segmentSliceOrNull(long pos, long len) throws IOException {
MemorySegmentIndexInput buildSlice(String sliceDescription, long ofs, long length) {
return super.buildSlice(sliceDescription, this.offset + ofs, length);
}
+
+ @Override
+ public void prefetch(long offset, long length) throws IOException {
+ Objects.checkFromIndexSize(offset, length, this.length);
+ super.prefetch(this.offset + offset, length);
+ }
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestConcurrentMergeScheduler.java b/lucene/core/src/test/org/apache/lucene/index/TestConcurrentMergeScheduler.java
index e0b2c49d8548..fcf42177570e 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestConcurrentMergeScheduler.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestConcurrentMergeScheduler.java
@@ -418,7 +418,6 @@ protected void doMerge(MergeSource mergeSource, MergePolicy.OneMerge merge)
dir.close();
}
- @SuppressForbidden(reason = "Thread sleep")
public void testIntraMergeThreadPoolIsLimitedByMaxThreads() throws IOException {
ConcurrentMergeScheduler mergeScheduler = new ConcurrentMergeScheduler();
MergeScheduler.MergeSource mergeSource =
@@ -475,11 +474,12 @@ public void merge(MergePolicy.OneMerge merge) throws IOException {
Executor executor = mergeScheduler.intraMergeExecutor;
AtomicInteger threadsExecutedOnPool = new AtomicInteger();
AtomicInteger threadsExecutedOnSelf = new AtomicInteger();
- for (int i = 0; i < 4; i++) {
+ CountDownLatch latch = new CountDownLatch(1);
+ final int totalThreads = 4;
+ for (int i = 0; i < totalThreads; i++) {
mergeScheduler.mergeThreads.add(
mergeScheduler.new MergeThread(mergeSource, merge) {
@Override
- @SuppressForbidden(reason = "Thread sleep")
public void run() {
executor.execute(
() -> {
@@ -489,7 +489,7 @@ public void run() {
threadsExecutedOnPool.incrementAndGet();
}
try {
- Thread.sleep(100);
+ latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
@@ -500,6 +500,10 @@ public void run() {
for (ConcurrentMergeScheduler.MergeThread thread : mergeScheduler.mergeThreads) {
thread.start();
}
+ while (threadsExecutedOnSelf.get() + threadsExecutedOnPool.get() < totalThreads) {
+ Thread.yield();
+ }
+ latch.countDown();
mergeScheduler.sync();
assertEquals(3, threadsExecutedOnSelf.get());
assertEquals(1, threadsExecutedOnPool.get());
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java
index 9663d6762554..285296d55c19 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java
@@ -25,7 +25,9 @@
import java.util.Collections;
import java.util.List;
import java.util.Locale;
+import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.TermVectorsReader;
+import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
@@ -53,6 +55,7 @@
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.hnsw.HnswGraph;
public class TestSortingCodecReader extends LuceneTestCase {
@@ -151,12 +154,16 @@ public void testSortOnAddIndicesRandom() throws IOException {
docIds.add(i);
}
Collections.shuffle(docIds, random());
+ // If true, index a vector and points for every doc
+ boolean dense = random().nextBoolean();
try (RandomIndexWriter iw = new RandomIndexWriter(random(), dir)) {
for (int i = 0; i < numDocs; i++) {
int docId = docIds.get(i);
Document doc = new Document();
doc.add(new StringField("string_id", Integer.toString(docId), Field.Store.YES));
- doc.add(new LongPoint("point_id", docId));
+ if (dense || docId % 3 == 0) {
+ doc.add(new LongPoint("point_id", docId));
+ }
String s = RandomStrings.randomRealisticUnicodeOfLength(random(), 25);
doc.add(new TextField("text_field", s, Field.Store.YES));
doc.add(new BinaryDocValuesField("text_field", new BytesRef(s)));
@@ -168,7 +175,9 @@ public void testSortOnAddIndicesRandom() throws IOException {
doc.add(new BinaryDocValuesField("binary_dv", new BytesRef(Integer.toString(docId))));
doc.add(
new SortedSetDocValuesField("sorted_set_dv", new BytesRef(Integer.toString(docId))));
- doc.add(new KnnFloatVectorField("vector", new float[] {(float) docId}));
+ if (dense || docId % 2 == 0) {
+ doc.add(new KnnFloatVectorField("vector", new float[] {(float) docId}));
+ }
doc.add(new NumericDocValuesField("foo", random().nextInt(20)));
FieldType ft = new FieldType(StringField.TYPE_NOT_STORED);
@@ -239,6 +248,13 @@ public void testSortOnAddIndicesRandom() throws IOException {
SortedSetDocValues sorted_set_dv = leaf.getSortedSetDocValues("sorted_set_dv");
SortedDocValues binary_sorted_dv = leaf.getSortedDocValues("binary_sorted_dv");
FloatVectorValues vectorValues = leaf.getFloatVectorValues("vector");
+ KnnVectorsReader vectorsReader = ((CodecReader) leaf).getVectorReader();
+ HnswGraph graph;
+ if (vectorsReader instanceof HnswGraphProvider hnswGraphProvider) {
+ graph = hnswGraphProvider.getGraph("vector");
+ } else {
+ graph = null;
+ }
NumericDocValues ids = leaf.getNumericDocValues("id");
long prevValue = -1;
boolean usingAltIds = false;
@@ -264,7 +280,14 @@ public void testSortOnAddIndicesRandom() throws IOException {
assertTrue(sorted_numeric_dv.advanceExact(idNext));
assertTrue(sorted_set_dv.advanceExact(idNext));
assertTrue(binary_sorted_dv.advanceExact(idNext));
- assertEquals(idNext, valuesIterator.advance(idNext));
+ if (dense || prevValue % 2 == 0) {
+ assertEquals(idNext, valuesIterator.advance(idNext));
+ if (graph != null) {
+ graph.seek(0, valuesIterator.index());
+ assertNotEquals(DocIdSetIterator.NO_MORE_DOCS, graph.nextNeighbor());
+ }
+ }
+
assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue());
assertEquals(
new BytesRef(ids.longValue() + ""),
@@ -276,9 +299,11 @@ public void testSortOnAddIndicesRandom() throws IOException {
assertEquals(1, sorted_numeric_dv.docValueCount());
assertEquals(ids.longValue(), sorted_numeric_dv.nextValue());
- float[] vectorValue = vectorValues.vectorValue(valuesIterator.index());
- assertEquals(1, vectorValue.length);
- assertEquals((float) ids.longValue(), vectorValue[0], 0.001f);
+ if (dense || prevValue % 2 == 0) {
+ float[] vectorValue = vectorValues.vectorValue(valuesIterator.index());
+ assertEquals(1, vectorValue.length);
+ assertEquals((float) ids.longValue(), vectorValue[0], 0.001f);
+ }
Fields termVectors = leaf.termVectors().get(idNext);
assertTrue(
@@ -291,9 +316,13 @@ public void testSortOnAddIndicesRandom() throws IOException {
leaf.storedFields().document(idNext).get("string_id"));
IndexSearcher searcher = new IndexSearcher(r);
TopDocs result =
- searcher.search(LongPoint.newExactQuery("point_id", ids.longValue()), 1);
- assertEquals(1, result.totalHits.value());
- assertEquals(idNext, result.scoreDocs[0].doc);
+ searcher.search(LongPoint.newExactQuery("point_id", ids.longValue()), 10);
+ if (dense || ids.longValue() % 3 == 0) {
+ assertEquals(1, result.totalHits.value());
+ assertEquals(idNext, result.scoreDocs[0].doc);
+ } else {
+ assertEquals(0, result.totalHits.value());
+ }
result =
searcher.search(new TermQuery(new Term("string_id", "" + ids.longValue())), 1);
diff --git a/lucene/core/src/test/org/apache/lucene/store/TestMMapDirectory.java b/lucene/core/src/test/org/apache/lucene/store/TestMMapDirectory.java
index d01d6ec50ebb..f69befca850c 100644
--- a/lucene/core/src/test/org/apache/lucene/store/TestMMapDirectory.java
+++ b/lucene/core/src/test/org/apache/lucene/store/TestMMapDirectory.java
@@ -329,4 +329,42 @@ public void testNoGroupingFunc() {
assertFalse(func.apply("segment.si").isPresent());
assertFalse(func.apply("_51a.si").isPresent());
}
+
+ public void testPrefetchWithSingleSegment() throws IOException {
+ testPrefetchWithSegments(64 * 1024);
+ }
+
+ public void testPrefetchWithMultiSegment() throws IOException {
+ testPrefetchWithSegments(16 * 1024);
+ }
+
+ static final Class IOOBE = IndexOutOfBoundsException.class;
+
+ // does not verify that the actual segment is prefetched, but rather exercises the code and bounds
+ void testPrefetchWithSegments(int maxChunkSize) throws IOException {
+ byte[] bytes = new byte[(maxChunkSize * 2) + 1];
+ try (Directory dir =
+ new MMapDirectory(createTempDir("testPrefetchWithSegments"), maxChunkSize)) {
+ try (IndexOutput out = dir.createOutput("test", IOContext.DEFAULT)) {
+ out.writeBytes(bytes, 0, bytes.length);
+ }
+
+ try (var in = dir.openInput("test", IOContext.READONCE)) {
+ in.prefetch(0, in.length());
+ expectThrows(IOOBE, () -> in.prefetch(1, in.length()));
+ expectThrows(IOOBE, () -> in.prefetch(in.length(), 1));
+
+ var slice1 = in.slice("slice-1", 1, in.length() - 1);
+ slice1.prefetch(0, slice1.length());
+ expectThrows(IOOBE, () -> slice1.prefetch(1, slice1.length()));
+ expectThrows(IOOBE, () -> slice1.prefetch(slice1.length(), 1));
+
+ // we sliced off all but one byte from the first complete memory segment
+ var slice2 = in.slice("slice-2", maxChunkSize - 1, in.length() - maxChunkSize + 1);
+ slice2.prefetch(0, slice2.length());
+ expectThrows(IOOBE, () -> slice2.prefetch(1, slice2.length()));
+ expectThrows(IOOBE, () -> slice2.prefetch(slice2.length(), 1));
+ }
+ }
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestFixedBitSet.java b/lucene/core/src/test/org/apache/lucene/util/TestFixedBitSet.java
index 4beb147b1e3b..39acd5ea209e 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestFixedBitSet.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestFixedBitSet.java
@@ -665,13 +665,49 @@ public void testOrRange() {
}
FixedBitSet.orRange(source, sourceFrom, dest, destFrom, length);
for (int i = 0; i < dest.length(); ++i) {
- if (i % 10 == 0) {
- assertTrue(dest.get(i));
- } else if (i >= destFrom && i < destFrom + length) {
- int sourceI = sourceFrom + (i - destFrom);
- assertEquals("" + i, source.get(sourceI) || i % 10 == 0, dest.get(i));
+ boolean destSet = i % 10 == 0;
+ if (i < destFrom || i >= destFrom + length) {
+ // Outside of the range, unmodified
+ assertEquals("" + i, destSet, dest.get(i));
} else {
- assertFalse(dest.get(i));
+ boolean sourceSet = source.get(sourceFrom + (i - destFrom));
+ assertEquals(sourceSet || destSet, dest.get(i));
+ }
+ }
+ }
+ }
+ }
+ }
+
+ public void testAndRange() {
+ FixedBitSet dest = new FixedBitSet(1_000);
+ FixedBitSet source = new FixedBitSet(10_000);
+ for (int i = 0; i < source.length(); i += 3) {
+ source.set(i);
+ }
+
+ // Test all possible alignments, and both a "short" (less than 64) and a long length.
+ for (int sourceFrom = 64; sourceFrom < 128; ++sourceFrom) {
+ for (int destFrom = 256; destFrom < 320; ++destFrom) {
+ for (int length :
+ new int[] {
+ 0,
+ TestUtil.nextInt(random(), 1, Long.SIZE - 1),
+ TestUtil.nextInt(random(), Long.SIZE, 512)
+ }) {
+ dest.clear();
+ for (int i = 0; i < dest.length(); i += 2) {
+ dest.set(i);
+ }
+ FixedBitSet.andRange(source, sourceFrom, dest, destFrom, length);
+ for (int i = 0; i < dest.length(); ++i) {
+ boolean destSet = i % 2 == 0;
+ if (i < destFrom || i >= destFrom + length) {
+ // Outside of the range, unmodified
+ assertEquals("" + i, destSet, dest.get(i));
+ } else {
+ boolean sourceSet = source.get(sourceFrom + (i - destFrom));
+ assertEquals("" + i, sourceSet && destSet, dest.get(i));
}
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
index 1da8c8169a98..c6da6d30fafc 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java
@@ -550,12 +550,11 @@ public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws
// another graph to do the assertion
OnHeapHnswGraph graphAfterInit =
InitializedHnswGraphBuilder.initGraph(
- 10, initializerGraph, initializerOrdMap, initializerGraph.size());
+ initializerGraph, initializerOrdMap, initializerGraph.size());
HnswGraphBuilder finalBuilder =
InitializedHnswGraphBuilder.fromGraph(
finalscorerSupplier,
- 10,
30,
seed,
initializerGraph,
@@ -593,7 +592,6 @@ public void testHnswGraphBuilderInitializationFromGraph_withNonZeroOffset() thro
HnswGraphBuilder finalBuilder =
InitializedHnswGraphBuilder.fromGraph(
finalscorerSupplier,
- 10,
30,
seed,
initializerGraph,
@@ -987,7 +985,7 @@ public void testConcurrentMergeBuilder() throws IOException {
HnswGraphBuilder.randSeed = random().nextLong();
HnswConcurrentMergeBuilder builder =
new HnswConcurrentMergeBuilder(
- taskExecutor, 4, scorerSupplier, 10, 30, new OnHeapHnswGraph(10, size), null);
+ taskExecutor, 4, scorerSupplier, 30, new OnHeapHnswGraph(10, size), null);
builder.setBatchSize(100);
builder.build(size);
exec.shutdownNow();
@@ -1337,6 +1335,11 @@ public int entryNode() throws IOException {
return delegate.entryNode();
}
+ @Override
+ public int maxConn() throws IOException {
+ return delegate.maxConn();
+ }
+
@Override
public NodesIterator getNodesOnLevel(int level) throws IOException {
return delegate.getNodesOnLevel(level);
diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java
index 316afff5ee25..7ea05150e14a 100644
--- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java
+++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java
@@ -250,6 +250,11 @@ public int entryNode() {
return 0;
}
+ @Override
+ public int maxConn() {
+ return 0;
+ }
+
@Override
public String toString() {
StringBuilder buf = new StringBuilder();
diff --git a/lucene/core/src/test/org/apache/lucene/util/packed/TestPackedInts.java b/lucene/core/src/test/org/apache/lucene/util/packed/TestPackedInts.java
index e87f708c8d22..b114070ba9c0 100644
--- a/lucene/core/src/test/org/apache/lucene/util/packed/TestPackedInts.java
+++ b/lucene/core/src/test/org/apache/lucene/util/packed/TestPackedInts.java
@@ -986,7 +986,7 @@ public void testPackedLongValues() {
new long[RandomNumbers.randomIntBetween(random(), 1, TEST_NIGHTLY ? 1000000 : 10000)];
float[] ratioOptions = new float[] {PackedInts.DEFAULT, PackedInts.COMPACT, PackedInts.FAST};
for (int bpv : new int[] {0, 1, 63, 64, RandomNumbers.randomIntBetween(random(), 2, 62)}) {
- for (DataType dataType : Arrays.asList(DataType.DELTA_PACKED)) {
+ for (DataType dataType : DataType.values()) {
final int pageSize = 1 << TestUtil.nextInt(random(), 6, 20);
float acceptableOverheadRatio =
ratioOptions[TestUtil.nextInt(random(), 0, ratioOptions.length - 1)];
diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/AbstractBPReorderer.java b/lucene/misc/src/java/org/apache/lucene/misc/index/AbstractBPReorderer.java
new file mode 100644
index 000000000000..3f7442a25263
--- /dev/null
+++ b/lucene/misc/src/java/org/apache/lucene/misc/index/AbstractBPReorderer.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.misc.index;
+
+/** Base class for docid-reorderers implemented using binary partitioning (BP). */
+public abstract class AbstractBPReorderer implements IndexReorderer {
+ /**
+ * Minimum size of partitions. The algorithm will stop recursing when reaching partitions below
+ * this number of documents: 32.
+ */
+ public static final int DEFAULT_MIN_PARTITION_SIZE = 32;
+
+ /**
+ * Default maximum number of iterations per recursion level: 20. Higher numbers of iterations
+ * typically don't help significantly.
+ */
+ public static final int DEFAULT_MAX_ITERS = 20;
+
+ protected int minPartitionSize = DEFAULT_MIN_PARTITION_SIZE;
+ protected int maxIters = DEFAULT_MAX_ITERS;
+ protected double ramBudgetMB;
+
+ public AbstractBPReorderer() {
+ // 10% of the available heap size by default
+ setRAMBudgetMB(Runtime.getRuntime().totalMemory() / 1024d / 1024d / 10d);
+ }
+
+ /** Set the minimum partition size, when the algorithm stops recursing, 32 by default. */
+ public void setMinPartitionSize(int minPartitionSize) {
+ if (minPartitionSize < 1) {
+ throw new IllegalArgumentException(
+ "minPartitionSize must be at least 1, got " + minPartitionSize);
+ }
+ this.minPartitionSize = minPartitionSize;
+ }
+
+ /**
+ * Set the maximum number of iterations on each recursion level, 20 by default. Experiments
+ * suggests that values above 20 do not help much. However, values below 20 can be used to trade
+ * effectiveness for faster reordering.
+ */
+ public void setMaxIters(int maxIters) {
+ if (maxIters < 1) {
+ throw new IllegalArgumentException("maxIters must be at least 1, got " + maxIters);
+ }
+ this.maxIters = maxIters;
+ }
+
+ /**
+ * Set the amount of RAM that graph partitioning is allowed to use. More RAM allows running
+ * faster. If not enough RAM is provided, a {@link NotEnoughRAMException} will be thrown. This is
+ * 10% of the total heap size by default.
+ */
+ public void setRAMBudgetMB(double ramBudgetMB) {
+ this.ramBudgetMB = ramBudgetMB;
+ }
+
+ /** Exception that is thrown when not enough RAM is available. */
+ public static class NotEnoughRAMException extends RuntimeException {
+ NotEnoughRAMException(String message) {
+ super(message);
+ }
+ }
+}
diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java b/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java
index f51321fe8424..6507ff025006 100644
--- a/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java
+++ b/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java
@@ -90,14 +90,7 @@
*
* Note: This is a slow operation that consumes O(maxDoc + numTerms * numThreads) memory.
*/
-public final class BPIndexReorderer {
-
- /** Exception that is thrown when not enough RAM is available. */
- public static class NotEnoughRAMException extends RuntimeException {
- private NotEnoughRAMException(String message) {
- super(message);
- }
- }
+public final class BPIndexReorderer extends AbstractBPReorderer {
/** Block size for terms in the forward index */
private static final int TERM_IDS_BLOCK_SIZE = 17;
@@ -108,33 +101,14 @@ private NotEnoughRAMException(String message) {
/** Minimum required document frequency for terms to be considered: 4,096. */
public static final int DEFAULT_MIN_DOC_FREQ = 4096;
- /**
- * Minimum size of partitions. The algorithm will stop recursing when reaching partitions below
- * this number of documents: 32.
- */
- public static final int DEFAULT_MIN_PARTITION_SIZE = 32;
-
- /**
- * Default maximum number of iterations per recursion level: 20. Higher numbers of iterations
- * typically don't help significantly.
- */
- public static final int DEFAULT_MAX_ITERS = 20;
-
private int minDocFreq;
private float maxDocFreq;
- private int minPartitionSize;
- private int maxIters;
- private double ramBudgetMB;
private Set fields;
/** Constructor. */
public BPIndexReorderer() {
setMinDocFreq(DEFAULT_MIN_DOC_FREQ);
setMaxDocFreq(1f);
- setMinPartitionSize(DEFAULT_MIN_PARTITION_SIZE);
- setMaxIters(DEFAULT_MAX_ITERS);
- // 10% of the available heap size by default
- setRAMBudgetMB(Runtime.getRuntime().totalMemory() / 1024d / 1024d / 10d);
setFields(null);
}
@@ -159,36 +133,6 @@ public void setMaxDocFreq(float maxDocFreq) {
this.maxDocFreq = maxDocFreq;
}
- /** Set the minimum partition size, when the algorithm stops recursing, 32 by default. */
- public void setMinPartitionSize(int minPartitionSize) {
- if (minPartitionSize < 1) {
- throw new IllegalArgumentException(
- "minPartitionSize must be at least 1, got " + minPartitionSize);
- }
- this.minPartitionSize = minPartitionSize;
- }
-
- /**
- * Set the maximum number of iterations on each recursion level, 20 by default. Experiments
- * suggests that values above 20 do not help much. However, values below 20 can be used to trade
- * effectiveness for faster reordering.
- */
- public void setMaxIters(int maxIters) {
- if (maxIters < 1) {
- throw new IllegalArgumentException("maxIters must be at least 1, got " + maxIters);
- }
- this.maxIters = maxIters;
- }
-
- /**
- * Set the amount of RAM that graph partitioning is allowed to use. More RAM allows running
- * faster. If not enough RAM is provided, a {@link NotEnoughRAMException} will be thrown. This is
- * 10% of the total heap size by default.
- */
- public void setRAMBudgetMB(double ramBudgetMB) {
- this.ramBudgetMB = ramBudgetMB;
- }
-
/**
* Sets the fields to use to perform partitioning. A {@code null} value indicates that all indexed
* fields should be used.
@@ -830,6 +774,7 @@ public void onFinish() throws IOException {
* enable integration into {@link BPReorderingMergePolicy}, {@link #reorder(CodecReader,
* Directory, Executor)} should be preferred in general.
*/
+ @Override
public Sorter.DocMap computeDocMap(CodecReader reader, Directory tempDir, Executor executor)
throws IOException {
if (docRAMRequirements(reader.maxDoc()) >= ramBudgetMB * 1024 * 1024) {
diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/BPReorderingMergePolicy.java b/lucene/misc/src/java/org/apache/lucene/misc/index/BPReorderingMergePolicy.java
index 5cd363192fac..70fcea733eb3 100644
--- a/lucene/misc/src/java/org/apache/lucene/misc/index/BPReorderingMergePolicy.java
+++ b/lucene/misc/src/java/org/apache/lucene/misc/index/BPReorderingMergePolicy.java
@@ -27,7 +27,7 @@
import org.apache.lucene.index.SegmentCommitInfo;
import org.apache.lucene.index.SegmentInfos;
import org.apache.lucene.index.Sorter;
-import org.apache.lucene.misc.index.BPIndexReorderer.NotEnoughRAMException;
+import org.apache.lucene.misc.index.AbstractBPReorderer.NotEnoughRAMException;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.SetOnce;
@@ -42,7 +42,7 @@ public final class BPReorderingMergePolicy extends FilterMergePolicy {
/** Whether a segment has been reordered. */
static final String REORDERED = "bp.reordered";
- private final BPIndexReorderer reorderer;
+ private final IndexReorderer reorderer;
private int minNaturalMergeNumDocs = 1;
private float minNaturalMergeRatioFromBiggestSegment = 0f;
@@ -59,7 +59,7 @@ public final class BPReorderingMergePolicy extends FilterMergePolicy {
* @param in the merge policy to use to compute merges
* @param reorderer the {@link BPIndexReorderer} to use to renumber doc IDs
*/
- public BPReorderingMergePolicy(MergePolicy in, BPIndexReorderer reorderer) {
+ public BPReorderingMergePolicy(MergePolicy in, IndexReorderer reorderer) {
super(in);
this.reorderer = reorderer;
}
diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/BpVectorReorderer.java b/lucene/misc/src/java/org/apache/lucene/misc/index/BpVectorReorderer.java
new file mode 100644
index 000000000000..246109ede04c
--- /dev/null
+++ b/lucene/misc/src/java/org/apache/lucene/misc/index/BpVectorReorderer.java
@@ -0,0 +1,790 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.misc.index;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.ForkJoinWorkerThread;
+import java.util.concurrent.RecursiveAction;
+import org.apache.lucene.index.CodecReader;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.SortingCodecReader;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.TaskExecutor;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.FSDirectory;
+import org.apache.lucene.util.CloseableThreadLocal;
+import org.apache.lucene.util.IntroSelector;
+import org.apache.lucene.util.IntsRef;
+import org.apache.lucene.util.VectorUtil;
+
+/**
+ * Implementation of "recursive graph bisection", also called "bipartite graph partitioning" and
+ * often abbreviated BP, an approach to doc ID assignment that aims at reducing the sum of the log
+ * gap between consecutive neighbor node ids. See {@link BPIndexReorderer}.
+ */
+public class BpVectorReorderer extends AbstractBPReorderer {
+
+ /*
+ * Note on using centroids (mean of vectors in a partition) to maximize scores of pairs of vectors within each partition:
+ * The function used to compare the vectors must have higher values for vectors
+ * that are more similar to each other, and must preserve inequalities over sums of vectors.
+ *
+ * This property enables us to use the centroid of a collection of vectors to represent the
+ * collection. For Euclidean and inner product score functions, the centroid is the point that
+ * minimizes the sum of distances from all the points (thus maximizing the score).
+ *
+ *
sum((c0 - v)^2) = n * c0^2 - 2 * c0 * sum(v) + sum(v^2) taking derivative w.r.t. c0 and
+ * setting to 0 we get sum(v) = n * c0; i.e. c0 (the centroid) is the place that minimizes the sum
+ * of (l2) distances from the vectors (thus maximizing the euclidean score function).
+ *
+ *
to maximize dot-product over unit vectors, note that: sum(dot(c0, v)) = dot(c0, sum(v))
+ * which is maximized, again, when c0 = sum(v) / n. For max inner product score, vectors may not
+ * be unit vectors. In this case there is no maximum, but since all colinear vectors of whatever
+ * scale will generate the same partition for these angular scores, we are free to choose any
+ * scale and ignore the normalization factor.
+ */
+
+ /** Minimum problem size that will result in tasks being split. */
+ private static final int FORK_THRESHOLD = 8192;
+
+ /**
+ * Limits how many incremental updates we do before initiating a full recalculation. Some wasted
+ * work is done when this is exceeded, but more is saved when it is not. Setting this to zero
+ * prevents any incremental updates from being done, instead the centroids are fully recalculated
+ * for each iteration. We're not able to make it very big since too much numerical error
+ * accumulates, which seems to be around 50, thus resulting in suboptimal reordering. It's not
+ * clear how helpful this is though; measurements vary, so it is currently disabled (= 0).
+ */
+ private static final int MAX_CENTROID_UPDATES = 0;
+
+ private final String partitionField;
+
+ /** Constructor. */
+ public BpVectorReorderer(String partitionField) {
+ setMinPartitionSize(DEFAULT_MIN_PARTITION_SIZE);
+ setMaxIters(DEFAULT_MAX_ITERS);
+ // 10% of the available heap size by default
+ setRAMBudgetMB(Runtime.getRuntime().totalMemory() / 1024d / 1024d / 10d);
+ this.partitionField = partitionField;
+ }
+
+ private static class PerThreadState {
+
+ final FloatVectorValues vectors;
+ final float[] leftCentroid;
+ final float[] rightCentroid;
+ final float[] scratch;
+
+ PerThreadState(FloatVectorValues vectors) {
+ try {
+ this.vectors = vectors.copy();
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ leftCentroid = new float[vectors.dimension()];
+ rightCentroid = new float[leftCentroid.length];
+ scratch = new float[leftCentroid.length];
+ }
+ }
+
+ private static class DocMap extends Sorter.DocMap {
+
+ private final int[] newToOld;
+ private final int[] oldToNew;
+
+ public DocMap(int[] newToOld) {
+ this.newToOld = newToOld;
+ oldToNew = new int[newToOld.length];
+ for (int i = 0; i < newToOld.length; ++i) {
+ oldToNew[newToOld[i]] = i;
+ }
+ }
+
+ @Override
+ public int size() {
+ return newToOld.length;
+ }
+
+ @Override
+ public int oldToNew(int docID) {
+ return oldToNew[docID];
+ }
+
+ @Override
+ public int newToOld(int docID) {
+ return newToOld[docID];
+ }
+ }
+
+ private abstract class BaseRecursiveAction extends RecursiveAction {
+
+ protected final TaskExecutor executor;
+ protected final int depth;
+
+ BaseRecursiveAction(TaskExecutor executor, int depth) {
+ this.executor = executor;
+ this.depth = depth;
+ }
+
+ protected final boolean shouldFork(int problemSize, int totalProblemSize) {
+ if (executor == null) {
+ return false;
+ }
+ if (getSurplusQueuedTaskCount() > 3) {
+ // Fork tasks if this worker doesn't have more queued work than other workers
+ // See javadocs of #getSurplusQueuedTaskCount for more details
+ return false;
+ }
+ if (problemSize == totalProblemSize) {
+ // Sometimes fork regardless of the problem size to make sure that unit tests also exercise
+ // forking
+ return true;
+ }
+ return problemSize > FORK_THRESHOLD;
+ }
+ }
+
+ private class ReorderTask extends BaseRecursiveAction {
+
+ private final VectorSimilarityFunction vectorScore;
+ // the ids assigned to this task, a sub-range of all the ids
+ private final IntsRef ids;
+ // the biases for the ids - a number < 0 when the doc goes left and > 0 for right
+ private final float[] biases;
+ private final CloseableThreadLocal threadLocal;
+
+ ReorderTask(
+ IntsRef ids,
+ float[] biases,
+ CloseableThreadLocal threadLocal,
+ TaskExecutor executor,
+ int depth,
+ VectorSimilarityFunction vectorScore) {
+ super(executor, depth);
+ this.ids = ids;
+ this.biases = biases;
+ this.threadLocal = threadLocal;
+ this.vectorScore = vectorScore;
+ }
+
+ @Override
+ protected void compute() {
+ if (depth > 0) {
+ Arrays.sort(ids.ints, ids.offset, ids.offset + ids.length);
+ } else {
+ assert sorted(ids);
+ }
+
+ int halfLength = ids.length >>> 1;
+ if (halfLength < minPartitionSize) {
+ return;
+ }
+
+ // split the ids in half
+ IntsRef left = new IntsRef(ids.ints, ids.offset, halfLength);
+ IntsRef right = new IntsRef(ids.ints, ids.offset + halfLength, ids.length - halfLength);
+
+ PerThreadState state = threadLocal.get();
+ FloatVectorValues vectors = state.vectors;
+ float[] leftCentroid = state.leftCentroid;
+ float[] rightCentroid = state.rightCentroid;
+ float[] scratch = state.scratch;
+
+ try {
+ computeCentroid(left, vectors, leftCentroid, vectorScore);
+ computeCentroid(right, vectors, rightCentroid, vectorScore);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+
+ for (int iter = 0; iter < maxIters; ++iter) {
+ int moved;
+ try {
+ moved = shuffle(vectors, ids, right.offset, leftCentroid, rightCentroid, scratch, biases);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ if (moved == 0) {
+ break;
+ }
+ if (moved > MAX_CENTROID_UPDATES) {
+ // if we swapped too many times we don't use the relative calculation because it
+ // introduces too much error
+ try {
+ computeCentroid(left, vectors, leftCentroid, vectorScore);
+ computeCentroid(right, vectors, rightCentroid, vectorScore);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+ }
+
+ // It is fine for all tasks to share the same docs / biases array since they all work on
+ // different slices of the array at a given point in time.
+ ReorderTask leftTask =
+ new ReorderTask(left, biases, threadLocal, executor, depth + 1, vectorScore);
+ ReorderTask rightTask =
+ new ReorderTask(right, biases, threadLocal, executor, depth + 1, vectorScore);
+
+ if (shouldFork(ids.length, ids.ints.length)) {
+ invokeAll(leftTask, rightTask);
+ } else {
+ leftTask.compute();
+ rightTask.compute();
+ }
+ }
+
+ static void computeCentroid(
+ IntsRef ids,
+ FloatVectorValues vectors,
+ float[] centroid,
+ VectorSimilarityFunction vectorSimilarity)
+ throws IOException {
+ Arrays.fill(centroid, 0);
+ for (int i = ids.offset; i < ids.offset + ids.length; i++) {
+ VectorUtil.add(centroid, vectors.vectorValue(ids.ints[i]));
+ }
+ switch (vectorSimilarity) {
+ case EUCLIDEAN, MAXIMUM_INNER_PRODUCT -> vectorScalarMul(1 / (float) ids.length, centroid);
+ case DOT_PRODUCT, COSINE ->
+ vectorScalarMul(
+ 1 / (float) Math.sqrt(VectorUtil.dotProduct(centroid, centroid)), centroid);
+ }
+ }
+
+ /** Shuffle IDs across both partitions so that each partition is closer to its centroid. */
+ private int shuffle(
+ FloatVectorValues vectors,
+ IntsRef ids,
+ int midPoint,
+ float[] leftCentroid,
+ float[] rightCentroid,
+ float[] scratch,
+ float[] biases)
+ throws IOException {
+
+ /* Computing biases requires a distance calculation for each vector (document) which can be
+ * costly, especially as the vector dimension increases, so we try to parallelize it. We also
+ * have the option of performing incremental updates based on the difference of the previous and
+ * the new centroid, which can be less costly, but introduces incremental numeric error, and
+ * needs tuning to be usable. It is disabled by default (see MAX_CENTROID_UPDATES).
+ */
+ new ComputeBiasTask(
+ ids.ints,
+ biases,
+ ids.offset,
+ ids.offset + ids.length,
+ leftCentroid,
+ rightCentroid,
+ threadLocal,
+ executor,
+ depth,
+ vectorScore)
+ .compute();
+ vectorSubtract(leftCentroid, rightCentroid, scratch);
+ float scale = (float) Math.sqrt(VectorUtil.dotProduct(scratch, scratch));
+ float maxLeftBias = Float.NEGATIVE_INFINITY;
+ for (int i = ids.offset; i < midPoint; ++i) {
+ maxLeftBias = Math.max(maxLeftBias, biases[i]);
+ }
+ float minRightBias = Float.POSITIVE_INFINITY;
+ for (int i = midPoint, end = ids.offset + ids.length; i < end; ++i) {
+ minRightBias = Math.min(minRightBias, biases[i]);
+ }
+ float gain = maxLeftBias - minRightBias;
+ /* This compares the gain of swapping the doc from the left side that is most attracted to the
+ * right and the doc from the right side that is most attracted to the left against the
+ * average vector length (/500) rather than zero. 500 is an arbitrary heuristic value
+ * determined empirically - basically we stop iterating once the centroids move less than
+ * 1/500 the sum of their lengths.
+ *
+ * TODO We could try incorporating simulated annealing by including the iteration number in the
+ * formula? eg { 1000 * gain <= scale * iter }.
+ */
+
+ // System.out.printf("at depth=%d, midPoint=%d, gain=%f\n", depth, midPoint, gain);
+ if (500 * gain <= scale) {
+ return 0;
+ }
+
+ class Selector extends IntroSelector {
+ int count = 0;
+ int pivotDoc;
+ float pivotBias;
+
+ @Override
+ public void setPivot(int i) {
+ pivotDoc = ids.ints[i];
+ pivotBias = biases[i];
+ }
+
+ @Override
+ public int comparePivot(int j) {
+ int cmp = Float.compare(pivotBias, biases[j]);
+ if (cmp == 0) {
+ // Tie break on the ID to preserve ID ordering as much as possible
+ cmp = pivotDoc - ids.ints[j];
+ }
+ return cmp;
+ }
+
+ @Override
+ public void swap(int i, int j) {
+ float tmpBias = biases[i];
+ biases[i] = biases[j];
+ biases[j] = tmpBias;
+
+ if (i < midPoint == j < midPoint) {
+ int tmpDoc = ids.ints[i];
+ ids.ints[i] = ids.ints[j];
+ ids.ints[j] = tmpDoc;
+ } else {
+ // If we're swapping across the left and right sides, we need to keep centroids
+ // up-to-date.
+ count++;
+ int from = Math.min(i, j);
+ int to = Math.max(i, j);
+ try {
+ swapIdsAndCentroids(
+ ids,
+ from,
+ to,
+ midPoint,
+ vectors,
+ leftCentroid,
+ rightCentroid,
+ scratch,
+ count,
+ vectorScore);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+ }
+ }
+
+ Selector selector = new Selector();
+ selector.select(ids.offset, ids.offset + ids.length, midPoint);
+ // System.out.printf("swapped %d / %d\n", selector.count, ids.length);
+ return selector.count;
+ }
+
+ private static boolean centroidValid(
+ float[] centroid, FloatVectorValues vectors, IntsRef ids, int count) throws IOException {
+ // recompute centroid to check the incremental calculation
+ float[] check = new float[centroid.length];
+ computeCentroid(ids, vectors, check, VectorSimilarityFunction.EUCLIDEAN);
+ for (int i = 0; i < check.length; ++i) {
+ float diff = Math.abs(check[i] - centroid[i]);
+ if (diff > 1e-4) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static void swapIdsAndCentroids(
+ IntsRef ids,
+ int from,
+ int to,
+ int midPoint,
+ FloatVectorValues vectors,
+ float[] leftCentroid,
+ float[] rightCentroid,
+ float[] scratch,
+ int count,
+ VectorSimilarityFunction vectorScore)
+ throws IOException {
+ assert from < to;
+
+ int[] idArr = ids.ints;
+ int fromId = idArr[from];
+ int toId = idArr[to];
+
+ // Now update the centroids, this makes things much faster than invalidating it and having to
+ // recompute on the next iteration. Should be faster than full recalculation if the number of
+ // swaps is reasonable.
+
+ // We want the net effect to be moving "from" left->right and "to" right->left
+
+ // (1) scratch = to - from
+ if (count <= MAX_CENTROID_UPDATES && vectorScore == VectorSimilarityFunction.EUCLIDEAN) {
+ int relativeMidpoint = midPoint - ids.offset;
+ vectorSubtract(vectors.vectorValue(toId), vectors.vectorValue(fromId), scratch);
+ // we must normalize to the proper scale by accounting for the number of points contributing
+ // to each centroid
+ // left += scratch / size(left)
+ vectorScalarMul(1 / (float) relativeMidpoint, scratch);
+ VectorUtil.add(leftCentroid, scratch);
+ // right -= scratch / size(right)
+ vectorScalarMul(-relativeMidpoint / (float) (ids.length - relativeMidpoint), scratch);
+ VectorUtil.add(rightCentroid, scratch);
+ }
+
+ idArr[from] = toId;
+ idArr[to] = fromId;
+
+ if (count <= MAX_CENTROID_UPDATES) {
+ assert centroidValid(
+ leftCentroid, vectors, new IntsRef(idArr, ids.offset, midPoint - ids.offset), count);
+ assert centroidValid(
+ rightCentroid,
+ vectors,
+ new IntsRef(ids.ints, midPoint, ids.length - midPoint + ids.offset),
+ count);
+ }
+ }
+ }
+
+ /**
+ * Adds the second argument to the first
+ *
+ * @param u the destination
+ * @param v the vector to add to the destination
+ */
+ static void vectorSubtract(float[] u, float[] v, float[] result) {
+ for (int i = 0; i < u.length; i++) {
+ result[i] = u[i] - v[i];
+ }
+ }
+
+ static void vectorScalarMul(float x, float[] v) {
+ for (int i = 0; i < v.length; i++) {
+ v[i] *= x;
+ }
+ }
+
+ private class ComputeBiasTask extends BaseRecursiveAction {
+
+ private final int[] ids;
+ private final float[] biases;
+ private final int start;
+ private final int end;
+ private final float[] leftCentroid;
+ private final float[] rightCentroid;
+ private final CloseableThreadLocal threadLocal;
+ private final VectorSimilarityFunction vectorScore;
+
+ ComputeBiasTask(
+ int[] ids,
+ float[] biases,
+ int start,
+ int end,
+ float[] leftCentroid,
+ float[] rightCentroid,
+ CloseableThreadLocal threadLocal,
+ TaskExecutor executor,
+ int depth,
+ VectorSimilarityFunction vectorScore) {
+ super(executor, depth);
+ this.ids = ids;
+ this.biases = biases;
+ this.start = start;
+ this.end = end;
+ this.leftCentroid = leftCentroid;
+ this.rightCentroid = rightCentroid;
+ this.threadLocal = threadLocal;
+ this.vectorScore = vectorScore;
+ }
+
+ @Override
+ protected void compute() {
+ final int problemSize = end - start;
+ if (problemSize > 1 && shouldFork(problemSize, ids.length)) {
+ final int mid = (start + end) >>> 1;
+ invokeAll(
+ new ComputeBiasTask(
+ ids,
+ biases,
+ start,
+ mid,
+ leftCentroid,
+ rightCentroid,
+ threadLocal,
+ executor,
+ depth,
+ vectorScore),
+ new ComputeBiasTask(
+ ids,
+ biases,
+ mid,
+ end,
+ leftCentroid,
+ rightCentroid,
+ threadLocal,
+ executor,
+ depth,
+ vectorScore));
+ } else {
+ FloatVectorValues vectors = threadLocal.get().vectors;
+ try {
+ for (int i = start; i < end; ++i) {
+ biases[i] = computeBias(vectors.vectorValue(ids[i]), leftCentroid, rightCentroid);
+ }
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+ }
+
+ /**
+ * Compute a float that is negative when a vector is attracted to the left and positive
+ * otherwise.
+ */
+ private float computeBias(float[] vector, float[] leftCentroid, float[] rightCentroid) {
+ return switch (vectorScore) {
+ case EUCLIDEAN ->
+ VectorUtil.squareDistance(vector, leftCentroid)
+ - VectorUtil.squareDistance(vector, rightCentroid);
+ case MAXIMUM_INNER_PRODUCT, COSINE, DOT_PRODUCT ->
+ VectorUtil.dotProduct(vector, rightCentroid)
+ - VectorUtil.dotProduct(vector, leftCentroid);
+ default -> throw new IllegalStateException("unsupported vector score: " + vectorScore);
+ };
+ }
+ }
+
+ @Override
+ public Sorter.DocMap computeDocMap(CodecReader reader, Directory tempDir, Executor executor)
+ throws IOException {
+ TaskExecutor taskExecutor;
+ if (executor == null) {
+ taskExecutor = null;
+ } else {
+ taskExecutor = new TaskExecutor(executor);
+ }
+ VectorSimilarityFunction vectorScore = checkField(reader, partitionField);
+ if (vectorScore == null) {
+ return null;
+ }
+ FloatVectorValues floats = reader.getFloatVectorValues(partitionField);
+ Sorter.DocMap valueMap = computeValueMap(floats, vectorScore, taskExecutor);
+ return valueMapToDocMap(valueMap, floats, reader.maxDoc());
+ }
+
+ /** Expert: Compute the {@link DocMap} that holds the new vector ordinal numbering. */
+ Sorter.DocMap computeValueMap(
+ FloatVectorValues vectors, VectorSimilarityFunction vectorScore, TaskExecutor executor) {
+ if (docRAMRequirements(vectors.size()) >= ramBudgetMB * 1024 * 1024) {
+ throw new NotEnoughRAMException(
+ "At least "
+ + Math.ceil(docRAMRequirements(vectors.size()) / 1024. / 1024.)
+ + "MB of RAM are required to hold metadata about documents in RAM, but current RAM budget is "
+ + ramBudgetMB
+ + "MB");
+ }
+ return new DocMap(computePermutation(vectors, vectorScore, executor));
+ }
+
+ /**
+ * Compute a permutation of the ID space that maximizes vector score between consecutive postings.
+ */
+ private int[] computePermutation(
+ FloatVectorValues vectors, VectorSimilarityFunction vectorScore, TaskExecutor executor) {
+ final int size = vectors.size();
+ int[] sortedIds = new int[size];
+ for (int i = 0; i < size; ++i) {
+ sortedIds[i] = i;
+ }
+ try (CloseableThreadLocal threadLocal =
+ new CloseableThreadLocal<>() {
+ @Override
+ protected PerThreadState initialValue() {
+ return new PerThreadState(vectors);
+ }
+ }) {
+ IntsRef ids = new IntsRef(sortedIds, 0, sortedIds.length);
+ new ReorderTask(ids, new float[size], threadLocal, executor, 0, vectorScore).compute();
+ }
+ return sortedIds;
+ }
+
+ /** Returns true if, and only if, the given {@link IntsRef} is sorted. */
+ private static boolean sorted(IntsRef intsRef) {
+ for (int i = 1; i < intsRef.length; ++i) {
+ if (intsRef.ints[intsRef.offset + i - 1] > intsRef.ints[intsRef.offset + i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static long docRAMRequirements(int maxDoc) {
+ // We need one int per vector for the doc map, plus one float to store the bias associated with
+ // this vector.
+ return 2L * Integer.BYTES * maxDoc;
+ }
+
+ /**
+ * @param args two args: a path containing an index to reorder. the name of the field the contents
+ * of which to use for reordering
+ */
+ @SuppressWarnings("unused")
+ public static void main(String... args) throws IOException {
+ if (args.length < 2 || args.length > 8) {
+ usage();
+ }
+ String directory = args[0];
+ String field = args[1];
+ BpVectorReorderer reorderer = new BpVectorReorderer(field);
+ int threadCount = Runtime.getRuntime().availableProcessors();
+ try {
+ for (int i = 2; i < args.length; i++) {
+ switch (args[i]) {
+ case "--max-iters" -> reorderer.setMaxIters(Integer.parseInt(args[++i]));
+ case "--min-partition-size" -> reorderer.setMinPartitionSize(Integer.parseInt(args[++i]));
+ case "--thread-count" -> threadCount = Integer.parseInt(args[++i]);
+ default -> throw new IllegalArgumentException("unknown argument: " + args[i]);
+ }
+ }
+ } catch (NumberFormatException | ArrayIndexOutOfBoundsException e) {
+ usage();
+ }
+ Executor executor;
+ if (threadCount != 1) {
+ executor = new ForkJoinPool(threadCount, p -> new ForkJoinWorkerThread(p) {}, null, false);
+ } else {
+ executor = null;
+ }
+ try (Directory dir = FSDirectory.open(Path.of(directory))) {
+ reorderer.reorderIndexDirectory(dir, executor);
+ }
+ }
+
+ void reorderIndexDirectory(Directory directory, Executor executor) throws IOException {
+ try (IndexReader reader = DirectoryReader.open(directory)) {
+ IndexWriterConfig iwc = new IndexWriterConfig();
+ iwc.setOpenMode(IndexWriterConfig.OpenMode.CREATE);
+ try (IndexWriter writer = new IndexWriter(directory, iwc)) {
+ for (LeafReaderContext ctx : reader.leaves()) {
+ CodecReader codecReader = (CodecReader) ctx.reader();
+ writer.addIndexes(
+ SortingCodecReader.wrap(
+ codecReader, computeDocMap(codecReader, null, executor), null));
+ }
+ }
+ }
+ }
+
+ private static VectorSimilarityFunction checkField(CodecReader reader, String field)
+ throws IOException {
+ FieldInfo finfo = reader.getFieldInfos().fieldInfo(field);
+ if (finfo == null) {
+ return null;
+ /*
+ throw new IllegalStateException(
+ "field not found: " + field + " in leaf " + reader.getContext().ord);
+ */
+ }
+ if (finfo.hasVectorValues() == false) {
+ return null;
+ /*
+ throw new IllegalStateException(
+ "field not a vector field: " + field + " in leaf " + reader.getContext().ord);
+ */
+ }
+ if (finfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
+ return null;
+ /*
+ throw new IllegalStateException(
+ "vector field not encoded as float32: " + field + " in leaf " + reader.getContext().ord);
+ */
+ }
+ return finfo.getVectorSimilarityFunction();
+ }
+
+ private static void usage() {
+ throw new IllegalArgumentException(
+ """
+ usage: reorder
+ [--max-iters N]
+ [--min-partition-size P]
+ [--thread-count T]""");
+ }
+
+ private static Sorter.DocMap valueMapToDocMap(
+ Sorter.DocMap valueMap, FloatVectorValues values, int maxDoc) throws IOException {
+ if (maxDoc == values.size()) {
+ return valueMap;
+ }
+ // valueMap maps old/new ords
+ // values maps old docs/old ords
+ // we want old docs/new docs map
+ // sort docs with no value at the end
+ int[] newToOld = new int[maxDoc];
+ int[] oldToNew = new int[newToOld.length];
+ int docid = 0;
+ int ord = 0;
+ int nextNullDoc = values.size();
+ DocIdSetIterator it = values.iterator();
+ for (int nextDoc = it.nextDoc(); nextDoc != NO_MORE_DOCS; nextDoc = it.nextDoc()) {
+ while (docid < nextDoc) {
+ oldToNew[docid] = nextNullDoc;
+ newToOld[nextNullDoc] = docid;
+ ++docid;
+ ++nextNullDoc;
+ }
+ // check me
+ assert docid == nextDoc;
+ int newOrd = valueMap.oldToNew(ord);
+ oldToNew[docid] = newOrd;
+ newToOld[newOrd] = docid;
+ ++ord;
+ ++docid;
+ }
+ while (docid < maxDoc) {
+ oldToNew[docid] = nextNullDoc;
+ newToOld[nextNullDoc] = docid;
+ ++docid;
+ ++nextNullDoc;
+ }
+
+ return new Sorter.DocMap() {
+
+ @Override
+ public int size() {
+ return newToOld.length;
+ }
+
+ @Override
+ public int oldToNew(int docID) {
+ return oldToNew[docID];
+ }
+
+ @Override
+ public int newToOld(int docID) {
+ return newToOld[docID];
+ }
+ };
+ }
+}
diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/IndexReorderer.java b/lucene/misc/src/java/org/apache/lucene/misc/index/IndexReorderer.java
new file mode 100644
index 000000000000..1fdba7a4eb03
--- /dev/null
+++ b/lucene/misc/src/java/org/apache/lucene/misc/index/IndexReorderer.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.misc.index;
+
+import java.io.IOException;
+import java.util.concurrent.Executor;
+import org.apache.lucene.index.CodecReader;
+import org.apache.lucene.index.Sorter;
+import org.apache.lucene.store.Directory;
+
+/** Interface for docid-reordering expected by {@link BPReorderingMergePolicy}. */
+public interface IndexReorderer {
+ /**
+ * Returns a mapping from old to new docids.
+ *
+ * @param reader the reader whose docs are to be reordered
+ * @param tempDir temporary files may be stored here while reordering.
+ * @param executor may be used to parallelize reordering work.
+ */
+ Sorter.DocMap computeDocMap(CodecReader reader, Directory tempDir, Executor executor)
+ throws IOException;
+}
diff --git a/lucene/misc/src/java/org/apache/lucene/misc/store/DirectIODirectory.java b/lucene/misc/src/java/org/apache/lucene/misc/store/DirectIODirectory.java
index ff7ea2341acd..8b5f4fd76b77 100644
--- a/lucene/misc/src/java/org/apache/lucene/misc/store/DirectIODirectory.java
+++ b/lucene/misc/src/java/org/apache/lucene/misc/store/DirectIODirectory.java
@@ -16,6 +16,8 @@
*/
package org.apache.lucene.misc.store;
+import static java.nio.ByteOrder.LITTLE_ENDIAN;
+
import java.io.EOFException;
import java.io.IOException;
import java.io.UncheckedIOException;
@@ -314,7 +316,7 @@ public DirectIOIndexInput(Path path, int blockSize, int bufferSize) throws IOExc
this.blockSize = blockSize;
this.channel = FileChannel.open(path, StandardOpenOption.READ, getDirectOpenOption());
- this.buffer = ByteBuffer.allocateDirect(bufferSize + blockSize - 1).alignedSlice(blockSize);
+ this.buffer = allocateBuffer(bufferSize, blockSize);
isOpen = true;
isClone = false;
@@ -329,7 +331,7 @@ private DirectIOIndexInput(DirectIOIndexInput other) throws IOException {
this.blockSize = other.blockSize;
final int bufferSize = other.buffer.capacity();
- this.buffer = ByteBuffer.allocateDirect(bufferSize + blockSize - 1).alignedSlice(blockSize);
+ this.buffer = allocateBuffer(bufferSize, blockSize);
isOpen = true;
isClone = true;
@@ -338,6 +340,12 @@ private DirectIOIndexInput(DirectIOIndexInput other) throws IOException {
seek(other.getFilePointer());
}
+ private static ByteBuffer allocateBuffer(int bufferSize, int blockSize) {
+ return ByteBuffer.allocateDirect(bufferSize + blockSize - 1)
+ .alignedSlice(blockSize)
+ .order(LITTLE_ENDIAN);
+ }
+
@Override
public void close() throws IOException {
if (isOpen && !isClone) {
@@ -389,6 +397,33 @@ public byte readByte() throws IOException {
return buffer.get();
}
+ @Override
+ public short readShort() throws IOException {
+ if (buffer.remaining() >= Short.BYTES) {
+ return buffer.getShort();
+ } else {
+ return super.readShort();
+ }
+ }
+
+ @Override
+ public int readInt() throws IOException {
+ if (buffer.remaining() >= Integer.BYTES) {
+ return buffer.getInt();
+ } else {
+ return super.readInt();
+ }
+ }
+
+ @Override
+ public long readLong() throws IOException {
+ if (buffer.remaining() >= Long.BYTES) {
+ return buffer.getLong();
+ } else {
+ return super.readLong();
+ }
+ }
+
private void refill(int bytesToRead) throws IOException {
filePos += buffer.capacity();
@@ -428,6 +463,63 @@ public void readBytes(byte[] dst, int offset, int len) throws IOException {
}
}
+ @Override
+ public void readInts(int[] dst, int offset, int len) throws IOException {
+ int remainingDst = len;
+ while (remainingDst > 0) {
+ int cnt = Math.min(buffer.remaining() / Integer.BYTES, remainingDst);
+ buffer.asIntBuffer().get(dst, offset + len - remainingDst, cnt);
+ buffer.position(buffer.position() + Integer.BYTES * cnt);
+ remainingDst -= cnt;
+ if (remainingDst > 0) {
+ if (buffer.hasRemaining()) {
+ dst[offset + len - remainingDst] = readInt();
+ --remainingDst;
+ } else {
+ refill(remainingDst * Integer.BYTES);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void readFloats(float[] dst, int offset, int len) throws IOException {
+ int remainingDst = len;
+ while (remainingDst > 0) {
+ int cnt = Math.min(buffer.remaining() / Float.BYTES, remainingDst);
+ buffer.asFloatBuffer().get(dst, offset + len - remainingDst, cnt);
+ buffer.position(buffer.position() + Float.BYTES * cnt);
+ remainingDst -= cnt;
+ if (remainingDst > 0) {
+ if (buffer.hasRemaining()) {
+ dst[offset + len - remainingDst] = Float.intBitsToFloat(readInt());
+ --remainingDst;
+ } else {
+ refill(remainingDst * Float.BYTES);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void readLongs(long[] dst, int offset, int len) throws IOException {
+ int remainingDst = len;
+ while (remainingDst > 0) {
+ int cnt = Math.min(buffer.remaining() / Long.BYTES, remainingDst);
+ buffer.asLongBuffer().get(dst, offset + len - remainingDst, cnt);
+ buffer.position(buffer.position() + Long.BYTES * cnt);
+ remainingDst -= cnt;
+ if (remainingDst > 0) {
+ if (buffer.hasRemaining()) {
+ dst[offset + len - remainingDst] = readLong();
+ --remainingDst;
+ } else {
+ refill(remainingDst * Long.BYTES);
+ }
+ }
+ }
+ }
+
@Override
public DirectIOIndexInput clone() {
try {
diff --git a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPReorderingMergePolicy.java b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPReorderingMergePolicy.java
index b4c68b30c85a..9d5b193632d2 100644
--- a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPReorderingMergePolicy.java
+++ b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPReorderingMergePolicy.java
@@ -20,6 +20,7 @@
import java.io.UncheckedIOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
+import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
@@ -35,17 +36,32 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.IOUtils;
+import org.junit.Before;
public class TestBPReorderingMergePolicy extends LuceneTestCase {
+ AbstractBPReorderer reorderer;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ if (random().nextBoolean()) {
+ BPIndexReorderer bpIndexReorderer = new BPIndexReorderer();
+ bpIndexReorderer.setMinDocFreq(2);
+ reorderer = bpIndexReorderer;
+ } else {
+ BpVectorReorderer bpVectorReorderer = new BpVectorReorderer("vector");
+ reorderer = bpVectorReorderer;
+ }
+ reorderer.setMinPartitionSize(2);
+ }
+
public void testReorderOnMerge() throws IOException {
Directory dir1 = newDirectory();
Directory dir2 = newDirectory();
IndexWriter w1 =
new IndexWriter(dir1, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()));
- BPIndexReorderer reorderer = new BPIndexReorderer();
- reorderer.setMinDocFreq(2);
- reorderer.setMinPartitionSize(2);
BPReorderingMergePolicy mp = new BPReorderingMergePolicy(newLogMergePolicy(), reorderer);
mp.setMinNaturalMergeNumDocs(2);
IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig().setMergePolicy(mp));
@@ -54,10 +70,14 @@ public void testReorderOnMerge() throws IOException {
doc.add(idField);
StringField bodyField = new StringField("body", "", Store.YES);
doc.add(bodyField);
+ KnnFloatVectorField vectorField = new KnnFloatVectorField("vector", new float[] {0});
+ doc.add(vectorField);
for (int i = 0; i < 10000; ++i) {
idField.setStringValue(Integer.toString(i));
- bodyField.setStringValue(Integer.toString(i % 2 == 0 ? 0 : i % 10));
+ int intValue = i % 2 == 0 ? 0 : i % 10;
+ bodyField.setStringValue(Integer.toString(intValue));
+ vectorField.setVectorValue(new float[] {intValue});
w1.addDocument(doc);
w2.addDocument(doc);
@@ -131,10 +151,15 @@ public void testReorderOnAddIndexes() throws IOException {
doc.add(idField);
StringField bodyField = new StringField("body", "", Store.YES);
doc.add(bodyField);
+ KnnFloatVectorField vectorField = new KnnFloatVectorField("vector", new float[] {0});
+ doc.add(vectorField);
for (int i = 0; i < 10000; ++i) {
idField.setStringValue(Integer.toString(i));
- bodyField.setStringValue(Integer.toString(i % 2 == 0 ? 0 : i % 10));
+ int intValue = i % 2 == 0 ? 0 : i % 10;
+ bodyField.setStringValue(Integer.toString(intValue));
+ vectorField.setVectorValue(new float[] {intValue});
+
w1.addDocument(doc);
if (i % 3 == 0) {
@@ -147,9 +172,6 @@ public void testReorderOnAddIndexes() throws IOException {
}
Directory dir2 = newDirectory();
- BPIndexReorderer reorderer = new BPIndexReorderer();
- reorderer.setMinDocFreq(2);
- reorderer.setMinPartitionSize(2);
BPReorderingMergePolicy mp = new BPReorderingMergePolicy(newLogMergePolicy(), reorderer);
mp.setMinNaturalMergeNumDocs(2);
IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig().setMergePolicy(mp));
@@ -222,9 +244,6 @@ public void testReorderOnAddIndexes() throws IOException {
public void testReorderDoesntHaveEnoughRAM() throws IOException {
// This just makes sure that reordering the index on merge does not corrupt its content
Directory dir = newDirectory();
- BPIndexReorderer reorderer = new BPIndexReorderer();
- reorderer.setMinDocFreq(2);
- reorderer.setMinPartitionSize(2);
reorderer.setRAMBudgetMB(Double.MIN_VALUE);
BPReorderingMergePolicy mp = new BPReorderingMergePolicy(newLogMergePolicy(), reorderer);
mp.setMinNaturalMergeNumDocs(2);
@@ -234,10 +253,14 @@ public void testReorderDoesntHaveEnoughRAM() throws IOException {
doc.add(idField);
StringField bodyField = new StringField("body", "", Store.YES);
doc.add(bodyField);
+ KnnFloatVectorField vectorField = new KnnFloatVectorField("vector", new float[] {0});
+ doc.add(vectorField);
for (int i = 0; i < 10; ++i) {
idField.setStringValue(Integer.toString(i));
- bodyField.setStringValue(Integer.toString(i % 2 == 0 ? 0 : i % 10));
+ int intValue = i % 2 == 0 ? 0 : i % 10;
+ bodyField.setStringValue(Integer.toString(intValue));
+ vectorField.setVectorValue(new float[] {intValue});
w.addDocument(doc);
DirectoryReader.open(w).close();
}
diff --git a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBpVectorReorderer.java b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBpVectorReorderer.java
new file mode 100644
index 000000000000..e4398a76183a
--- /dev/null
+++ b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBpVectorReorderer.java
@@ -0,0 +1,433 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.misc.index;
+
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.ForkJoinWorkerThread;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.lucene101.Lucene101Codec;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.document.StoredField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.StoredFields;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.TaskExecutor;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.VectorUtil;
+
+/** Tests reordering vector values using Binary Partitioning */
+public class TestBpVectorReorderer extends LuceneTestCase {
+
+ public static final String FIELD_NAME = "knn";
+ BpVectorReorderer reorderer;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ reorderer = new BpVectorReorderer(FIELD_NAME);
+ reorderer.setMinPartitionSize(1);
+ reorderer.setMaxIters(10);
+ }
+
+ private void createQuantizedIndex(Directory dir, List vectors) throws IOException {
+ IndexWriterConfig cfg = new IndexWriterConfig();
+ cfg.setCodec(
+ new Lucene101Codec() {
+ @Override
+ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
+ return new Lucene99HnswScalarQuantizedVectorsFormat(8, 32);
+ }
+ });
+ try (IndexWriter writer = new IndexWriter(dir, cfg)) {
+ int i = 0;
+ for (float[] vector : vectors) {
+ Document doc = new Document();
+ doc.add(new KnnFloatVectorField(FIELD_NAME, vector));
+ doc.add(new StoredField("id", i++));
+ writer.addDocument(doc);
+ }
+ }
+ }
+
+ public void testRandom() {
+ List points = new ArrayList<>();
+ // This test may fail for small N; 100 seems big enough for the law of large numbers to make it
+ // work w/very high probability
+ for (int i = 0; i < 100; i++) {
+ points.add(new float[] {random().nextFloat(), random().nextFloat(), random().nextFloat()});
+ }
+ double closestDistanceSum = sumClosestDistances(points);
+ // run one iter so we can see what it did
+ reorderer.setMaxIters(1);
+ Sorter.DocMap map =
+ reorderer.computeValueMap(
+ FloatVectorValues.fromFloats(points, 3), VectorSimilarityFunction.EUCLIDEAN, null);
+ List reordered = new ArrayList<>();
+ for (int i = 0; i < points.size(); i++) {
+ reordered.add(points.get(map.newToOld(i)));
+ }
+ double reorderedClosestDistanceSum = sumClosestDistances(reordered);
+ assertTrue(
+ reorderedClosestDistanceSum + ">" + closestDistanceSum,
+ reorderedClosestDistanceSum <= closestDistanceSum);
+ }
+
+ // Compute the sum of (for each point, the absolute difference between its ordinal and the ordinal
+ // of its closest neighbor in Euclidean space) as a measure of whether the reordering successfully
+ // brought vector-space neighbors closer together in ordinal space.
+ private static double sumClosestDistances(List points) {
+ int sum = 0;
+ for (int i = 0; i < points.size(); i++) {
+ int closest = -1;
+ double closeness = Double.MAX_VALUE;
+ for (int j = 0; j < points.size(); j++) {
+ if (j == i) {
+ continue;
+ }
+ double distance = VectorUtil.squareDistance(points.get(i), points.get(j));
+ if (distance < closeness) {
+ closest = j;
+ closeness = distance;
+ }
+ }
+ sum += Math.abs(closest - i);
+ }
+ return sum;
+ }
+
+ public void testEuclideanLinear() {
+ doTestEuclideanLinear(null);
+ }
+
+ public void testQuantizedIndex() throws Exception {
+ doTestQuantizedIndex(null);
+ }
+
+ public void testEuclideanLinearConcurrent() {
+ int concurrency = random().nextInt(7) + 1;
+ // The default ForkJoinPool implementation uses a thread factory that removes all permissions on
+ // threads, so we need to create our own to avoid tests failing with FS-based directories.
+ ForkJoinPool pool =
+ new ForkJoinPool(
+ concurrency, p -> new ForkJoinWorkerThread(p) {}, null, random().nextBoolean());
+ try {
+ doTestEuclideanLinear(pool);
+ } finally {
+ pool.shutdown();
+ }
+ }
+
+ private void doTestEuclideanLinear(Executor executor) {
+ // a set of 2d points on a line
+ List vectors = randomLinearVectors();
+ List shuffled = shuffleVectors(vectors);
+ TaskExecutor taskExecutor = getTaskExecutor(executor);
+ Sorter.DocMap map =
+ reorderer.computeValueMap(
+ FloatVectorValues.fromFloats(shuffled, 2),
+ VectorSimilarityFunction.EUCLIDEAN,
+ taskExecutor);
+ verifyEuclideanLinear(map, vectors, shuffled);
+ }
+
+ private static TaskExecutor getTaskExecutor(Executor executor) {
+ TaskExecutor taskExecutor;
+ if (executor != null) {
+ taskExecutor = new TaskExecutor(executor);
+ } else {
+ taskExecutor = null;
+ }
+ return taskExecutor;
+ }
+
+ private void doTestQuantizedIndex(Executor executor) throws IOException {
+ // a set of 2d points on a line
+ List vectors = randomLinearVectors();
+ List shuffled = shuffleVectors(vectors);
+ try (Directory dir = newDirectory()) {
+ createQuantizedIndex(dir, shuffled);
+ reorderer.reorderIndexDirectory(dir, executor);
+ int[] newToOld = new int[vectors.size()];
+ int[] oldToNew = new int[vectors.size()];
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ LeafReader leafReader = getOnlyLeafReader(reader);
+ for (int docid = 0; docid < reader.maxDoc(); docid++) {
+ if (leafReader.getLiveDocs() == null || leafReader.getLiveDocs().get(docid)) {
+ int oldid = Integer.parseInt(leafReader.storedFields().document(docid).get("id"));
+ newToOld[docid] = oldid;
+ oldToNew[oldid] = docid;
+ } else {
+ newToOld[docid] = -1;
+ }
+ }
+ }
+ verifyEuclideanLinear(
+ new Sorter.DocMap() {
+ @Override
+ public int oldToNew(int docID) {
+ return oldToNew[docID];
+ }
+
+ @Override
+ public int newToOld(int docID) {
+ return newToOld[docID];
+ }
+
+ @Override
+ public int size() {
+ return newToOld.length;
+ }
+ },
+ vectors,
+ shuffled);
+ }
+ }
+
+ private static List shuffleVectors(List vectors) {
+ List shuffled = new ArrayList<>(vectors);
+ Collections.shuffle(shuffled, random());
+ return shuffled;
+ }
+
+ private static List randomLinearVectors() {
+ int n = random().nextInt(100) + 10;
+ List vectors = new ArrayList<>();
+ float b = random().nextFloat();
+ float m = random().nextFloat();
+ float x = random().nextFloat();
+ for (int i = 0; i < n; i++) {
+ vectors.add(new float[] {x, m * x + b});
+ x += random().nextFloat();
+ }
+ return vectors;
+ }
+
+ private static void verifyEuclideanLinear(
+ Sorter.DocMap map, List vectors, List shuffled) {
+ int count = shuffled.size();
+ assertEquals(count, map.size());
+ float[] midPoint = vectors.get(count / 2);
+ float[] first = shuffled.get(map.newToOld(0));
+ boolean lowFirst = first[0] < midPoint[0];
+ for (int i = 0; i < count; i++) {
+ int oldIndex = map.newToOld(i);
+ assertEquals(i, map.oldToNew(oldIndex));
+ // check the "new" order
+ float[] v = shuffled.get(oldIndex);
+ // first the low vectors, then the high ones, or the other way. Within any given block the
+ // partitioning is kind of arbitrary -
+ // we don't get a global ordering
+ if (i < count / 2 == lowFirst) {
+ assertTrue("out of order at " + i, v[0] <= midPoint[0] && v[1] <= midPoint[1]);
+ } else {
+ assertTrue("out of order at " + i, v[0] >= midPoint[0] && v[1] >= midPoint[1]);
+ }
+ }
+ }
+
+ public void testDotProductCircular() {
+ doTestDotProductCircular(null);
+ }
+
+ public void testDotProductConcurrent() {
+ int concurrency = random().nextInt(7) + 1;
+ // The default ForkJoinPool implementation uses a thread factory that removes all permissions on
+ // threads, so we need to create our own to avoid tests failing with FS-based directories.
+ ForkJoinPool pool =
+ new ForkJoinPool(
+ concurrency, p -> new ForkJoinWorkerThread(p) {}, null, random().nextBoolean());
+ try {
+ doTestDotProductCircular(new TaskExecutor(pool));
+ } finally {
+ pool.shutdown();
+ }
+ }
+
+ public void doTestDotProductCircular(TaskExecutor executor) {
+ // a set of 2d points on a line
+ int n = random().nextInt(100) + 10;
+ List vectors = new ArrayList<>();
+ double t = random().nextDouble();
+ for (int i = 0; i < n; i++) {
+ vectors.add(new float[] {(float) Math.cos(t), (float) Math.sin(t)});
+ t += random().nextDouble();
+ }
+ Sorter.DocMap map =
+ reorderer.computeValueMap(
+ FloatVectorValues.fromFloats(vectors, 2),
+ VectorSimilarityFunction.DOT_PRODUCT,
+ executor);
+ assertEquals(n, map.size());
+ double t0min = 2 * Math.PI, t0max = 0;
+ double t1min = 2 * Math.PI, t1max = 0;
+ // find the range of the lower half and the range of the upper half
+ // they should be non-overlapping
+ for (int i = 0; i < n; i++) {
+ int oldIndex = map.newToOld(i);
+ assertEquals(i, map.oldToNew(oldIndex));
+ // check the "new" order
+ float[] v = vectors.get(oldIndex);
+ t = angle2pi(Math.atan2(v[1], v[0]));
+ if (i < n / 2) {
+ t0min = Math.min(t0min, t);
+ t0max = Math.max(t0max, t);
+ } else {
+ t1min = Math.min(t1min, t);
+ t1max = Math.max(t1max, t);
+ }
+ }
+ assertTrue(
+ "ranges overlap",
+ (angularDifference(t0min, t0max) < angularDifference(t0min, t1min)
+ && angularDifference(t0min, t0max) < angularDifference(t0min, t1max))
+ || (angularDifference(t1min, t1max) < angularDifference(t1min, t0min)
+ && angularDifference(t1min, t1max) < angularDifference(t1min, t0max)));
+ }
+
+ public void testIndexReorderDense() throws Exception {
+ List vectors = shuffleVectors(randomLinearVectors());
+ // compute the expected ordering
+ Sorter.DocMap expected =
+ reorderer.computeValueMap(
+ FloatVectorValues.fromFloats(vectors, 2), VectorSimilarityFunction.EUCLIDEAN, null);
+ Path tmpdir = createTempDir();
+ try (Directory dir = newFSDirectory(tmpdir)) {
+ // create an index with a single leaf
+ try (IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig())) {
+ int id = 0;
+ for (float[] vector : vectors) {
+ Document doc = new Document();
+ doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
+ doc.add(new StoredField("id", id++));
+ writer.addDocument(doc);
+ }
+ writer.forceMerge(1);
+ }
+ int threadCount = random().nextInt(4) + 1;
+ threadCount = 1;
+ // reorder using the index reordering tool
+ BpVectorReorderer.main(
+ tmpdir.toString(),
+ "f",
+ "--min-partition-size",
+ "1",
+ "--max-iters",
+ "10",
+ "--thread-count",
+ Integer.toString(threadCount));
+ // verify the ordering is the same
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ LeafReader leafReader = getOnlyLeafReader(reader);
+ FloatVectorValues values = leafReader.getFloatVectorValues("f");
+ int newId = 0;
+ StoredFields storedFields = reader.storedFields();
+ KnnVectorValues.DocIndexIterator it = values.iterator();
+ while (it.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+ int storedId = Integer.parseInt(storedFields.document(it.docID()).get("id"));
+ assertEquals(expected.oldToNew(storedId), newId);
+ float[] expectedVector = vectors.get(expected.newToOld(it.docID()));
+ float[] actualVector = values.vectorValue(it.index());
+ assertArrayEquals(
+ "values differ at index " + storedId + "->" + newId + " docid=" + it.docID(),
+ expectedVector,
+ actualVector,
+ 0);
+ newId++;
+ }
+ }
+ }
+ }
+
+ public void testIndexReorderSparse() throws Exception {
+ List vectors = shuffleVectors(randomLinearVectors());
+ // compute the expected ordering
+ Sorter.DocMap expected =
+ reorderer.computeValueMap(
+ FloatVectorValues.fromFloats(vectors, 2), VectorSimilarityFunction.EUCLIDEAN, null);
+ Path tmpdir = createTempDir();
+ int maxDoc = 0;
+ try (Directory dir = newFSDirectory(tmpdir)) {
+ // create an index with a single leaf
+ try (IndexWriter writer = new IndexWriter(dir, new IndexWriterConfig())) {
+ for (float[] vector : vectors) {
+ Document doc = new Document();
+ if (random().nextBoolean()) {
+ for (int i = 0; i < random().nextInt(3); i++) {
+ // insert some gaps -- docs with no vectors
+ writer.addDocument(doc);
+ maxDoc++;
+ }
+ }
+ doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN));
+ writer.addDocument(doc);
+ maxDoc++;
+ }
+ }
+ // reorder using the index reordering tool
+ BpVectorReorderer.main(
+ tmpdir.toString(), "f", "--min-partition-size", "1", "--max-iters", "10");
+ // verify the ordering is the same
+ try (IndexReader reader = DirectoryReader.open(dir)) {
+ LeafReader leafReader = getOnlyLeafReader(reader);
+ assertEquals(maxDoc, leafReader.maxDoc());
+ FloatVectorValues values = leafReader.getFloatVectorValues("f");
+ int lastDocID = 0;
+ KnnVectorValues.DocIndexIterator it = values.iterator();
+ while (it.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+ lastDocID = it.docID();
+ float[] expectedVector = vectors.get(expected.newToOld(lastDocID));
+ float[] actualVector = values.vectorValue(it.index());
+ assertArrayEquals(expectedVector, actualVector, 0);
+ }
+ // docs with no vectors sort at the end
+ assertEquals(vectors.size() - 1, lastDocID);
+ }
+ }
+ }
+
+ static double angularDifference(double a, double b) {
+ return angle2pi(b - a);
+ }
+
+ static double angle2pi(double a) {
+ while (a > 2 * Math.PI) {
+ a -= 2 * Math.PI;
+ }
+ while (a < 0) {
+ a += 2 * Math.PI;
+ }
+ return a;
+ }
+}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
index bd2911fc50a8..22fa438b04b4 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java
@@ -223,7 +223,11 @@ public void close() throws IOException {
@Override
public HnswGraph getGraph(String field) throws IOException {
- return ((HnswGraphProvider) delegate).getGraph(field);
+ if (delegate instanceof HnswGraphProvider) {
+ return ((HnswGraphProvider) delegate).getGraph(field);
+ } else {
+ return null;
+ }
}
}
}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingLiveDocsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingLiveDocsFormat.java
index f45ea821a555..e2152e45aa58 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingLiveDocsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingLiveDocsFormat.java
@@ -24,6 +24,7 @@
import org.apache.lucene.store.IOContext;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.FixedBitSet;
/** Just like the default live docs format but with additional asserts. */
public class AssertingLiveDocsFormat extends LiveDocsFormat {
@@ -88,6 +89,12 @@ public int length() {
return in.length();
}
+ @Override
+ public void applyMask(FixedBitSet bitSet, int offset) {
+ assert offset >= 0;
+ in.applyMask(bitSet, offset);
+ }
+
@Override
public String toString() {
return "Asserting(" + in + ")";
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/RandomPostingsTester.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/RandomPostingsTester.java
index 64ca9cca35f0..15c9c324732c 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/RandomPostingsTester.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/RandomPostingsTester.java
@@ -1388,10 +1388,6 @@ private void verifyEnum(
PostingsEnum pe2 = termsEnum.postings(null, flags);
FixedBitSet set1 = new FixedBitSet(1024);
FixedBitSet set2 = new FixedBitSet(1024);
- FixedBitSet acceptDocs = new FixedBitSet(maxDoc);
- for (int i = 0; i < maxDoc; i += 2) {
- acceptDocs.set(i);
- }
while (true) {
pe1.nextDoc();
@@ -1400,11 +1396,9 @@ private void verifyEnum(
int offset =
TestUtil.nextInt(random, Math.max(0, pe1.docID() - set1.length()), pe1.docID());
int upTo = offset + random.nextInt(set1.length());
- pe1.intoBitSet(acceptDocs, upTo, set1, offset);
+ pe1.intoBitSet(upTo, set1, offset);
for (int d = pe2.docID(); d < upTo; d = pe2.nextDoc()) {
- if (acceptDocs.get(d)) {
- set2.set(d - offset);
- }
+ set2.set(d - offset);
}
assertEquals(set1, set2);
diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java
index 7200d4b5f4dc..9717f738e82e 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java
@@ -24,7 +24,6 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
-import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
/** Wraps a Scorer with additional checks */
@@ -196,11 +195,10 @@ public long cost() {
}
@Override
- public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
- throws IOException {
+ public void intoBitSet(int upTo, FixedBitSet bitSet, int offset) throws IOException {
assert docID() != -1;
assert offset <= docID();
- in.intoBitSet(acceptDocs, upTo, bitSet, offset);
+ in.intoBitSet(upTo, bitSet, offset);
assert docID() >= upTo;
}
};