diff --git a/build.gradle b/build.gradle
index ba77f2ef6d5c..0eae57bd10f3 100644
--- a/build.gradle
+++ b/build.gradle
@@ -36,7 +36,7 @@ apply from: file('gradle/globals.gradle')
// Calculate project version:
version = {
// Release manager: update base version here after release:
- String baseVersion = '9.6.0'
+ String baseVersion = '9.7.0'
// On a release explicitly set release version in one go:
// -Dversion.release=x.y.z
@@ -119,7 +119,7 @@ apply from: file('gradle/ide/eclipse.gradle')
// (java, tests)
apply from: file('gradle/java/folder-layout.gradle')
apply from: file('gradle/java/javac.gradle')
-apply from: file('gradle/java/memorysegment-mrjar.gradle')
+apply from: file('gradle/java/core-mrjar.gradle')
apply from: file('gradle/testing/defaults-tests.gradle')
apply from: file('gradle/testing/randomization.gradle')
apply from: file('gradle/testing/fail-on-no-tests.gradle')
@@ -158,7 +158,7 @@ apply from: file('gradle/generation/javacc.gradle')
apply from: file('gradle/generation/forUtil.gradle')
apply from: file('gradle/generation/antlr.gradle')
apply from: file('gradle/generation/unicode-test-classes.gradle')
-apply from: file('gradle/generation/panama-foreign.gradle')
+apply from: file('gradle/generation/extract-jdk-apis.gradle')
apply from: file('gradle/datasets/external-datasets.gradle')
diff --git a/buildSrc/scriptDepVersions.gradle b/buildSrc/scriptDepVersions.gradle
index 8751da632492..a6eae860b1c0 100644
--- a/buildSrc/scriptDepVersions.gradle
+++ b/buildSrc/scriptDepVersions.gradle
@@ -22,7 +22,7 @@
ext {
scriptDepVersions = [
"apache-rat": "0.14",
- "asm": "9.4",
+ "asm": "9.5",
"commons-codec": "1.13",
"ecj": "3.30.0",
"flexmark": "0.61.24",
diff --git a/dev-tools/doap/lucene.rdf b/dev-tools/doap/lucene.rdf
index 0c749ad54d44..c0a09528cef2 100644
--- a/dev-tools/doap/lucene.rdf
+++ b/dev-tools/doap/lucene.rdf
@@ -67,6 +67,13 @@
+
+
+ lucene-9.6.0
+ 2023-05-09
+ 9.6.0
+
+ lucene-9.5.0
@@ -74,7 +81,6 @@
9.5.0
-
lucene-9.4.2
diff --git a/gradle/generation/panama-foreign.gradle b/gradle/generation/extract-jdk-apis.gradle
similarity index 73%
rename from gradle/generation/panama-foreign.gradle
rename to gradle/generation/extract-jdk-apis.gradle
index 694c4656e2f6..78e74aa261a3 100644
--- a/gradle/generation/panama-foreign.gradle
+++ b/gradle/generation/extract-jdk-apis.gradle
@@ -17,10 +17,17 @@
def resources = scriptResources(buildscript)
+configure(rootProject) {
+ ext {
+ // also change this in extractor tool: ExtractForeignAPI
+ vectorIncubatorJavaVersions = [ JavaVersion.VERSION_20, JavaVersion.VERSION_21 ] as Set
+ }
+}
+
configure(project(":lucene:core")) {
ext {
apijars = file('src/generated/jdk');
- panamaJavaVersions = [ 19, 20 ]
+ mrjarJavaVersions = [ 19, 20, 21 ]
}
configurations {
@@ -31,9 +38,9 @@ configure(project(":lucene:core")) {
apiextractor "org.ow2.asm:asm:${scriptDepVersions['asm']}"
}
- for (jdkVersion : panamaJavaVersions) {
- def task = tasks.create(name: "generatePanamaForeignApiJar${jdkVersion}", type: JavaExec) {
- description "Regenerate the API-only JAR file with public Panama Foreign API from JDK ${jdkVersion}"
+ for (jdkVersion : mrjarJavaVersions) {
+ def task = tasks.create(name: "generateJdkApiJar${jdkVersion}", type: JavaExec) {
+ description "Regenerate the API-only JAR file with public Panama Foreign & Vector API from JDK ${jdkVersion}"
group "generation"
javaLauncher = javaToolchains.launcherFor {
@@ -45,18 +52,21 @@ configure(project(":lucene:core")) {
javaLauncher.get()
return true
} catch (Exception e) {
- logger.warn('Launcher for Java {} is not available; skipping regeneration of Panama Foreign API JAR.', jdkVersion)
+ logger.warn('Launcher for Java {} is not available; skipping regeneration of Panama Foreign & Vector API JAR.', jdkVersion)
logger.warn('Error: {}', e.cause?.message)
logger.warn("Please make sure to point env 'JAVA{}_HOME' to exactly JDK version {} or enable Gradle toolchain auto-download.", jdkVersion, jdkVersion)
return false
}
}
-
+
classpath = configurations.apiextractor
- mainClass = file("${resources}/ExtractForeignAPI.java") as String
+ mainClass = file("${resources}/ExtractJdkApis.java") as String
+ systemProperties = [
+ 'user.timezone': 'UTC'
+ ]
args = [
jdkVersion,
- new File(apijars, "panama-foreign-jdk${jdkVersion}.apijar"),
+ new File(apijars, "jdk${jdkVersion}.apijar"),
]
}
diff --git a/gradle/generation/extract-jdk-apis/ExtractJdkApis.java b/gradle/generation/extract-jdk-apis/ExtractJdkApis.java
new file mode 100644
index 000000000000..7dfb9edfe2a0
--- /dev/null
+++ b/gradle/generation/extract-jdk-apis/ExtractJdkApis.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import java.io.IOException;
+import java.net.URI;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.PathMatcher;
+import java.nio.file.Paths;
+import java.nio.file.attribute.FileTime;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipOutputStream;
+
+import org.objectweb.asm.AnnotationVisitor;
+import org.objectweb.asm.ClassReader;
+import org.objectweb.asm.ClassVisitor;
+import org.objectweb.asm.ClassWriter;
+import org.objectweb.asm.FieldVisitor;
+import org.objectweb.asm.MethodVisitor;
+import org.objectweb.asm.Opcodes;
+import org.objectweb.asm.Type;
+
+public final class ExtractJdkApis {
+
+ private static final FileTime FIXED_FILEDATE = FileTime.from(Instant.parse("2022-01-01T00:00:00Z"));
+
+ private static final String PATTERN_PANAMA_FOREIGN = "java.base/java/{lang/foreign/*,nio/channels/FileChannel,util/Objects}";
+ private static final String PATTERN_VECTOR_INCUBATOR = "jdk.incubator.vector/jdk/incubator/vector/*";
+ private static final String PATTERN_VECTOR_VM_INTERNALS = "java.base/jdk/internal/vm/vector/VectorSupport{,$Vector,$VectorMask,$VectorPayload,$VectorShuffle}";
+
+ static final Map> CLASSFILE_PATTERNS = Map.of(
+ 19, List.of(PATTERN_PANAMA_FOREIGN),
+ 20, List.of(PATTERN_PANAMA_FOREIGN, PATTERN_VECTOR_VM_INTERNALS, PATTERN_VECTOR_INCUBATOR),
+ 21, List.of(PATTERN_PANAMA_FOREIGN)
+ );
+
+ public static void main(String... args) throws IOException {
+ if (args.length != 2) {
+ throw new IllegalArgumentException("Need two parameters: java version, output file");
+ }
+ Integer jdk = Integer.valueOf(args[0]);
+ if (jdk.intValue() != Runtime.version().feature()) {
+ throw new IllegalStateException("Incorrect java version: " + Runtime.version().feature());
+ }
+ if (!CLASSFILE_PATTERNS.containsKey(jdk)) {
+ throw new IllegalArgumentException("No support to extract stubs from java version: " + jdk);
+ }
+ var outputPath = Paths.get(args[1]);
+
+ // create JRT filesystem and build a combined FileMatcher:
+ var jrtPath = Paths.get(URI.create("jrt:/")).toRealPath();
+ var patterns = CLASSFILE_PATTERNS.get(jdk).stream()
+ .map(pattern -> jrtPath.getFileSystem().getPathMatcher("glob:" + pattern + ".class"))
+ .toArray(PathMatcher[]::new);
+ PathMatcher pattern = p -> Arrays.stream(patterns).anyMatch(matcher -> matcher.matches(p));
+
+ // Collect all files to process:
+ final List filesToExtract;
+ try (var stream = Files.walk(jrtPath)) {
+ filesToExtract = stream.filter(p -> pattern.matches(jrtPath.relativize(p))).collect(Collectors.toList());
+ }
+
+ // Process all class files:
+ try (var out = new ZipOutputStream(Files.newOutputStream(outputPath))) {
+ process(filesToExtract, out);
+ }
+ }
+
+ private static void process(List filesToExtract, ZipOutputStream out) throws IOException {
+ var classesToInclude = new HashSet();
+ var references = new HashMap();
+ var processed = new TreeMap();
+ System.out.println("Transforming " + filesToExtract.size() + " class files...");
+ for (Path p : filesToExtract) {
+ try (var in = Files.newInputStream(p)) {
+ var reader = new ClassReader(in);
+ var cw = new ClassWriter(0);
+ var cleaner = new Cleaner(cw, classesToInclude, references);
+ reader.accept(cleaner, ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
+ processed.put(reader.getClassName(), cw.toByteArray());
+ }
+ }
+ // recursively add all superclasses / interfaces of visible classes to classesToInclude:
+ for (Set a = classesToInclude; !a.isEmpty();) {
+ a = a.stream().map(references::get).filter(Objects::nonNull).flatMap(Arrays::stream).collect(Collectors.toSet());
+ classesToInclude.addAll(a);
+ }
+ // remove all non-visible or not referenced classes:
+ processed.keySet().removeIf(Predicate.not(classesToInclude::contains));
+ System.out.println("Writing " + processed.size() + " visible classes...");
+ for (var cls : processed.entrySet()) {
+ String cn = cls.getKey();
+ System.out.println("Writing stub for class: " + cn);
+ out.putNextEntry(new ZipEntry(cn.concat(".class")).setLastModifiedTime(FIXED_FILEDATE));
+ out.write(cls.getValue());
+ out.closeEntry();
+ }
+ classesToInclude.removeIf(processed.keySet()::contains);
+ System.out.println("Referenced classes not included: " + classesToInclude);
+ }
+
+ static boolean isVisible(int access) {
+ return (access & (Opcodes.ACC_PROTECTED | Opcodes.ACC_PUBLIC)) != 0;
+ }
+
+ static class Cleaner extends ClassVisitor {
+ private static final String PREVIEW_ANN = "jdk/internal/javac/PreviewFeature";
+ private static final String PREVIEW_ANN_DESCR = Type.getObjectType(PREVIEW_ANN).getDescriptor();
+
+ private final Set classesToInclude;
+ private final Map references;
+
+ Cleaner(ClassWriter out, Set classesToInclude, Map references) {
+ super(Opcodes.ASM9, out);
+ this.classesToInclude = classesToInclude;
+ this.references = references;
+ }
+
+ @Override
+ public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
+ super.visit(Opcodes.V11, access, name, signature, superName, interfaces);
+ if (isVisible(access)) {
+ classesToInclude.add(name);
+ }
+ references.put(name, Stream.concat(Stream.of(superName), Arrays.stream(interfaces)).toArray(String[]::new));
+ }
+
+ @Override
+ public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
+ return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
+ }
+
+ @Override
+ public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
+ if (!isVisible(access)) {
+ return null;
+ }
+ return new FieldVisitor(Opcodes.ASM9, super.visitField(access, name, descriptor, signature, value)) {
+ @Override
+ public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
+ return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
+ }
+ };
+ }
+
+ @Override
+ public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
+ if (!isVisible(access)) {
+ return null;
+ }
+ return new MethodVisitor(Opcodes.ASM9, super.visitMethod(access, name, descriptor, signature, exceptions)) {
+ @Override
+ public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
+ return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
+ }
+ };
+ }
+
+ @Override
+ public void visitInnerClass(String name, String outerName, String innerName, int access) {
+ if (!Objects.equals(outerName, PREVIEW_ANN)) {
+ super.visitInnerClass(name, outerName, innerName, access);
+ }
+ }
+
+ @Override
+ public void visitPermittedSubclass(String c) {
+ }
+
+ }
+
+}
diff --git a/gradle/generation/panama-foreign/ExtractForeignAPI.java b/gradle/generation/panama-foreign/ExtractForeignAPI.java
deleted file mode 100644
index 44253ea0122b..000000000000
--- a/gradle/generation/panama-foreign/ExtractForeignAPI.java
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-import java.io.IOException;
-import java.net.URI;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.Paths;
-import java.nio.file.attribute.FileTime;
-import java.time.Instant;
-import java.util.Objects;
-import java.util.stream.Collectors;
-import java.util.zip.ZipEntry;
-import java.util.zip.ZipOutputStream;
-
-import org.objectweb.asm.AnnotationVisitor;
-import org.objectweb.asm.ClassReader;
-import org.objectweb.asm.ClassVisitor;
-import org.objectweb.asm.ClassWriter;
-import org.objectweb.asm.FieldVisitor;
-import org.objectweb.asm.MethodVisitor;
-import org.objectweb.asm.Opcodes;
-import org.objectweb.asm.Type;
-
-public final class ExtractForeignAPI {
-
- private static final FileTime FIXED_FILEDATE = FileTime.from(Instant.parse("2022-01-01T00:00:00Z"));
-
- public static void main(String... args) throws IOException {
- if (args.length != 2) {
- throw new IllegalArgumentException("Need two parameters: java version, output file");
- }
- if (Integer.parseInt(args[0]) != Runtime.version().feature()) {
- throw new IllegalStateException("Incorrect java version: " + Runtime.version().feature());
- }
- var outputPath = Paths.get(args[1]);
- var javaBaseModule = Paths.get(URI.create("jrt:/")).resolve("java.base").toRealPath();
- var fileMatcher = javaBaseModule.getFileSystem().getPathMatcher("glob:java/{lang/foreign/*,nio/channels/FileChannel,util/Objects}.class");
- try (var out = new ZipOutputStream(Files.newOutputStream(outputPath)); var stream = Files.walk(javaBaseModule)) {
- var filesToExtract = stream.map(javaBaseModule::relativize).filter(fileMatcher::matches).sorted().collect(Collectors.toList());
- for (Path relative : filesToExtract) {
- System.out.println("Processing class file: " + relative);
- try (var in = Files.newInputStream(javaBaseModule.resolve(relative))) {
- final var reader = new ClassReader(in);
- final var cw = new ClassWriter(0);
- reader.accept(new Cleaner(cw), ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
- out.putNextEntry(new ZipEntry(relative.toString()).setLastModifiedTime(FIXED_FILEDATE));
- out.write(cw.toByteArray());
- out.closeEntry();
- }
- }
- }
- }
-
- static class Cleaner extends ClassVisitor {
- private static final String PREVIEW_ANN = "jdk/internal/javac/PreviewFeature";
- private static final String PREVIEW_ANN_DESCR = Type.getObjectType(PREVIEW_ANN).getDescriptor();
-
- private boolean completelyHidden = false;
-
- Cleaner(ClassWriter out) {
- super(Opcodes.ASM9, out);
- }
-
- private boolean isHidden(int access) {
- return completelyHidden || (access & (Opcodes.ACC_PROTECTED | Opcodes.ACC_PUBLIC)) == 0;
- }
-
- @Override
- public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
- super.visit(Opcodes.V11, access, name, signature, superName, interfaces);
- completelyHidden = isHidden(access);
- }
-
- @Override
- public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
- return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
- }
-
- @Override
- public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
- if (isHidden(access)) {
- return null;
- }
- return new FieldVisitor(Opcodes.ASM9, super.visitField(access, name, descriptor, signature, value)) {
- @Override
- public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
- return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
- }
- };
- }
-
- @Override
- public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
- if (isHidden(access)) {
- return null;
- }
- return new MethodVisitor(Opcodes.ASM9, super.visitMethod(access, name, descriptor, signature, exceptions)) {
- @Override
- public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
- return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible);
- }
- };
- }
-
- @Override
- public void visitInnerClass(String name, String outerName, String innerName, int access) {
- if (!Objects.equals(outerName, PREVIEW_ANN)) {
- super.visitInnerClass(name, outerName, innerName, access);
- }
- }
-
- @Override
- public void visitPermittedSubclass(String c) {
- }
-
- }
-
-}
diff --git a/gradle/java/memorysegment-mrjar.gradle b/gradle/java/core-mrjar.gradle
similarity index 82%
rename from gradle/java/memorysegment-mrjar.gradle
rename to gradle/java/core-mrjar.gradle
index 137f8a3c567d..5715e782f000 100644
--- a/gradle/java/memorysegment-mrjar.gradle
+++ b/gradle/java/core-mrjar.gradle
@@ -15,11 +15,11 @@
* limitations under the License.
*/
-// Produce an MR-JAR with Java 19+ MemorySegment implementation for MMapDirectory
+// Produce an MR-JAR with Java 19+ foreign and vector implementations
configure(project(":lucene:core")) {
plugins.withType(JavaPlugin) {
- for (jdkVersion : panamaJavaVersions) {
+ for (jdkVersion : mrjarJavaVersions) {
sourceSets.create("main${jdkVersion}") {
java {
srcDirs = ["src/java${jdkVersion}"]
@@ -29,7 +29,7 @@ configure(project(":lucene:core")) {
dependencies.add("main${jdkVersion}Implementation", sourceSets.main.output)
tasks.named("compileMain${jdkVersion}Java").configure {
- def apijar = new File(apijars, "panama-foreign-jdk${jdkVersion}.apijar")
+ def apijar = new File(apijars, "jdk${jdkVersion}.apijar")
inputs.file(apijar)
@@ -40,12 +40,14 @@ configure(project(":lucene:core")) {
"-Xlint:-options",
"--patch-module", "java.base=${apijar}",
"--add-exports", "java.base/java.lang.foreign=ALL-UNNAMED",
+ // for compilation we patch the incubator packages into java.base, this has no effect on resulting class files:
+ "--add-exports", "java.base/jdk.incubator.vector=ALL-UNNAMED",
]
}
}
tasks.named('jar').configure {
- for (jdkVersion : panamaJavaVersions) {
+ for (jdkVersion : mrjarJavaVersions) {
into("META-INF/versions/${jdkVersion}") {
from sourceSets["main${jdkVersion}"].output
}
diff --git a/gradle/template.gradle.properties b/gradle/template.gradle.properties
index a626d39f3bb5..9ac8c42e9dd4 100644
--- a/gradle/template.gradle.properties
+++ b/gradle/template.gradle.properties
@@ -102,5 +102,5 @@ tests.jvms=@TEST_JVMS@
org.gradle.java.installations.auto-download=true
# Set these to enable automatic JVM location discovery.
-org.gradle.java.installations.fromEnv=JAVA17_HOME,JAVA19_HOME,JAVA20_HOME,JAVA21_HOME,RUNTIME_JAVA_HOME
+org.gradle.java.installations.fromEnv=JAVA17_HOME,JAVA19_HOME,JAVA20_HOME,JAVA21_HOME,JAVA22_HOME,RUNTIME_JAVA_HOME
#org.gradle.java.installations.paths=(custom paths)
diff --git a/gradle/testing/defaults-tests.gradle b/gradle/testing/defaults-tests.gradle
index 9f50cda8ca79..f7a348f0b66c 100644
--- a/gradle/testing/defaults-tests.gradle
+++ b/gradle/testing/defaults-tests.gradle
@@ -47,7 +47,7 @@ allprojects {
description: "Number of forked test JVMs"],
[propName: 'tests.haltonfailure', value: true, description: "Halt processing on test failure."],
[propName: 'tests.jvmargs',
- value: { -> propertyOrEnvOrDefault("tests.jvmargs", "TEST_JVM_ARGS", "-XX:TieredStopAtLevel=1 -XX:+UseParallelGC -XX:ActiveProcessorCount=1") },
+ value: { -> propertyOrEnvOrDefault("tests.jvmargs", "TEST_JVM_ARGS", isCIBuild ? "" : "-XX:TieredStopAtLevel=1 -XX:+UseParallelGC -XX:ActiveProcessorCount=1") },
description: "Arguments passed to each forked JVM."],
// Other settings.
[propName: 'tests.neverUpToDate', value: true,
@@ -119,11 +119,16 @@ allprojects {
if (rootProject.runtimeJavaVersion < JavaVersion.VERSION_16) {
jvmArgs '--illegal-access=deny'
}
-
+
// Lucene needs to optional modules at runtime, which we want to enforce for testing
// (if the runner JVM does not support them, it will fail tests):
jvmArgs '--add-modules', 'jdk.unsupported,jdk.management'
+ // Enable the vector incubator module on supported Java versions:
+ if (rootProject.vectorIncubatorJavaVersions.contains(rootProject.runtimeJavaVersion)) {
+ jvmArgs '--add-modules', 'jdk.incubator.vector'
+ }
+
def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties")
def tempDir = layout.projectDirectory.dir(testsTmpDir.toString())
jvmArgumentProviders.add(
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 322e8c2acae8..341101cf2a95 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -3,6 +3,106 @@ Lucene Change Log
For more information on past and future Lucene versions, please see:
http://s.apache.org/luceneversions
+======================== Lucene 9.7.0 =======================
+
+API Changes
+---------------------
+
+* GITHUB#11840, GITHUB#12304: Query rewrite now takes an IndexSearcher instead of
+ IndexReader to enable concurrent rewriting. Please note: This is implemented in
+ a backwards compatible way. A query overriding any of both rewrite methods is
+ supported. To implement this backwards layer in Lucene 9.x the
+ RuntimePermission "accessDeclaredMembers" is needed in applications using
+ SecurityManager. (Patrick Zhai, Ben Trent, Uwe Schindler)
+
+* GITHUB#12321: DaciukMihovAutomatonBuilder has been marked deprecated in preparation of reducing its visibility in
+ a future release. (Greg Miller)
+
+* GITHUB#12268: Add BitSet.clear() without parameters for clearing the entire set
+ (Jonathan Ellis)
+
+* GITHUB#12346: add new IndexWriter#updateDocuments(Query, Iterable) API
+ to update documents atomically, with respect to refresh and commit using a query. (Patrick Zhai)
+
+New Features
+---------------------
+
+* GITHUB#12257: Create OnHeapHnswGraphSearcher to let OnHeapHnswGraph to be searched in a thread-safety manner. (Patrick Zhai)
+
+* GITHUB#12302, GITHUB#12311, GITHUB#12363: Add vectorized implementations of VectorUtil.dotProduct(),
+ squareDistance(), cosine() with Java 20 or 21 jdk.incubator.vector APIs. Applications started
+ with command line parameter "java --add-modules jdk.incubator.vector" on exactly Java 20 or 21
+ will automatically use the new vectorized implementations if running on a supported platform
+ (x86 AVX2 or later, ARM NEON). This is an opt-in feature and requires explicit Java
+ command line flag! When enabled, Lucene logs a notice using java.util.logging. Please test
+ thoroughly and report bugs/slowness to Lucene's mailing list.
+ (Chris Hegarty, Robert Muir, Uwe Schindler)
+
+* GITHUB#12294: Add support for Java 21 foreign memory API. If Java 19 up to 21 is used,
+ MMapDirectory will mmap Lucene indexes in chunks of 16 GiB (instead of 1 GiB) and indexes
+ closed while queries are running can no longer crash the JVM. To disable this feature,
+ pass the following sysprop on Java command line:
+ "-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler)
+
+* GITHUB#12252 Add function queries for computing similarity scores between knn vectors. (Elia Porciani, Alessandro Benedetti)
+
+Improvements
+---------------------
+
+* GITHUB#12245: Add support for Score Mode to `ToParentBlockJoinQuery` explain. (Marcus Eagan via Mikhail Khludnev)
+
+* GITHUB#12305: Minor cleanup and improvements to DaciukMihovAutomatonBuilder. (Greg Miller)
+
+* GITHUB#12325: Parallelize AbstractKnnVectorQuery rewrite across slices rather than segments. (Luca Cavanna)
+
+* GITHUB#12333: NumericLeafComparator#competitiveIterator makes better use of a "search after" value when paginating.
+ (Chaitanya Gohel)
+
+* GITHUB#12290: Make memory fence in ByteBufferGuard explicit using `VarHandle.fullFence()`
+
+* GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit.
+ (Greg Miller)
+
+* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler)
+
+Optimizations
+---------------------
+
+* GITHUB#12324: Speed up sparse block advanceExact with tiny step in IndexedDISI. (Guo Feng)
+
+* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun)
+
+* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)
+
+* GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai)
+
+* GITHUB#12235: Optimize HNSW diversity calculation. (Patrick Zhai)
+
+* GITHUB#12328: Optimize ConjunctionDISI.createConjunction (Armin Braun)
+
+* GITHUB#12357: Better paging when doing backwards random reads. This speeds up
+ queries relying on terms in NIOFSDirectory and SimpleFSDirectory. (Alan Woodward)
+
+* GITHUB#12339: Optimize part of duplicate calculation numDeletesToMerge in merge phase (fudongying)
+
+* GITHUB#12334: Honor after value for skipping documents even if queue is not full for PagingFieldCollector (Chaitanya Gohel)
+
+Bug Fixes
+---------------------
+
+* GITHUB#12291: Skip blank lines from stopwords list. (Jerry Chin)
+
+* GITHUB#11350: Handle possible differences in FieldInfo when merging indices created with Lucene 8.x (Tomás Fernández Löbbe)
+
+* GITHUB#12352: [Tessellator] Improve the checks that validate the diagonal between two polygon nodes so
+ the resulting polygons are valid counter clockwise polygons. (Ignacio Vera)
+
+* LUCENE-10181: Restrict GraphTokenStreamFiniteStrings#articulationPointsRecurse recursion depth. (Chris Fournier)
+
+Other
+---------------------
+(No changes)
+
======================== Lucene 9.6.0 =======================
API Changes
@@ -32,6 +132,8 @@ New Features
crash the JVM. To disable this feature, pass the following sysprop on Java command line:
"-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler)
+* GITHUB#12169: Introduce a new token filter to expand synonyms based on Word2Vec DL4j models. (Daniele Antuzi, Ilaria Petreti, Alessandro Benedetti)
+
Improvements
---------------------
@@ -45,6 +147,8 @@ Improvements
* GITHUB#12175: Remove SortedSetDocValuesSetQuery in favor of TermInSetQuery with DocValuesRewriteMethod. (Greg Miller)
+* GITHUB#12166: Remove the now unused class pointInPolygon. (Marcus Eagan via Christine Poerschke and Nick Knize)
+
* GITHUB#12126: Refactor part of IndexFileDeleter and ReplicaFileDeleter into a public common utility class
FileDeleter. (Patrick Zhai)
@@ -86,7 +190,9 @@ Bug Fixes
* GITHUB#12212: Bug fix for a DrillSideways issue where matching hits could occasionally be missed. (Frederic Thevenet)
-* GITHUB#12220: Hunspell: disallow hidden title-case entries from compound middle/end
+* GITHUB#12220: Hunspell: disallow hidden title-case entries from compound middle/end (Peter Gromov)
+
+* GITHUB#12260: Fix SynonymQuery equals implementation to take the targeted field name into account (Luca Cavanna)
Build
---------------------
@@ -157,7 +263,7 @@ API Changes
* GITHUB#11962: VectorValues#cost() now delegates to VectorValues#size().
(Adrien Grand)
-
+
* GITHUB#11984: Improved TimeLimitBulkScorer to check the timeout at exponantial rate.
(Costin Leau)
diff --git a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java
index 8c245e7058c7..988deaf99e59 100644
--- a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java
+++ b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java
@@ -89,6 +89,8 @@
import org.apache.lucene.analysis.standard.StandardTokenizer;
import org.apache.lucene.analysis.stempel.StempelStemmer;
import org.apache.lucene.analysis.synonym.SynonymMap;
+import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel;
+import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider;
import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.tests.analysis.MockTokenFilter;
@@ -99,8 +101,10 @@
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil;
import org.apache.lucene.util.AttributeFactory;
import org.apache.lucene.util.AttributeSource;
+import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.IgnoreRandomChains;
+import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.Version;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
@@ -415,6 +419,27 @@ private String randomNonEmptyString(Random random) {
}
}
});
+ put(
+ Word2VecSynonymProvider.class,
+ random -> {
+ final int numEntries = atLeast(10);
+ final int vectorDimension = random.nextInt(99) + 1;
+ Word2VecModel model = new Word2VecModel(numEntries, vectorDimension);
+ for (int j = 0; j < numEntries; j++) {
+ String s = TestUtil.randomSimpleString(random, 10, 20);
+ float[] vec = new float[vectorDimension];
+ for (int i = 0; i < vectorDimension; i++) {
+ vec[i] = random.nextFloat();
+ }
+ model.addTermAndVector(new TermAndVector(new BytesRef(s), vec));
+ }
+ try {
+ return new Word2VecSynonymProvider(model);
+ } catch (IOException e) {
+ Rethrow.rethrow(e);
+ return null; // unreachable code
+ }
+ });
put(
DateFormat.class,
random -> {
diff --git a/lucene/analysis/common/src/java/module-info.java b/lucene/analysis/common/src/java/module-info.java
index 5679f0dde295..15ad5a2b1af0 100644
--- a/lucene/analysis/common/src/java/module-info.java
+++ b/lucene/analysis/common/src/java/module-info.java
@@ -78,6 +78,7 @@
exports org.apache.lucene.analysis.sr;
exports org.apache.lucene.analysis.sv;
exports org.apache.lucene.analysis.synonym;
+ exports org.apache.lucene.analysis.synonym.word2vec;
exports org.apache.lucene.analysis.ta;
exports org.apache.lucene.analysis.te;
exports org.apache.lucene.analysis.th;
@@ -256,6 +257,7 @@
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory,
org.apache.lucene.analysis.synonym.SynonymFilterFactory,
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory,
+ org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory,
org.apache.lucene.analysis.core.FlattenGraphFilterFactory,
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory,
org.apache.lucene.analysis.te.TeluguStemFilterFactory,
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Dictionary.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Dictionary.java
index e94047b67db3..b7a4029a523b 100644
--- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Dictionary.java
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Dictionary.java
@@ -155,7 +155,7 @@ public class Dictionary {
boolean checkCompoundCase, checkCompoundDup, checkCompoundRep;
boolean checkCompoundTriple, simplifiedTriple;
int compoundMin = 3, compoundMax = Integer.MAX_VALUE;
- List compoundRules; // nullable
+ CompoundRule[] compoundRules; // nullable
List checkCompoundPatterns = new ArrayList<>();
// ignored characters (dictionary, affix, inputs)
@@ -601,11 +601,11 @@ private String[] splitBySpace(LineNumberReader reader, String line, int minParts
return parts;
}
- private List parseCompoundRules(LineNumberReader reader, int num)
+ private CompoundRule[] parseCompoundRules(LineNumberReader reader, int num)
throws IOException, ParseException {
- List compoundRules = new ArrayList<>();
+ CompoundRule[] compoundRules = new CompoundRule[num];
for (int i = 0; i < num; i++) {
- compoundRules.add(new CompoundRule(singleArgument(reader, reader.readLine()), this));
+ compoundRules[i] = new CompoundRule(singleArgument(reader, reader.readLine()), this);
}
return compoundRules;
}
@@ -992,7 +992,7 @@ private int mergeDictionaries(
// if we haven't seen any custom morphological data, try to parse one
if (!hasCustomMorphData) {
int morphStart = line.indexOf(MORPH_SEPARATOR);
- if (morphStart >= 0 && morphStart < line.length()) {
+ if (morphStart >= 0) {
String data = line.substring(morphStart + 1);
hasCustomMorphData =
splitMorphData(data).stream().anyMatch(s -> !s.startsWith("ph:"));
@@ -1321,14 +1321,22 @@ private List splitMorphData(String morphData) {
if (morphData.isBlank()) {
return Collections.emptyList();
}
- return Arrays.stream(morphData.split("\\s+"))
- .filter(
- s ->
- s.length() > 3
- && Character.isLetter(s.charAt(0))
- && Character.isLetter(s.charAt(1))
- && s.charAt(2) == ':')
- .collect(Collectors.toList());
+
+ List result = null;
+ int start = 0;
+ for (int i = 0; i <= morphData.length(); i++) {
+ if (i == morphData.length() || Character.isWhitespace(morphData.charAt(i))) {
+ if (i - start > 3
+ && Character.isLetter(morphData.charAt(start))
+ && Character.isLetter(morphData.charAt(start + 1))
+ && morphData.charAt(start + 2) == ':') {
+ if (result == null) result = new ArrayList<>();
+ result.add(morphData.substring(start, i));
+ }
+ start = i + 1;
+ }
+ }
+ return result == null ? List.of() : result;
}
boolean hasFlag(IntsRef forms, char flag) {
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Hunspell.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Hunspell.java
index 1e2a1add13cd..3b58e0f4f980 100644
--- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Hunspell.java
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Hunspell.java
@@ -450,7 +450,7 @@ private boolean checkCompoundRules(
if (forms != null) {
words.add(forms);
- if (dictionary.compoundRules.stream().anyMatch(r -> r.mayMatch(words))) {
+ if (mayHaveCompoundRule(words)) {
if (checkLastCompoundPart(wordChars, offset + breakPos, length - breakPos, words)) {
return true;
}
@@ -467,6 +467,15 @@ private boolean checkCompoundRules(
return false;
}
+ private boolean mayHaveCompoundRule(List words) {
+ for (CompoundRule rule : dictionary.compoundRules) {
+ if (rule.mayMatch(words)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
private boolean checkLastCompoundPart(
char[] wordChars, int start, int length, List words) {
IntsRef ref = new IntsRef(new int[1], 0, 1);
@@ -475,7 +484,12 @@ private boolean checkLastCompoundPart(
Stemmer.RootProcessor stopOnMatching =
(stem, formID, morphDataId, outerPrefix, innerPrefix, outerSuffix, innerSuffix) -> {
ref.ints[0] = formID;
- return dictionary.compoundRules.stream().noneMatch(r -> r.fullyMatches(words));
+ for (CompoundRule r : dictionary.compoundRules) {
+ if (r.fullyMatches(words)) {
+ return false;
+ }
+ }
+ return true;
};
boolean found = !stemmer.doStem(wordChars, start, length, COMPOUND_RULE_END, stopOnMatching);
words.remove(words.size() - 1);
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/TrigramAutomaton.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/TrigramAutomaton.java
index dfe994ccf827..f4404e4bcf02 100644
--- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/TrigramAutomaton.java
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/TrigramAutomaton.java
@@ -79,7 +79,7 @@ private int runAutomatonOnStringChars(String s) {
}
int ngramScore(CharsRef s2) {
- countedSubstrings.clear(0, countedSubstrings.length());
+ countedSubstrings.clear();
int score1 = 0, score2 = 0, score3 = 0; // scores for substrings of length 1, 2 and 3
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java
new file mode 100644
index 000000000000..f022dd8eca67
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedReader;
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.Locale;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipInputStream;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TermAndVector;
+
+/**
+ * Dl4jModelReader reads the file generated by the library Deeplearning4j and provide a
+ * Word2VecModel with normalized vectors
+ *
+ *
Dl4j Word2Vec documentation:
+ * https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to
+ * generate a model using dl4j:
+ * https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java
+ *
+ * @lucene.experimental
+ */
+public class Dl4jModelReader implements Closeable {
+
+ private static final String MODEL_FILE_NAME_PREFIX = "syn0";
+
+ private final ZipInputStream word2VecModelZipFile;
+
+ public Dl4jModelReader(InputStream stream) {
+ this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream));
+ }
+
+ public Word2VecModel read() throws IOException {
+
+ ZipEntry entry;
+ while ((entry = word2VecModelZipFile.getNextEntry()) != null) {
+ String fileName = entry.getName();
+ if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) {
+ BufferedReader reader =
+ new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8));
+
+ String header = reader.readLine();
+ String[] headerValues = header.split(" ");
+ int dictionarySize = Integer.parseInt(headerValues[0]);
+ int vectorDimension = Integer.parseInt(headerValues[1]);
+
+ Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension);
+ String line = reader.readLine();
+ boolean isTermB64Encoded = false;
+ if (line != null) {
+ String[] tokens = line.split(" ");
+ isTermB64Encoded =
+ tokens[0].substring(0, 3).toLowerCase(Locale.ROOT).compareTo("b64") == 0;
+ model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
+ }
+ while ((line = reader.readLine()) != null) {
+ String[] tokens = line.split(" ");
+ model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded));
+ }
+ return model;
+ }
+ }
+ throw new IllegalArgumentException(
+ "Cannot read Dl4j word2vec model - '"
+ + MODEL_FILE_NAME_PREFIX
+ + "' file is missing in the zip. '"
+ + MODEL_FILE_NAME_PREFIX
+ + "' is a mandatory file containing the mapping between terms and vectors generated by the DL4j library.");
+ }
+
+ private static TermAndVector extractTermAndVector(
+ String[] tokens, int vectorDimension, boolean isTermB64Encoded) {
+ BytesRef term = isTermB64Encoded ? decodeB64Term(tokens[0]) : new BytesRef((tokens[0]));
+
+ float[] vector = new float[tokens.length - 1];
+
+ if (vectorDimension != vector.length) {
+ throw new RuntimeException(
+ String.format(
+ Locale.ROOT,
+ "Word2Vec model file corrupted. "
+ + "Declared vectors of size %d but found vector of size %d for word %s (%s)",
+ vectorDimension,
+ vector.length,
+ tokens[0],
+ term.utf8ToString()));
+ }
+
+ for (int i = 1; i < tokens.length; i++) {
+ vector[i - 1] = Float.parseFloat(tokens[i]);
+ }
+ return new TermAndVector(term, vector);
+ }
+
+ static BytesRef decodeB64Term(String term) {
+ byte[] buffer = Base64.getDecoder().decode(term.substring(4));
+ return new BytesRef(buffer, 0, buffer.length);
+ }
+
+ @Override
+ public void close() throws IOException {
+ word2VecModelZipFile.close();
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/TermAndBoost.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/TermAndBoost.java
new file mode 100644
index 000000000000..03fdeecb0f20
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/TermAndBoost.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import org.apache.lucene.util.BytesRef;
+
+/** Wraps a term and boost */
+public class TermAndBoost {
+ /** the term */
+ public final BytesRef term;
+ /** the boost */
+ public final float boost;
+
+ /** Creates a new TermAndBoost */
+ public TermAndBoost(BytesRef term, float boost) {
+ this.term = BytesRef.deepCopyOf(term);
+ this.boost = boost;
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java
new file mode 100644
index 000000000000..6719639b67d9
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefHash;
+import org.apache.lucene.util.TermAndVector;
+import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
+
+/**
+ * Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
+ * word in dictionary
+ *
+ * @lucene.experimental
+ */
+public class Word2VecModel implements RandomAccessVectorValues {
+
+ private final int dictionarySize;
+ private final int vectorDimension;
+ private final TermAndVector[] termsAndVectors;
+ private final BytesRefHash word2Vec;
+ private int loadedCount = 0;
+
+ public Word2VecModel(int dictionarySize, int vectorDimension) {
+ this.dictionarySize = dictionarySize;
+ this.vectorDimension = vectorDimension;
+ this.termsAndVectors = new TermAndVector[dictionarySize];
+ this.word2Vec = new BytesRefHash();
+ }
+
+ private Word2VecModel(
+ int dictionarySize,
+ int vectorDimension,
+ TermAndVector[] termsAndVectors,
+ BytesRefHash word2Vec) {
+ this.dictionarySize = dictionarySize;
+ this.vectorDimension = vectorDimension;
+ this.termsAndVectors = termsAndVectors;
+ this.word2Vec = word2Vec;
+ }
+
+ public void addTermAndVector(TermAndVector modelEntry) {
+ modelEntry.normalizeVector();
+ this.termsAndVectors[loadedCount++] = modelEntry;
+ this.word2Vec.add(modelEntry.getTerm());
+ }
+
+ @Override
+ public float[] vectorValue(int targetOrd) {
+ return termsAndVectors[targetOrd].getVector();
+ }
+
+ public float[] vectorValue(BytesRef term) {
+ int termOrd = this.word2Vec.find(term);
+ if (termOrd < 0) return null;
+ TermAndVector entry = this.termsAndVectors[termOrd];
+ return (entry == null) ? null : entry.getVector();
+ }
+
+ public BytesRef termValue(int targetOrd) {
+ return termsAndVectors[targetOrd].getTerm();
+ }
+
+ @Override
+ public int dimension() {
+ return vectorDimension;
+ }
+
+ @Override
+ public int size() {
+ return dictionarySize;
+ }
+
+ @Override
+ public RandomAccessVectorValues copy() throws IOException {
+ return new Word2VecModel(
+ this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec);
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java
new file mode 100644
index 000000000000..a8db4c4c764a
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.List;
+import org.apache.lucene.analysis.TokenFilter;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.synonym.SynonymGraphFilter;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
+import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
+import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefBuilder;
+
+/**
+ * Applies single-token synonyms from a Word2Vec trained network to an incoming {@link TokenStream}.
+ *
+ * @lucene.experimental
+ */
+public final class Word2VecSynonymFilter extends TokenFilter {
+
+ private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
+ private final PositionIncrementAttribute posIncrementAtt =
+ addAttribute(PositionIncrementAttribute.class);
+ private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class);
+ private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class);
+
+ private final Word2VecSynonymProvider synonymProvider;
+ private final int maxSynonymsPerTerm;
+ private final float minAcceptedSimilarity;
+ private final LinkedList synonymBuffer = new LinkedList<>();
+ private State lastState;
+
+ /**
+ * Apply previously built synonymProvider to incoming tokens.
+ *
+ * @param input input tokenstream
+ * @param synonymProvider synonym provider
+ * @param maxSynonymsPerTerm maximum number of result returned by the synonym search
+ * @param minAcceptedSimilarity minimal value of cosine similarity between the searched vector and
+ * the retrieved ones
+ */
+ public Word2VecSynonymFilter(
+ TokenStream input,
+ Word2VecSynonymProvider synonymProvider,
+ int maxSynonymsPerTerm,
+ float minAcceptedSimilarity) {
+ super(input);
+ if (synonymProvider == null) {
+ throw new IllegalArgumentException("The SynonymProvider must be non-null");
+ }
+ this.synonymProvider = synonymProvider;
+ this.maxSynonymsPerTerm = maxSynonymsPerTerm;
+ this.minAcceptedSimilarity = minAcceptedSimilarity;
+ }
+
+ @Override
+ public boolean incrementToken() throws IOException {
+
+ if (!synonymBuffer.isEmpty()) {
+ TermAndBoost synonym = synonymBuffer.pollFirst();
+ clearAttributes();
+ restoreState(this.lastState);
+ termAtt.setEmpty();
+ termAtt.append(synonym.term.utf8ToString());
+ typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM);
+ posLenAtt.setPositionLength(1);
+ posIncrementAtt.setPositionIncrement(0);
+ return true;
+ }
+
+ if (input.incrementToken()) {
+ BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
+ bytesRefBuilder.copyChars(termAtt.buffer(), 0, termAtt.length());
+ BytesRef term = bytesRefBuilder.get();
+ List synonyms =
+ this.synonymProvider.getSynonyms(term, maxSynonymsPerTerm, minAcceptedSimilarity);
+ if (synonyms.size() > 0) {
+ this.lastState = captureState();
+ this.synonymBuffer.addAll(synonyms);
+ }
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void reset() throws IOException {
+ super.reset();
+ synonymBuffer.clear();
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java
new file mode 100644
index 000000000000..32b6288926fc
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Map;
+import org.apache.lucene.analysis.TokenFilterFactory;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProviderFactory.Word2VecSupportedFormats;
+import org.apache.lucene.util.ResourceLoader;
+import org.apache.lucene.util.ResourceLoaderAware;
+
+/**
+ * Factory for {@link Word2VecSynonymFilter}.
+ *
+ * @lucene.experimental
+ * @lucene.spi {@value #NAME}
+ */
+public class Word2VecSynonymFilterFactory extends TokenFilterFactory
+ implements ResourceLoaderAware {
+
+ /** SPI name */
+ public static final String NAME = "Word2VecSynonym";
+
+ public static final int DEFAULT_MAX_SYNONYMS_PER_TERM = 5;
+ public static final float DEFAULT_MIN_ACCEPTED_SIMILARITY = 0.8f;
+
+ private final int maxSynonymsPerTerm;
+ private final float minAcceptedSimilarity;
+ private final Word2VecSupportedFormats format;
+ private final String word2vecModelFileName;
+
+ private Word2VecSynonymProvider synonymProvider;
+
+ public Word2VecSynonymFilterFactory(Map args) {
+ super(args);
+ this.maxSynonymsPerTerm = getInt(args, "maxSynonymsPerTerm", DEFAULT_MAX_SYNONYMS_PER_TERM);
+ this.minAcceptedSimilarity =
+ getFloat(args, "minAcceptedSimilarity", DEFAULT_MIN_ACCEPTED_SIMILARITY);
+ this.word2vecModelFileName = require(args, "model");
+
+ String modelFormat = get(args, "format", "dl4j").toUpperCase(Locale.ROOT);
+ try {
+ this.format = Word2VecSupportedFormats.valueOf(modelFormat);
+ } catch (IllegalArgumentException exc) {
+ throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc);
+ }
+
+ if (!args.isEmpty()) {
+ throw new IllegalArgumentException("Unknown parameters: " + args);
+ }
+ if (minAcceptedSimilarity <= 0 || minAcceptedSimilarity > 1) {
+ throw new IllegalArgumentException(
+ "minAcceptedSimilarity must be in the range (0, 1]. Found: " + minAcceptedSimilarity);
+ }
+ if (maxSynonymsPerTerm <= 0) {
+ throw new IllegalArgumentException(
+ "maxSynonymsPerTerm must be a positive integer greater than 0. Found: "
+ + maxSynonymsPerTerm);
+ }
+ }
+
+ /** Default ctor for compatibility with SPI */
+ public Word2VecSynonymFilterFactory() {
+ throw defaultCtorException();
+ }
+
+ Word2VecSynonymProvider getSynonymProvider() {
+ return this.synonymProvider;
+ }
+
+ @Override
+ public TokenStream create(TokenStream input) {
+ return synonymProvider == null
+ ? input
+ : new Word2VecSynonymFilter(
+ input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
+ }
+
+ @Override
+ public void inform(ResourceLoader loader) throws IOException {
+ this.synonymProvider =
+ Word2VecSynonymProviderFactory.getSynonymProvider(loader, word2vecModelFileName, format);
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
new file mode 100644
index 000000000000..3089f1587a4e
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH;
+import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.List;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.hnsw.HnswGraphBuilder;
+import org.apache.lucene.util.hnsw.HnswGraphSearcher;
+import org.apache.lucene.util.hnsw.NeighborQueue;
+import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
+
+/**
+ * The Word2VecSynonymProvider generates the list of sysnonyms of a term.
+ *
+ * @lucene.experimental
+ */
+public class Word2VecSynonymProvider {
+
+ private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
+ VectorSimilarityFunction.DOT_PRODUCT;
+ private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32;
+ private final Word2VecModel word2VecModel;
+ private final OnHeapHnswGraph hnswGraph;
+
+ /**
+ * Word2VecSynonymProvider constructor
+ *
+ * @param model containing the set of TermAndVector entries
+ */
+ public Word2VecSynonymProvider(Word2VecModel model) throws IOException {
+ word2VecModel = model;
+
+ HnswGraphBuilder builder =
+ HnswGraphBuilder.create(
+ word2VecModel,
+ VECTOR_ENCODING,
+ SIMILARITY_FUNCTION,
+ DEFAULT_MAX_CONN,
+ DEFAULT_BEAM_WIDTH,
+ HnswGraphBuilder.randSeed);
+ this.hnswGraph = builder.build(word2VecModel.copy());
+ }
+
+ public List getSynonyms(
+ BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException {
+
+ if (term == null) {
+ throw new IllegalArgumentException("Term must not be null");
+ }
+
+ LinkedList result = new LinkedList<>();
+ float[] query = word2VecModel.vectorValue(term);
+ if (query != null) {
+ NeighborQueue synonyms =
+ HnswGraphSearcher.search(
+ query,
+ // The query vector is in the model. When looking for the top-k
+ // it's always the nearest neighbour of itself so, we look for the top-k+1
+ maxSynonymsPerTerm + 1,
+ word2VecModel,
+ VECTOR_ENCODING,
+ SIMILARITY_FUNCTION,
+ hnswGraph,
+ null,
+ Integer.MAX_VALUE);
+
+ int size = synonyms.size();
+ for (int i = 0; i < size; i++) {
+ float similarity = synonyms.topScore();
+ int id = synonyms.pop();
+
+ BytesRef synonym = word2VecModel.termValue(id);
+ // We remove the original query term
+ if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) {
+ result.addFirst(new TermAndBoost(synonym, similarity));
+ }
+ }
+ }
+ return result;
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java
new file mode 100644
index 000000000000..ea849e653cd6
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.lucene.util.ResourceLoader;
+
+/**
+ * Supply Word2Vec Word2VecSynonymProvider cache avoiding that multiple instances of
+ * Word2VecSynonymFilterFactory will instantiate multiple instances of the same SynonymProvider.
+ * Assumes synonymProvider implementations are thread-safe.
+ */
+public class Word2VecSynonymProviderFactory {
+
+ enum Word2VecSupportedFormats {
+ DL4J
+ }
+
+ private static Map word2vecSynonymProviders =
+ new ConcurrentHashMap<>();
+
+ public static Word2VecSynonymProvider getSynonymProvider(
+ ResourceLoader loader, String modelFileName, Word2VecSupportedFormats format)
+ throws IOException {
+ Word2VecSynonymProvider synonymProvider = word2vecSynonymProviders.get(modelFileName);
+ if (synonymProvider == null) {
+ try (InputStream stream = loader.openResource(modelFileName)) {
+ try (Dl4jModelReader reader = getModelReader(format, stream)) {
+ synonymProvider = new Word2VecSynonymProvider(reader.read());
+ }
+ }
+ word2vecSynonymProviders.put(modelFileName, synonymProvider);
+ }
+ return synonymProvider;
+ }
+
+ private static Dl4jModelReader getModelReader(
+ Word2VecSupportedFormats format, InputStream stream) {
+ switch (format) {
+ case DL4J:
+ return new Dl4jModelReader(stream);
+ }
+ return null;
+ }
+}
diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java
new file mode 100644
index 000000000000..e8d69ab3cf9b
--- /dev/null
+++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Analysis components for Synonyms using Word2Vec model. */
+package org.apache.lucene.analysis.synonym.word2vec;
diff --git a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory
index 19a34b7840a8..1e4e17eaeadf 100644
--- a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory
+++ b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory
@@ -118,6 +118,7 @@ org.apache.lucene.analysis.sv.SwedishLightStemFilterFactory
org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory
org.apache.lucene.analysis.synonym.SynonymFilterFactory
org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory
+org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory
org.apache.lucene.analysis.core.FlattenGraphFilterFactory
org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory
org.apache.lucene.analysis.te.TeluguStemFilterFactory
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestFlattenGraphFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestFlattenGraphFilter.java
index 7b35f56016cb..7fa901ac5bf4 100644
--- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestFlattenGraphFilter.java
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestFlattenGraphFilter.java
@@ -40,8 +40,8 @@
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.CharsRefBuilder;
+import org.apache.lucene.util.automaton.Automata;
import org.apache.lucene.util.automaton.Automaton;
-import org.apache.lucene.util.automaton.DaciukMihovAutomatonBuilder;
import org.apache.lucene.util.automaton.Operations;
import org.apache.lucene.util.automaton.Transition;
@@ -780,7 +780,7 @@ public void testPathsNotLost() throws IOException {
acceptStrings.sort(Comparator.naturalOrder());
acceptStrings = acceptStrings.stream().limit(wordCount).collect(Collectors.toList());
- Automaton nonFlattenedAutomaton = DaciukMihovAutomatonBuilder.build(acceptStrings);
+ Automaton nonFlattenedAutomaton = Automata.makeStringUnion(acceptStrings);
TokenStream ts = AutomatonToTokenStream.toTokenStream(nonFlattenedAutomaton);
TokenStream flattenedTokenStream = new FlattenGraphFilter(ts);
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/hunspell/TestHunspell.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/hunspell/TestHunspell.java
index e160cd7851ce..7bee301837f8 100644
--- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/hunspell/TestHunspell.java
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/hunspell/TestHunspell.java
@@ -31,6 +31,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CancellationException;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
@@ -74,7 +75,8 @@ public void testCustomCheckCanceledGivesPartialResult() throws Exception {
};
Hunspell hunspell = new Hunspell(dictionary, RETURN_PARTIAL_RESULT, checkCanceled);
- assertEquals(expected, hunspell.suggest("apac"));
+ // pass a long timeout so that slower CI servers are more predictable.
+ assertEquals(expected, hunspell.suggest("apac", TimeUnit.DAYS.toMillis(1)));
counter.set(0);
var e =
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java
new file mode 100644
index 000000000000..213dcdaccd33
--- /dev/null
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Test;
+
+public class TestDl4jModelReader extends LuceneTestCase {
+
+ private static final String MODEL_FILE = "word2vec-model.zip";
+ private static final String MODEL_EMPTY_FILE = "word2vec-empty-model.zip";
+ private static final String CORRUPTED_VECTOR_DIMENSION_MODEL_FILE =
+ "word2vec-corrupted-vector-dimension-model.zip";
+
+ InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_FILE);
+ Dl4jModelReader unit = new Dl4jModelReader(stream);
+
+ @Test
+ public void read_zipFileWithMetadata_shouldReturnDictionarySize() throws Exception {
+ Word2VecModel model = unit.read();
+ long expectedDictionarySize = 235;
+ assertEquals(expectedDictionarySize, model.size());
+ }
+
+ @Test
+ public void read_zipFileWithMetadata_shouldReturnVectorLength() throws Exception {
+ Word2VecModel model = unit.read();
+ int expectedVectorDimension = 100;
+ assertEquals(expectedVectorDimension, model.dimension());
+ }
+
+ @Test
+ public void read_zipFile_shouldReturnDecodedTerm() throws Exception {
+ Word2VecModel model = unit.read();
+ BytesRef expectedDecodedFirstTerm = new BytesRef("it");
+ assertEquals(expectedDecodedFirstTerm, model.termValue(0));
+ }
+
+ @Test
+ public void decodeTerm_encodedTerm_shouldReturnDecodedTerm() throws Exception {
+ byte[] originalInput = "lucene".getBytes(StandardCharsets.UTF_8);
+ String B64encodedLuceneTerm = Base64.getEncoder().encodeToString(originalInput);
+ String word2vecEncodedLuceneTerm = "B64:" + B64encodedLuceneTerm;
+ assertEquals(new BytesRef("lucene"), Dl4jModelReader.decodeB64Term(word2vecEncodedLuceneTerm));
+ }
+
+ @Test
+ public void read_EmptyZipFile_shouldThrowException() throws Exception {
+ try (InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_EMPTY_FILE)) {
+ Dl4jModelReader unit = new Dl4jModelReader(stream);
+ expectThrows(IllegalArgumentException.class, unit::read);
+ }
+ }
+
+ @Test
+ public void read_corruptedVectorDimensionModelFile_shouldThrowException() throws Exception {
+ try (InputStream stream =
+ TestDl4jModelReader.class.getResourceAsStream(CORRUPTED_VECTOR_DIMENSION_MODEL_FILE)) {
+ Dl4jModelReader unit = new Dl4jModelReader(stream);
+ expectThrows(RuntimeException.class, unit::read);
+ }
+ }
+}
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java
new file mode 100644
index 000000000000..3999931dd758
--- /dev/null
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.Tokenizer;
+import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
+import org.apache.lucene.tests.analysis.MockTokenizer;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TermAndVector;
+import org.junit.Test;
+
+public class TestWord2VecSynonymFilter extends BaseTokenStreamTestCase {
+
+ @Test
+ public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() throws Exception {
+ int maxSynonymPerTerm = 10;
+ float minAcceptedSimilarity = 0.9f;
+ Word2VecModel model = new Word2VecModel(6, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
+
+ Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
+
+ Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
+ assertAnalyzesTo(
+ a,
+ "pre a post", // input
+ new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output
+ new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset
+ new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset
+ new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types
+ new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements
+ new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts
+ a.close();
+ }
+
+ @Test
+ public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() throws Exception {
+ int maxSynonymPerTerm = 2;
+ float minAcceptedSimilarity = 0.9f;
+ Word2VecModel model = new Word2VecModel(5, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
+
+ Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
+
+ Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity);
+ assertAnalyzesTo(
+ a,
+ "pre a post", // input
+ new String[] {"pre", "a", "d", "e", "post"}, // output
+ new int[] {0, 4, 4, 4, 6}, // start offset
+ new int[] {3, 5, 5, 5, 10}, // end offset
+ new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types
+ new int[] {1, 1, 0, 0, 1}, // posIncrements
+ new int[] {1, 1, 1, 1, 1}); // posLenghts
+ a.close();
+ }
+
+ @Test
+ public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Exception {
+ Word2VecModel model = new Word2VecModel(8, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {1, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("post"), new float[] {-10, -11}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("after"), new float[] {-8, -10}));
+
+ Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
+
+ Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f);
+ assertAnalyzesTo(
+ a,
+ "pre a post", // input
+ new String[] {"pre", "a", "d", "e", "c", "b", "post", "after"}, // output
+ new int[] {0, 4, 4, 4, 4, 4, 6, 6}, // start offset
+ new int[] {3, 5, 5, 5, 5, 5, 10, 10}, // end offset
+ new String[] { // types
+ "word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM"
+ },
+ new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements
+ new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths
+ a.close();
+ }
+
+ @Test
+ public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms()
+ throws Exception {
+ Word2VecModel model = new Word2VecModel(4, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, -10}));
+
+ Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model);
+
+ Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f);
+ assertAnalyzesTo(
+ a,
+ "pre a post", // input
+ new String[] {"pre", "a", "post"}, // output
+ new int[] {0, 4, 6}, // start offset
+ new int[] {3, 5, 10}, // end offset
+ new String[] {"word", "word", "word"}, // types
+ new int[] {1, 1, 1}, // posIncrements
+ new int[] {1, 1, 1}); // posLengths
+ a.close();
+ }
+
+ private Analyzer getAnalyzer(
+ Word2VecSynonymProvider synonymProvider,
+ int maxSynonymsPerTerm,
+ float minAcceptedSimilarity) {
+ return new Analyzer() {
+ @Override
+ protected TokenStreamComponents createComponents(String fieldName) {
+ Tokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false);
+ // Make a local variable so testRandomHuge doesn't share it across threads!
+ Word2VecSynonymFilter synFilter =
+ new Word2VecSynonymFilter(
+ tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity);
+ return new TokenStreamComponents(tokenizer, synFilter);
+ }
+ };
+ }
+}
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java
new file mode 100644
index 000000000000..007fedf4abed
--- /dev/null
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import org.apache.lucene.tests.analysis.BaseTokenStreamFactoryTestCase;
+import org.apache.lucene.util.ClasspathResourceLoader;
+import org.apache.lucene.util.ResourceLoader;
+import org.junit.Test;
+
+public class TestWord2VecSynonymFilterFactory extends BaseTokenStreamFactoryTestCase {
+
+ public static final String FACTORY_NAME = "Word2VecSynonym";
+ private static final String WORD2VEC_MODEL_FILE = "word2vec-model.zip";
+
+ @Test
+ public void testInform() throws Exception {
+ ResourceLoader loader = new ClasspathResourceLoader(getClass());
+ assertTrue("loader is null and it shouldn't be", loader != null);
+ Word2VecSynonymFilterFactory factory =
+ (Word2VecSynonymFilterFactory)
+ tokenFilterFactory(
+ FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "minAcceptedSimilarity", "0.7");
+
+ Word2VecSynonymProvider synonymProvider = factory.getSynonymProvider();
+ assertNotEquals(null, synonymProvider);
+ }
+
+ @Test
+ public void missingRequiredArgument_shouldThrowException() throws Exception {
+ IllegalArgumentException expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "format",
+ "dl4j",
+ "minAcceptedSimilarity",
+ "0.7",
+ "maxSynonymsPerTerm",
+ "10");
+ });
+ assertTrue(expected.getMessage().contains("Configuration Error: missing parameter 'model'"));
+ }
+
+ @Test
+ public void unsupportedModelFormat_shouldThrowException() throws Exception {
+ IllegalArgumentException expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "format", "bogusValue");
+ });
+ assertTrue(expected.getMessage().contains("Model format 'BOGUSVALUE' not supported"));
+ }
+
+ @Test
+ public void bogusArgument_shouldThrowException() throws Exception {
+ IllegalArgumentException expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "bogusArg", "bogusValue");
+ });
+ assertTrue(expected.getMessage().contains("Unknown parameters"));
+ }
+
+ @Test
+ public void illegalArguments_shouldThrowException() throws Exception {
+ IllegalArgumentException expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "model",
+ WORD2VEC_MODEL_FILE,
+ "minAcceptedSimilarity",
+ "2",
+ "maxSynonymsPerTerm",
+ "10");
+ });
+ assertTrue(
+ expected
+ .getMessage()
+ .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 2"));
+
+ expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "model",
+ WORD2VEC_MODEL_FILE,
+ "minAcceptedSimilarity",
+ "0",
+ "maxSynonymsPerTerm",
+ "10");
+ });
+ assertTrue(
+ expected
+ .getMessage()
+ .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 0"));
+
+ expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "model",
+ WORD2VEC_MODEL_FILE,
+ "minAcceptedSimilarity",
+ "0.7",
+ "maxSynonymsPerTerm",
+ "-1");
+ });
+ assertTrue(
+ expected
+ .getMessage()
+ .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: -1"));
+
+ expected =
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> {
+ tokenFilterFactory(
+ FACTORY_NAME,
+ "model",
+ WORD2VEC_MODEL_FILE,
+ "minAcceptedSimilarity",
+ "0.7",
+ "maxSynonymsPerTerm",
+ "0");
+ });
+ assertTrue(
+ expected
+ .getMessage()
+ .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: 0"));
+ }
+}
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java
new file mode 100644
index 000000000000..3e7e6bce07a3
--- /dev/null
+++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.analysis.synonym.word2vec;
+
+import java.io.IOException;
+import java.util.List;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TermAndVector;
+import org.junit.Test;
+
+public class TestWord2VecSynonymProvider extends LuceneTestCase {
+
+ private static final int MAX_SYNONYMS_PER_TERM = 10;
+ private static final float MIN_ACCEPTED_SIMILARITY = 0.85f;
+
+ private final Word2VecSynonymProvider unit;
+
+ public TestWord2VecSynonymProvider() throws IOException {
+ Word2VecModel model = new Word2VecModel(2, 3);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {0.24f, 0.78f, 0.28f}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {0.44f, 0.01f, 0.81f}));
+ unit = new Word2VecSynonymProvider(model);
+ }
+
+ @Test
+ public void getSynonyms_nullToken_shouldThrowException() {
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> unit.getSynonyms(null, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY));
+ }
+
+ @Test
+ public void getSynonyms_shouldReturnSynonymsBasedOnMinAcceptedSimilarity() throws Exception {
+ Word2VecModel model = new Word2VecModel(6, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
+
+ Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
+
+ BytesRef inputTerm = new BytesRef("a");
+ String[] expectedSynonyms = {"d", "e", "c", "b"};
+ List actualSynonymsResults =
+ unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
+
+ assertEquals(4, actualSynonymsResults.size());
+ for (int i = 0; i < expectedSynonyms.length; i++) {
+ assertEquals(new BytesRef(expectedSynonyms[i]), actualSynonymsResults.get(i).term);
+ }
+ }
+
+ @Test
+ public void getSynonyms_shouldReturnSynonymsBoost() throws Exception {
+ Word2VecModel model = new Word2VecModel(3, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {1, 1}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {99, 101}));
+
+ Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
+
+ BytesRef inputTerm = new BytesRef("a");
+ List actualSynonymsResults =
+ unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
+
+ BytesRef expectedFirstSynonymTerm = new BytesRef("b");
+ double expectedFirstSynonymBoost = 1.0;
+ assertEquals(expectedFirstSynonymTerm, actualSynonymsResults.get(0).term);
+ assertEquals(expectedFirstSynonymBoost, actualSynonymsResults.get(0).boost, 0.001f);
+ }
+
+ @Test
+ public void noSynonymsWithinAcceptedSimilarity_shouldReturnNoSynonyms() throws Exception {
+ Word2VecModel model = new Word2VecModel(4, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {6, -6}));
+
+ Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model);
+
+ BytesRef inputTerm = newBytesRef("a");
+ List actualSynonymsResults =
+ unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY);
+ assertEquals(0, actualSynonymsResults.size());
+ }
+
+ @Test
+ public void testModel_shouldReturnNormalizedVectors() {
+ Word2VecModel model = new Word2VecModel(4, 2);
+ model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10}));
+ model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10}));
+
+ float[] vectorIdA = model.vectorValue(new BytesRef("a"));
+ float[] vectorIdF = model.vectorValue(new BytesRef("f"));
+ assertArrayEquals(new float[] {0.70710f, 0.70710f}, vectorIdA, 0.001f);
+ assertArrayEquals(new float[] {-0.0995f, 0.99503f}, vectorIdF, 0.001f);
+ }
+
+ @Test
+ public void normalizedVector_shouldReturnModule1() {
+ TermAndVector synonymTerm = new TermAndVector(new BytesRef("a"), new float[] {10, 10});
+ synonymTerm.normalizeVector();
+ float[] vector = synonymTerm.getVector();
+ float len = 0;
+ for (int i = 0; i < vector.length; i++) {
+ len += vector[i] * vector[i];
+ }
+ assertEquals(1, Math.sqrt(len), 0.0001f);
+ }
+}
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip
new file mode 100644
index 000000000000..e25693dd83cf
Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip differ
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip
new file mode 100644
index 000000000000..57d7832dd787
Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip differ
diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip
new file mode 100644
index 000000000000..6d31b8d5a3fa
Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip differ
diff --git a/lucene/analysis/smartcn/src/resources/org/apache/lucene/analysis/cn/smart/stopwords.txt b/lucene/analysis/smartcn/src/resources/org/apache/lucene/analysis/cn/smart/stopwords.txt
index fb0d71ad7d2c..65bcfd4e1b65 100644
--- a/lucene/analysis/smartcn/src/resources/org/apache/lucene/analysis/cn/smart/stopwords.txt
+++ b/lucene/analysis/smartcn/src/resources/org/apache/lucene/analysis/cn/smart/stopwords.txt
@@ -53,7 +53,5 @@ $
●
// the line below contains an IDEOGRAPHIC SPACE character (Used as a space in Chinese)
-
//////////////// English Stop Words ////////////////
-
//////////////// Chinese Stop Words ////////////////
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java
index fc82ce588860..639bdbd73339 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java
@@ -219,7 +219,7 @@ static short writeBitSet(DocIdSetIterator it, IndexOutput out, byte denseRankPow
// Flush block
flush(prevBlock, buffer, blockCardinality, denseRankPower, out);
// Reset for next block
- buffer.clear(0, buffer.length());
+ buffer.clear();
totalCardinality += blockCardinality;
blockCardinality = 0;
}
@@ -233,7 +233,7 @@ static short writeBitSet(DocIdSetIterator it, IndexOutput out, byte denseRankPow
jumps, out.getFilePointer() - origo, totalCardinality, jumpBlockIndex, prevBlock + 1);
totalCardinality += blockCardinality;
flush(prevBlock, buffer, blockCardinality, denseRankPower, out);
- buffer.clear(0, buffer.length());
+ buffer.clear();
prevBlock++;
}
final int lastBlock =
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
index e7f16b4f3fc9..4b1f7068a5f2 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java
@@ -183,10 +183,10 @@ private void addDiverseNeighbors(int node, NeighborQueue candidates) throws IOEx
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node()[i];
- Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
- nbrNbr.add(node, neighbors.score()[i]);
- if (nbrNbr.size() > maxConn) {
- diversityUpdate(nbrNbr);
+ Lucene90NeighborArray nbrsOfNbr = hnsw.getNeighbors(nbr);
+ nbrsOfNbr.add(node, neighbors.score()[i]);
+ if (nbrsOfNbr.size() > maxConn) {
+ diversityUpdate(nbrsOfNbr);
}
}
}
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
index b0e9d160457b..c82920181cc4 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java
@@ -204,10 +204,10 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
- Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
- nbrNbr.add(node, neighbors.score[i]);
- if (nbrNbr.size() > maxConn) {
- diversityUpdate(nbrNbr);
+ Lucene91NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
+ nbrsOfNbr.add(node, neighbors.score[i]);
+ if (nbrsOfNbr.size() > maxConn) {
+ diversityUpdate(nbrsOfNbr);
}
}
}
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBackwardsCompatibility.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBackwardsCompatibility.java
index d030777e3783..40541f505bea 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBackwardsCompatibility.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBackwardsCompatibility.java
@@ -419,7 +419,9 @@ public void testCreateEmptyIndex() throws Exception {
"9.4.2-cfs",
"9.4.2-nocfs",
"9.5.0-cfs",
- "9.5.0-nocfs"
+ "9.5.0-nocfs",
+ "9.6.0-cfs",
+ "9.6.0-nocfs"
};
public static String[] getOldNames() {
@@ -459,7 +461,8 @@ public static String[] getOldNames() {
"sorted.9.4.0",
"sorted.9.4.1",
"sorted.9.4.2",
- "sorted.9.5.0"
+ "sorted.9.5.0",
+ "sorted.9.6.0"
};
public static String[] getOldSortedNames() {
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-cfs.zip b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-cfs.zip
new file mode 100644
index 000000000000..bc40e99b3d77
Binary files /dev/null and b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-cfs.zip differ
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-nocfs.zip b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-nocfs.zip
new file mode 100644
index 000000000000..8e97a46c1485
Binary files /dev/null and b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-nocfs.zip differ
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/sorted.9.6.0.zip b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/sorted.9.6.0.zip
new file mode 100644
index 000000000000..f3af2e4011ee
Binary files /dev/null and b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/sorted.9.6.0.zip differ
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java
index c020d7487cb3..12d11e59b9d7 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java
@@ -35,6 +35,7 @@
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.FuzzyTermsEnum;
+import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.TermQuery;
@@ -214,7 +215,8 @@ private Query newTermQuery(IndexReader reader, Term term) throws IOException {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ IndexReader reader = indexSearcher.getIndexReader();
ScoreTermQueue q = new ScoreTermQueue(MAX_NUM_TERMS);
// load up the list of possible terms
for (FieldVals f : fieldVals) {
diff --git a/lucene/core/src/generated/jdk/README.md b/lucene/core/src/generated/jdk/README.md
index 371bbebf8518..48a014b992b8 100644
--- a/lucene/core/src/generated/jdk/README.md
+++ b/lucene/core/src/generated/jdk/README.md
@@ -40,4 +40,4 @@ to point the Lucene build system to missing JDK versions. The regeneration task
a warning if a specific JDK is missing, leaving the already existing `.apijar` file
untouched.
-The extraction is done with the ASM library, see `ExtractForeignAPI.java` source code.
+The extraction is done with the ASM library, see `ExtractJdkApis.java` source code.
diff --git a/lucene/core/src/generated/jdk/jdk19.apijar b/lucene/core/src/generated/jdk/jdk19.apijar
new file mode 100644
index 000000000000..4a04f1440e4a
Binary files /dev/null and b/lucene/core/src/generated/jdk/jdk19.apijar differ
diff --git a/lucene/core/src/generated/jdk/jdk20.apijar b/lucene/core/src/generated/jdk/jdk20.apijar
new file mode 100644
index 000000000000..942ddef057b7
Binary files /dev/null and b/lucene/core/src/generated/jdk/jdk20.apijar differ
diff --git a/lucene/core/src/generated/jdk/jdk21.apijar b/lucene/core/src/generated/jdk/jdk21.apijar
new file mode 100644
index 000000000000..3ded0aaaed4e
Binary files /dev/null and b/lucene/core/src/generated/jdk/jdk21.apijar differ
diff --git a/lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar b/lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar
deleted file mode 100644
index c9b73d9193bf..000000000000
Binary files a/lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar and /dev/null differ
diff --git a/lucene/core/src/generated/jdk/panama-foreign-jdk20.apijar b/lucene/core/src/generated/jdk/panama-foreign-jdk20.apijar
deleted file mode 100644
index 03baf38a1938..000000000000
Binary files a/lucene/core/src/generated/jdk/panama-foreign-jdk20.apijar and /dev/null differ
diff --git a/lucene/core/src/java/org/apache/lucene/analysis/WordlistLoader.java b/lucene/core/src/java/org/apache/lucene/analysis/WordlistLoader.java
index 30ada92eb39b..8e18f4ad76d3 100644
--- a/lucene/core/src/java/org/apache/lucene/analysis/WordlistLoader.java
+++ b/lucene/core/src/java/org/apache/lucene/analysis/WordlistLoader.java
@@ -40,9 +40,9 @@ public class WordlistLoader {
private WordlistLoader() {}
/**
- * Reads lines from a Reader and adds every line as an entry to a CharArraySet (omitting leading
- * and trailing whitespace). Every line of the Reader should contain only one word. The words need
- * to be in lowercase if you make use of an Analyzer which uses LowerCaseFilter (like
+ * Reads lines from a Reader and adds every non-blank line as an entry to a CharArraySet (omitting
+ * leading and trailing whitespace). Every line of the Reader should contain only one word. The
+ * words need to be in lowercase if you make use of an Analyzer which uses LowerCaseFilter (like
* StandardAnalyzer).
*
* @param reader Reader containing the wordlist
@@ -53,7 +53,10 @@ public static CharArraySet getWordSet(Reader reader, CharArraySet result) throws
try (BufferedReader br = getBufferedReader(reader)) {
String word = null;
while ((word = br.readLine()) != null) {
- result.add(word.trim());
+ word = word.trim();
+ // skip blank lines
+ if (word.isEmpty()) continue;
+ result.add(word);
}
}
return result;
@@ -101,10 +104,10 @@ public static CharArraySet getWordSet(InputStream stream, Charset charset) throw
}
/**
- * Reads lines from a Reader and adds every non-comment line as an entry to a CharArraySet
- * (omitting leading and trailing whitespace). Every line of the Reader should contain only one
- * word. The words need to be in lowercase if you make use of an Analyzer which uses
- * LowerCaseFilter (like StandardAnalyzer).
+ * Reads lines from a Reader and adds every non-blank non-comment line as an entry to a
+ * CharArraySet (omitting leading and trailing whitespace). Every line of the Reader should
+ * contain only one word. The words need to be in lowercase if you make use of an Analyzer which
+ * uses LowerCaseFilter (like StandardAnalyzer).
*
* @param reader Reader containing the wordlist
* @param comment The string representing a comment.
@@ -117,7 +120,10 @@ public static CharArraySet getWordSet(Reader reader, String comment, CharArraySe
String word = null;
while ((word = br.readLine()) != null) {
if (word.startsWith(comment) == false) {
- result.add(word.trim());
+ word = word.trim();
+ // skip blank lines
+ if (word.isEmpty()) continue;
+ result.add(word);
}
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java
index 8da289e3ad35..512ab4b1e556 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java
@@ -110,12 +110,12 @@ public final class IndexedDISI extends DocIdSetIterator {
private static void flush(
int block, FixedBitSet buffer, int cardinality, byte denseRankPower, IndexOutput out)
throws IOException {
- assert block >= 0 && block < 65536;
+ assert block >= 0 && block < BLOCK_SIZE;
out.writeShort((short) block);
- assert cardinality > 0 && cardinality <= 65536;
+ assert cardinality > 0 && cardinality <= BLOCK_SIZE;
out.writeShort((short) (cardinality - 1));
if (cardinality > MAX_ARRAY_LENGTH) {
- if (cardinality != 65536) { // all docs are set
+ if (cardinality != BLOCK_SIZE) { // all docs are set
if (denseRankPower != -1) {
final byte[] rank = createRank(buffer, denseRankPower);
out.writeBytes(rank, rank.length);
@@ -220,7 +220,7 @@ public static short writeBitSet(DocIdSetIterator it, IndexOutput out, byte dense
// Flush block
flush(prevBlock, buffer, blockCardinality, denseRankPower, out);
// Reset for next block
- buffer.clear(0, buffer.length());
+ buffer.clear();
totalCardinality += blockCardinality;
blockCardinality = 0;
}
@@ -234,7 +234,7 @@ public static short writeBitSet(DocIdSetIterator it, IndexOutput out, byte dense
jumps, out.getFilePointer() - origo, totalCardinality, jumpBlockIndex, prevBlock + 1);
totalCardinality += blockCardinality;
flush(prevBlock, buffer, blockCardinality, denseRankPower, out);
- buffer.clear(0, buffer.length());
+ buffer.clear();
prevBlock++;
}
final int lastBlock =
@@ -418,6 +418,7 @@ public static RandomAccessInput createJumpTable(
// SPARSE variables
boolean exists;
+ int nextExistDocInBlock = -1;
// DENSE variables
long word;
@@ -495,7 +496,8 @@ private void readBlockHeader() throws IOException {
if (numValues <= MAX_ARRAY_LENGTH) {
method = Method.SPARSE;
blockEnd = slice.getFilePointer() + (numValues << 1);
- } else if (numValues == 65536) {
+ nextExistDocInBlock = -1;
+ } else if (numValues == BLOCK_SIZE) {
method = Method.ALL;
blockEnd = slice.getFilePointer();
gap = block - index - 1;
@@ -550,6 +552,7 @@ boolean advanceWithinBlock(IndexedDISI disi, int target) throws IOException {
if (doc >= targetInBlock) {
disi.doc = disi.block | doc;
disi.exists = true;
+ disi.nextExistDocInBlock = doc;
return true;
}
}
@@ -560,6 +563,10 @@ boolean advanceWithinBlock(IndexedDISI disi, int target) throws IOException {
boolean advanceExactWithinBlock(IndexedDISI disi, int target) throws IOException {
final int targetInBlock = target & 0xFFFF;
// TODO: binary search
+ if (disi.nextExistDocInBlock > targetInBlock) {
+ assert !disi.exists;
+ return false;
+ }
if (target == disi.doc) {
return disi.exists;
}
@@ -567,6 +574,7 @@ boolean advanceExactWithinBlock(IndexedDISI disi, int target) throws IOException
int doc = Short.toUnsignedInt(disi.slice.readShort());
disi.index++;
if (doc >= targetInBlock) {
+ disi.nextExistDocInBlock = doc;
if (doc != targetInBlock) {
disi.index--;
disi.slice.seek(disi.slice.getFilePointer() - Short.BYTES);
diff --git a/lucene/core/src/java/org/apache/lucene/document/BinaryRangeFieldRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/BinaryRangeFieldRangeQuery.java
index ec814a4b0a7e..b225f4d98fe2 100644
--- a/lucene/core/src/java/org/apache/lucene/document/BinaryRangeFieldRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/BinaryRangeFieldRangeQuery.java
@@ -21,7 +21,6 @@
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.DocValues;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ConstantScoreScorer;
@@ -90,8 +89,8 @@ public void visit(QueryVisitor visitor) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- return super.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ return super.rewrite(indexSearcher);
}
private BinaryRangeDocValues getValues(LeafReader reader, String field) throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/document/DoubleRangeSlowRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/DoubleRangeSlowRangeQuery.java
index 6406b1be81a2..f0f9040d08c9 100644
--- a/lucene/core/src/java/org/apache/lucene/document/DoubleRangeSlowRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/DoubleRangeSlowRangeQuery.java
@@ -20,7 +20,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
@@ -79,8 +79,8 @@ public String toString(String field) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- return super.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ return super.rewrite(indexSearcher);
}
private static byte[] encodeRanges(double[] min, double[] max) {
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
index 27e689644594..85c3bad35199 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java
@@ -31,6 +31,7 @@
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FieldDoc;
+import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.similarities.BM25Similarity;
@@ -223,7 +224,7 @@ abstract static class FeatureFunction {
abstract Explanation explain(String field, String feature, float w, int freq);
- FeatureFunction rewrite(IndexReader reader) throws IOException {
+ FeatureFunction rewrite(IndexSearcher indexSearcher) throws IOException {
return this;
}
}
@@ -340,11 +341,11 @@ static final class SaturationFunction extends FeatureFunction {
}
@Override
- public FeatureFunction rewrite(IndexReader reader) throws IOException {
+ public FeatureFunction rewrite(IndexSearcher indexSearcher) throws IOException {
if (pivot != null) {
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
- float newPivot = computePivotFeatureValue(reader, field, feature);
+ float newPivot = computePivotFeatureValue(indexSearcher.getIndexReader(), field, feature);
return new SaturationFunction(field, feature, newPivot);
}
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
index de899a96d4ea..8b1ea8c47258 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
@@ -20,7 +20,6 @@
import java.util.Objects;
import org.apache.lucene.document.FeatureField.FeatureFunction;
import org.apache.lucene.index.ImpactsEnum;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
@@ -52,12 +51,12 @@ final class FeatureQuery extends Query {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- FeatureFunction rewritten = function.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ FeatureFunction rewritten = function.rewrite(indexSearcher);
if (function != rewritten) {
return new FeatureQuery(fieldName, featureName, rewritten);
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/document/FloatRangeSlowRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/FloatRangeSlowRangeQuery.java
index 9c011621e066..62981f4870aa 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FloatRangeSlowRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FloatRangeSlowRangeQuery.java
@@ -20,7 +20,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
@@ -79,8 +79,8 @@ public String toString(String field) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- return super.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ return super.rewrite(indexSearcher);
}
private static byte[] encodeRanges(float[] min, float[] max) {
diff --git a/lucene/core/src/java/org/apache/lucene/document/IntRangeSlowRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/IntRangeSlowRangeQuery.java
index cd0714a52909..99d5fa212f67 100644
--- a/lucene/core/src/java/org/apache/lucene/document/IntRangeSlowRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/IntRangeSlowRangeQuery.java
@@ -19,7 +19,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
@@ -77,8 +77,8 @@ public String toString(String field) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- return super.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ return super.rewrite(indexSearcher);
}
private static byte[] encodeRanges(int[] min, int[] max) {
diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java
index fabbc5259e3b..87cb6a9f056e 100644
--- a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java
@@ -17,6 +17,7 @@
package org.apache.lucene.document;
+import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@@ -100,7 +101,7 @@ public static FieldType createFieldType(
public KnnByteVectorField(
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
- fieldsData = vector;
+ fieldsData = vector; // null-check done above
}
/**
@@ -136,6 +137,11 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) {
+ " using byte[] but the field encoding is "
+ fieldType.vectorEncoding());
}
+ Objects.requireNonNull(vector, "vector value must not be null");
+ if (vector.length != fieldType.vectorDimension()) {
+ throw new IllegalArgumentException(
+ "The number of vector dimensions does not match the field type");
+ }
fieldsData = vector;
}
diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java
index d6673293c720..9d1cd02c013e 100644
--- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java
+++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java
@@ -17,6 +17,7 @@
package org.apache.lucene.document;
+import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@@ -101,7 +102,7 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) {
public KnnFloatVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
- fieldsData = vector;
+ fieldsData = VectorUtil.checkFinite(vector); // null check done above
}
/**
@@ -137,7 +138,12 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) {
+ " using float[] but the field encoding is "
+ fieldType.vectorEncoding());
}
- fieldsData = vector;
+ Objects.requireNonNull(vector, "vector value must not be null");
+ if (vector.length != fieldType.vectorDimension()) {
+ throw new IllegalArgumentException(
+ "The number of vector dimensions does not match the field type");
+ }
+ fieldsData = VectorUtil.checkFinite(vector);
}
/** Return the vector value of this field */
diff --git a/lucene/core/src/java/org/apache/lucene/document/LongRangeSlowRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/LongRangeSlowRangeQuery.java
index a4c164524fc5..b86d64bef62c 100644
--- a/lucene/core/src/java/org/apache/lucene/document/LongRangeSlowRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/LongRangeSlowRangeQuery.java
@@ -20,7 +20,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
@@ -79,8 +79,8 @@ public String toString(String field) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- return super.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ return super.rewrite(indexSearcher);
}
private static byte[] encodeRanges(long[] min, long[] max) {
diff --git a/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesRangeQuery.java
index fb23056126bb..21264409136f 100644
--- a/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesRangeQuery.java
@@ -19,7 +19,6 @@
import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.index.DocValues;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
@@ -84,11 +83,11 @@ public String toString(String field) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (lowerValue == Long.MIN_VALUE && upperValue == Long.MAX_VALUE) {
return new FieldExistsQuery(field);
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesSetQuery.java b/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesSetQuery.java
index 72d5d42b30a3..fe5e8c3afedd 100644
--- a/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesSetQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesSetQuery.java
@@ -20,7 +20,6 @@
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.DocValues;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
@@ -85,11 +84,11 @@ public long ramBytesUsed() {
}
@Override
- public Query rewrite(IndexReader indexReader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (numbers.size() == 0) {
return new MatchNoDocsQuery();
}
- return super.rewrite(indexReader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/document/SortedSetDocValuesRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/SortedSetDocValuesRangeQuery.java
index f7eab990d3d5..928257cbd1fa 100644
--- a/lucene/core/src/java/org/apache/lucene/document/SortedSetDocValuesRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/SortedSetDocValuesRangeQuery.java
@@ -19,7 +19,6 @@
import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.index.DocValues;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
@@ -98,11 +97,11 @@ public String toString(String field) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (lowerValue == null && upperValue == null) {
return new FieldExistsQuery(field);
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/geo/Tessellator.java b/lucene/core/src/java/org/apache/lucene/geo/Tessellator.java
index 5179037fa54f..ae0190f9d459 100644
--- a/lucene/core/src/java/org/apache/lucene/geo/Tessellator.java
+++ b/lucene/core/src/java/org/apache/lucene/geo/Tessellator.java
@@ -1090,17 +1090,22 @@ private static final Node splitPolygon(final Node a, final Node b, boolean edgeF
* Determines whether a diagonal between two polygon nodes lies within a polygon interior. (This
* determines the validity of the ray.) *
*/
- private static final boolean isValidDiagonal(final Node a, final Node b) {
+ private static boolean isValidDiagonal(final Node a, final Node b) {
+ if (a.next.idx == b.idx
+ || a.previous.idx == b.idx
+ // check next edges are locally visible
+ || isLocallyInside(a.previous, b) == false
+ || isLocallyInside(b.next, a) == false
+ // check polygons are CCW in both sides
+ || isCWPolygon(a, b) == false
+ || isCWPolygon(b, a) == false) {
+ return false;
+ }
if (isVertexEquals(a, b)) {
- // If points are equal then use it if they are valid polygons
- return isCWPolygon(a, b);
+ return true;
}
- return a.next.idx != b.idx
- && a.previous.idx != b.idx
- && isLocallyInside(a, b)
+ return isLocallyInside(a, b)
&& isLocallyInside(b, a)
- && isLocallyInside(a.previous, b)
- && isLocallyInside(b.next, a)
&& middleInsert(a, a.getX(), a.getY(), b.getX(), b.getY())
// make sure we don't introduce collinear lines
&& area(a.previous.getX(), a.previous.getY(), a.getX(), a.getY(), b.getX(), b.getY()) != 0
@@ -1114,7 +1119,7 @@ && area(a.getX(), a.getY(), b.getX(), b.getY(), b.previous.getX(), b.previous.ge
/** Determine whether the polygon defined between node start and node end is CW */
private static boolean isCWPolygon(final Node start, final Node end) {
// The polygon must be CW
- return (signedArea(start, end) < 0) ? true : false;
+ return signedArea(start, end) < 0;
}
/** Determine the signed area between node start and node end */
diff --git a/lucene/core/src/java/org/apache/lucene/index/CachingMergeContext.java b/lucene/core/src/java/org/apache/lucene/index/CachingMergeContext.java
new file mode 100644
index 000000000000..a9974c64e0b1
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/index/CachingMergeContext.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Set;
+import org.apache.lucene.util.InfoStream;
+
+/**
+ * a wrapper of IndexWriter MergeContext. Try to cache the {@link
+ * #numDeletesToMerge(SegmentCommitInfo)} result in merge phase, to avoid duplicate calculation
+ */
+class CachingMergeContext implements MergePolicy.MergeContext {
+ final MergePolicy.MergeContext mergeContext;
+ final HashMap cachedNumDeletesToMerge = new HashMap<>();
+
+ CachingMergeContext(MergePolicy.MergeContext mergeContext) {
+ this.mergeContext = mergeContext;
+ }
+
+ @Override
+ public final int numDeletesToMerge(SegmentCommitInfo info) throws IOException {
+ Integer numDeletesToMerge = cachedNumDeletesToMerge.get(info);
+ if (numDeletesToMerge != null) {
+ return numDeletesToMerge;
+ }
+ numDeletesToMerge = mergeContext.numDeletesToMerge(info);
+ cachedNumDeletesToMerge.put(info, numDeletesToMerge);
+ return numDeletesToMerge;
+ }
+
+ @Override
+ public final int numDeletedDocs(SegmentCommitInfo info) {
+ return mergeContext.numDeletedDocs(info);
+ }
+
+ @Override
+ public final InfoStream getInfoStream() {
+ return mergeContext.getInfoStream();
+ }
+
+ @Override
+ public final Set getMergingSegments() {
+ return mergeContext.getMergingSegments();
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/index/DocumentsWriterDeleteQueue.java b/lucene/core/src/java/org/apache/lucene/index/DocumentsWriterDeleteQueue.java
index ed86416a2b19..4d8bfd054654 100644
--- a/lucene/core/src/java/org/apache/lucene/index/DocumentsWriterDeleteQueue.java
+++ b/lucene/core/src/java/org/apache/lucene/index/DocumentsWriterDeleteQueue.java
@@ -142,6 +142,10 @@ static Node newNode(Term term) {
return new TermNode(term);
}
+ static Node newNode(Query query) {
+ return new QueryNode(query);
+ }
+
static Node newNode(DocValuesUpdate... updates) {
return new DocValuesUpdatesNode(updates);
}
@@ -437,6 +441,23 @@ public String toString() {
}
}
+ private static final class QueryNode extends Node {
+
+ QueryNode(Query query) {
+ super(query);
+ }
+
+ @Override
+ void apply(BufferedUpdates bufferedDeletes, int docIDUpto) {
+ bufferedDeletes.addQuery(item, docIDUpto);
+ }
+
+ @Override
+ public String toString() {
+ return "del=" + item;
+ }
+ }
+
private static final class QueryArrayNode extends Node {
QueryArrayNode(Query[] query) {
super(query);
diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
index b180d3bdda3d..a572b2258af5 100644
--- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
@@ -36,7 +36,7 @@
*/
public class ExitableDirectoryReader extends FilterDirectoryReader {
- private QueryTimeout queryTimeout;
+ private final QueryTimeout queryTimeout;
/** Exception that is thrown to prematurely terminate a term enumeration. */
@SuppressWarnings("serial")
@@ -50,7 +50,7 @@ public ExitingReaderException(String msg) {
/** Wrapper class for a SubReaderWrapper that is used by the ExitableDirectoryReader. */
public static class ExitableSubReaderWrapper extends SubReaderWrapper {
- private QueryTimeout queryTimeout;
+ private final QueryTimeout queryTimeout;
/** Constructor * */
public ExitableSubReaderWrapper(QueryTimeout queryTimeout) {
@@ -810,7 +810,7 @@ public static class ExitableTermsEnum extends FilterTermsEnum {
// Create bit mask in the form of 0000 1111 for efficient checking
private static final int NUM_CALLS_PER_TIMEOUT_CHECK = (1 << 4) - 1; // 15
private int calls;
- private QueryTimeout queryTimeout;
+ private final QueryTimeout queryTimeout;
/** Constructor * */
public ExitableTermsEnum(TermsEnum termsEnum, QueryTimeout queryTimeout) {
diff --git a/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java b/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java
index c48b851ae526..df2054c18846 100644
--- a/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java
+++ b/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java
@@ -400,6 +400,150 @@ static void verifySameVectorOptions(
}
}
+ /*
+ This method will create a new instance of FieldInfo if any attribute changes (and it changes in a compatible way).
+ It is intended only to be used in indices where schema validation is not strict (legacy indices). It will return null
+ if no changes are done on this FieldInfo
+ */
+ FieldInfo handleLegacySupportedUpdates(FieldInfo otherFi) {
+ IndexOptions newIndexOptions = this.indexOptions;
+ boolean newStoreTermVector = this.storeTermVector;
+ boolean newOmitNorms = this.omitNorms;
+ boolean newStorePayloads = this.storePayloads;
+ DocValuesType newDocValues = this.docValuesType;
+ int newPointDimensionCount = this.pointDimensionCount;
+ int newPointNumBytes = this.pointNumBytes;
+ int newPointIndexDimensionCount = this.pointIndexDimensionCount;
+ long newDvGen = this.dvGen;
+
+ boolean fieldInfoChanges = false;
+ // System.out.println("FI.update field=" + name + " indexed=" + indexed + " omitNorms=" +
+ // omitNorms + " this.omitNorms=" + this.omitNorms);
+ if (this.indexOptions != otherFi.indexOptions) {
+ if (this.indexOptions == IndexOptions.NONE) {
+ newIndexOptions = otherFi.indexOptions;
+ fieldInfoChanges = true;
+ } else if (otherFi.indexOptions != IndexOptions.NONE) {
+ throw new IllegalArgumentException(
+ "cannot change field \""
+ + name
+ + "\" from index options="
+ + this.indexOptions
+ + " to inconsistent index options="
+ + otherFi.indexOptions);
+ }
+ }
+
+ if (this.pointDimensionCount != otherFi.pointDimensionCount
+ && otherFi.pointDimensionCount != 0) {
+ if (this.pointDimensionCount == 0) {
+ fieldInfoChanges = true;
+ newPointDimensionCount = otherFi.pointDimensionCount;
+ } else {
+ throw new IllegalArgumentException(
+ "cannot change field \""
+ + name
+ + "\" from points dimensionCount="
+ + this.pointDimensionCount
+ + " to inconsistent dimensionCount="
+ + otherFi.pointDimensionCount);
+ }
+ }
+
+ if (this.pointIndexDimensionCount != otherFi.pointIndexDimensionCount
+ && otherFi.pointIndexDimensionCount != 0) {
+ if (this.pointIndexDimensionCount == 0) {
+ fieldInfoChanges = true;
+ newPointIndexDimensionCount = otherFi.pointIndexDimensionCount;
+ } else {
+ throw new IllegalArgumentException(
+ "cannot change field \""
+ + name
+ + "\" from points indexDimensionCount="
+ + this.pointIndexDimensionCount
+ + " to inconsistent indexDimensionCount="
+ + otherFi.pointIndexDimensionCount);
+ }
+ }
+
+ if (this.pointNumBytes != otherFi.pointNumBytes && otherFi.pointNumBytes != 0) {
+ if (this.pointNumBytes == 0) {
+ fieldInfoChanges = true;
+ newPointNumBytes = otherFi.pointNumBytes;
+ } else {
+ throw new IllegalArgumentException(
+ "cannot change field \""
+ + name
+ + "\" from points numBytes="
+ + this.pointNumBytes
+ + " to inconsistent numBytes="
+ + otherFi.pointNumBytes);
+ }
+ }
+
+ if (newIndexOptions
+ != IndexOptions.NONE) { // if updated field data is not for indexing, leave the updates out
+ if (this.storeTermVector != otherFi.storeTermVector && this.storeTermVector == false) {
+ fieldInfoChanges = true;
+ newStoreTermVector = true; // once vector, always vector
+ }
+ if (this.storePayloads != otherFi.storePayloads
+ && this.storePayloads == false
+ && newIndexOptions.compareTo(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS) >= 0) {
+ fieldInfoChanges = true;
+ newStorePayloads = true;
+ }
+
+ // Awkward: only drop norms if incoming update is indexed:
+ if (otherFi.indexOptions != IndexOptions.NONE
+ && this.omitNorms != otherFi.omitNorms
+ && this.omitNorms == false) {
+ fieldInfoChanges = true;
+ newOmitNorms = true; // if one require omitNorms at least once, it remains off for life
+ }
+ }
+
+ if (otherFi.docValuesType != DocValuesType.NONE
+ && otherFi.docValuesType != this.docValuesType) {
+ if (this.docValuesType == DocValuesType.NONE) {
+ fieldInfoChanges = true;
+ newDocValues = otherFi.docValuesType;
+ newDvGen = otherFi.dvGen;
+ } else {
+ throw new IllegalArgumentException(
+ "cannot change DocValues type from "
+ + docValuesType
+ + " to "
+ + otherFi.docValuesType
+ + " for field \""
+ + name
+ + "\"");
+ }
+ }
+
+ if (!fieldInfoChanges) {
+ return null;
+ }
+ return new FieldInfo(
+ this.name,
+ this.number,
+ newStoreTermVector,
+ newOmitNorms,
+ newStorePayloads,
+ newIndexOptions,
+ newDocValues,
+ newDvGen,
+ this.attributes, // attributes don't need to be handled here because they are handled for
+ // the non-legacy case in FieldInfos
+ newPointDimensionCount,
+ newPointIndexDimensionCount,
+ newPointNumBytes,
+ this.vectorDimension,
+ this.vectorEncoding,
+ this.vectorSimilarityFunction,
+ this.softDeletesField);
+ }
+
/**
* Record that this field is indexed with points, with the specified number of dimensions and
* bytes per dimension.
diff --git a/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java b/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java
index 7029303992ed..85aac5439d3c 100644
--- a/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java
+++ b/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java
@@ -628,6 +628,48 @@ FieldInfo constructFieldInfo(String fieldName, DocValuesType dvType, int newFiel
isSoftDeletesField);
}
+ synchronized void setDocValuesType(int number, String name, DocValuesType dvType) {
+ verifyConsistent(number, name, dvType);
+ docValuesType.put(name, dvType);
+ }
+
+ synchronized void verifyConsistent(Integer number, String name, DocValuesType dvType) {
+ if (name.equals(numberToName.get(number)) == false) {
+ throw new IllegalArgumentException(
+ "field number "
+ + number
+ + " is already mapped to field name \""
+ + numberToName.get(number)
+ + "\", not \""
+ + name
+ + "\"");
+ }
+ if (number.equals(nameToNumber.get(name)) == false) {
+ throw new IllegalArgumentException(
+ "field name \""
+ + name
+ + "\" is already mapped to field number \""
+ + nameToNumber.get(name)
+ + "\", not \""
+ + number
+ + "\"");
+ }
+ DocValuesType currentDVType = docValuesType.get(name);
+ if (dvType != DocValuesType.NONE
+ && currentDVType != null
+ && currentDVType != DocValuesType.NONE
+ && dvType != currentDVType) {
+ throw new IllegalArgumentException(
+ "cannot change DocValues type from "
+ + currentDVType
+ + " to "
+ + dvType
+ + " for field \""
+ + name
+ + "\"");
+ }
+ }
+
synchronized Set getFieldNames() {
return Set.copyOf(nameToNumber.keySet());
}
@@ -708,6 +750,26 @@ FieldInfo add(FieldInfo fi, long dvGen) {
final FieldInfo curFi = fieldInfo(fi.getName());
if (curFi != null) {
curFi.verifySameSchema(fi, globalFieldNumbers.strictlyConsistent);
+
+ if (!globalFieldNumbers.strictlyConsistent) {
+ // For the not strictly consistent case (legacy index), we may need to merge the
+ // FieldInfo instances
+ FieldInfo updatedFieldInfo = curFi.handleLegacySupportedUpdates(fi);
+ if (updatedFieldInfo != null) {
+ if (curFi.getDocValuesType() == DocValuesType.NONE
+ && updatedFieldInfo.getDocValuesType() != DocValuesType.NONE) {
+ // Must also update docValuesType map so it's
+ // aware of this field's DocValuesType. This will throw IllegalArgumentException if
+ // an illegal type change was attempted.
+ globalFieldNumbers.setDocValuesType(
+ updatedFieldInfo.number,
+ updatedFieldInfo.getName(),
+ updatedFieldInfo.getDocValuesType());
+ }
+ // Since the FieldInfo changed, update in map
+ byName.put(fi.getName(), updatedFieldInfo);
+ }
+ }
if (fi.attributes() != null) {
fi.attributes().forEach((k, v) -> curFi.putAttribute(k, v));
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java b/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java
index f488e1f83f62..07dca8945548 100644
--- a/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java
@@ -1522,6 +1522,19 @@ public long updateDocuments(
delTerm == null ? null : DocumentsWriterDeleteQueue.newNode(delTerm), docs);
}
+ /**
+ * Similar to {@link #updateDocuments(Term, Iterable)}, but take a query instead of a term to
+ * identify the documents to be updated
+ *
+ * @lucene.experimental
+ */
+ public long updateDocuments(
+ Query delQuery, Iterable extends Iterable extends IndexableField>> docs)
+ throws IOException {
+ return updateDocuments(
+ delQuery == null ? null : DocumentsWriterDeleteQueue.newNode(delQuery), docs);
+ }
+
private long updateDocuments(
final DocumentsWriterDeleteQueue.Node> delNode,
Iterable extends Iterable extends IndexableField>> docs)
@@ -2202,10 +2215,11 @@ public void forceMergeDeletes(boolean doWait) throws IOException {
}
final MergePolicy mergePolicy = config.getMergePolicy();
+ final CachingMergeContext cachingMergeContext = new CachingMergeContext(this);
MergePolicy.MergeSpecification spec;
boolean newMergesFound = false;
synchronized (this) {
- spec = mergePolicy.findForcedDeletesMerges(segmentInfos, this);
+ spec = mergePolicy.findForcedDeletesMerges(segmentInfos, cachingMergeContext);
newMergesFound = spec != null;
if (newMergesFound) {
final int numMerges = spec.merges.size();
@@ -2315,6 +2329,7 @@ private synchronized MergePolicy.MergeSpecification updatePendingMerges(
}
final MergePolicy.MergeSpecification spec;
+ final CachingMergeContext cachingMergeContext = new CachingMergeContext(this);
if (maxNumSegments != UNBOUNDED_MAX_MERGE_SEGMENTS) {
assert trigger == MergeTrigger.EXPLICIT || trigger == MergeTrigger.MERGE_FINISHED
: "Expected EXPLICT or MERGE_FINISHED as trigger even with maxNumSegments set but was: "
@@ -2322,7 +2337,10 @@ private synchronized MergePolicy.MergeSpecification updatePendingMerges(
spec =
mergePolicy.findForcedMerges(
- segmentInfos, maxNumSegments, Collections.unmodifiableMap(segmentsToMerge), this);
+ segmentInfos,
+ maxNumSegments,
+ Collections.unmodifiableMap(segmentsToMerge),
+ cachingMergeContext);
if (spec != null) {
final int numMerges = spec.merges.size();
for (int i = 0; i < numMerges; i++) {
@@ -2334,7 +2352,7 @@ private synchronized MergePolicy.MergeSpecification updatePendingMerges(
switch (trigger) {
case GET_READER:
case COMMIT:
- spec = mergePolicy.findFullFlushMerges(trigger, segmentInfos, this);
+ spec = mergePolicy.findFullFlushMerges(trigger, segmentInfos, cachingMergeContext);
break;
case ADD_INDEXES:
throw new IllegalStateException(
@@ -2346,7 +2364,7 @@ private synchronized MergePolicy.MergeSpecification updatePendingMerges(
case SEGMENT_FLUSH:
case CLOSING:
default:
- spec = mergePolicy.findMerges(trigger, segmentInfos, this);
+ spec = mergePolicy.findMerges(trigger, segmentInfos, cachingMergeContext);
}
}
if (spec != null) {
diff --git a/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java b/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java
index 0c64f4c2c9ac..f1e543423670 100644
--- a/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java
+++ b/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java
@@ -17,14 +17,17 @@
package org.apache.lucene.index;
/**
- * Base for query timeout implementations, which will provide a {@code shouldExit()} method, used
- * with {@link ExitableDirectoryReader}.
+ * Query timeout abstraction that controls whether a query should continue or be stopped. Can be set
+ * to the searcher through {@link org.apache.lucene.search.IndexSearcher#setTimeout(QueryTimeout)},
+ * in which case bulk scoring will be time-bound. Can also be used in combination with {@link
+ * ExitableDirectoryReader}.
*/
public interface QueryTimeout {
/**
- * Called from {@link ExitableDirectoryReader.ExitableTermsEnum#next()} to determine whether to
- * stop processing a query.
+ * Called to determine whether to stop processing a query
+ *
+ * @return true if the query should stop, false otherwise
*/
- public abstract boolean shouldExit();
+ boolean shouldExit();
}
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
index 3646cf65584b..8a515cb79fc9 100644
--- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java
@@ -90,8 +90,8 @@ public float compare(byte[] v1, byte[] v2) {
/**
* Calculates a similarity score between the two vectors with a specified function. Higher
- * similarity scores correspond to closer vectors. The offsets and lengths of the BytesRefs
- * determine the vector data that is compared. Each (signed) byte represents a vector dimension.
+ * similarity scores correspond to closer vectors. Each (signed) byte represents a vector
+ * dimension.
*
* @param v1 a vector
* @param v2 another vector, of the same dimension
diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
index eaa7cc1e8337..072846785723 100644
--- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java
@@ -19,9 +19,13 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
+import java.util.List;
import java.util.Objects;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.FutureTask;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
@@ -29,6 +33,7 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.ThreadInterruptedException;
/**
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@@ -51,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query {
private final Query filter;
public AbstractKnnVectorQuery(String field, int k, Query filter) {
- this.field = field;
+ this.field = Objects.requireNonNull(field, "field");
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
@@ -60,12 +65,11 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ IndexReader reader = indexSearcher.getIndexReader();
- Weight filterWeight = null;
+ final Weight filterWeight;
if (filter != null) {
- IndexSearcher indexSearcher = new IndexSearcher(reader);
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
@@ -73,17 +77,17 @@ public Query rewrite(IndexReader reader) throws IOException {
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
+ } else {
+ filterWeight = null;
}
- for (LeafReaderContext ctx : reader.leaves()) {
- TopDocs results = searchLeaf(ctx, filterWeight);
- if (ctx.docBase > 0) {
- for (ScoreDoc scoreDoc : results.scoreDocs) {
- scoreDoc.doc += ctx.docBase;
- }
- }
- perLeafResults[ctx.ord] = results;
- }
+ SliceExecutor sliceExecutor = indexSearcher.getSliceExecutor();
+ // in case of parallel execution, the leaf results are not ordered by leaf context's ordinal
+ TopDocs[] perLeafResults =
+ (sliceExecutor == null)
+ ? sequentialSearch(reader.leaves(), filterWeight)
+ : parallelSearch(indexSearcher.getSlices(), filterWeight, sliceExecutor);
+
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
@@ -92,7 +96,67 @@ public Query rewrite(IndexReader reader) throws IOException {
return createRewrittenQuery(reader, topK);
}
+ private TopDocs[] sequentialSearch(
+ List leafReaderContexts, Weight filterWeight) {
+ try {
+ TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
+ for (LeafReaderContext ctx : leafReaderContexts) {
+ perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
+ }
+ return perLeafResults;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private TopDocs[] parallelSearch(
+ IndexSearcher.LeafSlice[] slices, Weight filterWeight, SliceExecutor sliceExecutor) {
+
+ List> tasks = new ArrayList<>(slices.length);
+ int segmentsCount = 0;
+ for (IndexSearcher.LeafSlice slice : slices) {
+ segmentsCount += slice.leaves.length;
+ tasks.add(
+ new FutureTask<>(
+ () -> {
+ TopDocs[] results = new TopDocs[slice.leaves.length];
+ int i = 0;
+ for (LeafReaderContext context : slice.leaves) {
+ results[i++] = searchLeaf(context, filterWeight);
+ }
+ return results;
+ }));
+ }
+
+ sliceExecutor.invokeAll(tasks);
+
+ TopDocs[] topDocs = new TopDocs[segmentsCount];
+ int i = 0;
+ for (FutureTask task : tasks) {
+ try {
+ for (TopDocs docs : task.get()) {
+ topDocs[i++] = docs;
+ }
+ } catch (ExecutionException e) {
+ throw new RuntimeException(e.getCause());
+ } catch (InterruptedException e) {
+ throw new ThreadInterruptedException(e);
+ }
+ }
+ return topDocs;
+ }
+
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
+ TopDocs results = getLeafResults(ctx, filterWeight);
+ if (ctx.docBase > 0) {
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ scoreDoc.doc += ctx.docBase;
+ }
+ }
+ return results;
+ }
+
+ private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
Bits liveDocs = ctx.reader().getLiveDocs();
int maxDoc = ctx.reader().maxDoc();
diff --git a/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java b/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java
index 8b73eb5b27a9..2c7e41ac9719 100644
--- a/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java
@@ -19,7 +19,6 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
@@ -268,11 +267,12 @@ public String toString(String field) {
}
@Override
- public final Query rewrite(IndexReader reader) throws IOException {
+ public final Query rewrite(IndexSearcher indexSearcher) throws IOException {
final TermStates[] contexts = ArrayUtil.copyOfSubArray(this.contexts, 0, this.contexts.length);
for (int i = 0; i < contexts.length; ++i) {
- if (contexts[i] == null || contexts[i].wasBuiltFor(reader.getContext()) == false) {
- contexts[i] = TermStates.build(reader.getContext(), terms[i], true);
+ if (contexts[i] == null
+ || contexts[i].wasBuiltFor(indexSearcher.getTopReaderContext()) == false) {
+ contexts[i] = TermStates.build(indexSearcher.getTopReaderContext(), terms[i], true);
}
}
@@ -287,7 +287,7 @@ public final Query rewrite(IndexReader reader) throws IOException {
}
for (int i = 0; i < contexts.length; ++i) {
- contexts[i] = adjustFrequencies(reader.getContext(), contexts[i], df, ttf);
+ contexts[i] = adjustFrequencies(indexSearcher.getTopReaderContext(), contexts[i], df, ttf);
}
Query[] termQueries = new Query[terms.length];
diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java
index f5c69621b15c..07823ae5c4d7 100644
--- a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java
@@ -30,7 +30,6 @@
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause.Occur;
/**
@@ -247,7 +246,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (clauses.size() == 0) {
return new MatchNoDocsQuery("empty BooleanQuery");
}
@@ -286,12 +285,12 @@ public Query rewrite(IndexReader reader) throws IOException {
Query rewritten;
if (occur == Occur.FILTER || occur == Occur.MUST_NOT) {
// Clauses that are not involved in scoring can get some extra simplifications
- rewritten = new ConstantScoreQuery(query).rewrite(reader);
+ rewritten = new ConstantScoreQuery(query).rewrite(indexSearcher);
if (rewritten instanceof ConstantScoreQuery) {
rewritten = ((ConstantScoreQuery) rewritten).getQuery();
}
} else {
- rewritten = query.rewrite(reader);
+ rewritten = query.rewrite(indexSearcher);
}
if (rewritten != query || query.getClass() == MatchNoDocsQuery.class) {
// rewrite clause
@@ -566,7 +565,7 @@ public Query rewrite(IndexReader reader) throws IOException {
}
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java b/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java
index 47c0f5f6f285..375269637000 100644
--- a/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java
@@ -18,7 +18,6 @@
import java.io.IOException;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
/**
* A {@link Query} wrapper that allows to give a boost to the wrapped query. Boost values that are
@@ -73,8 +72,8 @@ public int hashCode() {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- final Query rewritten = query.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ final Query rewritten = query.rewrite(indexSearcher);
if (boost == 1f) {
return rewritten;
@@ -99,7 +98,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new BoostQuery(rewritten, boost);
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java b/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java
index 2a7e04447481..89f14fff20bc 100644
--- a/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java
+++ b/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java
@@ -31,4 +31,10 @@ public final class CollectionTerminatedException extends RuntimeException {
public CollectionTerminatedException() {
super();
}
+
+ @Override
+ public Throwable fillInStackTrace() {
+ // never re-thrown so we can save the expensive stacktrace
+ return this;
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java b/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java
index b70224f1ec50..03091b748b1c 100644
--- a/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java
+++ b/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java
@@ -20,7 +20,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.Comparator;
import java.util.List;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitSet;
@@ -99,20 +98,22 @@ static DocIdSetIterator createConjunction(
allIterators.size() > 0
? allIterators.get(0).docID()
: twoPhaseIterators.get(0).approximation.docID();
- boolean iteratorsOnTheSameDoc = allIterators.stream().allMatch(it -> it.docID() == curDoc);
- iteratorsOnTheSameDoc =
- iteratorsOnTheSameDoc
- && twoPhaseIterators.stream().allMatch(it -> it.approximation().docID() == curDoc);
- if (iteratorsOnTheSameDoc == false) {
- throw new IllegalArgumentException(
- "Sub-iterators of ConjunctionDISI are not on the same document!");
+ long minCost = Long.MAX_VALUE;
+ for (DocIdSetIterator allIterator : allIterators) {
+ if (allIterator.docID() != curDoc) {
+ throwSubIteratorsNotOnSameDocument();
+ }
+ minCost = Math.min(allIterator.cost(), minCost);
+ }
+ for (TwoPhaseIterator it : twoPhaseIterators) {
+ if (it.approximation().docID() != curDoc) {
+ throwSubIteratorsNotOnSameDocument();
+ }
}
-
- long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong();
List bitSetIterators = new ArrayList<>();
List iterators = new ArrayList<>();
for (DocIdSetIterator iterator : allIterators) {
- if (iterator.cost() > minCost && iterator instanceof BitSetIterator) {
+ if (iterator instanceof BitSetIterator && iterator.cost() > minCost) {
// we put all bitset iterators into bitSetIterators
// except if they have the minimum cost, since we need
// them to lead the iteration in that case
@@ -142,6 +143,11 @@ static DocIdSetIterator createConjunction(
return disi;
}
+ private static void throwSubIteratorsNotOnSameDocument() {
+ throw new IllegalArgumentException(
+ "Sub-iterators of ConjunctionDISI are not on the same document!");
+ }
+
final DocIdSetIterator lead1, lead2;
final DocIdSetIterator[] others;
@@ -150,14 +156,7 @@ private ConjunctionDISI(List extends DocIdSetIterator> iterators) {
// Sort the array the first time to allow the least frequent DocsEnum to
// lead the matching.
- CollectionUtil.timSort(
- iterators,
- new Comparator() {
- @Override
- public int compare(DocIdSetIterator o1, DocIdSetIterator o2) {
- return Long.compare(o1.cost(), o2.cost());
- }
- });
+ CollectionUtil.timSort(iterators, (o1, o2) -> Long.compare(o1.cost(), o2.cost()));
lead1 = iterators.get(0);
lead2 = iterators.get(1);
others = iterators.subList(2, iterators.size()).toArray(new DocIdSetIterator[0]);
@@ -326,13 +325,7 @@ private ConjunctionTwoPhaseIterator(
assert twoPhaseIterators.size() > 0;
CollectionUtil.timSort(
- twoPhaseIterators,
- new Comparator() {
- @Override
- public int compare(TwoPhaseIterator o1, TwoPhaseIterator o2) {
- return Float.compare(o1.matchCost(), o2.matchCost());
- }
- });
+ twoPhaseIterators, (o1, o2) -> Float.compare(o1.matchCost(), o2.matchCost()));
this.twoPhaseIterators =
twoPhaseIterators.toArray(new TwoPhaseIterator[twoPhaseIterators.size()]);
diff --git a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java
index 2225cc109444..48f763e16460 100644
--- a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java
@@ -18,7 +18,6 @@
import java.io.IOException;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.Bits;
@@ -40,8 +39,9 @@ public Query getQuery() {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- Query rewritten = query.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+
+ Query rewritten = query.rewrite(indexSearcher);
// Do some extra simplifications that are legal since scores are not needed on the wrapped
// query.
@@ -70,7 +70,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new ConstantScoreQuery(((BoostQuery) rewritten).getQuery());
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java
index 8c41c0a37cf0..1ab4f3b50818 100644
--- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java
@@ -23,7 +23,6 @@
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
/**
@@ -209,11 +208,10 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
/**
* Optimize our representation and our subqueries representations
*
- * @param reader the IndexReader we query
* @return an optimized copy of us (which may not be a copy if there is nothing to optimize)
*/
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (disjuncts.isEmpty()) {
return new MatchNoDocsQuery("empty DisjunctionMaxQuery");
}
@@ -233,7 +231,7 @@ public Query rewrite(IndexReader reader) throws IOException {
boolean actuallyRewritten = false;
List rewrittenDisjuncts = new ArrayList<>();
for (Query sub : disjuncts) {
- Query rewrittenSub = sub.rewrite(reader);
+ Query rewrittenSub = sub.rewrite(indexSearcher);
actuallyRewritten |= rewrittenSub != sub;
rewrittenDisjuncts.add(rewrittenSub);
}
@@ -242,7 +240,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new DisjunctionMaxQuery(rewrittenDisjuncts, tieBreakerMultiplier);
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java
index 303dc516af8a..f27b791dd95a 100644
--- a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java
+++ b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java
@@ -22,7 +22,6 @@
import java.util.function.DoubleToLongFunction;
import java.util.function.LongToDoubleFunction;
import org.apache.lucene.index.DocValues;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.comparators.DoubleComparator;
@@ -85,7 +84,7 @@ public Explanation explain(LeafReaderContext ctx, int docId, Explanation scoreEx
*
*
Queries that use DoubleValuesSource objects should call rewrite() during {@link
* Query#createWeight(IndexSearcher, ScoreMode, float)} rather than during {@link
- * Query#rewrite(IndexReader)} to avoid IndexReader reference leakage.
+ * Query#rewrite(IndexSearcher)} to avoid IndexReader reference leakage.
*
*
For the same reason, implementations that cache references to the IndexSearcher should
* return a new object from this method.
diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java
index 98c65b9ddd66..3d78df8e07c3 100644
--- a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java
@@ -108,7 +108,8 @@ public int hashCode() {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ IndexReader reader = indexSearcher.getIndexReader();
boolean allReadersRewritable = true;
for (LeafReaderContext context : reader.leaves()) {
@@ -172,7 +173,7 @@ public Query rewrite(IndexReader reader) throws IOException {
if (allReadersRewritable) {
return new MatchAllDocsQuery();
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java b/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java
index 9ba52eb674f4..599608f0d842 100644
--- a/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java
@@ -19,7 +19,6 @@
import java.io.IOException;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.document.SortedNumericDocValuesField;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
/**
@@ -101,9 +100,9 @@ public int hashCode() {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- Query indexRewrite = indexQuery.rewrite(reader);
- Query dvRewrite = dvQuery.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ Query indexRewrite = indexQuery.rewrite(indexSearcher);
+ Query dvRewrite = dvQuery.rewrite(indexSearcher);
if (indexRewrite.getClass() == MatchAllDocsQuery.class
|| dvRewrite.getClass() == MatchAllDocsQuery.class) {
return new MatchAllDocsQuery();
diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
index 5798ea5d2cff..f167b7161123 100644
--- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java
@@ -796,9 +796,9 @@ protected void search(List leaves, Weight weight, Collector c
*/
public Query rewrite(Query original) throws IOException {
Query query = original;
- for (Query rewrittenQuery = query.rewrite(reader);
+ for (Query rewrittenQuery = query.rewrite(this);
rewrittenQuery != query;
- rewrittenQuery = query.rewrite(reader)) {
+ rewrittenQuery = query.rewrite(this)) {
query = rewrittenQuery;
}
query.visit(getNumClausesCheckVisitor());
@@ -998,6 +998,10 @@ public Executor getExecutor() {
return executor;
}
+ SliceExecutor getSliceExecutor() {
+ return sliceExecutor;
+ }
+
/**
* Thrown when an attempt is made to add more than {@link #getMaxClauseCount()} clauses. This
* typically happens if a PrefixQuery, FuzzyQuery, WildcardQuery, or TermRangeQuery is expanded to
diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSortSortedNumericDocValuesRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/IndexSortSortedNumericDocValuesRangeQuery.java
index eb2cf562994c..b678d34192ce 100644
--- a/lucene/core/src/java/org/apache/lucene/search/IndexSortSortedNumericDocValuesRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/IndexSortSortedNumericDocValuesRangeQuery.java
@@ -23,7 +23,6 @@
import org.apache.lucene.document.IntPoint;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.DocValues;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
@@ -133,12 +132,12 @@ public String toString(String field) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (lowerValue == Long.MIN_VALUE && upperValue == Long.MAX_VALUE) {
return new FieldExistsQuery(field);
}
- Query rewrittenFallback = fallbackQuery.rewrite(reader);
+ Query rewrittenFallback = fallbackQuery.rewrite(indexSearcher);
if (rewrittenFallback.getClass() == MatchAllDocsQuery.class) {
return new MatchAllDocsQuery();
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
index 4ec617c24470..10345cd7adf4 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java
@@ -71,7 +71,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k) {
*/
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
super(field, k, filter);
- this.target = target;
+ this.target = Objects.requireNonNull(target, "target");
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
index 2b1b3a69582e..3036e7c45162 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java
@@ -18,6 +18,7 @@
import java.io.IOException;
import java.util.Arrays;
+import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
@@ -25,6 +26,7 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.VectorUtil;
/**
* Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest
@@ -70,7 +72,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) {
*/
public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
super(field, k, filter);
- this.target = target;
+ this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java
index 1bf34569c395..27819235f644 100644
--- a/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java
@@ -24,7 +24,6 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
@@ -185,7 +184,7 @@ public int[] getPositions() {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (termArrays.length == 0) {
return new MatchNoDocsQuery("empty MultiPhraseQuery");
} else if (termArrays.length == 1) { // optimize one-term case
@@ -196,7 +195,7 @@ public Query rewrite(IndexReader reader) throws IOException {
}
return builder.build();
} else {
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java
index 0dc6671dcccb..44bd4279aba2 100644
--- a/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java
@@ -18,9 +18,9 @@
import java.io.IOException;
import java.util.Objects;
-import org.apache.lucene.index.FilteredTermsEnum; // javadocs
+import org.apache.lucene.index.FilteredTermsEnum;
import org.apache.lucene.index.IndexReader;
-import org.apache.lucene.index.SingleTermsEnum; // javadocs
+import org.apache.lucene.index.SingleTermsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.index.Terms;
@@ -321,8 +321,8 @@ public long getTermsCount() throws IOException {
* AttributeSource)}. For example, to rewrite to a single term, return a {@link SingleTermsEnum}
*/
@Override
- public final Query rewrite(IndexReader reader) throws IOException {
- return rewriteMethod.rewrite(reader, this);
+ public final Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ return rewriteMethod.rewrite(indexSearcher.getIndexReader(), this);
}
public RewriteMethod getRewriteMethod() {
diff --git a/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java
index 41b0decba7c7..e5023bf032d6 100644
--- a/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java
@@ -18,14 +18,13 @@
import java.io.IOException;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
/**
* This is a {@link PhraseQuery} which is optimized for n-gram phrase query. For example, when you
* query "ABCD" on a 2-gram field, you may want to use NGramPhraseQuery rather than {@link
- * PhraseQuery}, because NGramPhraseQuery will {@link #rewrite(IndexReader)} the query to "AB/0
- * CD/2", while {@link PhraseQuery} will query "AB/0 BC/1 CD/2" (where term/position).
+ * PhraseQuery}, because NGramPhraseQuery will {@link Query#rewrite(IndexSearcher)} the query to
+ * "AB/0 CD/2", while {@link PhraseQuery} will query "AB/0 BC/1 CD/2" (where term/position).
*/
public class NGramPhraseQuery extends Query {
@@ -44,7 +43,7 @@ public NGramPhraseQuery(int n, PhraseQuery query) {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
final Term[] terms = phraseQuery.getTerms();
final int[] positions = phraseQuery.getPositions();
@@ -63,7 +62,7 @@ public Query rewrite(IndexReader reader) throws IOException {
}
if (isOptimizable == false) {
- return phraseQuery.rewrite(reader);
+ return phraseQuery.rewrite(indexSearcher);
}
PhraseQuery.Builder builder = new PhraseQuery.Builder();
diff --git a/lucene/core/src/java/org/apache/lucene/search/NamedMatches.java b/lucene/core/src/java/org/apache/lucene/search/NamedMatches.java
index 9a24f9433bae..d0ec5c3a2124 100644
--- a/lucene/core/src/java/org/apache/lucene/search/NamedMatches.java
+++ b/lucene/core/src/java/org/apache/lucene/search/NamedMatches.java
@@ -25,7 +25,6 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
/**
@@ -113,8 +112,8 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- Query rewritten = in.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ Query rewritten = in.rewrite(indexSearcher);
if (rewritten != in) {
return new NamedQuery(name, rewritten);
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java
index 93a2ace64531..643861651367 100644
--- a/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java
@@ -24,7 +24,6 @@
import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PostingsReader;
import org.apache.lucene.index.ImpactsEnum;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
@@ -284,7 +283,7 @@ public int[] getPositions() {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (terms.length == 0) {
return new MatchNoDocsQuery("empty PhraseQuery");
} else if (terms.length == 1) {
@@ -296,7 +295,7 @@ public Query rewrite(IndexReader reader) throws IOException {
}
return new PhraseQuery(slop, terms, newPositions);
} else {
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/Query.java b/lucene/core/src/java/org/apache/lucene/search/Query.java
index 4f04728395d9..c872c076b76f 100644
--- a/lucene/core/src/java/org/apache/lucene/search/Query.java
+++ b/lucene/core/src/java/org/apache/lucene/search/Query.java
@@ -17,7 +17,10 @@
package org.apache.lucene.search;
import java.io.IOException;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.util.VirtualMethod;
/**
* The abstract base class for queries.
@@ -41,10 +44,21 @@
*
*
*
See also additional queries available in the Queries module
+ * href="{@docRoot}/../queries/overview-summary.html">Queries module.
*/
public abstract class Query {
+ private static final VirtualMethod oldMethod =
+ new VirtualMethod<>(Query.class, "rewrite", IndexReader.class);
+ private static final VirtualMethod newMethod =
+ new VirtualMethod<>(Query.class, "rewrite", IndexSearcher.class);
+ private final boolean isDeprecatedRewriteMethodOverridden =
+ AccessController.doPrivileged(
+ (PrivilegedAction)
+ () ->
+ VirtualMethod.compareImplementationDistance(this.getClass(), oldMethod, newMethod)
+ > 0);
+
/**
* Prints a query to a string, with field assumed to be the default field and
* omitted.
@@ -77,14 +91,36 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
*
Callers are expected to call rewrite multiple times if necessary, until the
* rewritten query is the same as the original query.
*
+ * @deprecated Use {@link Query#rewrite(IndexSearcher)}
* @see IndexSearcher#rewrite(Query)
*/
+ @Deprecated
public Query rewrite(IndexReader reader) throws IOException {
- return this;
+ return isDeprecatedRewriteMethodOverridden ? this : rewrite(new IndexSearcher(reader));
+ }
+
+ /**
+ * Expert: called to re-write queries into primitive queries. For example, a PrefixQuery will be
+ * rewritten into a BooleanQuery that consists of TermQuerys.
+ *
+ *
Callers are expected to call rewrite multiple times if necessary, until the
+ * rewritten query is the same as the original query.
+ *
+ *
The rewrite process may be able to make use of IndexSearcher's executor and be executed in
+ * parallel if the executor is provided.
+ *
+ *
However, if any of the intermediary queries do not satisfy the new API, parallel rewrite is
+ * not possible for any subsequent sub-queries. To take advantage of this API, the entire query
+ * tree must override this method.
+ *
+ * @see IndexSearcher#rewrite(Query)
+ */
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ return isDeprecatedRewriteMethodOverridden ? rewrite(indexSearcher.getIndexReader()) : this;
}
/**
- * Recurse through the query tree, visiting any child queries
+ * Recurse through the query tree, visiting any child queries.
*
* @param visitor a QueryVisitor to be called by each query in the tree
*/
@@ -95,8 +131,8 @@ public Query rewrite(IndexReader reader) throws IOException {
* that {@link QueryCache} works properly.
*
*
Typically a query will be equal to another only if it's an instance of the same class and
- * its document-filtering properties are identical that other instance. Utility methods are
- * provided for certain repetitive code.
+ * its document-filtering properties are identical to those of the other instance. Utility methods
+ * are provided for certain repetitive code.
*
* @see #sameClassAs(Object)
* @see #classHash()
@@ -119,7 +155,7 @@ public Query rewrite(IndexReader reader) throws IOException {
*
*
When this method is used in an implementation of {@link #equals(Object)}, consider using
* {@link #classHash()} in the implementation of {@link #hashCode} to differentiate different
- * class
+ * class.
*/
protected final boolean sameClassAs(Object other) {
return other != null && getClass() == other.getClass();
diff --git a/lucene/core/src/java/org/apache/lucene/search/QueueSizeBasedExecutor.java b/lucene/core/src/java/org/apache/lucene/search/QueueSizeBasedExecutor.java
index a76b81f5da19..65ba1ea5573d 100644
--- a/lucene/core/src/java/org/apache/lucene/search/QueueSizeBasedExecutor.java
+++ b/lucene/core/src/java/org/apache/lucene/search/QueueSizeBasedExecutor.java
@@ -17,7 +17,6 @@
package org.apache.lucene.search;
-import java.util.Collection;
import java.util.concurrent.ThreadPoolExecutor;
/**
@@ -30,31 +29,15 @@ class QueueSizeBasedExecutor extends SliceExecutor {
private final ThreadPoolExecutor threadPoolExecutor;
- public QueueSizeBasedExecutor(ThreadPoolExecutor threadPoolExecutor) {
+ QueueSizeBasedExecutor(ThreadPoolExecutor threadPoolExecutor) {
super(threadPoolExecutor);
this.threadPoolExecutor = threadPoolExecutor;
}
@Override
- public void invokeAll(Collection extends Runnable> tasks) {
- int i = 0;
-
- for (Runnable task : tasks) {
- boolean shouldExecuteOnCallerThread = false;
-
- // Execute last task on caller thread
- if (i == tasks.size() - 1) {
- shouldExecuteOnCallerThread = true;
- }
-
- if (threadPoolExecutor.getQueue().size()
- >= (threadPoolExecutor.getMaximumPoolSize() * LIMITING_FACTOR)) {
- shouldExecuteOnCallerThread = true;
- }
-
- processTask(task, shouldExecuteOnCallerThread);
-
- ++i;
- }
+ boolean shouldExecuteOnCallerThread(int index, int numTasks) {
+ return super.shouldExecuteOnCallerThread(index, numTasks)
+ || threadPoolExecutor.getQueue().size()
+ >= (threadPoolExecutor.getMaximumPoolSize() * LIMITING_FACTOR);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java b/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java
index debc3efae154..5873b83075b6 100644
--- a/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java
+++ b/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java
@@ -101,8 +101,7 @@ public Query rewrite(IndexReader reader, MultiTermQuery query) throws IOExceptio
protected abstract void checkMaxClauseCount(int count) throws IOException;
@Override
- public final Query rewrite(final IndexReader reader, final MultiTermQuery query)
- throws IOException {
+ public final Query rewrite(IndexReader reader, final MultiTermQuery query) throws IOException {
final B builder = getTopLevelBuilder();
final ParallelArraysTermCollector col = new ParallelArraysTermCollector();
collectTerms(reader, query, col);
diff --git a/lucene/core/src/java/org/apache/lucene/search/SliceExecutor.java b/lucene/core/src/java/org/apache/lucene/search/SliceExecutor.java
index c84beeb5fb78..0e593740914d 100644
--- a/lucene/core/src/java/org/apache/lucene/search/SliceExecutor.java
+++ b/lucene/core/src/java/org/apache/lucene/search/SliceExecutor.java
@@ -18,6 +18,7 @@
package org.apache.lucene.search;
import java.util.Collection;
+import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
@@ -28,54 +29,30 @@
class SliceExecutor {
private final Executor executor;
- public SliceExecutor(Executor executor) {
- this.executor = executor;
+ SliceExecutor(Executor executor) {
+ this.executor = Objects.requireNonNull(executor, "Executor is null");
}
- public void invokeAll(Collection extends Runnable> tasks) {
-
- if (tasks == null) {
- throw new IllegalArgumentException("Tasks is null");
- }
-
- if (executor == null) {
- throw new IllegalArgumentException("Executor is null");
- }
-
+ final void invokeAll(Collection extends Runnable> tasks) {
int i = 0;
-
for (Runnable task : tasks) {
- boolean shouldExecuteOnCallerThread = false;
-
- // Execute last task on caller thread
- if (i == tasks.size() - 1) {
- shouldExecuteOnCallerThread = true;
+ if (shouldExecuteOnCallerThread(i, tasks.size())) {
+ task.run();
+ } else {
+ try {
+ executor.execute(task);
+ } catch (
+ @SuppressWarnings("unused")
+ RejectedExecutionException e) {
+ task.run();
+ }
}
-
- processTask(task, shouldExecuteOnCallerThread);
++i;
}
- ;
}
- // Helper method to execute a single task
- protected void processTask(final Runnable task, final boolean shouldExecuteOnCallerThread) {
- if (task == null) {
- throw new IllegalArgumentException("Input is null");
- }
-
- if (!shouldExecuteOnCallerThread) {
- try {
- executor.execute(task);
-
- return;
- } catch (
- @SuppressWarnings("unused")
- RejectedExecutionException e) {
- // Execute on caller thread
- }
- }
-
- task.run();
+ boolean shouldExecuteOnCallerThread(int index, int numTasks) {
+ // Execute last task on caller thread
+ return index == numTasks - 1;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java
index 8131c879d3e9..7e314870f9c7 100644
--- a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java
@@ -29,7 +29,6 @@
import org.apache.lucene.index.Impacts;
import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.ImpactsSource;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SlowImpactsEnum;
@@ -114,7 +113,7 @@ public SynonymQuery build() {
*/
private SynonymQuery(TermAndBoost[] terms, String field) {
this.terms = Objects.requireNonNull(terms);
- this.field = field;
+ this.field = Objects.requireNonNull(field);
}
public List getTerms() {
@@ -147,11 +146,13 @@ public int hashCode() {
@Override
public boolean equals(Object other) {
- return sameClassAs(other) && Arrays.equals(terms, ((SynonymQuery) other).terms);
+ return sameClassAs(other)
+ && field.equals(((SynonymQuery) other).field)
+ && Arrays.equals(terms, ((SynonymQuery) other).terms);
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
// optimize zero and non-boosted single term cases
if (terms.length == 0) {
return new BooleanQuery.Builder().build();
diff --git a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java
index 2e0a297c7463..ed268751bba4 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java
@@ -17,11 +17,10 @@
package org.apache.lucene.search;
import java.io.IOException;
-import java.util.ArrayList;
+import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
-import java.util.List;
import java.util.SortedSet;
import org.apache.lucene.index.FilteredTermsEnum;
import org.apache.lucene.index.PrefixCodedTerms;
@@ -38,7 +37,6 @@
import org.apache.lucene.util.automaton.Automata;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.ByteRunAutomaton;
-import org.apache.lucene.util.automaton.CompiledAutomaton;
import org.apache.lucene.util.automaton.Operations;
/**
@@ -150,13 +148,17 @@ public void visit(QueryVisitor visitor) {
}
}
+ // TODO: This is pretty heavy-weight. If we have TermInSetQuery directly extend AutomatonQuery
+ // we won't have to do this (see GH#12176).
private ByteRunAutomaton asByteRunAutomaton() {
- TermIterator iterator = termData.iterator();
- List automata = new ArrayList<>();
- for (BytesRef term = iterator.next(); term != null; term = iterator.next()) {
- automata.add(Automata.makeBinary(term));
+ try {
+ Automaton a = Automata.makeBinaryStringUnion(termData.iterator());
+ return new ByteRunAutomaton(a, true, Operations.DEFAULT_DETERMINIZE_WORK_LIMIT);
+ } catch (IOException e) {
+ // Shouldn't happen since termData.iterator() provides an interator implementation that
+ // never throws:
+ throw new UncheckedIOException(e);
}
- return new CompiledAutomaton(Operations.union(automata)).runAutomaton;
}
@Override
diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingBulkScorer.java
index 5e33884ea4ff..517f0a0e77b2 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingBulkScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingBulkScorer.java
@@ -41,6 +41,12 @@ static class TimeExceededException extends RuntimeException {
private TimeExceededException() {
super("TimeLimit Exceeded");
}
+
+ @Override
+ public Throwable fillInStackTrace() {
+ // never re-thrown so we can save the expensive stacktrace
+ return this;
+ }
}
private final BulkScorer in;
diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java
index 4a208a3f0c5f..c50f3a372e97 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java
@@ -33,9 +33,9 @@ public class TimeLimitingCollector implements Collector {
/** Thrown when elapsed search time exceeds allowed search time. */
@SuppressWarnings("serial")
public static class TimeExceededException extends RuntimeException {
- private long timeAllowed;
- private long timeElapsed;
- private int lastDocCollected;
+ private final long timeAllowed;
+ private final long timeElapsed;
+ private final int lastDocCollected;
private TimeExceededException(long timeAllowed, long timeElapsed, int lastDocCollected) {
super(
diff --git a/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java b/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java
index 067867ca7348..b5c52ba7cedf 100644
--- a/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java
+++ b/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java
@@ -61,8 +61,7 @@ public int getSize() {
protected abstract int getMaxSize();
@Override
- public final Query rewrite(final IndexReader reader, final MultiTermQuery query)
- throws IOException {
+ public final Query rewrite(IndexReader reader, final MultiTermQuery query) throws IOException {
final int maxSize = Math.min(size, getMaxSize());
final PriorityQueue stQueue = new PriorityQueue<>();
collectTerms(
diff --git a/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java b/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java
index 0ac859d40f14..ea75530fb2b4 100644
--- a/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java
+++ b/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java
@@ -91,8 +91,8 @@ public abstract class NumericLeafComparator implements LeafFieldComparator {
// if skipping functionality should be enabled on this segment
private final boolean enableSkipping;
private final int maxDoc;
- private final byte[] minValueAsBytes;
- private final byte[] maxValueAsBytes;
+ private byte[] minValueAsBytes;
+ private byte[] maxValueAsBytes;
private DocIdSetIterator competitiveIterator;
private long iteratorCost = -1;
@@ -128,16 +128,10 @@ public NumericLeafComparator(LeafReaderContext context) throws IOException {
}
this.enableSkipping = true; // skipping is enabled when points are available
this.maxDoc = context.reader().maxDoc();
- this.maxValueAsBytes =
- reverse == false ? new byte[bytesCount] : topValueSet ? new byte[bytesCount] : null;
- this.minValueAsBytes =
- reverse ? new byte[bytesCount] : topValueSet ? new byte[bytesCount] : null;
this.competitiveIterator = DocIdSetIterator.all(maxDoc);
} else {
this.enableSkipping = false;
this.maxDoc = 0;
- this.maxValueAsBytes = null;
- this.minValueAsBytes = null;
}
}
@@ -191,7 +185,9 @@ public void setHitsThresholdReached() throws IOException {
// update its iterator to include possibly only docs that are "stronger" than the current bottom
// entry
private void updateCompetitiveIterator() throws IOException {
- if (enableSkipping == false || hitsThresholdReached == false || queueFull == false) return;
+ if (enableSkipping == false
+ || hitsThresholdReached == false
+ || (queueFull == false && topValueSet == false)) return;
// if some documents have missing points, check that missing values prohibits optimization
if ((pointValues.getDocCount() < maxDoc) && isMissingValueCompetitive()) {
return; // we can't filter out documents, as documents with missing values are competitive
@@ -204,13 +200,21 @@ private void updateCompetitiveIterator() throws IOException {
return;
}
if (reverse == false) {
- encodeBottom(maxValueAsBytes);
+ if (queueFull) { // bottom is avilable only when queue is full
+ maxValueAsBytes = maxValueAsBytes == null ? new byte[bytesCount] : maxValueAsBytes;
+ encodeBottom(maxValueAsBytes);
+ }
if (topValueSet) {
+ minValueAsBytes = minValueAsBytes == null ? new byte[bytesCount] : minValueAsBytes;
encodeTop(minValueAsBytes);
}
} else {
- encodeBottom(minValueAsBytes);
+ if (queueFull) { // bottom is avilable only when queue is full
+ minValueAsBytes = minValueAsBytes == null ? new byte[bytesCount] : minValueAsBytes;
+ encodeBottom(minValueAsBytes);
+ }
if (topValueSet) {
+ maxValueAsBytes = maxValueAsBytes == null ? new byte[bytesCount] : maxValueAsBytes;
encodeTop(maxValueAsBytes);
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/package-info.java b/lucene/core/src/java/org/apache/lucene/search/package-info.java
index 4f606597d0a3..31a9d4018c8f 100644
--- a/lucene/core/src/java/org/apache/lucene/search/package-info.java
+++ b/lucene/core/src/java/org/apache/lucene/search/package-info.java
@@ -357,11 +357,11 @@
* each Query implementation must provide an implementation of Weight. See the subsection on
* The Weight Interface below for details on implementing the
* Weight interface.
- *
{@link org.apache.lucene.search.Query#rewrite(org.apache.lucene.index.IndexReader)
- * rewrite(IndexReader reader)} — Rewrites queries into primitive queries. Primitive
- * queries are: {@link org.apache.lucene.search.TermQuery TermQuery}, {@link
- * org.apache.lucene.search.BooleanQuery BooleanQuery}, and other queries that
- * implement {@link org.apache.lucene.search.Query#createWeight(IndexSearcher,ScoreMode,float)
+ *
{@link org.apache.lucene.search.Query#rewrite(IndexSearcher) rewrite(IndexSearcher
+ * searcher)} — Rewrites queries into primitive queries. Primitive queries are: {@link
+ * org.apache.lucene.search.TermQuery TermQuery}, {@link org.apache.lucene.search.BooleanQuery
+ * BooleanQuery}, and other queries that implement {@link
+ * org.apache.lucene.search.Query#createWeight(IndexSearcher,ScoreMode,float)
* createWeight(IndexSearcher searcher,ScoreMode scoreMode, float boost)}
*
*
diff --git a/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java b/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java
index 974ad0c68747..442dbd7af422 100644
--- a/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java
+++ b/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java
@@ -43,7 +43,7 @@ public abstract class BufferedIndexInput extends IndexInput implements RandomAcc
/** A buffer size for merges set to {@value #MERGE_BUFFER_SIZE}. */
public static final int MERGE_BUFFER_SIZE = 4096;
- private int bufferSize = BUFFER_SIZE;
+ private final int bufferSize;
private ByteBuffer buffer = EMPTY_BYTEBUFFER;
@@ -72,7 +72,7 @@ public BufferedIndexInput(String resourceDesc, int bufferSize) {
this.bufferSize = bufferSize;
}
- /** Returns buffer size. @see #setBufferSize */
+ /** Returns buffer size */
public final int getBufferSize() {
return bufferSize;
}
@@ -220,55 +220,50 @@ public final long readVLong() throws IOException {
}
}
- @Override
- public final byte readByte(long pos) throws IOException {
+ private long resolvePositionInBuffer(long pos, int width) throws IOException {
long index = pos - bufferStart;
- if (index < 0 || index >= buffer.limit()) {
+ if (index >= 0 && index <= buffer.limit() - width) {
+ return index;
+ }
+ if (index < 0) {
+ // if we're moving backwards, then try and fill up the previous page rather than
+ // starting again at the current pos, to avoid successive backwards reads reloading
+ // the same data over and over again. We also check that we can read `width`
+ // bytes without going over the end of the buffer
+ bufferStart = Math.max(bufferStart - bufferSize, pos + width - bufferSize);
+ bufferStart = Math.max(bufferStart, 0);
+ bufferStart = Math.min(bufferStart, pos);
+ } else {
+ // we're moving forwards, reset the buffer to start at pos
bufferStart = pos;
- buffer.limit(0); // trigger refill() on read
- seekInternal(pos);
- refill();
- index = 0;
}
+ buffer.limit(0); // trigger refill() on read
+ seekInternal(bufferStart);
+ refill();
+ return pos - bufferStart;
+ }
+
+ @Override
+ public final byte readByte(long pos) throws IOException {
+ long index = resolvePositionInBuffer(pos, Byte.BYTES);
return buffer.get((int) index);
}
@Override
public final short readShort(long pos) throws IOException {
- long index = pos - bufferStart;
- if (index < 0 || index >= buffer.limit() - 1) {
- bufferStart = pos;
- buffer.limit(0); // trigger refill() on read
- seekInternal(pos);
- refill();
- index = 0;
- }
+ long index = resolvePositionInBuffer(pos, Short.BYTES);
return buffer.getShort((int) index);
}
@Override
public final int readInt(long pos) throws IOException {
- long index = pos - bufferStart;
- if (index < 0 || index >= buffer.limit() - 3) {
- bufferStart = pos;
- buffer.limit(0); // trigger refill() on read
- seekInternal(pos);
- refill();
- index = 0;
- }
+ long index = resolvePositionInBuffer(pos, Integer.BYTES);
return buffer.getInt((int) index);
}
@Override
public final long readLong(long pos) throws IOException {
- long index = pos - bufferStart;
- if (index < 0 || index >= buffer.limit() - 7) {
- bufferStart = pos;
- buffer.limit(0); // trigger refill() on read
- seekInternal(pos);
- refill();
- index = 0;
- }
+ long index = resolvePositionInBuffer(pos, Long.BYTES);
return buffer.getLong((int) index);
}
diff --git a/lucene/core/src/java/org/apache/lucene/store/ByteBufferGuard.java b/lucene/core/src/java/org/apache/lucene/store/ByteBufferGuard.java
index 2d75597f9deb..a9e65ffa06c2 100644
--- a/lucene/core/src/java/org/apache/lucene/store/ByteBufferGuard.java
+++ b/lucene/core/src/java/org/apache/lucene/store/ByteBufferGuard.java
@@ -17,11 +17,11 @@
package org.apache.lucene.store;
import java.io.IOException;
+import java.lang.invoke.VarHandle;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
-import java.util.concurrent.atomic.AtomicInteger;
/**
* A guard that is created for every {@link ByteBufferIndexInput} that tries on best effort to
@@ -49,9 +49,6 @@ static interface BufferCleaner {
/** Not volatile; see comments on visibility below! */
private boolean invalidated = false;
- /** Used as a store-store barrier; see comments below! */
- private final AtomicInteger barrier = new AtomicInteger();
-
/**
* Creates an instance to be used for a single {@link ByteBufferIndexInput} which must be shared
* by all of its clones.
@@ -69,10 +66,9 @@ public void invalidateAndUnmap(ByteBuffer... bufs) throws IOException {
// the "invalidated" field update visible to other threads. We specifically
// don't make "invalidated" field volatile for performance reasons, hoping the
// JVM won't optimize away reads of that field and hardware should ensure
- // caches are in sync after this call. This isn't entirely "fool-proof"
- // (see LUCENE-7409 discussion), but it has been shown to work in practice
- // and we count on this behavior.
- barrier.lazySet(0);
+ // caches are in sync after this call.
+ // For previous implementation (based on `AtomicInteger#lazySet(0)`) see LUCENE-7409.
+ VarHandle.fullFence();
// we give other threads a bit of time to finish reads on their ByteBuffer...:
Thread.yield();
// finally unmap the ByteBuffers:
diff --git a/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java b/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java
index 5d23fb2f1ae0..30acb7023f03 100644
--- a/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java
+++ b/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java
@@ -76,9 +76,9 @@
*
On exactly Java 19 this class will use the modern {@code MemorySegment} API which
- * allows to safely unmap (if you discover any problems with this preview API, you can disable it by
- * using system property {@link #ENABLE_MEMORY_SEGMENTS_SYSPROP}).
+ *
On exactly Java 19 / 20 / 21 this class will use the modern {@code MemorySegment} API
+ * which allows to safely unmap (if you discover any problems with this preview API, you can disable
+ * it by using system property {@link #ENABLE_MEMORY_SEGMENTS_SYSPROP}).
*
*
NOTE: Accessing this class either directly or indirectly from a thread while it's
* interrupted can close the underlying channel immediately if at the same time the thread is
@@ -123,7 +123,7 @@ public class MMapDirectory extends FSDirectory {
* Default max chunk size:
*
*
- *
16 GiBytes for 64 bit Java 19 JVMs
+ *
16 GiBytes for 64 bit Java 19 / 20 / 21 JVMs
*
1 GiBytes for other 64 bit JVMs
*
256 MiBytes for 32 bit JVMs
*
@@ -220,9 +220,9 @@ public MMapDirectory(Path path, LockFactory lockFactory, int maxChunkSize) throw
* files cannot be mapped. Using a lower chunk size makes the directory implementation a little
* bit slower (as the correct chunk may be resolved on lots of seeks) but the chance is higher
* that mmap does not fail. On 64 bit Java platforms, this parameter should always be large (like
- * 1 GiBytes, or even larger with Java 19), as the address space is big enough. If it is larger,
- * fragmentation of address space increases, but number of file handles and mappings is lower for
- * huge installations with many open indexes.
+ * 1 GiBytes, or even larger with recent Java versions), as the address space is big enough. If it
+ * is larger, fragmentation of address space increases, but number of file handles and mappings is
+ * lower for huge installations with many open indexes.
*
*
Please note: The chunk size is always rounded down to a power of 2.
*
@@ -417,7 +417,7 @@ private static MMapIndexInputProvider lookupProvider() {
}
final var lookup = MethodHandles.lookup();
final int runtimeVersion = Runtime.version().feature();
- if (runtimeVersion == 19 || runtimeVersion == 20) {
+ if (runtimeVersion >= 19 && runtimeVersion <= 21) {
try {
final var cls = lookup.findClass("org.apache.lucene.store.MemorySegmentIndexInputProvider");
// we use method handles, so we do not need to deal with setAccessible as we have private
@@ -437,9 +437,9 @@ private static MMapIndexInputProvider lookupProvider() {
throw new LinkageError(
"MemorySegmentIndexInputProvider is missing in Lucene JAR file", cnfe);
}
- } else if (runtimeVersion >= 21) {
+ } else if (runtimeVersion >= 22) {
LOG.warning(
- "You are running with Java 21 or later. To make full use of MMapDirectory, please update Apache Lucene.");
+ "You are running with Java 22 or later. To make full use of MMapDirectory, please update Apache Lucene.");
}
return new MappedByteBufferIndexInputProvider();
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/BitSet.java b/lucene/core/src/java/org/apache/lucene/util/BitSet.java
index f8b8ba65a591..c5d84833b28a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/BitSet.java
+++ b/lucene/core/src/java/org/apache/lucene/util/BitSet.java
@@ -43,6 +43,16 @@ public static BitSet of(DocIdSetIterator it, int maxDoc) throws IOException {
return set;
}
+ /**
+ * Clear all the bits of the set.
+ *
+ *
Depending on the implementation, this may be significantly faster than clear(0, length).
+ */
+ public void clear() {
+ // default implementation for compatibility
+ clear(0, length());
+ }
+
/** Set the bit at i. */
public abstract void set(int i);
diff --git a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
index 5566bd0c4833..ebf626a777da 100644
--- a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
+++ b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
@@ -147,6 +147,11 @@ public FixedBitSet(long[] storedBits, int numBits) {
assert verifyGhostBitsClear();
}
+ @Override
+ public void clear() {
+ Arrays.fill(bits, 0L);
+ }
+
/**
* Checks if the bits past numBits are clear. Some methods rely on this implicit assumption:
* search for "Depends on the ghost bits being clear!"
diff --git a/lucene/core/src/java/org/apache/lucene/util/SparseFixedBitSet.java b/lucene/core/src/java/org/apache/lucene/util/SparseFixedBitSet.java
index 49d61614e86d..b4ebe3cfc59a 100644
--- a/lucene/core/src/java/org/apache/lucene/util/SparseFixedBitSet.java
+++ b/lucene/core/src/java/org/apache/lucene/util/SparseFixedBitSet.java
@@ -17,6 +17,7 @@
package org.apache.lucene.util;
import java.io.IOException;
+import java.util.Arrays;
import org.apache.lucene.search.DocIdSetIterator;
/**
@@ -73,6 +74,17 @@ public SparseFixedBitSet(int length) {
+ RamUsageEstimator.shallowSizeOf(bits);
}
+ @Override
+ public void clear() {
+ Arrays.fill(bits, null);
+ Arrays.fill(indices, 0L);
+ nonZeroLongCount = 0;
+ ramBytesUsed =
+ BASE_RAM_BYTES_USED
+ + RamUsageEstimator.sizeOf(indices)
+ + RamUsageEstimator.shallowSizeOf(bits);
+ }
+
@Override
public int length() {
return length;
diff --git a/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java
new file mode 100644
index 000000000000..1ade19a19803
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.util.Locale;
+
+/**
+ * Word2Vec unit composed by a term with the associated vector
+ *
+ * @lucene.experimental
+ */
+public class TermAndVector {
+
+ private final BytesRef term;
+ private final float[] vector;
+
+ public TermAndVector(BytesRef term, float[] vector) {
+ this.term = term;
+ this.vector = vector;
+ }
+
+ public BytesRef getTerm() {
+ return this.term;
+ }
+
+ public float[] getVector() {
+ return this.vector;
+ }
+
+ public int size() {
+ return vector.length;
+ }
+
+ public void normalizeVector() {
+ float vectorLength = 0;
+ for (int i = 0; i < vector.length; i++) {
+ vectorLength += vector[i] * vector[i];
+ }
+ vectorLength = (float) Math.sqrt(vectorLength);
+ for (int i = 0; i < vector.length; i++) {
+ vector[i] /= vectorLength;
+ }
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder builder = new StringBuilder(this.term.utf8ToString());
+ builder.append(" [");
+ if (vector.length > 0) {
+ for (int i = 0; i < vector.length - 1; i++) {
+ builder.append(String.format(Locale.ROOT, "%.3f,", vector[i]));
+ }
+ builder.append(String.format(Locale.ROOT, "%.3f]", vector[vector.length - 1]));
+ }
+ return builder.toString();
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/UnicodeUtil.java b/lucene/core/src/java/org/apache/lucene/util/UnicodeUtil.java
index a9e9d815eb49..9cbd4910271b 100644
--- a/lucene/core/src/java/org/apache/lucene/util/UnicodeUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/UnicodeUtil.java
@@ -477,40 +477,67 @@ public static int UTF8toUTF32(final BytesRef utf8, final int[] ints) {
int utf8Upto = utf8.offset;
final byte[] bytes = utf8.bytes;
final int utf8Limit = utf8.offset + utf8.length;
+ UTF8CodePoint reuse = null;
while (utf8Upto < utf8Limit) {
- final int numBytes = utf8CodeLength[bytes[utf8Upto] & 0xFF];
- int v = 0;
- switch (numBytes) {
- case 1:
- ints[utf32Count++] = bytes[utf8Upto++];
- continue;
- case 2:
- // 5 useful bits
- v = bytes[utf8Upto++] & 31;
- break;
- case 3:
- // 4 useful bits
- v = bytes[utf8Upto++] & 15;
- break;
- case 4:
- // 3 useful bits
- v = bytes[utf8Upto++] & 7;
- break;
- default:
- throw new IllegalArgumentException("invalid utf8");
- }
-
- // TODO: this may read past utf8's limit.
- final int limit = utf8Upto + numBytes - 1;
- while (utf8Upto < limit) {
- v = v << 6 | bytes[utf8Upto++] & 63;
- }
- ints[utf32Count++] = v;
+ reuse = codePointAt(bytes, utf8Upto, reuse);
+ ints[utf32Count++] = reuse.codePoint;
+ utf8Upto += reuse.numBytes;
}
return utf32Count;
}
+ /**
+ * Computes the codepoint and codepoint length (in bytes) of the specified {@code offset} in the
+ * provided {@code utf8} byte array, assuming UTF8 encoding. As with other related methods in this
+ * class, this assumes valid UTF8 input and does not perform full UTF8
+ * validation. Passing invalid UTF8 or a position that is not a valid header byte position may
+ * result in undefined behavior. This makes no attempt to synchronize or validate.
+ */
+ public static UTF8CodePoint codePointAt(byte[] utf8, int pos, UTF8CodePoint reuse) {
+ if (reuse == null) {
+ reuse = new UTF8CodePoint();
+ }
+
+ int leadByte = utf8[pos] & 0xFF;
+ int numBytes = utf8CodeLength[leadByte];
+ reuse.numBytes = numBytes;
+ int v;
+ switch (numBytes) {
+ case 1:
+ reuse.codePoint = leadByte;
+ return reuse;
+ case 2:
+ v = leadByte & 31; // 5 useful bits
+ break;
+ case 3:
+ v = leadByte & 15; // 4 useful bits
+ break;
+ case 4:
+ v = leadByte & 7; // 3 useful bits
+ break;
+ default:
+ throw new IllegalArgumentException(
+ "Invalid UTF8 header byte: 0x" + Integer.toHexString(leadByte));
+ }
+
+ // TODO: this may read past utf8's limit.
+ final int limit = pos + numBytes;
+ pos++;
+ while (pos < limit) {
+ v = v << 6 | utf8[pos++] & 63;
+ }
+ reuse.codePoint = v;
+
+ return reuse;
+ }
+
+ /** Holds a codepoint along with the number of bytes required to represent it in UTF8 */
+ public static final class UTF8CodePoint {
+ public int codePoint;
+ public int numBytes;
+ }
+
/** Shift value for lead surrogate to form a supplementary character. */
private static final int LEAD_SURROGATE_SHIFT_ = 10;
/** Mask to retrieve the significant value from a trail surrogate. */
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
index 2a08436ec0b0..0921bb75a664 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
@@ -20,6 +20,8 @@
/** Utilities for computations with numeric arrays */
public final class VectorUtil {
+ private static final VectorUtilProvider PROVIDER = VectorUtilProvider.lookup(false);
+
private VectorUtil() {}
/**
@@ -31,68 +33,9 @@ public static float dotProduct(float[] a, float[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
- float res = 0f;
- /*
- * If length of vector is larger than 8, we use unrolled dot product to accelerate the
- * calculation.
- */
- int i;
- for (i = 0; i < a.length % 8; i++) {
- res += b[i] * a[i];
- }
- if (a.length < 8) {
- return res;
- }
- for (; i + 31 < a.length; i += 32) {
- res +=
- b[i + 0] * a[i + 0]
- + b[i + 1] * a[i + 1]
- + b[i + 2] * a[i + 2]
- + b[i + 3] * a[i + 3]
- + b[i + 4] * a[i + 4]
- + b[i + 5] * a[i + 5]
- + b[i + 6] * a[i + 6]
- + b[i + 7] * a[i + 7];
- res +=
- b[i + 8] * a[i + 8]
- + b[i + 9] * a[i + 9]
- + b[i + 10] * a[i + 10]
- + b[i + 11] * a[i + 11]
- + b[i + 12] * a[i + 12]
- + b[i + 13] * a[i + 13]
- + b[i + 14] * a[i + 14]
- + b[i + 15] * a[i + 15];
- res +=
- b[i + 16] * a[i + 16]
- + b[i + 17] * a[i + 17]
- + b[i + 18] * a[i + 18]
- + b[i + 19] * a[i + 19]
- + b[i + 20] * a[i + 20]
- + b[i + 21] * a[i + 21]
- + b[i + 22] * a[i + 22]
- + b[i + 23] * a[i + 23];
- res +=
- b[i + 24] * a[i + 24]
- + b[i + 25] * a[i + 25]
- + b[i + 26] * a[i + 26]
- + b[i + 27] * a[i + 27]
- + b[i + 28] * a[i + 28]
- + b[i + 29] * a[i + 29]
- + b[i + 30] * a[i + 30]
- + b[i + 31] * a[i + 31];
- }
- for (; i + 7 < a.length; i += 8) {
- res +=
- b[i + 0] * a[i + 0]
- + b[i + 1] * a[i + 1]
- + b[i + 2] * a[i + 2]
- + b[i + 3] * a[i + 3]
- + b[i + 4] * a[i + 4]
- + b[i + 5] * a[i + 5]
- + b[i + 6] * a[i + 6]
- + b[i + 7] * a[i + 7];
- }
- return res;
+ float r = PROVIDER.dotProduct(a, b);
+ assert Float.isFinite(r);
+ return r;
}
/**
@@ -100,42 +43,21 @@ public static float dotProduct(float[] a, float[] b) {
*
* @throws IllegalArgumentException if the vectors' dimensions differ.
*/
- public static float cosine(float[] v1, float[] v2) {
- if (v1.length != v2.length) {
- throw new IllegalArgumentException(
- "vector dimensions differ: " + v1.length + "!=" + v2.length);
- }
-
- float sum = 0.0f;
- float norm1 = 0.0f;
- float norm2 = 0.0f;
- int dim = v1.length;
-
- for (int i = 0; i < dim; i++) {
- float elem1 = v1[i];
- float elem2 = v2[i];
- sum += elem1 * elem2;
- norm1 += elem1 * elem1;
- norm2 += elem2 * elem2;
+ public static float cosine(float[] a, float[] b) {
+ if (a.length != b.length) {
+ throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
- return (float) (sum / Math.sqrt(norm1 * norm2));
+ float r = PROVIDER.cosine(a, b);
+ assert Float.isFinite(r);
+ return r;
}
/** Returns the cosine similarity between the two vectors. */
public static float cosine(byte[] a, byte[] b) {
- // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
- int sum = 0;
- int norm1 = 0;
- int norm2 = 0;
-
- for (int i = 0; i < a.length; i++) {
- byte elem1 = a[i];
- byte elem2 = b[i];
- sum += elem1 * elem2;
- norm1 += elem1 * elem1;
- norm2 += elem2 * elem2;
+ if (a.length != b.length) {
+ throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
- return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
+ return PROVIDER.cosine(a, b);
}
/**
@@ -143,52 +65,21 @@ public static float cosine(byte[] a, byte[] b) {
*
* @throws IllegalArgumentException if the vectors' dimensions differ.
*/
- public static float squareDistance(float[] v1, float[] v2) {
- if (v1.length != v2.length) {
- throw new IllegalArgumentException(
- "vector dimensions differ: " + v1.length + "!=" + v2.length);
- }
- float squareSum = 0.0f;
- int dim = v1.length;
- int i;
- for (i = 0; i + 8 <= dim; i += 8) {
- squareSum += squareDistanceUnrolled(v1, v2, i);
- }
- for (; i < dim; i++) {
- float diff = v1[i] - v2[i];
- squareSum += diff * diff;
+ public static float squareDistance(float[] a, float[] b) {
+ if (a.length != b.length) {
+ throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
- return squareSum;
- }
-
- private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
- float diff0 = v1[index + 0] - v2[index + 0];
- float diff1 = v1[index + 1] - v2[index + 1];
- float diff2 = v1[index + 2] - v2[index + 2];
- float diff3 = v1[index + 3] - v2[index + 3];
- float diff4 = v1[index + 4] - v2[index + 4];
- float diff5 = v1[index + 5] - v2[index + 5];
- float diff6 = v1[index + 6] - v2[index + 6];
- float diff7 = v1[index + 7] - v2[index + 7];
- return diff0 * diff0
- + diff1 * diff1
- + diff2 * diff2
- + diff3 * diff3
- + diff4 * diff4
- + diff5 * diff5
- + diff6 * diff6
- + diff7 * diff7;
+ float r = PROVIDER.squareDistance(a, b);
+ assert Float.isFinite(r);
+ return r;
}
/** Returns the sum of squared differences of the two vectors. */
public static int squareDistance(byte[] a, byte[] b) {
- // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
- int squareSum = 0;
- for (int i = 0; i < a.length; i++) {
- int diff = a[i] - b[i];
- squareSum += diff * diff;
+ if (a.length != b.length) {
+ throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
- return squareSum;
+ return PROVIDER.squareDistance(a, b);
}
/**
@@ -250,12 +141,10 @@ public static void add(float[] u, float[] v) {
* @return the value of the dot product of the two vectors
*/
public static int dotProduct(byte[] a, byte[] b) {
- assert a.length == b.length;
- int total = 0;
- for (int i = 0; i < a.length; i++) {
- total += a[i] * b[i];
+ if (a.length != b.length) {
+ throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
- return total;
+ return PROVIDER.dotProduct(a, b);
}
/**
@@ -270,4 +159,20 @@ public static float dotProductScore(byte[] a, byte[] b) {
float denom = (float) (a.length * (1 << 15));
return 0.5f + dotProduct(a, b) / denom;
}
+
+ /**
+ * Checks if a float vector only has finite components.
+ *
+ * @param v bytes containing a vector
+ * @return the vector for call-chaining
+ * @throws IllegalArgumentException if any component of vector is not finite
+ */
+ public static float[] checkFinite(float[] v) {
+ for (int i = 0; i < v.length; i++) {
+ if (!Float.isFinite(v[i])) {
+ throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]);
+ }
+ }
+ return v;
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java
new file mode 100644
index 000000000000..665181e86788
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+/** The default VectorUtil provider implementation. */
+final class VectorUtilDefaultProvider implements VectorUtilProvider {
+
+ VectorUtilDefaultProvider() {}
+
+ @Override
+ public float dotProduct(float[] a, float[] b) {
+ float res = 0f;
+ /*
+ * If length of vector is larger than 8, we use unrolled dot product to accelerate the
+ * calculation.
+ */
+ int i;
+ for (i = 0; i < a.length % 8; i++) {
+ res += b[i] * a[i];
+ }
+ if (a.length < 8) {
+ return res;
+ }
+ for (; i + 31 < a.length; i += 32) {
+ res +=
+ b[i + 0] * a[i + 0]
+ + b[i + 1] * a[i + 1]
+ + b[i + 2] * a[i + 2]
+ + b[i + 3] * a[i + 3]
+ + b[i + 4] * a[i + 4]
+ + b[i + 5] * a[i + 5]
+ + b[i + 6] * a[i + 6]
+ + b[i + 7] * a[i + 7];
+ res +=
+ b[i + 8] * a[i + 8]
+ + b[i + 9] * a[i + 9]
+ + b[i + 10] * a[i + 10]
+ + b[i + 11] * a[i + 11]
+ + b[i + 12] * a[i + 12]
+ + b[i + 13] * a[i + 13]
+ + b[i + 14] * a[i + 14]
+ + b[i + 15] * a[i + 15];
+ res +=
+ b[i + 16] * a[i + 16]
+ + b[i + 17] * a[i + 17]
+ + b[i + 18] * a[i + 18]
+ + b[i + 19] * a[i + 19]
+ + b[i + 20] * a[i + 20]
+ + b[i + 21] * a[i + 21]
+ + b[i + 22] * a[i + 22]
+ + b[i + 23] * a[i + 23];
+ res +=
+ b[i + 24] * a[i + 24]
+ + b[i + 25] * a[i + 25]
+ + b[i + 26] * a[i + 26]
+ + b[i + 27] * a[i + 27]
+ + b[i + 28] * a[i + 28]
+ + b[i + 29] * a[i + 29]
+ + b[i + 30] * a[i + 30]
+ + b[i + 31] * a[i + 31];
+ }
+ for (; i + 7 < a.length; i += 8) {
+ res +=
+ b[i + 0] * a[i + 0]
+ + b[i + 1] * a[i + 1]
+ + b[i + 2] * a[i + 2]
+ + b[i + 3] * a[i + 3]
+ + b[i + 4] * a[i + 4]
+ + b[i + 5] * a[i + 5]
+ + b[i + 6] * a[i + 6]
+ + b[i + 7] * a[i + 7];
+ }
+ return res;
+ }
+
+ @Override
+ public float cosine(float[] a, float[] b) {
+ float sum = 0.0f;
+ float norm1 = 0.0f;
+ float norm2 = 0.0f;
+ int dim = a.length;
+
+ for (int i = 0; i < dim; i++) {
+ float elem1 = a[i];
+ float elem2 = b[i];
+ sum += elem1 * elem2;
+ norm1 += elem1 * elem1;
+ norm2 += elem2 * elem2;
+ }
+ return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
+ }
+
+ @Override
+ public float squareDistance(float[] a, float[] b) {
+ float squareSum = 0.0f;
+ int dim = a.length;
+ int i;
+ for (i = 0; i + 8 <= dim; i += 8) {
+ squareSum += squareDistanceUnrolled(a, b, i);
+ }
+ for (; i < dim; i++) {
+ float diff = a[i] - b[i];
+ squareSum += diff * diff;
+ }
+ return squareSum;
+ }
+
+ private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
+ float diff0 = v1[index + 0] - v2[index + 0];
+ float diff1 = v1[index + 1] - v2[index + 1];
+ float diff2 = v1[index + 2] - v2[index + 2];
+ float diff3 = v1[index + 3] - v2[index + 3];
+ float diff4 = v1[index + 4] - v2[index + 4];
+ float diff5 = v1[index + 5] - v2[index + 5];
+ float diff6 = v1[index + 6] - v2[index + 6];
+ float diff7 = v1[index + 7] - v2[index + 7];
+ return diff0 * diff0
+ + diff1 * diff1
+ + diff2 * diff2
+ + diff3 * diff3
+ + diff4 * diff4
+ + diff5 * diff5
+ + diff6 * diff6
+ + diff7 * diff7;
+ }
+
+ @Override
+ public int dotProduct(byte[] a, byte[] b) {
+ int total = 0;
+ for (int i = 0; i < a.length; i++) {
+ total += a[i] * b[i];
+ }
+ return total;
+ }
+
+ @Override
+ public float cosine(byte[] a, byte[] b) {
+ // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
+ int sum = 0;
+ int norm1 = 0;
+ int norm2 = 0;
+
+ for (int i = 0; i < a.length; i++) {
+ byte elem1 = a[i];
+ byte elem2 = b[i];
+ sum += elem1 * elem2;
+ norm1 += elem1 * elem1;
+ norm2 += elem2 * elem2;
+ }
+ return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
+ }
+
+ @Override
+ public int squareDistance(byte[] a, byte[] b) {
+ // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
+ int squareSum = 0;
+ for (int i = 0; i < a.length; i++) {
+ int diff = a[i] - b[i];
+ squareSum += diff * diff;
+ }
+ return squareSum;
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java
new file mode 100644
index 000000000000..3fd29c2bf349
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util;
+
+import java.lang.Runtime.Version;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.logging.Logger;
+
+/** A provider of VectorUtil implementations. */
+interface VectorUtilProvider {
+
+ /** Calculates the dot product of the given float arrays. */
+ float dotProduct(float[] a, float[] b);
+
+ /** Returns the cosine similarity between the two vectors. */
+ float cosine(float[] v1, float[] v2);
+
+ /** Returns the sum of squared differences of the two vectors. */
+ float squareDistance(float[] a, float[] b);
+
+ /** Returns the dot product computed over signed bytes. */
+ int dotProduct(byte[] a, byte[] b);
+
+ /** Returns the cosine similarity between the two byte vectors. */
+ float cosine(byte[] a, byte[] b);
+
+ /** Returns the sum of squared differences of the two byte vectors. */
+ int squareDistance(byte[] a, byte[] b);
+
+ // -- provider lookup mechanism
+
+ static final Logger LOG = Logger.getLogger(VectorUtilProvider.class.getName());
+
+ /** The minimal version of Java that has the bugfix for JDK-8301190. */
+ static final Version VERSION_JDK8301190_FIXED = Version.parse("20.0.2");
+
+ static VectorUtilProvider lookup(boolean testMode) {
+ final int runtimeVersion = Runtime.version().feature();
+ if (runtimeVersion >= 20 && runtimeVersion <= 21) {
+ // is locale sane (only buggy in Java 20)
+ if (isAffectedByJDK8301190()) {
+ LOG.warning(
+ "Java runtime is using a buggy default locale; Java vector incubator API can't be enabled: "
+ + Locale.getDefault());
+ return new VectorUtilDefaultProvider();
+ }
+ // is the incubator module present and readable (JVM providers may to exclude them or it is
+ // build with jlink)
+ if (!vectorModulePresentAndReadable()) {
+ LOG.warning(
+ "Java vector incubator module is not readable. For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API.");
+ return new VectorUtilDefaultProvider();
+ }
+ if (!testMode && isClientVM()) {
+ LOG.warning("C2 compiler is disabled; Java vector incubator API can't be enabled");
+ return new VectorUtilDefaultProvider();
+ }
+ try {
+ // we use method handles with lookup, so we do not need to deal with setAccessible as we
+ // have private access through the lookup:
+ final var lookup = MethodHandles.lookup();
+ final var cls = lookup.findClass("org.apache.lucene.util.VectorUtilPanamaProvider");
+ final var constr =
+ lookup.findConstructor(cls, MethodType.methodType(void.class, boolean.class));
+ try {
+ return (VectorUtilProvider) constr.invoke(testMode);
+ } catch (UnsupportedOperationException uoe) {
+ // not supported because preferred vector size too small or similar
+ LOG.warning("Java vector incubator API was not enabled. " + uoe.getMessage());
+ return new VectorUtilDefaultProvider();
+ } catch (RuntimeException | Error e) {
+ throw e;
+ } catch (Throwable th) {
+ throw new AssertionError(th);
+ }
+ } catch (NoSuchMethodException | IllegalAccessException e) {
+ throw new LinkageError(
+ "VectorUtilPanamaProvider is missing correctly typed constructor", e);
+ } catch (ClassNotFoundException cnfe) {
+ throw new LinkageError("VectorUtilPanamaProvider is missing in Lucene JAR file", cnfe);
+ }
+ } else if (runtimeVersion >= 22) {
+ LOG.warning(
+ "You are running with Java 22 or later. To make full use of the Vector API, please update Apache Lucene.");
+ }
+ return new VectorUtilDefaultProvider();
+ }
+
+ private static boolean vectorModulePresentAndReadable() {
+ var opt =
+ ModuleLayer.boot().modules().stream()
+ .filter(m -> m.getName().equals("jdk.incubator.vector"))
+ .findFirst();
+ if (opt.isPresent()) {
+ VectorUtilProvider.class.getModule().addReads(opt.get());
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Check if runtime is affected by JDK-8301190 (avoids assertion when default language is say
+ * "tr").
+ */
+ private static boolean isAffectedByJDK8301190() {
+ return VERSION_JDK8301190_FIXED.compareToIgnoreOptional(Runtime.version()) > 0
+ && !Objects.equals("I", "i".toUpperCase(Locale.getDefault()));
+ }
+
+ @SuppressWarnings("removal")
+ @SuppressForbidden(reason = "security manager")
+ private static boolean isClientVM() {
+ try {
+ final PrivilegedAction action =
+ () -> System.getProperty("java.vm.info", "").contains("emulated-client");
+ return AccessController.doPrivileged(action);
+ } catch (
+ @SuppressWarnings("unused")
+ SecurityException e) {
+ LOG.warning(
+ "SecurityManager denies permission to 'java.vm.info' system property, so state of C2 compiler can't be detected. "
+ + "In case of performance issues allow access to this property.");
+ return false;
+ }
+ }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/Version.java b/lucene/core/src/java/org/apache/lucene/util/Version.java
index 3bc4be34d5a1..bd37efbd8187 100644
--- a/lucene/core/src/java/org/apache/lucene/util/Version.java
+++ b/lucene/core/src/java/org/apache/lucene/util/Version.java
@@ -259,11 +259,16 @@ public final class Version {
@Deprecated public static final Version LUCENE_9_5_0 = new Version(9, 5, 0);
/**
- * Match settings and bugs in Lucene's 9.6.0 release.
+ * @deprecated (9.7.0) Use latest
+ */
+ @Deprecated public static final Version LUCENE_9_6_0 = new Version(9, 6, 0);
+
+ /**
+ * Match settings and bugs in Lucene's 9.7.0 release.
*
*
Use this to get the latest & greatest settings, bug fixes, etc, for Lucene.
*/
- public static final Version LUCENE_9_6_0 = new Version(9, 6, 0);
+ public static final Version LUCENE_9_7_0 = new Version(9, 7, 0);
// To add a new version:
// * Only add above this comment
@@ -279,7 +284,7 @@ public final class Version {
* re-test your entire application to ensure it behaves as expected, as some defaults may
* have changed and may break functionality in your application.
*/
- public static final Version LATEST = LUCENE_9_6_0;
+ public static final Version LATEST = LUCENE_9_7_0;
/**
* Constant for backwards compatibility.
diff --git a/lucene/core/src/java/org/apache/lucene/util/VirtualMethod.java b/lucene/core/src/java/org/apache/lucene/util/VirtualMethod.java
index 05eef2a86617..a7c2a71cc1c8 100644
--- a/lucene/core/src/java/org/apache/lucene/util/VirtualMethod.java
+++ b/lucene/core/src/java/org/apache/lucene/util/VirtualMethod.java
@@ -17,6 +17,8 @@
package org.apache.lucene.util;
import java.lang.reflect.Method;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
@@ -49,13 +51,20 @@
*
*
It is important to use {@link AccessController#doPrivileged(PrivilegedAction)} for the actual
+ * call to get the implementation distance because the subclass may be in a different package. The
+ * static constructors do not need to use {@code AccessController} because it just initializes our
+ * own method reference. The caller should have access to all declared members in its own class.
+ *
*
{@link #getImplementationDistance} returns the distance of the subclass that overrides this
* method. The one with the larger distance should be used preferable. This way also more
* complicated method rename scenarios can be handled (think of 2.9 {@code TokenStream}
diff --git a/lucene/core/src/java/org/apache/lucene/util/automaton/Automata.java b/lucene/core/src/java/org/apache/lucene/util/automaton/Automata.java
index 9a642338f09b..c829429d3c9b 100644
--- a/lucene/core/src/java/org/apache/lucene/util/automaton/Automata.java
+++ b/lucene/core/src/java/org/apache/lucene/util/automaton/Automata.java
@@ -29,9 +29,11 @@
package org.apache.lucene.util.automaton;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefIterator;
import org.apache.lucene.util.StringHelper;
/**
@@ -40,6 +42,11 @@
* @lucene.experimental
*/
public final class Automata {
+ /**
+ * {@link #makeStringUnion(Collection)} limits terms of this max length to ensure the stack
+ * doesn't overflow while building, since our algorithm currently relies on recursion.
+ */
+ public static final int MAX_STRING_UNION_TERM_LENGTH = 1000;
private Automata() {}
@@ -573,7 +580,49 @@ public static Automaton makeStringUnion(Collection utf8Strings) {
if (utf8Strings.isEmpty()) {
return makeEmpty();
} else {
- return DaciukMihovAutomatonBuilder.build(utf8Strings);
+ return DaciukMihovAutomatonBuilder.build(utf8Strings, false);
+ }
+ }
+
+ /**
+ * Returns a new (deterministic and minimal) automaton that accepts the union of the given
+ * collection of {@link BytesRef}s representing UTF-8 encoded strings. The resulting automaton
+ * will be built in a binary representation.
+ *
+ * @param utf8Strings The input strings, UTF-8 encoded. The collection must be in sorted order.
+ * @return An {@link Automaton} accepting all input strings. The resulting automaton is binary
+ * based (UTF-8 encoded byte transition labels).
+ */
+ public static Automaton makeBinaryStringUnion(Collection utf8Strings) {
+ if (utf8Strings.isEmpty()) {
+ return makeEmpty();
+ } else {
+ return DaciukMihovAutomatonBuilder.build(utf8Strings, true);
}
}
+
+ /**
+ * Returns a new (deterministic and minimal) automaton that accepts the union of the given
+ * iterator of {@link BytesRef}s representing UTF-8 encoded strings.
+ *
+ * @param utf8Strings The input strings, UTF-8 encoded. The iterator must be in sorted order.
+ * @return An {@link Automaton} accepting all input strings. The resulting automaton is codepoint
+ * based (full unicode codepoints on transitions).
+ */
+ public static Automaton makeStringUnion(BytesRefIterator utf8Strings) throws IOException {
+ return DaciukMihovAutomatonBuilder.build(utf8Strings, false);
+ }
+
+ /**
+ * Returns a new (deterministic and minimal) automaton that accepts the union of the given
+ * iterator of {@link BytesRef}s representing UTF-8 encoded strings. The resulting automaton will
+ * be built in a binary representation.
+ *
+ * @param utf8Strings The input strings, UTF-8 encoded. The iterator must be in sorted order.
+ * @return An {@link Automaton} accepting all input strings. The resulting automaton is binary
+ * based (UTF-8 encoded byte transition labels).
+ */
+ public static Automaton makeBinaryStringUnion(BytesRefIterator utf8Strings) throws IOException {
+ return DaciukMihovAutomatonBuilder.build(utf8Strings, true);
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/automaton/DaciukMihovAutomatonBuilder.java b/lucene/core/src/java/org/apache/lucene/util/automaton/DaciukMihovAutomatonBuilder.java
index 94002b04a40b..7048d4538d15 100644
--- a/lucene/core/src/java/org/apache/lucene/util/automaton/DaciukMihovAutomatonBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/automaton/DaciukMihovAutomatonBuilder.java
@@ -16,14 +16,15 @@
*/
package org.apache.lucene.util.automaton;
+import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.IdentityHashMap;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.CharsRef;
+import org.apache.lucene.util.BytesRefBuilder;
+import org.apache.lucene.util.BytesRefIterator;
import org.apache.lucene.util.UnicodeUtil;
/**
@@ -32,14 +33,22 @@
*
* @see #build(Collection)
* @see Automata#makeStringUnion(Collection)
+ * @see Automata#makeBinaryStringUnion(Collection)
+ * @see Automata#makeStringUnion(BytesRefIterator)
+ * @see Automata#makeBinaryStringUnion(BytesRefIterator)
+ * @deprecated Visibility of this class will be reduced in a future release. Users can access this
+ * functionality directly through {@link Automata#makeStringUnion(Collection)}
*/
+@Deprecated
public final class DaciukMihovAutomatonBuilder {
/**
* This builder rejects terms that are more than 1k chars long since it then uses recursion based
* on the length of the string, which might cause stack overflows.
+ *
+ * @deprecated See {@link Automata#MAX_STRING_UNION_TERM_LENGTH}
*/
- public static final int MAX_TERM_LENGTH = 1_000;
+ @Deprecated public static final int MAX_TERM_LENGTH = 1_000;
/** The default constructor is private. Use static methods directly. */
private DaciukMihovAutomatonBuilder() {
@@ -179,56 +188,18 @@ private static boolean referenceEquals(Object[] a1, Object[] a2) {
private HashMap stateRegistry = new HashMap<>();
/** Root automaton state. */
- private State root = new State();
-
- /** Previous sequence added to the automaton in {@link #add(CharsRef)}. */
- private CharsRef previous;
+ private final State root = new State();
- /** A comparator used for enforcing sorted UTF8 order, used in assertions only. */
- @SuppressWarnings("deprecation")
- private static final Comparator comparator = CharsRef.getUTF16SortedAsUTF8Comparator();
+ /** Used for input order checking (only through assertions right now) */
+ private BytesRefBuilder previous;
- /**
- * Add another character sequence to this automaton. The sequence must be lexicographically larger
- * or equal compared to any previous sequences added to this automaton (the input must be sorted).
- */
- public void add(CharsRef current) {
- if (current.length > MAX_TERM_LENGTH) {
- throw new IllegalArgumentException(
- "This builder doesn't allow terms that are larger than 1,000 characters, got " + current);
- }
- assert stateRegistry != null : "Automaton already built.";
- assert previous == null || comparator.compare(previous, current) <= 0
- : "Input must be in sorted UTF-8 order: " + previous + " >= " + current;
- assert setPrevious(current);
-
- // Descend in the automaton (find matching prefix).
- int pos = 0, max = current.length();
- State next, state = root;
- while (pos < max && (next = state.lastChild(Character.codePointAt(current, pos))) != null) {
- state = next;
- // todo, optimize me
- pos += Character.charCount(Character.codePointAt(current, pos));
+ /** Copy current into an internal buffer. */
+ private boolean setPrevious(BytesRef current) {
+ if (previous == null) {
+ previous = new BytesRefBuilder();
}
-
- if (state.hasChildren()) replaceOrRegister(state);
-
- addSuffix(state, current, pos);
- }
-
- /**
- * Finalize the automaton and return the root state. No more strings can be added to the builder
- * after this call.
- *
- * @return Root automaton state.
- */
- public State complete() {
- if (this.stateRegistry == null) throw new IllegalStateException();
-
- if (root.hasChildren()) replaceOrRegister(root);
-
- stateRegistry = null;
- return root;
+ previous.copyBytes(current);
+ return true;
}
/** Internal recursive traversal for conversion. */
@@ -253,35 +224,116 @@ private static int convert(
return converted;
}
+ /**
+ * Called after adding all terms. Performs final minimization and converts to a standard {@link
+ * Automaton} instance.
+ */
+ private Automaton completeAndConvert() {
+ // Final minimization:
+ if (this.stateRegistry == null) throw new IllegalStateException();
+ if (root.hasChildren()) replaceOrRegister(root);
+ stateRegistry = null;
+
+ // Convert:
+ Automaton.Builder a = new Automaton.Builder();
+ convert(a, root, new IdentityHashMap<>());
+ return a.finish();
+ }
+
/**
* Build a minimal, deterministic automaton from a sorted list of {@link BytesRef} representing
* strings in UTF-8. These strings must be binary-sorted.
+ *
+ * @deprecated Please see {@link Automata#makeStringUnion(Collection)} instead
*/
+ @Deprecated
public static Automaton build(Collection input) {
+ return build(input, false);
+ }
+
+ /**
+ * Build a minimal, deterministic automaton from a sorted list of {@link BytesRef} representing
+ * strings in UTF-8. These strings must be binary-sorted.
+ */
+ static Automaton build(Collection input, boolean asBinary) {
final DaciukMihovAutomatonBuilder builder = new DaciukMihovAutomatonBuilder();
- char[] chars = new char[0];
- CharsRef ref = new CharsRef();
for (BytesRef b : input) {
- chars = ArrayUtil.grow(chars, b.length);
- final int len = UnicodeUtil.UTF8toUTF16(b, chars);
- ref.chars = chars;
- ref.length = len;
- builder.add(ref);
+ builder.add(b, asBinary);
}
- Automaton.Builder a = new Automaton.Builder();
- convert(a, builder.complete(), new IdentityHashMap());
+ return builder.completeAndConvert();
+ }
- return a.finish();
+ /**
+ * Build a minimal, deterministic automaton from a sorted list of {@link BytesRef} representing
+ * strings in UTF-8. These strings must be binary-sorted. Creates an {@link Automaton} with either
+ * UTF-8 codepoints as transition labels or binary (compiled) transition labels based on {@code
+ * asBinary}.
+ */
+ static Automaton build(BytesRefIterator input, boolean asBinary) throws IOException {
+ final DaciukMihovAutomatonBuilder builder = new DaciukMihovAutomatonBuilder();
+
+ for (BytesRef b = input.next(); b != null; b = input.next()) {
+ builder.add(b, asBinary);
+ }
+
+ return builder.completeAndConvert();
}
- /** Copy current into an internal buffer. */
- private boolean setPrevious(CharsRef current) {
- // don't need to copy, once we fix https://issues.apache.org/jira/browse/LUCENE-3277
- // still, called only from assert
- previous = CharsRef.deepCopyOf(current);
- return true;
+ private void add(BytesRef current, boolean asBinary) {
+ if (current.length > Automata.MAX_STRING_UNION_TERM_LENGTH) {
+ throw new IllegalArgumentException(
+ "This builder doesn't allow terms that are larger than "
+ + Automata.MAX_STRING_UNION_TERM_LENGTH
+ + " characters, got "
+ + current);
+ }
+ assert stateRegistry != null : "Automaton already built.";
+ assert previous == null || previous.get().compareTo(current) <= 0
+ : "Input must be in sorted UTF-8 order: " + previous.get() + " >= " + current;
+ assert setPrevious(current);
+
+ // Reusable codepoint information if we're building a non-binary based automaton
+ UnicodeUtil.UTF8CodePoint codePoint = null;
+
+ // Descend in the automaton (find matching prefix).
+ byte[] bytes = current.bytes;
+ int pos = current.offset, max = current.offset + current.length;
+ State next, state = root;
+ if (asBinary) {
+ while (pos < max && (next = state.lastChild(bytes[pos] & 0xff)) != null) {
+ state = next;
+ pos++;
+ }
+ } else {
+ while (pos < max) {
+ codePoint = UnicodeUtil.codePointAt(bytes, pos, codePoint);
+ next = state.lastChild(codePoint.codePoint);
+ if (next == null) {
+ break;
+ }
+ state = next;
+ pos += codePoint.numBytes;
+ }
+ }
+
+ if (state.hasChildren()) replaceOrRegister(state);
+
+ // Add suffix
+ if (asBinary) {
+ while (pos < max) {
+ state = state.newState(bytes[pos] & 0xff);
+ pos++;
+ }
+ } else {
+ while (pos < max) {
+ codePoint = UnicodeUtil.codePointAt(bytes, pos, codePoint);
+ state = state.newState(codePoint.codePoint);
+ pos += codePoint.numBytes;
+ }
+ }
+ state.is_final = true;
}
/**
@@ -300,18 +352,4 @@ private void replaceOrRegister(State state) {
stateRegistry.put(child, child);
}
}
-
- /**
- * Add a suffix of current starting at fromIndex (inclusive) to state
- * state.
- */
- private void addSuffix(State state, CharSequence current, int fromIndex) {
- final int len = current.length();
- while (fromIndex < len) {
- int cp = Character.codePointAt(current, fromIndex);
- state = state.newState(cp);
- fromIndex += Character.charCount(cp);
- }
- state.is_final = true;
- }
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java b/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java
index 6db58e7119ea..9ebe5b998037 100644
--- a/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java
+++ b/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java
@@ -1135,7 +1135,7 @@ public static String getCommonPrefix(Automaton a) {
FixedBitSet tmp = current;
current = next;
next = tmp;
- next.clear(0, next.length());
+ next.clear();
}
return builder.toString();
}
@@ -1311,9 +1311,22 @@ static Automaton totalize(Automaton a) {
}
/**
- * Returns the topological sort of all states reachable from the initial state. Behavior is
- * undefined if this automaton has cycles. CPU cost is O(numTransitions), and the implementation
- * is recursive so an automaton matching long strings may exhaust the java stack.
+ * Returns the topological sort of all states reachable from the initial state. This method
+ * assumes that the automaton does not contain cycles, and will throw an IllegalArgumentException
+ * if a cycle is detected. The CPU cost is O(numTransitions), and the implementation is
+ * non-recursive, so it will not exhaust the java stack for automaton matching long strings. If
+ * there are dead states in the automaton, they will be removed from the returned array.
+ *
+ *
Note: This method uses a deque to iterative the states, which could potentially consume a
+ * lot of heap space for some automatons. Specifically, automatons with a deep level of states
+ * (i.e., a large number of transitions from the initial state to the final state) may
+ * particularly contribute to high memory usage. The memory consumption of this method can be
+ * considered as O(N), where N is the depth of the automaton (the maximum number of transitions
+ * from the initial state to any state). However, as this method detects cycles, it will never
+ * attempt to use infinite RAM.
+ *
+ * @param a the Automaton to be sorted
+ * @return the topologically sorted array of state ids
*/
public static int[] topoSortStates(Automaton a) {
if (a.getNumStates() == 0) {
@@ -1321,8 +1334,7 @@ public static int[] topoSortStates(Automaton a) {
}
int numStates = a.getNumStates();
int[] states = new int[numStates];
- final BitSet visited = new BitSet(numStates);
- int upto = topoSortStatesRecurse(a, visited, states, 0, 0, 0);
+ int upto = topoSortStates(a, states);
if (upto < states.length) {
// There were dead states
@@ -1341,24 +1353,49 @@ public static int[] topoSortStates(Automaton a) {
return states;
}
- // TODO: not great that this is recursive... in theory a
- // large automata could exceed java's stack so the maximum level of recursion is bounded to 1000
- private static int topoSortStatesRecurse(
- Automaton a, BitSet visited, int[] states, int upto, int state, int level) {
- if (level > MAX_RECURSION_LEVEL) {
- throw new IllegalArgumentException("input automaton is too large: " + level);
- }
+ /**
+ * Performs a topological sort on the states of the given Automaton.
+ *
+ * @param a The automaton whose states are to be topologically sorted.
+ * @param states An int array which stores the states.
+ * @return the number of states in the final sorted list.
+ * @throws IllegalArgumentException if the input automaton has a cycle.
+ */
+ private static int topoSortStates(Automaton a, int[] states) {
+ BitSet onStack = new BitSet(a.getNumStates());
+ BitSet visited = new BitSet(a.getNumStates());
+ var stack = new ArrayDeque();
+ stack.push(0); // Assuming that the initial state is 0.
+ int upto = 0;
Transition t = new Transition();
- int count = a.initTransition(state, t);
- for (int i = 0; i < count; i++) {
- a.getNextTransition(t);
- if (!visited.get(t.dest)) {
- visited.set(t.dest);
- upto = topoSortStatesRecurse(a, visited, states, upto, t.dest, level + 1);
+
+ while (!stack.isEmpty()) {
+ int state = stack.peek(); // Just peek, don't remove the state yet
+
+ int count = a.initTransition(state, t);
+ boolean pushed = false;
+ for (int i = 0; i < count; i++) {
+ a.getNextTransition(t);
+ if (!visited.get(t.dest)) {
+ visited.set(t.dest);
+ stack.push(t.dest); // Push the next unvisited state onto the stack
+ onStack.set(state);
+ pushed = true;
+ break; // Exit the loop, we'll continue from here in the next iteration
+ } else if (onStack.get(t.dest)) {
+ // If the state is on the current recursion stack, we have detected a cycle
+ throw new IllegalArgumentException("Input automaton has a cycle.");
+ }
+ }
+
+ // If we haven't pushed any new state onto the stack, we're done with this state
+ if (!pushed) {
+ onStack.clear(state); // remove the node from the current recursion stack
+ stack.pop();
+ states[upto] = state;
+ upto++;
}
}
- states[upto] = state;
- upto++;
return upto;
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java b/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java
index 6711dfb6230f..321c6ff133a0 100644
--- a/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java
+++ b/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java
@@ -45,6 +45,8 @@
* different paths of the {@link Automaton}.
*/
public final class GraphTokenStreamFiniteStrings {
+ /** Maximum level of recursion allowed in recursive operations. */
+ private static final int MAX_RECURSION_LEVEL = 1000;
private AttributeSource[] tokens = new AttributeSource[4];
private final Automaton det;
@@ -271,7 +273,12 @@ private static void articulationPointsRecurse(
a.getNextTransition(t);
if (visited.get(t.dest) == false) {
parent[t.dest] = state;
- articulationPointsRecurse(a, t.dest, d + 1, depth, low, parent, visited, points);
+ if (d < MAX_RECURSION_LEVEL) {
+ articulationPointsRecurse(a, t.dest, d + 1, depth, low, parent, visited, points);
+ } else {
+ throw new IllegalArgumentException(
+ "Exceeded maximum recursion level during graph analysis");
+ }
childCount++;
if (low[t.dest] >= depth[state]) {
isArticulation = true;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
index 2c5e84be2859..5738cc8c75d8 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
@@ -41,8 +41,17 @@
*/
public final class HnswGraphBuilder {
+ /** Default number of maximum connections per node */
+ public static final int DEFAULT_MAX_CONN = 16;
+
+ /**
+ * Default number of the size of the queue maintained while searching during a graph construction.
+ */
+ public static final int DEFAULT_BEAM_WIDTH = 100;
+
/** Default random seed for level generation * */
private static final long DEFAULT_RAND_SEED = 42;
+
/** A name for the HNSW component for the info-stream * */
public static final String HNSW_COMPONENT = "HNSW";
@@ -220,7 +229,9 @@ private void initializeFromGraph(
binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
break;
}
- newNeighbors.insertSorted(newNeighbor, score);
+ // we are not sure whether the previous graph contains
+ // unchecked nodes, so we have to assume they're all unchecked
+ newNeighbors.addOutOfOrder(newNeighbor, score);
}
}
}
@@ -316,11 +327,11 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
- NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
- nbrNbr.insertSorted(node, neighbors.score[i]);
- if (nbrNbr.size() > maxConnOnLevel) {
- int indexToRemove = findWorstNonDiverse(nbrNbr);
- nbrNbr.removeIndex(indexToRemove);
+ NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
+ nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
+ if (nbrsOfNbr.size() > maxConnOnLevel) {
+ int indexToRemove = findWorstNonDiverse(nbrsOfNbr);
+ nbrsOfNbr.removeIndex(indexToRemove);
}
}
}
@@ -335,7 +346,7 @@ private void selectAndLinkDiverse(
float cScore = candidates.score[i];
assert cNode < hnsw.size();
if (diversityCheck(cNode, cScore, neighbors)) {
- neighbors.add(cNode, cScore);
+ neighbors.addInOrder(cNode, cScore);
}
}
}
@@ -347,7 +358,7 @@ private void popToScratch(NeighborQueue candidates) {
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
float maxSimilarity = candidates.topScore();
- scratch.add(candidates.pop(), maxSimilarity);
+ scratch.addInOrder(candidates.pop(), maxSimilarity);
}
}
@@ -405,53 +416,119 @@ private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
+ int[] uncheckedIndexes = neighbors.sort();
+ if (uncheckedIndexes == null) {
+ // all nodes are checked, we will directly return the most distant one
+ return neighbors.size() - 1;
+ }
+ int uncheckedCursor = uncheckedIndexes.length - 1;
for (int i = neighbors.size() - 1; i > 0; i--) {
- if (isWorstNonDiverse(i, neighbors)) {
+ if (uncheckedCursor < 0) {
+ // no unchecked node left
+ break;
+ }
+ if (isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) {
return i;
}
+ if (i == uncheckedIndexes[uncheckedCursor]) {
+ uncheckedCursor--;
+ }
}
return neighbors.size() - 1;
}
- private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
+ private boolean isWorstNonDiverse(
+ int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor)
throws IOException {
int candidateNode = neighbors.node[candidateIndex];
switch (vectorEncoding) {
case BYTE:
return isWorstNonDiverse(
- candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors);
+ candidateIndex,
+ (byte[]) vectors.vectorValue(candidateNode),
+ neighbors,
+ uncheckedIndexes,
+ uncheckedCursor);
default:
case FLOAT32:
return isWorstNonDiverse(
- candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
+ candidateIndex,
+ (float[]) vectors.vectorValue(candidateNode),
+ neighbors,
+ uncheckedIndexes,
+ uncheckedCursor);
}
}
private boolean isWorstNonDiverse(
- int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
+ int candidateIndex,
+ float[] candidateVector,
+ NeighborArray neighbors,
+ int[] uncheckedIndexes,
+ int uncheckedCursor)
+ throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
- for (int i = candidateIndex - 1; i >= 0; i--) {
- float neighborSimilarity =
- similarityFunction.compare(
- candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
- // candidate node is too similar to node i given its score relative to the base node
- if (neighborSimilarity >= minAcceptedSimilarity) {
- return true;
+ if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
+ // the candidate itself is unchecked
+ for (int i = candidateIndex - 1; i >= 0; i--) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
+ // candidate node is too similar to node i given its score relative to the base node
+ if (neighborSimilarity >= minAcceptedSimilarity) {
+ return true;
+ }
+ }
+ } else {
+ // else we just need to make sure candidate does not violate diversity with the (newly
+ // inserted) unchecked nodes
+ assert candidateIndex > uncheckedIndexes[uncheckedCursor];
+ for (int i = uncheckedCursor; i >= 0; i--) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ candidateVector,
+ (float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
+ // candidate node is too similar to node i given its score relative to the base node
+ if (neighborSimilarity >= minAcceptedSimilarity) {
+ return true;
+ }
}
}
return false;
}
private boolean isWorstNonDiverse(
- int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException {
+ int candidateIndex,
+ byte[] candidateVector,
+ NeighborArray neighbors,
+ int[] uncheckedIndexes,
+ int uncheckedCursor)
+ throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
- for (int i = candidateIndex - 1; i >= 0; i--) {
- float neighborSimilarity =
- similarityFunction.compare(
- candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
- // candidate node is too similar to node i given its score relative to the base node
- if (neighborSimilarity >= minAcceptedSimilarity) {
- return true;
+ if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
+ // the candidate itself is unchecked
+ for (int i = candidateIndex - 1; i >= 0; i--) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
+ // candidate node is too similar to node i given its score relative to the base node
+ if (neighborSimilarity >= minAcceptedSimilarity) {
+ return true;
+ }
+ }
+ } else {
+ // else we just need to make sure candidate does not violate diversity with the (newly
+ // inserted) unchecked nodes
+ assert candidateIndex > uncheckedIndexes[uncheckedCursor];
+ for (int i = uncheckedCursor; i >= 0; i--) {
+ float neighborSimilarity =
+ similarityFunction.compare(
+ candidateVector,
+ (byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
+ // candidate node is too similar to node i given its score relative to the base node
+ if (neighborSimilarity >= minAcceptedSimilarity) {
+ return true;
+ }
}
}
return false;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
index 4857d5b9d577..5bc718169466 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
@@ -100,28 +100,31 @@ public static NeighborQueue search(
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
- NeighborQueue results;
+ return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
+ }
- int initialEp = graph.entryNode();
- if (initialEp == -1) {
- return new NeighborQueue(1, true);
- }
- int[] eps = new int[] {initialEp};
- int numVisited = 0;
- for (int level = graph.numLevels() - 1; level >= 1; level--) {
- results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
- numVisited += results.visitedCount();
- visitedLimit -= results.visitedCount();
- if (results.incomplete()) {
- results.setVisitedCount(numVisited);
- return results;
- }
- eps[0] = results.pop();
- }
- results =
- graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
- results.setVisitedCount(results.visitedCount() + numVisited);
- return results;
+ /**
+ * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
+ * {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding,
+ * VectorSimilarityFunction, HnswGraph, Bits, int)}
+ */
+ public static NeighborQueue search(
+ float[] query,
+ int topK,
+ RandomAccessVectorValues vectors,
+ VectorEncoding vectorEncoding,
+ VectorSimilarityFunction similarityFunction,
+ OnHeapHnswGraph graph,
+ Bits acceptOrds,
+ int visitedLimit)
+ throws IOException {
+ OnHeapHnswGraphSearcher graphSearcher =
+ new OnHeapHnswGraphSearcher<>(
+ vectorEncoding,
+ similarityFunction,
+ new NeighborQueue(topK, true),
+ new SparseFixedBitSet(vectors.size()));
+ return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}
/**
@@ -161,6 +164,46 @@ public static NeighborQueue search(
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
+ return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
+ }
+
+ /**
+ * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
+ * {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction,
+ * HnswGraph, Bits, int)}
+ */
+ public static NeighborQueue search(
+ byte[] query,
+ int topK,
+ RandomAccessVectorValues vectors,
+ VectorEncoding vectorEncoding,
+ VectorSimilarityFunction similarityFunction,
+ OnHeapHnswGraph graph,
+ Bits acceptOrds,
+ int visitedLimit)
+ throws IOException {
+ OnHeapHnswGraphSearcher graphSearcher =
+ new OnHeapHnswGraphSearcher<>(
+ vectorEncoding,
+ similarityFunction,
+ new NeighborQueue(topK, true),
+ new SparseFixedBitSet(vectors.size()));
+ return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
+ }
+
+ private static NeighborQueue search(
+ T query,
+ int topK,
+ RandomAccessVectorValues vectors,
+ HnswGraph graph,
+ HnswGraphSearcher graphSearcher,
+ Bits acceptOrds,
+ int visitedLimit)
+ throws IOException {
+ int initialEp = graph.entryNode();
+ if (initialEp == -1) {
+ return new NeighborQueue(1, true);
+ }
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
@@ -252,9 +295,9 @@ private NeighborQueue searchLevel(
}
int topCandidateNode = candidates.pop();
- graph.seek(level, topCandidateNode);
+ graphSeek(graph, level, topCandidateNode);
int friendOrd;
- while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
+ while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;
@@ -296,6 +339,62 @@ private void prepareScratchState(int capacity) {
if (visited.length() < capacity) {
visited = FixedBitSet.ensureCapacity((FixedBitSet) visited, capacity);
}
- visited.clear(0, visited.length());
+ visited.clear();
+ }
+
+ /**
+ * Seek a specific node in the given graph. The default implementation will just call {@link
+ * HnswGraph#seek(int, int)}
+ *
+ * @throws IOException when seeking the graph
+ */
+ void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException {
+ graph.seek(level, targetNode);
+ }
+
+ /**
+ * Get the next neighbor from the graph, you must call {@link #graphSeek(HnswGraph, int, int)}
+ * before calling this method. The default implementation will just call {@link
+ * HnswGraph#nextNeighbor()}
+ *
+ * @return see {@link HnswGraph#nextNeighbor()}
+ * @throws IOException when advance neighbors
+ */
+ int graphNextNeighbor(HnswGraph graph) throws IOException {
+ return graph.nextNeighbor();
+ }
+
+ /**
+ * This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner.
+ *
+ *
Note the class itself is NOT thread safe, but since each search will create one new graph
+ * searcher the search method is thread safe.
+ */
+ private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher {
+
+ private NeighborArray cur;
+ private int upto;
+
+ private OnHeapHnswGraphSearcher(
+ VectorEncoding vectorEncoding,
+ VectorSimilarityFunction similarityFunction,
+ NeighborQueue candidates,
+ BitSet visited) {
+ super(vectorEncoding, similarityFunction, candidates, visited);
+ }
+
+ @Override
+ void graphSeek(HnswGraph graph, int level, int targetNode) {
+ cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode);
+ upto = -1;
+ }
+
+ @Override
+ int graphNextNeighbor(HnswGraph graph) {
+ if (++upto < cur.size()) {
+ return cur.node[upto];
+ }
+ return NO_MORE_DOCS;
+ }
}
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
index ec1b5ec3e897..b44f7da8b8ad 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
@@ -34,6 +34,7 @@ public class NeighborArray {
float[] score;
int[] node;
+ private int sortedNodeSize;
public NeighborArray(int maxSize, boolean descOrder) {
node = new int[maxSize];
@@ -43,9 +44,10 @@ public NeighborArray(int maxSize, boolean descOrder) {
/**
* Add a new node to the NeighborArray. The new node must be worse than all previously stored
- * nodes.
+ * nodes. This cannot be called after {@link #addOutOfOrder(int, float)}
*/
- public void add(int newNode, float newScore) {
+ public void addInOrder(int newNode, float newScore) {
+ assert size == sortedNodeSize : "cannot call addInOrder after addOutOfOrder";
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
@@ -54,28 +56,80 @@ public void add(int newNode, float newScore) {
float previousScore = score[size - 1];
assert ((scoresDescOrder && (previousScore >= newScore))
|| (scoresDescOrder == false && (previousScore <= newScore)))
- : "Nodes are added in the incorrect order!";
+ : "Nodes are added in the incorrect order! Comparing "
+ + newScore
+ + " to "
+ + Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size));
}
node[size] = newNode;
score[size] = newScore;
++size;
+ ++sortedNodeSize;
}
- /** Add a new node to the NeighborArray into a correct sort position according to its score. */
- public void insertSorted(int newNode, float newScore) {
+ /** Add node and score but do not insert as sorted */
+ public void addOutOfOrder(int newNode, float newScore) {
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
+ node[size] = newNode;
+ score[size] = newScore;
+ size++;
+ }
+
+ /**
+ * Sort the array according to scores, and return the sorted indexes of previous unsorted nodes
+ * (unchecked nodes)
+ *
+ * @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is
+ * already fully sorted
+ */
+ public int[] sort() {
+ if (size == sortedNodeSize) {
+ // all nodes checked and sorted
+ return null;
+ }
+ assert sortedNodeSize < size;
+ int[] uncheckedIndexes = new int[size - sortedNodeSize];
+ int count = 0;
+ while (sortedNodeSize != size) {
+ uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside
+ for (int i = 0; i < count; i++) {
+ if (uncheckedIndexes[i] >= uncheckedIndexes[count]) {
+ // the previous inserted nodes has been shifted
+ uncheckedIndexes[i]++;
+ }
+ }
+ count++;
+ }
+ Arrays.sort(uncheckedIndexes);
+ return uncheckedIndexes;
+ }
+
+ /** insert the first unsorted node into its sorted position */
+ private int insertSortedInternal() {
+ assert sortedNodeSize < size : "Call this method only when there's unsorted node";
+ int tmpNode = node[sortedNodeSize];
+ float tmpScore = score[sortedNodeSize];
int insertionPoint =
scoresDescOrder
- ? descSortFindRightMostInsertionPoint(newScore)
- : ascSortFindRightMostInsertionPoint(newScore);
- System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint);
- System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint);
- node[insertionPoint] = newNode;
- score[insertionPoint] = newScore;
- ++size;
+ ? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize)
+ : ascSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize);
+ System.arraycopy(
+ node, insertionPoint, node, insertionPoint + 1, sortedNodeSize - insertionPoint);
+ System.arraycopy(
+ score, insertionPoint, score, insertionPoint + 1, sortedNodeSize - insertionPoint);
+ node[insertionPoint] = tmpNode;
+ score[insertionPoint] = tmpScore;
+ ++sortedNodeSize;
+ return insertionPoint;
+ }
+
+ /** This method is for test only. */
+ void insertSorted(int newNode, float newScore) {
+ addOutOfOrder(newNode, newScore);
+ insertSortedInternal();
}
public int size() {
@@ -97,15 +151,20 @@ public float[] score() {
public void clear() {
size = 0;
+ sortedNodeSize = 0;
}
public void removeLast() {
size--;
+ sortedNodeSize = Math.min(sortedNodeSize, size);
}
public void removeIndex(int idx) {
System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
+ if (idx < sortedNodeSize) {
+ sortedNodeSize--;
+ }
size--;
}
@@ -114,11 +173,11 @@ public String toString() {
return "NeighborArray[" + size + "]";
}
- private int ascSortFindRightMostInsertionPoint(float newScore) {
- int insertionPoint = Arrays.binarySearch(score, 0, size, newScore);
+ private int ascSortFindRightMostInsertionPoint(float newScore, int bound) {
+ int insertionPoint = Arrays.binarySearch(score, 0, bound, newScore);
if (insertionPoint >= 0) {
// find the right most position with the same score
- while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
+ while ((insertionPoint < bound - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
insertionPoint++;
}
insertionPoint++;
@@ -128,9 +187,9 @@ private int ascSortFindRightMostInsertionPoint(float newScore) {
return insertionPoint;
}
- private int descSortFindRightMostInsertionPoint(float newScore) {
+ private int descSortFindRightMostInsertionPoint(float newScore, int bound) {
int start = 0;
- int end = size - 1;
+ int end = bound - 1;
while (start <= end) {
int mid = (start + end) / 2;
if (score[mid] < newScore) end = mid - 1;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
index 9862536de08c..75a77af056a2 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java
@@ -170,14 +170,14 @@ public NodesIterator getNodesOnLevel(int level) {
public long ramBytesUsed() {
long neighborArrayBytes0 =
nsize0 * (Integer.BYTES + Float.BYTES)
- + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
- + RamUsageEstimator.NUM_BYTES_OBJECT_REF
- + Integer.BYTES * 2;
+ + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
+ + Integer.BYTES * 3;
long neighborArrayBytes =
nsize * (Integer.BYTES + Float.BYTES)
- + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2
- + RamUsageEstimator.NUM_BYTES_OBJECT_REF
- + Integer.BYTES * 2;
+ + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER
+ + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2
+ + Integer.BYTES * 3;
long total = 0;
for (int l = 0; l < numLevels; l++) {
if (l == 0) {
diff --git a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java
new file mode 100644
index 000000000000..a1a5a404223f
--- /dev/null
+++ b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java
@@ -0,0 +1,499 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util;
+
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+import java.util.logging.Logger;
+import jdk.incubator.vector.ByteVector;
+import jdk.incubator.vector.FloatVector;
+import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.ShortVector;
+import jdk.incubator.vector.Vector;
+import jdk.incubator.vector.VectorOperators;
+import jdk.incubator.vector.VectorShape;
+import jdk.incubator.vector.VectorSpecies;
+
+/** A VectorUtil provider implementation that leverages the Panama Vector API. */
+final class VectorUtilPanamaProvider implements VectorUtilProvider {
+
+ private static final int INT_SPECIES_PREF_BIT_SIZE = IntVector.SPECIES_PREFERRED.vectorBitSize();
+
+ private static final VectorSpecies PREF_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
+ private static final VectorSpecies PREF_BYTE_SPECIES;
+ private static final VectorSpecies PREF_SHORT_SPECIES;
+
+ /**
+ * x86 and less than 256-bit vectors.
+ *
+ *
it could be that it has only AVX1 and integer vectors are fast. it could also be that it has
+ * no AVX and integer vectors are extremely slow. don't use integer vectors to avoid landmines.
+ */
+ private final boolean hasFastIntegerVectors;
+
+ static {
+ if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
+ PREF_BYTE_SPECIES =
+ ByteVector.SPECIES_MAX.withShape(
+ VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 2));
+ PREF_SHORT_SPECIES =
+ ShortVector.SPECIES_MAX.withShape(
+ VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 1));
+ } else {
+ PREF_BYTE_SPECIES = null;
+ PREF_SHORT_SPECIES = null;
+ }
+ }
+
+ // Extracted to a method to be able to apply the SuppressForbidden annotation
+ @SuppressWarnings("removal")
+ @SuppressForbidden(reason = "security manager")
+ private static T doPrivileged(PrivilegedAction action) {
+ return AccessController.doPrivileged(action);
+ }
+
+ VectorUtilPanamaProvider(boolean testMode) {
+ if (!testMode && INT_SPECIES_PREF_BIT_SIZE < 128) {
+ throw new UnsupportedOperationException(
+ "Vector bit size is less than 128: " + INT_SPECIES_PREF_BIT_SIZE);
+ }
+
+ // hack to work around for JDK-8309727:
+ try {
+ doPrivileged(
+ () ->
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, new float[PREF_FLOAT_SPECIES.length()], 0));
+ } catch (SecurityException se) {
+ throw new UnsupportedOperationException(
+ "We hit initialization failure described in JDK-8309727: " + se);
+ }
+
+ // check if the system is x86 and less than 256-bit vectors:
+ var isAMD64withoutAVX2 = Constants.OS_ARCH.equals("amd64") && INT_SPECIES_PREF_BIT_SIZE < 256;
+ this.hasFastIntegerVectors = testMode || false == isAMD64withoutAVX2;
+
+ var log = Logger.getLogger(getClass().getName());
+ log.info(
+ "Java vector incubator API enabled"
+ + (testMode ? " (test mode)" : "")
+ + "; uses preferredBitSize="
+ + INT_SPECIES_PREF_BIT_SIZE);
+ }
+
+ @Override
+ public float dotProduct(float[] a, float[] b) {
+ int i = 0;
+ float res = 0;
+ // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
+ if (a.length > 2 * PREF_FLOAT_SPECIES.length()) {
+ // vector loop is unrolled 4x (4 accumulators in parallel)
+ FloatVector acc1 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector acc2 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector acc3 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector acc4 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length());
+ for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) {
+ FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
+ FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
+ acc1 = acc1.add(va.mul(vb));
+ FloatVector vc =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length());
+ FloatVector vd =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length());
+ acc2 = acc2.add(vc.mul(vd));
+ FloatVector ve =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length());
+ FloatVector vf =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length());
+ acc3 = acc3.add(ve.mul(vf));
+ FloatVector vg =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length());
+ FloatVector vh =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length());
+ acc4 = acc4.add(vg.mul(vh));
+ }
+ // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
+ upperBound = PREF_FLOAT_SPECIES.loopBound(a.length);
+ for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) {
+ FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
+ FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
+ acc1 = acc1.add(va.mul(vb));
+ }
+ // reduce
+ FloatVector res1 = acc1.add(acc2);
+ FloatVector res2 = acc3.add(acc4);
+ res += res1.add(res2).reduceLanes(VectorOperators.ADD);
+ }
+
+ for (; i < a.length; i++) {
+ res += b[i] * a[i];
+ }
+ return res;
+ }
+
+ @Override
+ public float cosine(float[] a, float[] b) {
+ int i = 0;
+ float sum = 0;
+ float norm1 = 0;
+ float norm2 = 0;
+ // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
+ if (a.length > 2 * PREF_FLOAT_SPECIES.length()) {
+ // vector loop is unrolled 4x (4 accumulators in parallel)
+ FloatVector sum1 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector sum2 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector sum3 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector sum4 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector norm1_1 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector norm1_2 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector norm1_3 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector norm1_4 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector norm2_1 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector norm2_2 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector norm2_3 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector norm2_4 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length());
+ for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) {
+ FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
+ FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
+ sum1 = sum1.add(va.mul(vb));
+ norm1_1 = norm1_1.add(va.mul(va));
+ norm2_1 = norm2_1.add(vb.mul(vb));
+ FloatVector vc =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length());
+ FloatVector vd =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length());
+ sum2 = sum2.add(vc.mul(vd));
+ norm1_2 = norm1_2.add(vc.mul(vc));
+ norm2_2 = norm2_2.add(vd.mul(vd));
+ FloatVector ve =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length());
+ FloatVector vf =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length());
+ sum3 = sum3.add(ve.mul(vf));
+ norm1_3 = norm1_3.add(ve.mul(ve));
+ norm2_3 = norm2_3.add(vf.mul(vf));
+ FloatVector vg =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length());
+ FloatVector vh =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length());
+ sum4 = sum4.add(vg.mul(vh));
+ norm1_4 = norm1_4.add(vg.mul(vg));
+ norm2_4 = norm2_4.add(vh.mul(vh));
+ }
+ // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
+ upperBound = PREF_FLOAT_SPECIES.loopBound(a.length);
+ for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) {
+ FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
+ FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
+ sum1 = sum1.add(va.mul(vb));
+ norm1_1 = norm1_1.add(va.mul(va));
+ norm2_1 = norm2_1.add(vb.mul(vb));
+ }
+ // reduce
+ FloatVector sumres1 = sum1.add(sum2);
+ FloatVector sumres2 = sum3.add(sum4);
+ FloatVector norm1res1 = norm1_1.add(norm1_2);
+ FloatVector norm1res2 = norm1_3.add(norm1_4);
+ FloatVector norm2res1 = norm2_1.add(norm2_2);
+ FloatVector norm2res2 = norm2_3.add(norm2_4);
+ sum += sumres1.add(sumres2).reduceLanes(VectorOperators.ADD);
+ norm1 += norm1res1.add(norm1res2).reduceLanes(VectorOperators.ADD);
+ norm2 += norm2res1.add(norm2res2).reduceLanes(VectorOperators.ADD);
+ }
+
+ for (; i < a.length; i++) {
+ float elem1 = a[i];
+ float elem2 = b[i];
+ sum += elem1 * elem2;
+ norm1 += elem1 * elem1;
+ norm2 += elem2 * elem2;
+ }
+ return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
+ }
+
+ @Override
+ public float squareDistance(float[] a, float[] b) {
+ int i = 0;
+ float res = 0;
+ // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
+ if (a.length > 2 * PREF_FLOAT_SPECIES.length()) {
+ // vector loop is unrolled 4x (4 accumulators in parallel)
+ FloatVector acc1 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector acc2 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector acc3 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ FloatVector acc4 = FloatVector.zero(PREF_FLOAT_SPECIES);
+ int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length());
+ for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) {
+ FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
+ FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
+ FloatVector diff1 = va.sub(vb);
+ acc1 = acc1.add(diff1.mul(diff1));
+ FloatVector vc =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length());
+ FloatVector vd =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length());
+ FloatVector diff2 = vc.sub(vd);
+ acc2 = acc2.add(diff2.mul(diff2));
+ FloatVector ve =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length());
+ FloatVector vf =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length());
+ FloatVector diff3 = ve.sub(vf);
+ acc3 = acc3.add(diff3.mul(diff3));
+ FloatVector vg =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length());
+ FloatVector vh =
+ FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length());
+ FloatVector diff4 = vg.sub(vh);
+ acc4 = acc4.add(diff4.mul(diff4));
+ }
+ // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
+ upperBound = PREF_FLOAT_SPECIES.loopBound(a.length);
+ for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) {
+ FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i);
+ FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i);
+ FloatVector diff = va.sub(vb);
+ acc1 = acc1.add(diff.mul(diff));
+ }
+ // reduce
+ FloatVector res1 = acc1.add(acc2);
+ FloatVector res2 = acc3.add(acc4);
+ res += res1.add(res2).reduceLanes(VectorOperators.ADD);
+ }
+
+ for (; i < a.length; i++) {
+ float diff = a[i] - b[i];
+ res += diff * diff;
+ }
+ return res;
+ }
+
+ // Binary functions, these all follow a general pattern like this:
+ //
+ // short intermediate = a * b;
+ // int accumulator = accumulator + intermediate;
+ //
+ // 256 or 512 bit vectors can process 64 or 128 bits at a time, respectively
+ // intermediate results use 128 or 256 bit vectors, respectively
+ // final accumulator uses 256 or 512 bit vectors, respectively
+ //
+ // We also support 128 bit vectors, using two 128 bit accumulators.
+ // This is slower but still faster than not vectorizing at all.
+
+ @Override
+ public int dotProduct(byte[] a, byte[] b) {
+ int i = 0;
+ int res = 0;
+ // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
+ // vectors (256-bit on intel to dodge performance landmines)
+ if (a.length >= 16 && hasFastIntegerVectors) {
+ // compute vectorized dot product consistent with VPDPBUSD instruction
+ if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
+ // optimized 256/512 bit implementation, processes 8/16 bytes at a time
+ int upperBound = PREF_BYTE_SPECIES.loopBound(a.length);
+ IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
+ for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) {
+ ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i);
+ ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i);
+ Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
+ Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
+ Vector prod16 = va16.mul(vb16);
+ Vector prod32 =
+ prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
+ acc = acc.add(prod32);
+ }
+ // reduce
+ res += acc.reduceLanes(VectorOperators.ADD);
+ } else {
+ // 128-bit implementation, which must "split up" vectors due to widening conversions
+ int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
+ IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
+ IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
+ for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
+ ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
+ ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
+ // expand each byte vector into short vector and multiply
+ Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
+ Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
+ Vector prod16 = va16.mul(vb16);
+ // split each short vector into two int vectors and add
+ Vector prod32_1 =
+ prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
+ Vector prod32_2 =
+ prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
+ acc1 = acc1.add(prod32_1);
+ acc2 = acc2.add(prod32_2);
+ }
+ // reduce
+ res += acc1.add(acc2).reduceLanes(VectorOperators.ADD);
+ }
+ }
+
+ for (; i < a.length; i++) {
+ res += b[i] * a[i];
+ }
+ return res;
+ }
+
+ @Override
+ public float cosine(byte[] a, byte[] b) {
+ int i = 0;
+ int sum = 0;
+ int norm1 = 0;
+ int norm2 = 0;
+ // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
+ // vectors (256-bit on intel to dodge performance landmines)
+ if (a.length >= 16 && hasFastIntegerVectors) {
+ if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
+ // optimized 256/512 bit implementation, processes 8/16 bytes at a time
+ int upperBound = PREF_BYTE_SPECIES.loopBound(a.length);
+ IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
+ IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
+ IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
+ for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) {
+ ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i);
+ ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i);
+ Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
+ Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
+ Vector prod16 = va16.mul(vb16);
+ Vector norm1_16 = va16.mul(va16);
+ Vector norm2_16 = vb16.mul(vb16);
+ Vector prod32 =
+ prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
+ Vector norm1_32 =
+ norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
+ Vector norm2_32 =
+ norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
+ accSum = accSum.add(prod32);
+ accNorm1 = accNorm1.add(norm1_32);
+ accNorm2 = accNorm2.add(norm2_32);
+ }
+ // reduce
+ sum += accSum.reduceLanes(VectorOperators.ADD);
+ norm1 += accNorm1.reduceLanes(VectorOperators.ADD);
+ norm2 += accNorm2.reduceLanes(VectorOperators.ADD);
+ } else {
+ // 128-bit implementation, which must "split up" vectors due to widening conversions
+ int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
+ IntVector accSum1 = IntVector.zero(IntVector.SPECIES_128);
+ IntVector accSum2 = IntVector.zero(IntVector.SPECIES_128);
+ IntVector accNorm1_1 = IntVector.zero(IntVector.SPECIES_128);
+ IntVector accNorm1_2 = IntVector.zero(IntVector.SPECIES_128);
+ IntVector accNorm2_1 = IntVector.zero(IntVector.SPECIES_128);
+ IntVector accNorm2_2 = IntVector.zero(IntVector.SPECIES_128);
+ for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
+ ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
+ ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
+ // expand each byte vector into short vector and perform multiplications
+ Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
+ Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
+ Vector prod16 = va16.mul(vb16);
+ Vector norm1_16 = va16.mul(va16);
+ Vector norm2_16 = vb16.mul(vb16);
+ // split each short vector into two int vectors and add
+ Vector prod32_1 =
+ prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
+ Vector prod32_2 =
+ prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
+ Vector norm1_32_1 =
+ norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
+ Vector norm1_32_2 =
+ norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
+ Vector norm2_32_1 =
+ norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
+ Vector norm2_32_2 =
+ norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
+ accSum1 = accSum1.add(prod32_1);
+ accSum2 = accSum2.add(prod32_2);
+ accNorm1_1 = accNorm1_1.add(norm1_32_1);
+ accNorm1_2 = accNorm1_2.add(norm1_32_2);
+ accNorm2_1 = accNorm2_1.add(norm2_32_1);
+ accNorm2_2 = accNorm2_2.add(norm2_32_2);
+ }
+ // reduce
+ sum += accSum1.add(accSum2).reduceLanes(VectorOperators.ADD);
+ norm1 += accNorm1_1.add(accNorm1_2).reduceLanes(VectorOperators.ADD);
+ norm2 += accNorm2_1.add(accNorm2_2).reduceLanes(VectorOperators.ADD);
+ }
+ }
+
+ for (; i < a.length; i++) {
+ byte elem1 = a[i];
+ byte elem2 = b[i];
+ sum += elem1 * elem2;
+ norm1 += elem1 * elem1;
+ norm2 += elem2 * elem2;
+ }
+ return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
+ }
+
+ @Override
+ public int squareDistance(byte[] a, byte[] b) {
+ int i = 0;
+ int res = 0;
+ // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
+ // vectors (256-bit on intel to dodge performance landmines)
+ if (a.length >= 16 && hasFastIntegerVectors) {
+ if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
+ // optimized 256/512 bit implementation, processes 8/16 bytes at a time
+ int upperBound = PREF_BYTE_SPECIES.loopBound(a.length);
+ IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
+ for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) {
+ ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i);
+ ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i);
+ Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
+ Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0);
+ Vector diff16 = va16.sub(vb16);
+ Vector diff32 =
+ diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0);
+ acc = acc.add(diff32.mul(diff32));
+ }
+ // reduce
+ res += acc.reduceLanes(VectorOperators.ADD);
+ } else {
+ // 128-bit implementation, which must "split up" vectors due to widening conversions
+ int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
+ IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
+ IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
+ for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
+ ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
+ ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
+ // expand each byte vector into short vector and subtract
+ Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
+ Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
+ Vector diff16 = va16.sub(vb16);
+ // split each short vector into two int vectors, square, and add
+ Vector diff32_1 =
+ diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
+ Vector diff32_2 =
+ diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
+ acc1 = acc1.add(diff32_1.mul(diff32_1));
+ acc2 = acc2.add(diff32_2.mul(diff32_2));
+ }
+ // reduce
+ res += acc1.add(acc2).reduceLanes(VectorOperators.ADD);
+ }
+ }
+
+ for (; i < a.length; i++) {
+ int diff = a[i] - b[i];
+ res += diff * diff;
+ }
+ return res;
+ }
+}
diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java
new file mode 100644
index 000000000000..7b2216add785
--- /dev/null
+++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java
@@ -0,0 +1,588 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.store;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+import java.nio.ByteOrder;
+import java.util.Arrays;
+import java.util.Objects;
+import org.apache.lucene.util.ArrayUtil;
+
+/**
+ * Base IndexInput implementation that uses an array of MemorySegments to represent a file.
+ *
+ *
For efficiency, this class requires that the segment size are a power-of-two (
+ * chunkSizePower).
+ */
+@SuppressWarnings("preview")
+abstract class MemorySegmentIndexInput extends IndexInput implements RandomAccessInput {
+ static final ValueLayout.OfByte LAYOUT_BYTE = ValueLayout.JAVA_BYTE;
+ static final ValueLayout.OfShort LAYOUT_LE_SHORT =
+ ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+ static final ValueLayout.OfInt LAYOUT_LE_INT =
+ ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+ static final ValueLayout.OfLong LAYOUT_LE_LONG =
+ ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+ static final ValueLayout.OfFloat LAYOUT_LE_FLOAT =
+ ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+
+ final long length;
+ final long chunkSizeMask;
+ final int chunkSizePower;
+ final Arena arena;
+ final MemorySegment[] segments;
+
+ int curSegmentIndex = -1;
+ MemorySegment
+ curSegment; // redundant for speed: segments[curSegmentIndex], also marker if closed!
+ long curPosition; // relative to curSegment, not globally
+
+ public static MemorySegmentIndexInput newInstance(
+ String resourceDescription,
+ Arena arena,
+ MemorySegment[] segments,
+ long length,
+ int chunkSizePower) {
+ assert Arrays.stream(segments).map(MemorySegment::scope).allMatch(arena.scope()::equals);
+ if (segments.length == 1) {
+ return new SingleSegmentImpl(resourceDescription, arena, segments[0], length, chunkSizePower);
+ } else {
+ return new MultiSegmentImpl(resourceDescription, arena, segments, 0, length, chunkSizePower);
+ }
+ }
+
+ private MemorySegmentIndexInput(
+ String resourceDescription,
+ Arena arena,
+ MemorySegment[] segments,
+ long length,
+ int chunkSizePower) {
+ super(resourceDescription);
+ this.arena = arena;
+ this.segments = segments;
+ this.length = length;
+ this.chunkSizePower = chunkSizePower;
+ this.chunkSizeMask = (1L << chunkSizePower) - 1L;
+ this.curSegment = segments[0];
+ }
+
+ void ensureOpen() {
+ if (curSegment == null) {
+ throw alreadyClosed(null);
+ }
+ }
+
+ // the unused parameter is just to silence javac about unused variables
+ RuntimeException handlePositionalIOOBE(RuntimeException unused, String action, long pos)
+ throws IOException {
+ if (pos < 0L) {
+ return new IllegalArgumentException(action + " negative position (pos=" + pos + "): " + this);
+ } else {
+ throw new EOFException(action + " past EOF (pos=" + pos + "): " + this);
+ }
+ }
+
+ // the unused parameter is just to silence javac about unused variables
+ AlreadyClosedException alreadyClosed(RuntimeException unused) {
+ return new AlreadyClosedException("Already closed: " + this);
+ }
+
+ @Override
+ public final byte readByte() throws IOException {
+ try {
+ final byte v = curSegment.get(LAYOUT_BYTE, curPosition);
+ curPosition++;
+ return v;
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException e) {
+ do {
+ curSegmentIndex++;
+ if (curSegmentIndex >= segments.length) {
+ throw new EOFException("read past EOF: " + this);
+ }
+ curSegment = segments[curSegmentIndex];
+ curPosition = 0L;
+ } while (curSegment.byteSize() == 0L);
+ final byte v = curSegment.get(LAYOUT_BYTE, curPosition);
+ curPosition++;
+ return v;
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public final void readBytes(byte[] b, int offset, int len) throws IOException {
+ try {
+ MemorySegment.copy(curSegment, LAYOUT_BYTE, curPosition, b, offset, len);
+ curPosition += len;
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException e) {
+ readBytesBoundary(b, offset, len);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ private void readBytesBoundary(byte[] b, int offset, int len) throws IOException {
+ try {
+ long curAvail = curSegment.byteSize() - curPosition;
+ while (len > curAvail) {
+ MemorySegment.copy(curSegment, LAYOUT_BYTE, curPosition, b, offset, (int) curAvail);
+ len -= curAvail;
+ offset += curAvail;
+ curSegmentIndex++;
+ if (curSegmentIndex >= segments.length) {
+ throw new EOFException("read past EOF: " + this);
+ }
+ curSegment = segments[curSegmentIndex];
+ curPosition = 0L;
+ curAvail = curSegment.byteSize();
+ }
+ MemorySegment.copy(curSegment, LAYOUT_BYTE, curPosition, b, offset, len);
+ curPosition += len;
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public void readInts(int[] dst, int offset, int length) throws IOException {
+ try {
+ MemorySegment.copy(curSegment, LAYOUT_LE_INT, curPosition, dst, offset, length);
+ curPosition += Integer.BYTES * (long) length;
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException iobe) {
+ super.readInts(dst, offset, length);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public void readLongs(long[] dst, int offset, int length) throws IOException {
+ try {
+ MemorySegment.copy(curSegment, LAYOUT_LE_LONG, curPosition, dst, offset, length);
+ curPosition += Long.BYTES * (long) length;
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException iobe) {
+ super.readLongs(dst, offset, length);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public void readFloats(float[] dst, int offset, int length) throws IOException {
+ try {
+ MemorySegment.copy(curSegment, LAYOUT_LE_FLOAT, curPosition, dst, offset, length);
+ curPosition += Float.BYTES * (long) length;
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException iobe) {
+ super.readFloats(dst, offset, length);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public final short readShort() throws IOException {
+ try {
+ final short v = curSegment.get(LAYOUT_LE_SHORT, curPosition);
+ curPosition += Short.BYTES;
+ return v;
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException e) {
+ return super.readShort();
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public final int readInt() throws IOException {
+ try {
+ final int v = curSegment.get(LAYOUT_LE_INT, curPosition);
+ curPosition += Integer.BYTES;
+ return v;
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException e) {
+ return super.readInt();
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public final long readLong() throws IOException {
+ try {
+ final long v = curSegment.get(LAYOUT_LE_LONG, curPosition);
+ curPosition += Long.BYTES;
+ return v;
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException e) {
+ return super.readLong();
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public long getFilePointer() {
+ ensureOpen();
+ return (((long) curSegmentIndex) << chunkSizePower) + curPosition;
+ }
+
+ @Override
+ public void seek(long pos) throws IOException {
+ ensureOpen();
+ // we use >> here to preserve negative, so we will catch AIOOBE,
+ // in case pos + offset overflows.
+ final int si = (int) (pos >> chunkSizePower);
+ try {
+ if (si != curSegmentIndex) {
+ final MemorySegment seg = segments[si];
+ // write values, on exception all is unchanged
+ this.curSegmentIndex = si;
+ this.curSegment = seg;
+ }
+ this.curPosition = Objects.checkIndex(pos & chunkSizeMask, curSegment.byteSize() + 1);
+ } catch (IndexOutOfBoundsException e) {
+ throw handlePositionalIOOBE(e, "seek", pos);
+ }
+ }
+
+ @Override
+ public byte readByte(long pos) throws IOException {
+ try {
+ final int si = (int) (pos >> chunkSizePower);
+ return segments[si].get(LAYOUT_BYTE, pos & chunkSizeMask);
+ } catch (IndexOutOfBoundsException ioobe) {
+ throw handlePositionalIOOBE(ioobe, "read", pos);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ // used only by random access methods to handle reads across boundaries
+ private void setPos(long pos, int si) throws IOException {
+ try {
+ final MemorySegment seg = segments[si];
+ // write values, on exception above all is unchanged
+ this.curPosition = pos & chunkSizeMask;
+ this.curSegmentIndex = si;
+ this.curSegment = seg;
+ } catch (IndexOutOfBoundsException ioobe) {
+ throw handlePositionalIOOBE(ioobe, "read", pos);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public short readShort(long pos) throws IOException {
+ final int si = (int) (pos >> chunkSizePower);
+ try {
+ return segments[si].get(LAYOUT_LE_SHORT, pos & chunkSizeMask);
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException ioobe) {
+ // either it's a boundary, or read past EOF, fall back:
+ setPos(pos, si);
+ return readShort();
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public int readInt(long pos) throws IOException {
+ final int si = (int) (pos >> chunkSizePower);
+ try {
+ return segments[si].get(LAYOUT_LE_INT, pos & chunkSizeMask);
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException ioobe) {
+ // either it's a boundary, or read past EOF, fall back:
+ setPos(pos, si);
+ return readInt();
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public long readLong(long pos) throws IOException {
+ final int si = (int) (pos >> chunkSizePower);
+ try {
+ return segments[si].get(LAYOUT_LE_LONG, pos & chunkSizeMask);
+ } catch (
+ @SuppressWarnings("unused")
+ IndexOutOfBoundsException ioobe) {
+ // either it's a boundary, or read past EOF, fall back:
+ setPos(pos, si);
+ return readLong();
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public final long length() {
+ return length;
+ }
+
+ @Override
+ public final MemorySegmentIndexInput clone() {
+ final MemorySegmentIndexInput clone = buildSlice((String) null, 0L, this.length);
+ try {
+ clone.seek(getFilePointer());
+ } catch (IOException ioe) {
+ throw new AssertionError(ioe);
+ }
+
+ return clone;
+ }
+
+ /**
+ * Creates a slice of this index input, with the given description, offset, and length. The slice
+ * is seeked to the beginning.
+ */
+ @Override
+ public final MemorySegmentIndexInput slice(String sliceDescription, long offset, long length) {
+ if (offset < 0 || length < 0 || offset + length > this.length) {
+ throw new IllegalArgumentException(
+ "slice() "
+ + sliceDescription
+ + " out of bounds: offset="
+ + offset
+ + ",length="
+ + length
+ + ",fileLength="
+ + this.length
+ + ": "
+ + this);
+ }
+
+ return buildSlice(sliceDescription, offset, length);
+ }
+
+ /** Builds the actual sliced IndexInput (may apply extra offset in subclasses). * */
+ MemorySegmentIndexInput buildSlice(String sliceDescription, long offset, long length) {
+ ensureOpen();
+
+ final long sliceEnd = offset + length;
+ final int startIndex = (int) (offset >>> chunkSizePower);
+ final int endIndex = (int) (sliceEnd >>> chunkSizePower);
+
+ // we always allocate one more slice, the last one may be a 0 byte one after truncating with
+ // asSlice():
+ final MemorySegment slices[] = ArrayUtil.copyOfSubArray(segments, startIndex, endIndex + 1);
+
+ // set the last segment's limit for the sliced view.
+ slices[slices.length - 1] = slices[slices.length - 1].asSlice(0L, sliceEnd & chunkSizeMask);
+
+ offset = offset & chunkSizeMask;
+
+ final String newResourceDescription = getFullSliceDescription(sliceDescription);
+ if (slices.length == 1) {
+ return new SingleSegmentImpl(
+ newResourceDescription,
+ null, // clones don't have an Arena, as they can't close)
+ slices[0].asSlice(offset, length),
+ length,
+ chunkSizePower);
+ } else {
+ return new MultiSegmentImpl(
+ newResourceDescription,
+ null, // clones don't have an Arena, as they can't close)
+ slices,
+ offset,
+ length,
+ chunkSizePower);
+ }
+ }
+
+ @Override
+ public final void close() throws IOException {
+ if (curSegment == null) {
+ return;
+ }
+
+ // make sure all accesses to this IndexInput instance throw NPE:
+ curSegment = null;
+ Arrays.fill(segments, null);
+
+ // the master IndexInput has an Arena and is able
+ // to release all resources (unmap segments) - a
+ // side effect is that other threads still using clones
+ // will throw IllegalStateException
+ if (arena != null) {
+ arena.close();
+ }
+ }
+
+ /** Optimization of MemorySegmentIndexInput for when there is only one segment. */
+ static final class SingleSegmentImpl extends MemorySegmentIndexInput {
+
+ SingleSegmentImpl(
+ String resourceDescription,
+ Arena arena,
+ MemorySegment segment,
+ long length,
+ int chunkSizePower) {
+ super(resourceDescription, arena, new MemorySegment[] {segment}, length, chunkSizePower);
+ this.curSegmentIndex = 0;
+ }
+
+ @Override
+ public void seek(long pos) throws IOException {
+ ensureOpen();
+ try {
+ curPosition = Objects.checkIndex(pos, length + 1);
+ } catch (IndexOutOfBoundsException e) {
+ throw handlePositionalIOOBE(e, "seek", pos);
+ }
+ }
+
+ @Override
+ public long getFilePointer() {
+ ensureOpen();
+ return curPosition;
+ }
+
+ @Override
+ public byte readByte(long pos) throws IOException {
+ try {
+ return curSegment.get(LAYOUT_BYTE, pos);
+ } catch (IndexOutOfBoundsException e) {
+ throw handlePositionalIOOBE(e, "read", pos);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public short readShort(long pos) throws IOException {
+ try {
+ return curSegment.get(LAYOUT_LE_SHORT, pos);
+ } catch (IndexOutOfBoundsException e) {
+ throw handlePositionalIOOBE(e, "read", pos);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public int readInt(long pos) throws IOException {
+ try {
+ return curSegment.get(LAYOUT_LE_INT, pos);
+ } catch (IndexOutOfBoundsException e) {
+ throw handlePositionalIOOBE(e, "read", pos);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+
+ @Override
+ public long readLong(long pos) throws IOException {
+ try {
+ return curSegment.get(LAYOUT_LE_LONG, pos);
+ } catch (IndexOutOfBoundsException e) {
+ throw handlePositionalIOOBE(e, "read", pos);
+ } catch (NullPointerException | IllegalStateException e) {
+ throw alreadyClosed(e);
+ }
+ }
+ }
+
+ /** This class adds offset support to MemorySegmentIndexInput, which is needed for slices. */
+ static final class MultiSegmentImpl extends MemorySegmentIndexInput {
+ private final long offset;
+
+ MultiSegmentImpl(
+ String resourceDescription,
+ Arena arena,
+ MemorySegment[] segments,
+ long offset,
+ long length,
+ int chunkSizePower) {
+ super(resourceDescription, arena, segments, length, chunkSizePower);
+ this.offset = offset;
+ try {
+ seek(0L);
+ } catch (IOException ioe) {
+ throw new AssertionError(ioe);
+ }
+ assert curSegment != null && curSegmentIndex >= 0;
+ }
+
+ @Override
+ RuntimeException handlePositionalIOOBE(RuntimeException unused, String action, long pos)
+ throws IOException {
+ return super.handlePositionalIOOBE(unused, action, pos - offset);
+ }
+
+ @Override
+ public void seek(long pos) throws IOException {
+ assert pos >= 0L : "negative position";
+ super.seek(pos + offset);
+ }
+
+ @Override
+ public long getFilePointer() {
+ return super.getFilePointer() - offset;
+ }
+
+ @Override
+ public byte readByte(long pos) throws IOException {
+ return super.readByte(pos + offset);
+ }
+
+ @Override
+ public short readShort(long pos) throws IOException {
+ return super.readShort(pos + offset);
+ }
+
+ @Override
+ public int readInt(long pos) throws IOException {
+ return super.readInt(pos + offset);
+ }
+
+ @Override
+ public long readLong(long pos) throws IOException {
+ return super.readLong(pos + offset);
+ }
+
+ @Override
+ MemorySegmentIndexInput buildSlice(String sliceDescription, long ofs, long length) {
+ return super.buildSlice(sliceDescription, this.offset + ofs, length);
+ }
+ }
+}
diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java
new file mode 100644
index 000000000000..e994c2dddfff
--- /dev/null
+++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.store;
+
+import java.io.IOException;
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.nio.channels.FileChannel;
+import java.nio.channels.FileChannel.MapMode;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
+import java.util.logging.Logger;
+import org.apache.lucene.util.Constants;
+import org.apache.lucene.util.Unwrappable;
+
+@SuppressWarnings("preview")
+final class MemorySegmentIndexInputProvider implements MMapDirectory.MMapIndexInputProvider {
+
+ public MemorySegmentIndexInputProvider() {
+ var log = Logger.getLogger(getClass().getName());
+ log.info(
+ "Using MemorySegmentIndexInput with Java 21; to disable start with -D"
+ + MMapDirectory.ENABLE_MEMORY_SEGMENTS_SYSPROP
+ + "=false");
+ }
+
+ @Override
+ public IndexInput openInput(Path path, IOContext context, int chunkSizePower, boolean preload)
+ throws IOException {
+ final String resourceDescription = "MemorySegmentIndexInput(path=\"" + path.toString() + "\")";
+
+ // Work around for JDK-8259028: we need to unwrap our test-only file system layers
+ path = Unwrappable.unwrapAll(path);
+
+ boolean success = false;
+ final Arena arena = Arena.ofShared();
+ try (var fc = FileChannel.open(path, StandardOpenOption.READ)) {
+ final long fileSize = fc.size();
+ final IndexInput in =
+ MemorySegmentIndexInput.newInstance(
+ resourceDescription,
+ arena,
+ map(arena, resourceDescription, fc, chunkSizePower, preload, fileSize),
+ fileSize,
+ chunkSizePower);
+ success = true;
+ return in;
+ } finally {
+ if (success == false) {
+ arena.close();
+ }
+ }
+ }
+
+ @Override
+ public long getDefaultMaxChunkSize() {
+ return Constants.JRE_IS_64BIT ? (1L << 34) : (1L << 28);
+ }
+
+ @Override
+ public boolean isUnmapSupported() {
+ return true;
+ }
+
+ @Override
+ public String getUnmapNotSupportedReason() {
+ return null;
+ }
+
+ private final MemorySegment[] map(
+ Arena arena,
+ String resourceDescription,
+ FileChannel fc,
+ int chunkSizePower,
+ boolean preload,
+ long length)
+ throws IOException {
+ if ((length >>> chunkSizePower) >= Integer.MAX_VALUE)
+ throw new IllegalArgumentException("File too big for chunk size: " + resourceDescription);
+
+ final long chunkSize = 1L << chunkSizePower;
+
+ // we always allocate one more segments, the last one may be a 0 byte one
+ final int nrSegments = (int) (length >>> chunkSizePower) + 1;
+
+ final MemorySegment[] segments = new MemorySegment[nrSegments];
+
+ long startOffset = 0L;
+ for (int segNr = 0; segNr < nrSegments; segNr++) {
+ final long segSize =
+ (length > (startOffset + chunkSize)) ? chunkSize : (length - startOffset);
+ final MemorySegment segment;
+ try {
+ segment = fc.map(MapMode.READ_ONLY, startOffset, segSize, arena);
+ } catch (IOException ioe) {
+ throw convertMapFailedIOException(ioe, resourceDescription, segSize);
+ }
+ if (preload) {
+ segment.load();
+ }
+ segments[segNr] = segment;
+ startOffset += segSize;
+ }
+ return segments;
+ }
+}
diff --git a/lucene/core/src/java21/org/apache/lucene/util/VectorUtilPanamaProvider.txt b/lucene/core/src/java21/org/apache/lucene/util/VectorUtilPanamaProvider.txt
new file mode 100644
index 000000000000..db75951206ef
--- /dev/null
+++ b/lucene/core/src/java21/org/apache/lucene/util/VectorUtilPanamaProvider.txt
@@ -0,0 +1,2 @@
+The version of VectorUtilPanamaProvider for Java 21 is identical to that of Java 20.
+As such, there is no specific 21 version - the Java 20 version will be loaded from the MRJAR.
\ No newline at end of file
diff --git a/lucene/core/src/test/org/apache/lucene/analysis/TestAutomatonToTokenStream.java b/lucene/core/src/test/org/apache/lucene/analysis/TestAutomatonToTokenStream.java
index 27d2c72ddb09..d558ad6d3697 100644
--- a/lucene/core/src/test/org/apache/lucene/analysis/TestAutomatonToTokenStream.java
+++ b/lucene/core/src/test/org/apache/lucene/analysis/TestAutomatonToTokenStream.java
@@ -22,8 +22,8 @@
import java.util.List;
import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.automaton.Automata;
import org.apache.lucene.util.automaton.Automaton;
-import org.apache.lucene.util.automaton.DaciukMihovAutomatonBuilder;
public class TestAutomatonToTokenStream extends BaseTokenStreamTestCase {
@@ -31,7 +31,7 @@ public void testSinglePath() throws IOException {
List acceptStrings = new ArrayList<>();
acceptStrings.add(new BytesRef("abc"));
- Automaton flatPathAutomaton = DaciukMihovAutomatonBuilder.build(acceptStrings);
+ Automaton flatPathAutomaton = Automata.makeStringUnion(acceptStrings);
TokenStream ts = AutomatonToTokenStream.toTokenStream(flatPathAutomaton);
assertTokenStreamContents(
ts,
@@ -48,7 +48,7 @@ public void testParallelPaths() throws IOException {
acceptStrings.add(new BytesRef("123"));
acceptStrings.add(new BytesRef("abc"));
- Automaton flatPathAutomaton = DaciukMihovAutomatonBuilder.build(acceptStrings);
+ Automaton flatPathAutomaton = Automata.makeStringUnion(acceptStrings);
TokenStream ts = AutomatonToTokenStream.toTokenStream(flatPathAutomaton);
assertTokenStreamContents(
ts,
@@ -65,7 +65,7 @@ public void testForkedPath() throws IOException {
acceptStrings.add(new BytesRef("ab3"));
acceptStrings.add(new BytesRef("abc"));
- Automaton flatPathAutomaton = DaciukMihovAutomatonBuilder.build(acceptStrings);
+ Automaton flatPathAutomaton = Automata.makeStringUnion(acceptStrings);
TokenStream ts = AutomatonToTokenStream.toTokenStream(flatPathAutomaton);
assertTokenStreamContents(
ts,
diff --git a/lucene/core/src/test/org/apache/lucene/analysis/TestWordlistLoader.java b/lucene/core/src/test/org/apache/lucene/analysis/TestWordlistLoader.java
index 7af64c0011eb..4747c86834ed 100644
--- a/lucene/core/src/test/org/apache/lucene/analysis/TestWordlistLoader.java
+++ b/lucene/core/src/test/org/apache/lucene/analysis/TestWordlistLoader.java
@@ -24,7 +24,7 @@
public class TestWordlistLoader extends LuceneTestCase {
public void testWordlistLoading() throws IOException {
- String s = "ONE\n two \nthree";
+ String s = "ONE\n two \nthree\n\n";
CharArraySet wordSet1 = WordlistLoader.getWordSet(new StringReader(s));
checkSet(wordSet1);
CharArraySet wordSet2 = WordlistLoader.getWordSet(new BufferedReader(new StringReader(s)));
diff --git a/lucene/core/src/test/org/apache/lucene/geo/TestTessellator.java b/lucene/core/src/test/org/apache/lucene/geo/TestTessellator.java
index 950772f32b75..44c17b48440c 100644
--- a/lucene/core/src/test/org/apache/lucene/geo/TestTessellator.java
+++ b/lucene/core/src/test/org/apache/lucene/geo/TestTessellator.java
@@ -915,6 +915,28 @@ public void testComplexPolygon54() throws Exception {
}
}
+ public void testComplexPolygon55() throws Exception {
+ String geoJson = GeoTestUtil.readShape("github-12352-1.geojson.gz");
+ Polygon[] polygons = Polygon.fromGeoJSON(geoJson);
+ for (Polygon polygon : polygons) {
+ List tessellation =
+ Tessellator.tessellate(polygon, random().nextBoolean());
+ assertEquals(area(polygon), area(tessellation), 0.0);
+ // don't check edges as it takes several minutes
+ }
+ }
+
+ public void testComplexPolygon56() throws Exception {
+ String geoJson = GeoTestUtil.readShape("github-12352-2.geojson.gz");
+ Polygon[] polygons = Polygon.fromGeoJSON(geoJson);
+ for (Polygon polygon : polygons) {
+ List tessellation =
+ Tessellator.tessellate(polygon, random().nextBoolean());
+ assertEquals(area(polygon), area(tessellation), 0.0);
+ // don't check edges as it takes several minutes
+ }
+ }
+
private static class TestCountingMonitor implements Tessellator.Monitor {
private int count = 0;
private int splitsStarted = 0;
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestCachingMergeContext.java b/lucene/core/src/test/org/apache/lucene/index/TestCachingMergeContext.java
new file mode 100644
index 000000000000..3210d54540d7
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/index/TestCachingMergeContext.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+import java.io.IOException;
+import java.util.Set;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.util.InfoStream;
+
+public class TestCachingMergeContext extends LuceneTestCase {
+ public void testNumDeletesToMerge() throws IOException {
+ MockMergeContext mergeContext = new MockMergeContext();
+ CachingMergeContext cachingMergeContext = new CachingMergeContext(mergeContext);
+ assertEquals(cachingMergeContext.numDeletesToMerge(null), 1);
+ assertEquals(cachingMergeContext.cachedNumDeletesToMerge.size(), 1);
+ assertEquals(
+ cachingMergeContext.cachedNumDeletesToMerge.getOrDefault(null, -1), Integer.valueOf(1));
+ assertEquals(mergeContext.count, 1);
+
+ // increase the mock count
+ mergeContext.numDeletesToMerge(null);
+ assertEquals(mergeContext.count, 2);
+
+ // assert the cache result still one
+ assertEquals(cachingMergeContext.numDeletesToMerge(null), 1);
+ assertEquals(cachingMergeContext.cachedNumDeletesToMerge.size(), 1);
+ assertEquals(
+ cachingMergeContext.cachedNumDeletesToMerge.getOrDefault(null, -1), Integer.valueOf(1));
+ }
+
+ private static final class MockMergeContext implements MergePolicy.MergeContext {
+ int count = 0;
+
+ @Override
+ public final int numDeletesToMerge(SegmentCommitInfo info) throws IOException {
+ this.count += 1;
+ return this.count;
+ }
+
+ @Override
+ public int numDeletedDocs(SegmentCommitInfo info) {
+ return 0;
+ }
+
+ @Override
+ public InfoStream getInfoStream() {
+ return null;
+ }
+
+ @Override
+ public Set getMergingSegments() {
+ return null;
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java
index 5dc11a52fb49..4a6365b53094 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java
@@ -40,6 +40,7 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.TestVectorUtil;
/**
* Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running
@@ -463,13 +464,21 @@ public void testVectorValues() throws IOException {
ExitingReaderException.class,
() ->
leaf.searchNearestVectors(
- "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE));
+ "vector",
+ TestVectorUtil.randomVector(dimension),
+ 5,
+ leaf.getLiveDocs(),
+ Integer.MAX_VALUE));
} else {
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
scanAndRetrieve(leaf, iter);
leaf.searchNearestVectors(
- "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE);
+ "vector",
+ TestVectorUtil.randomVector(dimension),
+ 5,
+ leaf.getLiveDocs(),
+ Integer.MAX_VALUE);
}
reader.close();
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestFieldInfo.java b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfo.java
new file mode 100644
index 000000000000..42f2c55fbf88
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfo.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.index;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import org.apache.lucene.tests.util.LuceneTestCase;
+
+public class TestFieldInfo extends LuceneTestCase {
+
+ public void testHandleLegacySupportedUpdatesValidIndexInfoChange() {
+
+ FieldInfo fi1 = new FieldInfoTestBuilder().setIndexOptions(IndexOptions.NONE).get();
+ FieldInfo fi2 = new FieldInfoTestBuilder().setOmitNorms(true).setStoreTermVector(false).get();
+
+ FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2);
+ assertNotNull(updatedFi);
+ assertEquals(updatedFi.getIndexOptions(), IndexOptions.DOCS);
+ assertFalse(
+ updatedFi
+ .hasNorms()); // fi2 is set to omitNorms and fi1 was not indexed, it's OK that the final
+ // FieldInfo returns hasNorms == false
+ compareAttributes(fi1, fi2, Set.of("getIndexOptions", "hasNorms"));
+ compareAttributes(fi1, updatedFi, Set.of("getIndexOptions", "hasNorms"));
+ compareAttributes(fi2, updatedFi, Set.of());
+
+ // The reverse return null since fi2 wouldn't change
+ assertNull(fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ public void testHandleLegacySupportedUpdatesInvalidIndexInfoChange() {
+ FieldInfo fi1 = new FieldInfoTestBuilder().setIndexOptions(IndexOptions.DOCS).get();
+ FieldInfo fi2 =
+ new FieldInfoTestBuilder().setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS).get();
+ assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi2));
+ assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ public void testHandleLegacySupportedUpdatesValidPointDimensionCount() {
+ FieldInfo fi1 = new FieldInfoTestBuilder().get();
+ FieldInfo fi2 = new FieldInfoTestBuilder().setPointDimensionCount(2).setPointNumBytes(2).get();
+ FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2);
+ assertNotNull(updatedFi);
+ assertEquals(2, updatedFi.getPointDimensionCount());
+ compareAttributes(fi1, fi2, Set.of("getPointDimensionCount", "getPointNumBytes"));
+ compareAttributes(fi1, updatedFi, Set.of("getPointDimensionCount", "getPointNumBytes"));
+ compareAttributes(fi2, updatedFi, Set.of());
+
+ // The reverse return null since fi2 wouldn't change
+ assertNull(fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ public void testHandleLegacySupportedUpdatesInvalidPointDimensionCount() {
+ FieldInfo fi1 = new FieldInfoTestBuilder().setPointDimensionCount(3).setPointNumBytes(2).get();
+ FieldInfo fi2 = new FieldInfoTestBuilder().setPointDimensionCount(2).setPointNumBytes(2).get();
+ FieldInfo fi3 = new FieldInfoTestBuilder().setPointDimensionCount(2).setPointNumBytes(3).get();
+ assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi2));
+ assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi1));
+
+ assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi3));
+ assertThrows(IllegalArgumentException.class, () -> fi3.handleLegacySupportedUpdates(fi1));
+
+ assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi3));
+ assertThrows(IllegalArgumentException.class, () -> fi3.handleLegacySupportedUpdates(fi2));
+ }
+
+ public void testHandleLegacySupportedUpdatesValidPointIndexDimensionCount() {
+ FieldInfo fi1 = new FieldInfoTestBuilder().get();
+ FieldInfo fi2 =
+ new FieldInfoTestBuilder()
+ .setPointIndexDimensionCount(2)
+ .setPointDimensionCount(2)
+ .setPointNumBytes(2)
+ .get();
+ FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2);
+ assertNotNull(updatedFi);
+ assertEquals(2, updatedFi.getPointDimensionCount());
+ compareAttributes(
+ fi1,
+ fi2,
+ Set.of("getPointDimensionCount", "getPointNumBytes", "getPointIndexDimensionCount"));
+ compareAttributes(
+ fi1,
+ updatedFi,
+ Set.of("getPointDimensionCount", "getPointNumBytes", "getPointIndexDimensionCount"));
+ compareAttributes(fi2, updatedFi, Set.of());
+
+ // The reverse return null since fi2 wouldn't change
+ assertNull(fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ public void testHandleLegacySupportedUpdatesInvalidPointIndexDimensionCount() {
+ FieldInfo fi1 =
+ new FieldInfoTestBuilder()
+ .setPointDimensionCount(2)
+ .setPointIndexDimensionCount(2)
+ .setPointNumBytes(2)
+ .get();
+ FieldInfo fi2 =
+ new FieldInfoTestBuilder()
+ .setPointDimensionCount(2)
+ .setPointIndexDimensionCount(3)
+ .setPointNumBytes(2)
+ .get();
+ assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi2));
+ assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ public void testHandleLegacySupportedUpdatesValidStoreTermVectors() {
+ FieldInfo fi1 = new FieldInfoTestBuilder().setStoreTermVector(false).get();
+ FieldInfo fi2 = new FieldInfoTestBuilder().setStoreTermVector(true).get();
+ FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2);
+ assertNotNull(updatedFi);
+ assertTrue(updatedFi.hasVectors());
+ compareAttributes(fi1, fi2, Set.of("hasVectors"));
+ compareAttributes(fi1, updatedFi, Set.of("hasVectors"));
+ compareAttributes(fi2, updatedFi, Set.of());
+
+ // The reverse return null since fi2 wouldn't change
+ assertNull(fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ public void testHandleLegacySupportedUpdatesValidStorePayloads() {
+ FieldInfo fi1 =
+ new FieldInfoTestBuilder()
+ .setStorePayloads(false)
+ .setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS)
+ .get();
+ FieldInfo fi2 =
+ new FieldInfoTestBuilder()
+ .setStorePayloads(true)
+ .setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS)
+ .get();
+ FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2);
+ assertNotNull(updatedFi);
+ assertTrue(updatedFi.hasPayloads());
+ compareAttributes(fi1, fi2, Set.of("hasPayloads"));
+ compareAttributes(fi1, updatedFi, Set.of("hasPayloads"));
+ compareAttributes(fi2, updatedFi, Set.of());
+
+ // The reverse return null since fi2 wouldn't change
+ assertNull(fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ public void testHandleLegacySupportedUpdatesValidOmitNorms() {
+ FieldInfo fi1 =
+ new FieldInfoTestBuilder().setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS).get();
+ FieldInfo fi2 =
+ new FieldInfoTestBuilder()
+ .setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS)
+ .setOmitNorms(true)
+ .get();
+ FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2);
+ assertNotNull(updatedFi);
+ assertFalse(updatedFi.hasNorms()); // Once norms are omitted, they are always omitted
+ compareAttributes(fi1, fi2, Set.of("hasNorms"));
+ compareAttributes(fi1, updatedFi, Set.of("hasNorms"));
+ compareAttributes(fi2, updatedFi, Set.of());
+
+ // The reverse return null since fi2 wouldn't change
+ assertNull(fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ public void testHandleLegacySupportedUpdatesValidDocValuesType() {
+
+ FieldInfo fi1 = new FieldInfoTestBuilder().get();
+ FieldInfo fi2 = new FieldInfoTestBuilder().setDocValues(DocValuesType.SORTED).setDvGen(1).get();
+
+ FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2);
+ assertNotNull(updatedFi);
+ assertEquals(DocValuesType.SORTED, updatedFi.getDocValuesType());
+ assertEquals(1, updatedFi.getDocValuesGen());
+ compareAttributes(fi1, fi2, Set.of("getDocValuesType", "getDocValuesGen"));
+ compareAttributes(fi1, updatedFi, Set.of("getDocValuesType", "getDocValuesGen"));
+ compareAttributes(fi2, updatedFi, Set.of());
+
+ // The reverse return null since fi2 wouldn't change
+ assertNull(fi2.handleLegacySupportedUpdates(fi1));
+
+ FieldInfo fi3 = new FieldInfoTestBuilder().setDocValues(DocValuesType.SORTED).setDvGen(2).get();
+ // Changes in DocValues Generation only are ignored
+ assertNull(fi2.handleLegacySupportedUpdates(fi3));
+ assertNull(fi3.handleLegacySupportedUpdates(fi2));
+ }
+
+ public void testHandleLegacySupportedUpdatesInvalidDocValuesTypeChange() {
+ FieldInfo fi1 = new FieldInfoTestBuilder().setDocValues(DocValuesType.SORTED).get();
+ FieldInfo fi2 = new FieldInfoTestBuilder().setDocValues(DocValuesType.SORTED_SET).get();
+ assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi2));
+ assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi1));
+ }
+
+ private static class FieldInfoTestBuilder {
+ String name = "f1";
+ int number = 1;
+ boolean storeTermVector = true;
+ boolean omitNorms = false;
+ boolean storePayloads = false;
+ IndexOptions indexOptions = IndexOptions.DOCS;
+ DocValuesType docValues = DocValuesType.NONE;
+ long dvGen = -1;
+ Map attributes = new HashMap<>();
+ int pointDimensionCount = 0;
+ int pointIndexDimensionCount = 0;
+ int pointNumBytes = 0;
+ int vectorDimension = 10;
+ VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
+ VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+ boolean softDeletesField = false;
+
+ FieldInfo get() {
+ return new FieldInfo(
+ name,
+ number,
+ storeTermVector,
+ omitNorms,
+ storePayloads,
+ indexOptions,
+ docValues,
+ dvGen,
+ attributes,
+ pointDimensionCount,
+ pointIndexDimensionCount,
+ pointNumBytes,
+ vectorDimension,
+ vectorEncoding,
+ vectorSimilarityFunction,
+ softDeletesField);
+ }
+
+ public FieldInfoTestBuilder setStoreTermVector(boolean storeTermVector) {
+ this.storeTermVector = storeTermVector;
+ return this;
+ }
+
+ public FieldInfoTestBuilder setOmitNorms(boolean omitNorms) {
+ this.omitNorms = omitNorms;
+ return this;
+ }
+
+ public FieldInfoTestBuilder setStorePayloads(boolean storePayloads) {
+ this.storePayloads = storePayloads;
+ return this;
+ }
+
+ public FieldInfoTestBuilder setIndexOptions(IndexOptions indexOptions) {
+ this.indexOptions = indexOptions;
+ return this;
+ }
+
+ public FieldInfoTestBuilder setDocValues(DocValuesType docValues) {
+ this.docValues = docValues;
+ return this;
+ }
+
+ public FieldInfoTestBuilder setDvGen(long dvGen) {
+ this.dvGen = dvGen;
+ return this;
+ }
+
+ public FieldInfoTestBuilder setPointDimensionCount(int pointDimensionCount) {
+ this.pointDimensionCount = pointDimensionCount;
+ return this;
+ }
+
+ public FieldInfoTestBuilder setPointIndexDimensionCount(int pointIndexDimensionCount) {
+ this.pointIndexDimensionCount = pointIndexDimensionCount;
+ return this;
+ }
+
+ public FieldInfoTestBuilder setPointNumBytes(int pointNumBytes) {
+ this.pointNumBytes = pointNumBytes;
+ return this;
+ }
+ }
+
+ private void compareAttributes(FieldInfo fi1, FieldInfo fi2, Set exclude) {
+ assertNotNull(fi1);
+ assertNotNull(fi2);
+ Arrays.stream(FieldInfo.class.getMethods())
+ .filter(
+ m ->
+ (m.getName().startsWith("get") || m.getName().startsWith("has"))
+ && !m.getName().equals("hashCode")
+ && !exclude.contains(m.getName())
+ && m.getParameterCount() == 0)
+ .forEach(
+ m -> {
+ try {
+ assertEquals(
+ "Unexpected difference in FieldInfo for method: " + m.getName(),
+ m.invoke(fi1),
+ m.invoke(fi2));
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ });
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java
index b5a9ac5034b7..a4272cf56334 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java
@@ -361,4 +361,43 @@ public void testRelaxConsistencyCheckForOldIndices() throws IOException {
}
}
}
+
+ public void testFieldInfosMergeBehaviorOnOldIndices() throws IOException {
+ try (Directory dir = newDirectory()) {
+ IndexWriterConfig config =
+ new IndexWriterConfig()
+ .setIndexCreatedVersionMajor(8)
+ .setMergeScheduler(new SerialMergeScheduler())
+ .setOpenMode(IndexWriterConfig.OpenMode.CREATE);
+ FieldType ft1 = new FieldType();
+ ft1.setIndexOptions(IndexOptions.NONE);
+ ft1.setStored(true);
+ FieldType ft2 = new FieldType();
+ ft2.setIndexOptions(IndexOptions.DOCS);
+ ft2.setStored(true);
+
+ try (IndexWriter writer = new IndexWriter(dir, config)) {
+ Document d1 = new Document();
+ // Document 1 has my_field with IndexOptions.NONE
+ d1.add(new Field("my_field", "first", ft1));
+ writer.addDocument(d1);
+ for (int i = 0; i < 100; i++) {
+ // Add some more docs to make sure segment 0 is the biggest one
+ Document d = new Document();
+ d.add(new Field("foo", "bar" + i, ft2));
+ writer.addDocument(d);
+ }
+ writer.flush();
+
+ Document d2 = new Document();
+ // Document 2 has my_field with IndexOptions.DOCS
+ d2.add(new Field("my_field", "first", ft2));
+ writer.addDocument(d2);
+ writer.flush();
+
+ writer.commit();
+ writer.forceMerge(1);
+ }
+ }
+ }
}
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestIndexWriter.java b/lucene/core/src/test/org/apache/lucene/index/TestIndexWriter.java
index e9a2fd9141dd..9c7b7c8a5290 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestIndexWriter.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestIndexWriter.java
@@ -3473,7 +3473,12 @@ public int numDeletesToMerge(
Document doc = new Document();
doc.add(new StringField("id", id, Field.Store.YES));
if (mixDeletes && random().nextBoolean()) {
- writer.updateDocuments(new Term("id", id), Arrays.asList(doc, doc));
+ if (random().nextBoolean()) {
+ writer.updateDocuments(new Term("id", id), Arrays.asList(doc, doc));
+ } else {
+ writer.updateDocuments(
+ new TermQuery(new Term("id", id)), Arrays.asList(doc, doc));
+ }
} else {
writer.softUpdateDocuments(
new Term("id", id),
diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
index 08f089430ba5..372384df5572 100644
--- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
+++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java
@@ -54,6 +54,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
+import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.junit.After;
import org.junit.Before;
@@ -62,7 +63,7 @@ public class TestKnnGraph extends LuceneTestCase {
private static final String KNN_GRAPH_FIELD = "vector";
- private static int M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
+ private static int M = HnswGraphBuilder.DEFAULT_MAX_CONN;
private Codec codec;
private Codec float32Codec;
@@ -80,7 +81,7 @@ public void setup() {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
+ return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
@@ -92,7 +93,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
+ return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
@@ -103,7 +104,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
new Lucene95Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
+ return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
}
@@ -115,7 +116,7 @@ private VectorEncoding randomVectorEncoding() {
@After
public void cleanup() {
- M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN;
+ M = HnswGraphBuilder.DEFAULT_MAX_CONN;
}
/** Basic test of creating documents in a graph */
diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
index ce1fa6d7abbe..24de1a463c73 100644
--- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java
@@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException {
IndexSearcher searcher = newSearcher(reader);
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
- expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
+ expectThrows(
+ RuntimeException.class,
+ IllegalArgumentException.class,
+ () -> searcher.search(kvq, 10));
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
}
}
@@ -236,7 +239,7 @@ public void testDifferentReader() throws IOException {
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
- Query dasq = query.rewrite(reader);
+ Query dasq = query.rewrite(newSearcher(reader));
IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader());
expectThrows(
IllegalStateException.class,
@@ -256,7 +259,7 @@ public void testAdvanceShallow() throws IOException {
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
- Query dasq = query.rewrite(reader);
+ Query dasq = query.rewrite(searcher);
Scorer scorer =
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
// before advancing the iterator
@@ -283,7 +286,7 @@ public void testScoreEuclidean() throws IOException {
IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
- Query rewritten = query.rewrite(reader);
+ Query rewritten = query.rewrite(searcher);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
@@ -322,7 +325,7 @@ public void testScoreCosine() throws IOException {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3);
- Query rewritten = query.rewrite(reader);
+ Query rewritten = query.rewrite(searcher);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
@@ -529,6 +532,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(9, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
+ RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@@ -543,6 +547,7 @@ public void testRandomWithFilter() throws IOException {
assertEquals(5, results.totalHits.value);
assertEquals(results.totalHits.value, results.scoreDocs.length);
expectThrows(
+ RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@@ -570,6 +575,7 @@ public void testRandomWithFilter() throws IOException {
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
expectThrows(
+ RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
@@ -742,6 +748,7 @@ public void testBitSetQuery() throws IOException {
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
expectThrows(
+ RuntimeException.class,
UnsupportedOperationException.class,
() ->
searcher.search(
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java
index 74badd25e777..87b8068d2f1f 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java
@@ -83,7 +83,7 @@ public void testSingleFilterClause() throws IOException {
query1.add(new TermQuery(new Term("field", "a")), Occur.FILTER);
// Single clauses rewrite to a term query
- final Query rewritten1 = query1.build().rewrite(reader);
+ final Query rewritten1 = query1.build().rewrite(searcher);
assertTrue(rewritten1 instanceof BoostQuery);
assertEquals(0f, ((BoostQuery) rewritten1).getBoost(), 0f);
@@ -946,7 +946,7 @@ public String toString(String field) {
}
@Override
- public Query rewrite(IndexReader indexReader) {
+ public Query rewrite(IndexSearcher indexSearcher) {
numRewrites++;
return this;
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestFieldExistsQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestFieldExistsQuery.java
index 307d6800aa57..209ad510889b 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestFieldExistsQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestFieldExistsQuery.java
@@ -63,7 +63,8 @@ public void testDocValuesRewriteWithTermsPresent() throws IOException {
final IndexReader reader = iw.getReader();
iw.close();
- assertTrue((new FieldExistsQuery("f")).rewrite(reader) instanceof MatchAllDocsQuery);
+ assertTrue(
+ (new FieldExistsQuery("f")).rewrite(newSearcher(reader)) instanceof MatchAllDocsQuery);
reader.close();
dir.close();
}
@@ -82,7 +83,8 @@ public void testDocValuesRewriteWithPointValuesPresent() throws IOException {
final IndexReader reader = iw.getReader();
iw.close();
- assertTrue(new FieldExistsQuery("dim").rewrite(reader) instanceof MatchAllDocsQuery);
+ assertTrue(
+ new FieldExistsQuery("dim").rewrite(newSearcher(reader)) instanceof MatchAllDocsQuery);
reader.close();
dir.close();
}
@@ -106,9 +108,10 @@ public void testDocValuesNoRewrite() throws IOException {
iw.commit();
final IndexReader reader = iw.getReader();
iw.close();
+ final IndexSearcher searcher = newSearcher(reader);
- assertFalse((new FieldExistsQuery("dim")).rewrite(reader) instanceof MatchAllDocsQuery);
- assertFalse((new FieldExistsQuery("f")).rewrite(reader) instanceof MatchAllDocsQuery);
+ assertFalse((new FieldExistsQuery("dim")).rewrite(searcher) instanceof MatchAllDocsQuery);
+ assertFalse((new FieldExistsQuery("f")).rewrite(searcher) instanceof MatchAllDocsQuery);
reader.close();
dir.close();
}
@@ -127,10 +130,11 @@ public void testDocValuesNoRewriteWithDocValues() throws IOException {
iw.commit();
final IndexReader reader = iw.getReader();
iw.close();
+ final IndexSearcher searcher = newSearcher(reader);
- assertFalse((new FieldExistsQuery("dv1")).rewrite(reader) instanceof MatchAllDocsQuery);
- assertFalse((new FieldExistsQuery("dv2")).rewrite(reader) instanceof MatchAllDocsQuery);
- assertFalse((new FieldExistsQuery("dv3")).rewrite(reader) instanceof MatchAllDocsQuery);
+ assertFalse((new FieldExistsQuery("dv1")).rewrite(searcher) instanceof MatchAllDocsQuery);
+ assertFalse((new FieldExistsQuery("dv2")).rewrite(searcher) instanceof MatchAllDocsQuery);
+ assertFalse((new FieldExistsQuery("dv3")).rewrite(searcher) instanceof MatchAllDocsQuery);
reader.close();
dir.close();
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java b/lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java
index c3bcda9d47a6..7e8693bd8b31 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java
@@ -453,20 +453,15 @@ private void runSliceExecutorTest(ThreadPoolExecutor service, boolean useRandomS
}
}
- private class RandomBlockingSliceExecutor extends SliceExecutor {
+ private static class RandomBlockingSliceExecutor extends SliceExecutor {
- public RandomBlockingSliceExecutor(Executor executor) {
+ RandomBlockingSliceExecutor(Executor executor) {
super(executor);
}
@Override
- public void invokeAll(Collection extends Runnable> tasks) {
-
- for (Runnable task : tasks) {
- boolean shouldExecuteOnCallerThread = random().nextBoolean();
-
- processTask(task, shouldExecuteOnCallerThread);
- }
+ boolean shouldExecuteOnCallerThread(int index, int numTasks) {
+ return random().nextBoolean();
}
}
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestIndexSortSortedNumericDocValuesRangeQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
index 8df8dcb1917e..cc5d8800e8a2 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestIndexSortSortedNumericDocValuesRangeQuery.java
@@ -339,7 +339,7 @@ public void testRewriteExhaustiveRange() throws IOException {
IndexReader reader = writer.getReader();
Query query = createQuery("field", Long.MIN_VALUE, Long.MAX_VALUE);
- Query rewrittenQuery = query.rewrite(reader);
+ Query rewrittenQuery = query.rewrite(newSearcher(reader));
assertEquals(new FieldExistsQuery("field"), rewrittenQuery);
writer.close();
@@ -357,7 +357,7 @@ public void testRewriteFallbackQuery() throws IOException {
Query fallbackQuery = new BooleanQuery.Builder().build();
Query query = new IndexSortSortedNumericDocValuesRangeQuery("field", 1, 42, fallbackQuery);
- Query rewrittenQuery = query.rewrite(reader);
+ Query rewrittenQuery = query.rewrite(newSearcher(reader));
assertNotEquals(query, rewrittenQuery);
assertThat(rewrittenQuery, instanceOf(IndexSortSortedNumericDocValuesRangeQuery.class));
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
index 684e260f12b7..8467ac506ae4 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java
@@ -95,7 +95,7 @@ public void testScoreNegativeDotProduct() throws IOException {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {1, 0}, 2);
- Query rewritten = query.rewrite(reader);
+ Query rewritten = query.rewrite(searcher);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
@@ -126,7 +126,7 @@ public void testScoreDotProduct() throws IOException {
IndexSearcher searcher = new IndexSearcher(reader);
AbstractKnnVectorQuery query =
getKnnVectorQuery("field", VectorUtil.l2normalize(new float[] {2, 3}), 3);
- Query rewritten = query.rewrite(reader);
+ Query rewritten = query.rewrite(searcher);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMatchNoDocsQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestMatchNoDocsQuery.java
index e1facc81e61a..43db42cc1d34 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestMatchNoDocsQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestMatchNoDocsQuery.java
@@ -44,7 +44,7 @@ public void testSimple() throws Exception {
assertEquals(query.toString(), "MatchNoDocsQuery(\"\")");
query = new MatchNoDocsQuery("field 'title' not found");
assertEquals(query.toString(), "MatchNoDocsQuery(\"field 'title' not found\")");
- Query rewrite = query.rewrite(null);
+ Query rewrite = query.rewrite((IndexSearcher) null);
assertTrue(rewrite instanceof MatchNoDocsQuery);
assertEquals(rewrite.toString(), "MatchNoDocsQuery(\"field 'title' not found\")");
}
@@ -87,7 +87,7 @@ public void testQuery() throws Exception {
assertEquals(query.toString(), "key:one MatchNoDocsQuery(\"field not found\")");
assertEquals(searcher.count(query), 1);
hits = searcher.search(query, 1000).scoreDocs;
- Query rewrite = query.rewrite(ir);
+ Query rewrite = query.rewrite(searcher);
assertEquals(1, hits.length);
assertEquals(rewrite.toString(), "key:one");
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMatchesIterator.java b/lucene/core/src/test/org/apache/lucene/search/TestMatchesIterator.java
index b62d1ec69f15..33acb57c8a5d 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestMatchesIterator.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestMatchesIterator.java
@@ -499,7 +499,7 @@ public void testMinimalSeekingWithWildcards() throws IOException {
SeekCountingLeafReader reader = new SeekCountingLeafReader(getOnlyLeafReader(this.reader));
this.searcher = new IndexSearcher(reader);
Query query = new PrefixQuery(new Term(FIELD_WITH_OFFSETS, "w"));
- Weight w = searcher.createWeight(query.rewrite(reader), ScoreMode.COMPLETE, 1);
+ Weight w = searcher.createWeight(query.rewrite(searcher), ScoreMode.COMPLETE, 1);
// docs 0-3 match several different terms here, but we only seek to the first term and
// then short-cut return; other terms are ignored until we try and iterate over matches
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestNGramPhraseQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestNGramPhraseQuery.java
index f42a576ce54f..905b43b93ed1 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestNGramPhraseQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestNGramPhraseQuery.java
@@ -29,6 +29,7 @@ public class TestNGramPhraseQuery extends LuceneTestCase {
private static IndexReader reader;
private static Directory directory;
+ private static IndexSearcher searcher;
@BeforeClass
public static void beforeClass() throws Exception {
@@ -36,6 +37,7 @@ public static void beforeClass() throws Exception {
RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
writer.close();
reader = DirectoryReader.open(directory);
+ searcher = new IndexSearcher(reader);
}
@AfterClass
@@ -50,8 +52,8 @@ public void testRewrite() throws Exception {
// bi-gram test ABC => AB/BC => AB/BC
NGramPhraseQuery pq1 = new NGramPhraseQuery(2, new PhraseQuery("f", "AB", "BC"));
- Query q = pq1.rewrite(reader);
- assertSame(q.rewrite(reader), q);
+ Query q = pq1.rewrite(searcher);
+ assertSame(q.rewrite(searcher), q);
PhraseQuery rewritten1 = (PhraseQuery) q;
assertArrayEquals(new Term[] {new Term("f", "AB"), new Term("f", "BC")}, rewritten1.getTerms());
assertArrayEquals(new int[] {0, 1}, rewritten1.getPositions());
@@ -59,7 +61,7 @@ public void testRewrite() throws Exception {
// bi-gram test ABCD => AB/BC/CD => AB//CD
NGramPhraseQuery pq2 = new NGramPhraseQuery(2, new PhraseQuery("f", "AB", "BC", "CD"));
- q = pq2.rewrite(reader);
+ q = pq2.rewrite(searcher);
assertTrue(q instanceof PhraseQuery);
assertNotSame(pq2, q);
PhraseQuery rewritten2 = (PhraseQuery) q;
@@ -70,7 +72,7 @@ public void testRewrite() throws Exception {
NGramPhraseQuery pq3 =
new NGramPhraseQuery(3, new PhraseQuery("f", "ABC", "BCD", "CDE", "DEF", "EFG", "FGH"));
- q = pq3.rewrite(reader);
+ q = pq3.rewrite(searcher);
assertTrue(q instanceof PhraseQuery);
assertNotSame(pq3, q);
PhraseQuery rewritten3 = (PhraseQuery) q;
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java b/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java
index aa2f3c892122..03b585710868 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java
@@ -139,10 +139,10 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- Query in2 = in.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ Query in2 = in.rewrite(indexSearcher);
if (in2 == in) {
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
} else {
return new AssertNeedsScores(in2, value);
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPhraseQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPhraseQuery.java
index da93b3a19510..f721f5e4fe21 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestPhraseQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestPhraseQuery.java
@@ -567,7 +567,7 @@ public void testEmptyPhraseQuery() throws Throwable {
/* test that a single term is rewritten to a term query */
public void testRewrite() throws IOException {
PhraseQuery pq = new PhraseQuery("foo", "bar");
- Query rewritten = pq.rewrite(searcher.getIndexReader());
+ Query rewritten = pq.rewrite(searcher);
assertTrue(rewritten instanceof TermQuery);
}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestQueryRewriteBackwardsCompatibility.java b/lucene/core/src/test/org/apache/lucene/search/TestQueryRewriteBackwardsCompatibility.java
new file mode 100644
index 000000000000..a9534b7f191e
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/TestQueryRewriteBackwardsCompatibility.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import java.io.IOException;
+import java.util.Objects;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.apache.lucene.tests.util.LuceneTestCase;
+
+public class TestQueryRewriteBackwardsCompatibility extends LuceneTestCase {
+
+ public void testQueryRewriteNoOverrides() throws IOException {
+ Directory directory = newDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(random(), directory, newIndexWriterConfig());
+ IndexReader reader = w.getReader();
+ w.close();
+ IndexSearcher searcher = newSearcher(reader);
+ Query queryNoOverrides = new TestQueryNoOverrides();
+ assertSame(queryNoOverrides, searcher.rewrite(queryNoOverrides));
+ assertSame(queryNoOverrides, queryNoOverrides.rewrite(searcher));
+ assertSame(queryNoOverrides, queryNoOverrides.rewrite(reader));
+ reader.close();
+ directory.close();
+ }
+
+ public void testSingleQueryRewrite() throws IOException {
+ Directory directory = newDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(random(), directory, newIndexWriterConfig());
+ IndexReader reader = w.getReader();
+ w.close();
+ IndexSearcher searcher = newSearcher(reader);
+
+ RewriteCountingQuery oldQuery = new OldQuery(null);
+ RewriteCountingQuery newQuery = new NewQuery(null);
+
+ oldQuery.rewrite(searcher);
+ oldQuery.rewrite(reader);
+
+ newQuery.rewrite(searcher);
+ newQuery.rewrite(reader);
+
+ assertEquals(2, oldQuery.rewriteCount);
+ assertEquals(2, newQuery.rewriteCount);
+ reader.close();
+ directory.close();
+ }
+
+ public void testNestedQueryRewrite() throws IOException {
+ Directory dir = newDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
+ IndexReader reader = w.getReader();
+ w.close();
+ IndexSearcher searcher = newSearcher(reader);
+
+ RewriteCountingQuery query = random().nextBoolean() ? new NewQuery(null) : new OldQuery(null);
+
+ for (int i = 0; i < 5 + random().nextInt(5); i++) {
+ query = random().nextBoolean() ? new NewQuery(query) : new OldQuery(query);
+ }
+
+ query.rewrite(searcher);
+ query.rewrite(reader);
+
+ RewriteCountingQuery innerQuery = query;
+ while (innerQuery != null) {
+ assertEquals(2, innerQuery.rewriteCount);
+ innerQuery = innerQuery.getInnerQuery();
+ }
+ reader.close();
+ dir.close();
+ }
+
+ public void testRewriteQueryInheritance() throws IOException {
+ Directory directory = newDirectory();
+ RandomIndexWriter w = new RandomIndexWriter(random(), directory, newIndexWriterConfig());
+ IndexReader reader = w.getReader();
+ w.close();
+ IndexSearcher searcher = newSearcher(reader);
+ NewRewritableCallingSuper oneRewrite = new NewRewritableCallingSuper();
+ NewRewritableCallingSuper twoRewrites = new OldNewRewritableCallingSuper();
+ NewRewritableCallingSuper threeRewrites = new OldOldNewRewritableCallingSuper();
+
+ searcher.rewrite(oneRewrite);
+ searcher.rewrite(twoRewrites);
+ searcher.rewrite(threeRewrites);
+ assertEquals(1, oneRewrite.rewriteCount);
+ assertEquals(2, twoRewrites.rewriteCount);
+ assertEquals(3, threeRewrites.rewriteCount);
+
+ reader.close();
+ directory.close();
+ }
+
+ private static class NewRewritableCallingSuper extends RewriteCountingQuery {
+
+ @Override
+ public Query rewrite(IndexSearcher searcher) throws IOException {
+ rewriteCount++;
+ return super.rewrite(searcher);
+ }
+
+ @Override
+ public String toString(String field) {
+ return "NewRewritableCallingSuper";
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {}
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj instanceof NewRewritableCallingSuper;
+ }
+
+ @Override
+ public int hashCode() {
+ return 1;
+ }
+
+ @Override
+ RewriteCountingQuery getInnerQuery() {
+ return null;
+ }
+ }
+
+ private static class OldNewRewritableCallingSuper extends NewRewritableCallingSuper {
+ @Override
+ public Query rewrite(IndexReader reader) throws IOException {
+ rewriteCount++;
+ return super.rewrite(reader);
+ }
+ }
+
+ private static class OldOldNewRewritableCallingSuper extends OldNewRewritableCallingSuper {
+ @Override
+ public Query rewrite(IndexReader reader) throws IOException {
+ rewriteCount++;
+ return super.rewrite(reader);
+ }
+ }
+
+ private abstract static class RewriteCountingQuery extends Query {
+ int rewriteCount = 0;
+
+ abstract RewriteCountingQuery getInnerQuery();
+ }
+
+ private static class OldQuery extends RewriteCountingQuery {
+ private final RewriteCountingQuery optionalSubQuery;
+
+ private OldQuery(RewriteCountingQuery optionalSubQuery) {
+ this.optionalSubQuery = optionalSubQuery;
+ }
+
+ @Override
+ public Query rewrite(IndexReader reader) throws IOException {
+ if (this.optionalSubQuery != null) {
+ this.optionalSubQuery.rewrite(reader);
+ }
+ rewriteCount++;
+ return this;
+ }
+
+ @Override
+ public String toString(String field) {
+ return "OldQuery";
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {}
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj instanceof OldQuery
+ && Objects.equals(((OldQuery) obj).optionalSubQuery, optionalSubQuery);
+ }
+
+ @Override
+ public int hashCode() {
+ return 42 ^ Objects.hash(optionalSubQuery);
+ }
+
+ @Override
+ RewriteCountingQuery getInnerQuery() {
+ return optionalSubQuery;
+ }
+ }
+
+ private static class NewQuery extends RewriteCountingQuery {
+ private final RewriteCountingQuery optionalSubQuery;
+
+ private NewQuery(RewriteCountingQuery optionalSubQuery) {
+ this.optionalSubQuery = optionalSubQuery;
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher searcher) throws IOException {
+ if (this.optionalSubQuery != null) {
+ this.optionalSubQuery.rewrite(searcher);
+ }
+ rewriteCount++;
+ return this;
+ }
+
+ @Override
+ public String toString(String field) {
+ return "NewQuery";
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {}
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj instanceof NewQuery
+ && Objects.equals(((NewQuery) obj).optionalSubQuery, optionalSubQuery);
+ }
+
+ @Override
+ public int hashCode() {
+ return 73 ^ Objects.hash(optionalSubQuery);
+ }
+
+ @Override
+ RewriteCountingQuery getInnerQuery() {
+ return optionalSubQuery;
+ }
+ }
+
+ private static class TestQueryNoOverrides extends Query {
+
+ private final int randomHash = random().nextInt();
+
+ @Override
+ public String toString(String field) {
+ return "TestQueryNoOverrides";
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {}
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj instanceof TestQueryNoOverrides
+ && randomHash == ((TestQueryNoOverrides) obj).randomHash;
+ }
+
+ @Override
+ public int hashCode() {
+ return randomHash;
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java
index 123307e75f8e..ae026a87367f 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java
@@ -73,6 +73,18 @@ public void testEquals() {
.addTerm(new Term("field", "c"), 0.2f)
.addTerm(new Term("field", "d"))
.build());
+
+ QueryUtils.checkUnequal(
+ new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.4f).build(),
+ new SynonymQuery.Builder("field").addTerm(new Term("field", "b"), 0.4f).build());
+
+ QueryUtils.checkUnequal(
+ new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.2f).build(),
+ new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.4f).build());
+
+ QueryUtils.checkUnequal(
+ new SynonymQuery.Builder("field1").addTerm(new Term("field1", "b"), 0.4f).build(),
+ new SynonymQuery.Builder("field2").addTerm(new Term("field2", "b"), 0.4f).build());
}
public void testBogusParams() {
@@ -127,6 +139,12 @@ public void testBogusParams() {
() -> {
new SynonymQuery.Builder("field1").addTerm(new Term("field1", "a"), -0f);
});
+
+ expectThrows(
+ NullPointerException.class,
+ () -> new SynonymQuery.Builder(null).addTerm(new Term("field1", "a"), -0f));
+
+ expectThrows(NullPointerException.class, () -> new SynonymQuery.Builder(null).build());
}
public void testToString() {
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
index 936ae8f5834a..abc6d7138f36 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java
@@ -950,12 +950,12 @@ public int hashCode() {
}
@Override
- public Query rewrite(IndexReader reader) throws IOException {
- Query rewritten = query.rewrite(reader);
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ Query rewritten = query.rewrite(indexSearcher);
if (rewritten != query) {
return new MaxScoreWrapperQuery(rewritten, maxRange, maxScore);
}
- return super.rewrite(reader);
+ return super.rewrite(indexSearcher);
}
@Override
diff --git a/lucene/core/src/test/org/apache/lucene/store/TestBufferedIndexInput.java b/lucene/core/src/test/org/apache/lucene/store/TestBufferedIndexInput.java
index 546a22493975..d19b2058e07b 100644
--- a/lucene/core/src/test/org/apache/lucene/store/TestBufferedIndexInput.java
+++ b/lucene/core/src/test/org/apache/lucene/store/TestBufferedIndexInput.java
@@ -18,6 +18,7 @@
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.util.Random;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
@@ -142,6 +143,72 @@ public void testEOF() throws Exception {
});
}
+ // Test that when reading backwards, we page backwards rather than refilling
+ // on every call
+ public void testBackwardsByteReads() throws IOException {
+ MyBufferedIndexInput input = new MyBufferedIndexInput(1024 * 8);
+ for (int i = 2048; i > 0; i -= random().nextInt(16)) {
+ assertEquals(byten(i), input.readByte(i));
+ }
+ assertEquals(3, input.readCount);
+ }
+
+ public void testBackwardsShortReads() throws IOException {
+ MyBufferedIndexInput input = new MyBufferedIndexInput(1024 * 8);
+ ByteBuffer bb = ByteBuffer.allocate(2);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ for (int i = 2048; i > 0; i -= (random().nextInt(16) + 1)) {
+ bb.clear();
+ bb.put(byten(i));
+ bb.put(byten(i + 1));
+ assertEquals(bb.getShort(0), input.readShort(i));
+ }
+ // readCount can be three or four, depending on whether or not we had to adjust the bufferStart
+ // to include a whole short
+ assertTrue(
+ "Expected 4 or 3, got " + input.readCount, input.readCount == 4 || input.readCount == 3);
+ }
+
+ public void testBackwardsIntReads() throws IOException {
+ MyBufferedIndexInput input = new MyBufferedIndexInput(1024 * 8);
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ for (int i = 2048; i > 0; i -= (random().nextInt(16) + 3)) {
+ bb.clear();
+ bb.put(byten(i));
+ bb.put(byten(i + 1));
+ bb.put(byten(i + 2));
+ bb.put(byten(i + 3));
+ assertEquals(bb.getInt(0), input.readInt(i));
+ }
+ // readCount can be three or four, depending on whether or not we had to adjust the bufferStart
+ // to include a whole int
+ assertTrue(
+ "Expected 4 or 3, got " + input.readCount, input.readCount == 4 || input.readCount == 3);
+ }
+
+ public void testBackwardsLongReads() throws IOException {
+ MyBufferedIndexInput input = new MyBufferedIndexInput(1024 * 8);
+ ByteBuffer bb = ByteBuffer.allocate(8);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ for (int i = 2048; i > 0; i -= (random().nextInt(16) + 7)) {
+ bb.clear();
+ bb.put(byten(i));
+ bb.put(byten(i + 1));
+ bb.put(byten(i + 2));
+ bb.put(byten(i + 3));
+ bb.put(byten(i + 4));
+ bb.put(byten(i + 5));
+ bb.put(byten(i + 6));
+ bb.put(byten(i + 7));
+ assertEquals(bb.getLong(0), input.readLong(i));
+ }
+ // readCount can be three or four, depending on whether or not we had to adjust the bufferStart
+ // to include a whole long
+ assertTrue(
+ "Expected 4 or 3, got " + input.readCount, input.readCount == 4 || input.readCount == 3);
+ }
+
// byten emulates a file - byten(n) returns the n'th byte in that file.
// MyBufferedIndexInput reads this "file".
private static byte byten(long n) {
@@ -150,7 +217,8 @@ private static byte byten(long n) {
private static class MyBufferedIndexInput extends BufferedIndexInput {
private long pos;
- private long len;
+ private final long len;
+ private long readCount = 0;
public MyBufferedIndexInput(long len) {
super("MyBufferedIndexInput(len=" + len + ")", BufferedIndexInput.BUFFER_SIZE);
@@ -164,14 +232,15 @@ public MyBufferedIndexInput() {
}
@Override
- protected void readInternal(ByteBuffer b) throws IOException {
+ protected void readInternal(ByteBuffer b) {
+ readCount++;
while (b.hasRemaining()) {
b.put(byten(pos++));
}
}
@Override
- protected void seekInternal(long pos) throws IOException {
+ protected void seekInternal(long pos) {
this.pos = pos;
}
diff --git a/lucene/core/src/test/org/apache/lucene/store/TestMmapDirectory.java b/lucene/core/src/test/org/apache/lucene/store/TestMmapDirectory.java
index dc638f6a529e..5597161a3a98 100644
--- a/lucene/core/src/test/org/apache/lucene/store/TestMmapDirectory.java
+++ b/lucene/core/src/test/org/apache/lucene/store/TestMmapDirectory.java
@@ -48,9 +48,9 @@ private static boolean isMemorySegmentImpl() {
public void testCorrectImplementation() {
final int runtimeVersion = Runtime.version().feature();
- if (runtimeVersion == 19 || runtimeVersion == 20) {
+ if (runtimeVersion >= 19 && runtimeVersion <= 21) {
assertTrue(
- "on Java 19 and Java 20 we should use MemorySegmentIndexInputProvider to create mmap IndexInputs",
+ "on Java 19, 20, and 21 we should use MemorySegmentIndexInputProvider to create mmap IndexInputs",
isMemorySegmentImpl());
} else {
assertSame(MappedByteBufferIndexInputProvider.class, MMapDirectory.PROVIDER.getClass());
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestUnicodeUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestUnicodeUtil.java
index 8e0348b706c6..2f309e945b74 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestUnicodeUtil.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestUnicodeUtil.java
@@ -166,6 +166,28 @@ public void testUTF8toUTF32() {
}
}
+ public void testUTF8CodePointAt() {
+ int num = atLeast(50000);
+ UnicodeUtil.UTF8CodePoint reuse = null;
+ for (int i = 0; i < num; i++) {
+ final String s = TestUtil.randomUnicodeString(random());
+ final byte[] utf8 = new byte[UnicodeUtil.maxUTF8Length(s.length())];
+ final int utf8Len = UnicodeUtil.UTF16toUTF8(s, 0, s.length(), utf8);
+
+ int[] expected = s.codePoints().toArray();
+ int pos = 0;
+ int expectedUpto = 0;
+ while (pos < utf8Len) {
+ reuse = UnicodeUtil.codePointAt(utf8, pos, reuse);
+ assertEquals(expected[expectedUpto], reuse.codePoint);
+ expectedUpto++;
+ pos += reuse.numBytes;
+ }
+ assertEquals(utf8Len, pos);
+ assertEquals(expected.length, expectedUpto);
+ }
+ }
+
public void testNewString() {
final int[] codePoints = {
Character.toCodePoint(Character.MIN_HIGH_SURROGATE, Character.MAX_LOW_SURROGATE),
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java
new file mode 100644
index 000000000000..be1205890bd8
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util;
+
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+import java.util.function.ToDoubleFunction;
+import java.util.function.ToIntFunction;
+import java.util.stream.IntStream;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.junit.BeforeClass;
+
+public class TestVectorUtilProviders extends LuceneTestCase {
+
+ private static final double DELTA = 1e-3;
+ private static final VectorUtilProvider LUCENE_PROVIDER = new VectorUtilDefaultProvider();
+ private static final VectorUtilProvider JDK_PROVIDER = VectorUtilProvider.lookup(true);
+
+ private static final int[] VECTOR_SIZES = {
+ 1, 4, 6, 8, 13, 16, 25, 32, 64, 100, 128, 207, 256, 300, 512, 702, 1024
+ };
+
+ private final int size;
+
+ public TestVectorUtilProviders(int size) {
+ this.size = size;
+ }
+
+ @ParametersFactory
+ public static Iterable