diff --git a/build.gradle b/build.gradle index ba77f2ef6d5c..0eae57bd10f3 100644 --- a/build.gradle +++ b/build.gradle @@ -36,7 +36,7 @@ apply from: file('gradle/globals.gradle') // Calculate project version: version = { // Release manager: update base version here after release: - String baseVersion = '9.6.0' + String baseVersion = '9.7.0' // On a release explicitly set release version in one go: // -Dversion.release=x.y.z @@ -119,7 +119,7 @@ apply from: file('gradle/ide/eclipse.gradle') // (java, tests) apply from: file('gradle/java/folder-layout.gradle') apply from: file('gradle/java/javac.gradle') -apply from: file('gradle/java/memorysegment-mrjar.gradle') +apply from: file('gradle/java/core-mrjar.gradle') apply from: file('gradle/testing/defaults-tests.gradle') apply from: file('gradle/testing/randomization.gradle') apply from: file('gradle/testing/fail-on-no-tests.gradle') @@ -158,7 +158,7 @@ apply from: file('gradle/generation/javacc.gradle') apply from: file('gradle/generation/forUtil.gradle') apply from: file('gradle/generation/antlr.gradle') apply from: file('gradle/generation/unicode-test-classes.gradle') -apply from: file('gradle/generation/panama-foreign.gradle') +apply from: file('gradle/generation/extract-jdk-apis.gradle') apply from: file('gradle/datasets/external-datasets.gradle') diff --git a/buildSrc/scriptDepVersions.gradle b/buildSrc/scriptDepVersions.gradle index 8751da632492..a6eae860b1c0 100644 --- a/buildSrc/scriptDepVersions.gradle +++ b/buildSrc/scriptDepVersions.gradle @@ -22,7 +22,7 @@ ext { scriptDepVersions = [ "apache-rat": "0.14", - "asm": "9.4", + "asm": "9.5", "commons-codec": "1.13", "ecj": "3.30.0", "flexmark": "0.61.24", diff --git a/dev-tools/doap/lucene.rdf b/dev-tools/doap/lucene.rdf index 0c749ad54d44..c0a09528cef2 100644 --- a/dev-tools/doap/lucene.rdf +++ b/dev-tools/doap/lucene.rdf @@ -67,6 +67,13 @@ + + + lucene-9.6.0 + 2023-05-09 + 9.6.0 + + lucene-9.5.0 @@ -74,7 +81,6 @@ 9.5.0 - lucene-9.4.2 diff --git a/gradle/generation/panama-foreign.gradle b/gradle/generation/extract-jdk-apis.gradle similarity index 73% rename from gradle/generation/panama-foreign.gradle rename to gradle/generation/extract-jdk-apis.gradle index 694c4656e2f6..78e74aa261a3 100644 --- a/gradle/generation/panama-foreign.gradle +++ b/gradle/generation/extract-jdk-apis.gradle @@ -17,10 +17,17 @@ def resources = scriptResources(buildscript) +configure(rootProject) { + ext { + // also change this in extractor tool: ExtractForeignAPI + vectorIncubatorJavaVersions = [ JavaVersion.VERSION_20, JavaVersion.VERSION_21 ] as Set + } +} + configure(project(":lucene:core")) { ext { apijars = file('src/generated/jdk'); - panamaJavaVersions = [ 19, 20 ] + mrjarJavaVersions = [ 19, 20, 21 ] } configurations { @@ -31,9 +38,9 @@ configure(project(":lucene:core")) { apiextractor "org.ow2.asm:asm:${scriptDepVersions['asm']}" } - for (jdkVersion : panamaJavaVersions) { - def task = tasks.create(name: "generatePanamaForeignApiJar${jdkVersion}", type: JavaExec) { - description "Regenerate the API-only JAR file with public Panama Foreign API from JDK ${jdkVersion}" + for (jdkVersion : mrjarJavaVersions) { + def task = tasks.create(name: "generateJdkApiJar${jdkVersion}", type: JavaExec) { + description "Regenerate the API-only JAR file with public Panama Foreign & Vector API from JDK ${jdkVersion}" group "generation" javaLauncher = javaToolchains.launcherFor { @@ -45,18 +52,21 @@ configure(project(":lucene:core")) { javaLauncher.get() return true } catch (Exception e) { - logger.warn('Launcher for Java {} is not available; skipping regeneration of Panama Foreign API JAR.', jdkVersion) + logger.warn('Launcher for Java {} is not available; skipping regeneration of Panama Foreign & Vector API JAR.', jdkVersion) logger.warn('Error: {}', e.cause?.message) logger.warn("Please make sure to point env 'JAVA{}_HOME' to exactly JDK version {} or enable Gradle toolchain auto-download.", jdkVersion, jdkVersion) return false } } - + classpath = configurations.apiextractor - mainClass = file("${resources}/ExtractForeignAPI.java") as String + mainClass = file("${resources}/ExtractJdkApis.java") as String + systemProperties = [ + 'user.timezone': 'UTC' + ] args = [ jdkVersion, - new File(apijars, "panama-foreign-jdk${jdkVersion}.apijar"), + new File(apijars, "jdk${jdkVersion}.apijar"), ] } diff --git a/gradle/generation/extract-jdk-apis/ExtractJdkApis.java b/gradle/generation/extract-jdk-apis/ExtractJdkApis.java new file mode 100644 index 000000000000..7dfb9edfe2a0 --- /dev/null +++ b/gradle/generation/extract-jdk-apis/ExtractJdkApis.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import java.io.IOException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.PathMatcher; +import java.nio.file.Paths; +import java.nio.file.attribute.FileTime; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; + +import org.objectweb.asm.AnnotationVisitor; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.FieldVisitor; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; + +public final class ExtractJdkApis { + + private static final FileTime FIXED_FILEDATE = FileTime.from(Instant.parse("2022-01-01T00:00:00Z")); + + private static final String PATTERN_PANAMA_FOREIGN = "java.base/java/{lang/foreign/*,nio/channels/FileChannel,util/Objects}"; + private static final String PATTERN_VECTOR_INCUBATOR = "jdk.incubator.vector/jdk/incubator/vector/*"; + private static final String PATTERN_VECTOR_VM_INTERNALS = "java.base/jdk/internal/vm/vector/VectorSupport{,$Vector,$VectorMask,$VectorPayload,$VectorShuffle}"; + + static final Map> CLASSFILE_PATTERNS = Map.of( + 19, List.of(PATTERN_PANAMA_FOREIGN), + 20, List.of(PATTERN_PANAMA_FOREIGN, PATTERN_VECTOR_VM_INTERNALS, PATTERN_VECTOR_INCUBATOR), + 21, List.of(PATTERN_PANAMA_FOREIGN) + ); + + public static void main(String... args) throws IOException { + if (args.length != 2) { + throw new IllegalArgumentException("Need two parameters: java version, output file"); + } + Integer jdk = Integer.valueOf(args[0]); + if (jdk.intValue() != Runtime.version().feature()) { + throw new IllegalStateException("Incorrect java version: " + Runtime.version().feature()); + } + if (!CLASSFILE_PATTERNS.containsKey(jdk)) { + throw new IllegalArgumentException("No support to extract stubs from java version: " + jdk); + } + var outputPath = Paths.get(args[1]); + + // create JRT filesystem and build a combined FileMatcher: + var jrtPath = Paths.get(URI.create("jrt:/")).toRealPath(); + var patterns = CLASSFILE_PATTERNS.get(jdk).stream() + .map(pattern -> jrtPath.getFileSystem().getPathMatcher("glob:" + pattern + ".class")) + .toArray(PathMatcher[]::new); + PathMatcher pattern = p -> Arrays.stream(patterns).anyMatch(matcher -> matcher.matches(p)); + + // Collect all files to process: + final List filesToExtract; + try (var stream = Files.walk(jrtPath)) { + filesToExtract = stream.filter(p -> pattern.matches(jrtPath.relativize(p))).collect(Collectors.toList()); + } + + // Process all class files: + try (var out = new ZipOutputStream(Files.newOutputStream(outputPath))) { + process(filesToExtract, out); + } + } + + private static void process(List filesToExtract, ZipOutputStream out) throws IOException { + var classesToInclude = new HashSet(); + var references = new HashMap(); + var processed = new TreeMap(); + System.out.println("Transforming " + filesToExtract.size() + " class files..."); + for (Path p : filesToExtract) { + try (var in = Files.newInputStream(p)) { + var reader = new ClassReader(in); + var cw = new ClassWriter(0); + var cleaner = new Cleaner(cw, classesToInclude, references); + reader.accept(cleaner, ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); + processed.put(reader.getClassName(), cw.toByteArray()); + } + } + // recursively add all superclasses / interfaces of visible classes to classesToInclude: + for (Set a = classesToInclude; !a.isEmpty();) { + a = a.stream().map(references::get).filter(Objects::nonNull).flatMap(Arrays::stream).collect(Collectors.toSet()); + classesToInclude.addAll(a); + } + // remove all non-visible or not referenced classes: + processed.keySet().removeIf(Predicate.not(classesToInclude::contains)); + System.out.println("Writing " + processed.size() + " visible classes..."); + for (var cls : processed.entrySet()) { + String cn = cls.getKey(); + System.out.println("Writing stub for class: " + cn); + out.putNextEntry(new ZipEntry(cn.concat(".class")).setLastModifiedTime(FIXED_FILEDATE)); + out.write(cls.getValue()); + out.closeEntry(); + } + classesToInclude.removeIf(processed.keySet()::contains); + System.out.println("Referenced classes not included: " + classesToInclude); + } + + static boolean isVisible(int access) { + return (access & (Opcodes.ACC_PROTECTED | Opcodes.ACC_PUBLIC)) != 0; + } + + static class Cleaner extends ClassVisitor { + private static final String PREVIEW_ANN = "jdk/internal/javac/PreviewFeature"; + private static final String PREVIEW_ANN_DESCR = Type.getObjectType(PREVIEW_ANN).getDescriptor(); + + private final Set classesToInclude; + private final Map references; + + Cleaner(ClassWriter out, Set classesToInclude, Map references) { + super(Opcodes.ASM9, out); + this.classesToInclude = classesToInclude; + this.references = references; + } + + @Override + public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { + super.visit(Opcodes.V11, access, name, signature, superName, interfaces); + if (isVisible(access)) { + classesToInclude.add(name); + } + references.put(name, Stream.concat(Stream.of(superName), Arrays.stream(interfaces)).toArray(String[]::new)); + } + + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); + } + + @Override + public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) { + if (!isVisible(access)) { + return null; + } + return new FieldVisitor(Opcodes.ASM9, super.visitField(access, name, descriptor, signature, value)) { + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); + } + }; + } + + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { + if (!isVisible(access)) { + return null; + } + return new MethodVisitor(Opcodes.ASM9, super.visitMethod(access, name, descriptor, signature, exceptions)) { + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); + } + }; + } + + @Override + public void visitInnerClass(String name, String outerName, String innerName, int access) { + if (!Objects.equals(outerName, PREVIEW_ANN)) { + super.visitInnerClass(name, outerName, innerName, access); + } + } + + @Override + public void visitPermittedSubclass​(String c) { + } + + } + +} diff --git a/gradle/generation/panama-foreign/ExtractForeignAPI.java b/gradle/generation/panama-foreign/ExtractForeignAPI.java deleted file mode 100644 index 44253ea0122b..000000000000 --- a/gradle/generation/panama-foreign/ExtractForeignAPI.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import java.io.IOException; -import java.net.URI; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.attribute.FileTime; -import java.time.Instant; -import java.util.Objects; -import java.util.stream.Collectors; -import java.util.zip.ZipEntry; -import java.util.zip.ZipOutputStream; - -import org.objectweb.asm.AnnotationVisitor; -import org.objectweb.asm.ClassReader; -import org.objectweb.asm.ClassVisitor; -import org.objectweb.asm.ClassWriter; -import org.objectweb.asm.FieldVisitor; -import org.objectweb.asm.MethodVisitor; -import org.objectweb.asm.Opcodes; -import org.objectweb.asm.Type; - -public final class ExtractForeignAPI { - - private static final FileTime FIXED_FILEDATE = FileTime.from(Instant.parse("2022-01-01T00:00:00Z")); - - public static void main(String... args) throws IOException { - if (args.length != 2) { - throw new IllegalArgumentException("Need two parameters: java version, output file"); - } - if (Integer.parseInt(args[0]) != Runtime.version().feature()) { - throw new IllegalStateException("Incorrect java version: " + Runtime.version().feature()); - } - var outputPath = Paths.get(args[1]); - var javaBaseModule = Paths.get(URI.create("jrt:/")).resolve("java.base").toRealPath(); - var fileMatcher = javaBaseModule.getFileSystem().getPathMatcher("glob:java/{lang/foreign/*,nio/channels/FileChannel,util/Objects}.class"); - try (var out = new ZipOutputStream(Files.newOutputStream(outputPath)); var stream = Files.walk(javaBaseModule)) { - var filesToExtract = stream.map(javaBaseModule::relativize).filter(fileMatcher::matches).sorted().collect(Collectors.toList()); - for (Path relative : filesToExtract) { - System.out.println("Processing class file: " + relative); - try (var in = Files.newInputStream(javaBaseModule.resolve(relative))) { - final var reader = new ClassReader(in); - final var cw = new ClassWriter(0); - reader.accept(new Cleaner(cw), ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES); - out.putNextEntry(new ZipEntry(relative.toString()).setLastModifiedTime(FIXED_FILEDATE)); - out.write(cw.toByteArray()); - out.closeEntry(); - } - } - } - } - - static class Cleaner extends ClassVisitor { - private static final String PREVIEW_ANN = "jdk/internal/javac/PreviewFeature"; - private static final String PREVIEW_ANN_DESCR = Type.getObjectType(PREVIEW_ANN).getDescriptor(); - - private boolean completelyHidden = false; - - Cleaner(ClassWriter out) { - super(Opcodes.ASM9, out); - } - - private boolean isHidden(int access) { - return completelyHidden || (access & (Opcodes.ACC_PROTECTED | Opcodes.ACC_PUBLIC)) == 0; - } - - @Override - public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { - super.visit(Opcodes.V11, access, name, signature, superName, interfaces); - completelyHidden = isHidden(access); - } - - @Override - public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { - return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); - } - - @Override - public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) { - if (isHidden(access)) { - return null; - } - return new FieldVisitor(Opcodes.ASM9, super.visitField(access, name, descriptor, signature, value)) { - @Override - public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { - return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); - } - }; - } - - @Override - public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { - if (isHidden(access)) { - return null; - } - return new MethodVisitor(Opcodes.ASM9, super.visitMethod(access, name, descriptor, signature, exceptions)) { - @Override - public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { - return Objects.equals(descriptor, PREVIEW_ANN_DESCR) ? null : super.visitAnnotation(descriptor, visible); - } - }; - } - - @Override - public void visitInnerClass(String name, String outerName, String innerName, int access) { - if (!Objects.equals(outerName, PREVIEW_ANN)) { - super.visitInnerClass(name, outerName, innerName, access); - } - } - - @Override - public void visitPermittedSubclass​(String c) { - } - - } - -} diff --git a/gradle/java/memorysegment-mrjar.gradle b/gradle/java/core-mrjar.gradle similarity index 82% rename from gradle/java/memorysegment-mrjar.gradle rename to gradle/java/core-mrjar.gradle index 137f8a3c567d..5715e782f000 100644 --- a/gradle/java/memorysegment-mrjar.gradle +++ b/gradle/java/core-mrjar.gradle @@ -15,11 +15,11 @@ * limitations under the License. */ -// Produce an MR-JAR with Java 19+ MemorySegment implementation for MMapDirectory +// Produce an MR-JAR with Java 19+ foreign and vector implementations configure(project(":lucene:core")) { plugins.withType(JavaPlugin) { - for (jdkVersion : panamaJavaVersions) { + for (jdkVersion : mrjarJavaVersions) { sourceSets.create("main${jdkVersion}") { java { srcDirs = ["src/java${jdkVersion}"] @@ -29,7 +29,7 @@ configure(project(":lucene:core")) { dependencies.add("main${jdkVersion}Implementation", sourceSets.main.output) tasks.named("compileMain${jdkVersion}Java").configure { - def apijar = new File(apijars, "panama-foreign-jdk${jdkVersion}.apijar") + def apijar = new File(apijars, "jdk${jdkVersion}.apijar") inputs.file(apijar) @@ -40,12 +40,14 @@ configure(project(":lucene:core")) { "-Xlint:-options", "--patch-module", "java.base=${apijar}", "--add-exports", "java.base/java.lang.foreign=ALL-UNNAMED", + // for compilation we patch the incubator packages into java.base, this has no effect on resulting class files: + "--add-exports", "java.base/jdk.incubator.vector=ALL-UNNAMED", ] } } tasks.named('jar').configure { - for (jdkVersion : panamaJavaVersions) { + for (jdkVersion : mrjarJavaVersions) { into("META-INF/versions/${jdkVersion}") { from sourceSets["main${jdkVersion}"].output } diff --git a/gradle/template.gradle.properties b/gradle/template.gradle.properties index a626d39f3bb5..9ac8c42e9dd4 100644 --- a/gradle/template.gradle.properties +++ b/gradle/template.gradle.properties @@ -102,5 +102,5 @@ tests.jvms=@TEST_JVMS@ org.gradle.java.installations.auto-download=true # Set these to enable automatic JVM location discovery. -org.gradle.java.installations.fromEnv=JAVA17_HOME,JAVA19_HOME,JAVA20_HOME,JAVA21_HOME,RUNTIME_JAVA_HOME +org.gradle.java.installations.fromEnv=JAVA17_HOME,JAVA19_HOME,JAVA20_HOME,JAVA21_HOME,JAVA22_HOME,RUNTIME_JAVA_HOME #org.gradle.java.installations.paths=(custom paths) diff --git a/gradle/testing/defaults-tests.gradle b/gradle/testing/defaults-tests.gradle index 9f50cda8ca79..f7a348f0b66c 100644 --- a/gradle/testing/defaults-tests.gradle +++ b/gradle/testing/defaults-tests.gradle @@ -47,7 +47,7 @@ allprojects { description: "Number of forked test JVMs"], [propName: 'tests.haltonfailure', value: true, description: "Halt processing on test failure."], [propName: 'tests.jvmargs', - value: { -> propertyOrEnvOrDefault("tests.jvmargs", "TEST_JVM_ARGS", "-XX:TieredStopAtLevel=1 -XX:+UseParallelGC -XX:ActiveProcessorCount=1") }, + value: { -> propertyOrEnvOrDefault("tests.jvmargs", "TEST_JVM_ARGS", isCIBuild ? "" : "-XX:TieredStopAtLevel=1 -XX:+UseParallelGC -XX:ActiveProcessorCount=1") }, description: "Arguments passed to each forked JVM."], // Other settings. [propName: 'tests.neverUpToDate', value: true, @@ -119,11 +119,16 @@ allprojects { if (rootProject.runtimeJavaVersion < JavaVersion.VERSION_16) { jvmArgs '--illegal-access=deny' } - + // Lucene needs to optional modules at runtime, which we want to enforce for testing // (if the runner JVM does not support them, it will fail tests): jvmArgs '--add-modules', 'jdk.unsupported,jdk.management' + // Enable the vector incubator module on supported Java versions: + if (rootProject.vectorIncubatorJavaVersions.contains(rootProject.runtimeJavaVersion)) { + jvmArgs '--add-modules', 'jdk.incubator.vector' + } + def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties") def tempDir = layout.projectDirectory.dir(testsTmpDir.toString()) jvmArgumentProviders.add( diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 322e8c2acae8..341101cf2a95 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -3,6 +3,106 @@ Lucene Change Log For more information on past and future Lucene versions, please see: http://s.apache.org/luceneversions +======================== Lucene 9.7.0 ======================= + +API Changes +--------------------- + +* GITHUB#11840, GITHUB#12304: Query rewrite now takes an IndexSearcher instead of + IndexReader to enable concurrent rewriting. Please note: This is implemented in + a backwards compatible way. A query overriding any of both rewrite methods is + supported. To implement this backwards layer in Lucene 9.x the + RuntimePermission "accessDeclaredMembers" is needed in applications using + SecurityManager. (Patrick Zhai, Ben Trent, Uwe Schindler) + +* GITHUB#12321: DaciukMihovAutomatonBuilder has been marked deprecated in preparation of reducing its visibility in + a future release. (Greg Miller) + +* GITHUB#12268: Add BitSet.clear() without parameters for clearing the entire set + (Jonathan Ellis) + +* GITHUB#12346: add new IndexWriter#updateDocuments(Query, Iterable) API + to update documents atomically, with respect to refresh and commit using a query. (Patrick Zhai) + +New Features +--------------------- + +* GITHUB#12257: Create OnHeapHnswGraphSearcher to let OnHeapHnswGraph to be searched in a thread-safety manner. (Patrick Zhai) + +* GITHUB#12302, GITHUB#12311, GITHUB#12363: Add vectorized implementations of VectorUtil.dotProduct(), + squareDistance(), cosine() with Java 20 or 21 jdk.incubator.vector APIs. Applications started + with command line parameter "java --add-modules jdk.incubator.vector" on exactly Java 20 or 21 + will automatically use the new vectorized implementations if running on a supported platform + (x86 AVX2 or later, ARM NEON). This is an opt-in feature and requires explicit Java + command line flag! When enabled, Lucene logs a notice using java.util.logging. Please test + thoroughly and report bugs/slowness to Lucene's mailing list. + (Chris Hegarty, Robert Muir, Uwe Schindler) + +* GITHUB#12294: Add support for Java 21 foreign memory API. If Java 19 up to 21 is used, + MMapDirectory will mmap Lucene indexes in chunks of 16 GiB (instead of 1 GiB) and indexes + closed while queries are running can no longer crash the JVM. To disable this feature, + pass the following sysprop on Java command line: + "-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler) + +* GITHUB#12252 Add function queries for computing similarity scores between knn vectors. (Elia Porciani, Alessandro Benedetti) + +Improvements +--------------------- + +* GITHUB#12245: Add support for Score Mode to `ToParentBlockJoinQuery` explain. (Marcus Eagan via Mikhail Khludnev) + +* GITHUB#12305: Minor cleanup and improvements to DaciukMihovAutomatonBuilder. (Greg Miller) + +* GITHUB#12325: Parallelize AbstractKnnVectorQuery rewrite across slices rather than segments. (Luca Cavanna) + +* GITHUB#12333: NumericLeafComparator#competitiveIterator makes better use of a "search after" value when paginating. + (Chaitanya Gohel) + +* GITHUB#12290: Make memory fence in ByteBufferGuard explicit using `VarHandle.fullFence()` + +* GITHUB#12320: Add "direct to binary" option for DaciukMihovAutomatonBuilder and use it in TermInSetQuery#visit. + (Greg Miller) + +* GITHUB#12281: Require indexed KNN float vectors and query vectors to be finite. (Jonathan Ellis, Uwe Schindler) + +Optimizations +--------------------- + +* GITHUB#12324: Speed up sparse block advanceExact with tiny step in IndexedDISI. (Guo Feng) + +* GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun) + +* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh) + +* GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai) + +* GITHUB#12235: Optimize HNSW diversity calculation. (Patrick Zhai) + +* GITHUB#12328: Optimize ConjunctionDISI.createConjunction (Armin Braun) + +* GITHUB#12357: Better paging when doing backwards random reads. This speeds up + queries relying on terms in NIOFSDirectory and SimpleFSDirectory. (Alan Woodward) + +* GITHUB#12339: Optimize part of duplicate calculation numDeletesToMerge in merge phase (fudongying) + +* GITHUB#12334: Honor after value for skipping documents even if queue is not full for PagingFieldCollector (Chaitanya Gohel) + +Bug Fixes +--------------------- + +* GITHUB#12291: Skip blank lines from stopwords list. (Jerry Chin) + +* GITHUB#11350: Handle possible differences in FieldInfo when merging indices created with Lucene 8.x (Tomás Fernández Löbbe) + +* GITHUB#12352: [Tessellator] Improve the checks that validate the diagonal between two polygon nodes so + the resulting polygons are valid counter clockwise polygons. (Ignacio Vera) + +* LUCENE-10181: Restrict GraphTokenStreamFiniteStrings#articulationPointsRecurse recursion depth. (Chris Fournier) + +Other +--------------------- +(No changes) + ======================== Lucene 9.6.0 ======================= API Changes @@ -32,6 +132,8 @@ New Features crash the JVM. To disable this feature, pass the following sysprop on Java command line: "-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler) +* GITHUB#12169: Introduce a new token filter to expand synonyms based on Word2Vec DL4j models. (Daniele Antuzi, Ilaria Petreti, Alessandro Benedetti) + Improvements --------------------- @@ -45,6 +147,8 @@ Improvements * GITHUB#12175: Remove SortedSetDocValuesSetQuery in favor of TermInSetQuery with DocValuesRewriteMethod. (Greg Miller) +* GITHUB#12166: Remove the now unused class pointInPolygon. (Marcus Eagan via Christine Poerschke and Nick Knize) + * GITHUB#12126: Refactor part of IndexFileDeleter and ReplicaFileDeleter into a public common utility class FileDeleter. (Patrick Zhai) @@ -86,7 +190,9 @@ Bug Fixes * GITHUB#12212: Bug fix for a DrillSideways issue where matching hits could occasionally be missed. (Frederic Thevenet) -* GITHUB#12220: Hunspell: disallow hidden title-case entries from compound middle/end +* GITHUB#12220: Hunspell: disallow hidden title-case entries from compound middle/end (Peter Gromov) + +* GITHUB#12260: Fix SynonymQuery equals implementation to take the targeted field name into account (Luca Cavanna) Build --------------------- @@ -157,7 +263,7 @@ API Changes * GITHUB#11962: VectorValues#cost() now delegates to VectorValues#size(). (Adrien Grand) - + * GITHUB#11984: Improved TimeLimitBulkScorer to check the timeout at exponantial rate. (Costin Leau) diff --git a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java index 8c245e7058c7..988deaf99e59 100644 --- a/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java +++ b/lucene/analysis.tests/src/test/org/apache/lucene/analysis/tests/TestRandomChains.java @@ -89,6 +89,8 @@ import org.apache.lucene.analysis.standard.StandardTokenizer; import org.apache.lucene.analysis.stempel.StempelStemmer; import org.apache.lucene.analysis.synonym.SynonymMap; +import org.apache.lucene.analysis.synonym.word2vec.Word2VecModel; +import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProvider; import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; import org.apache.lucene.tests.analysis.MockTokenFilter; @@ -99,8 +101,10 @@ import org.apache.lucene.tests.util.automaton.AutomatonTestUtil; import org.apache.lucene.util.AttributeFactory; import org.apache.lucene.util.AttributeSource; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.CharsRef; import org.apache.lucene.util.IgnoreRandomChains; +import org.apache.lucene.util.TermAndVector; import org.apache.lucene.util.Version; import org.apache.lucene.util.automaton.Automaton; import org.apache.lucene.util.automaton.CharacterRunAutomaton; @@ -415,6 +419,27 @@ private String randomNonEmptyString(Random random) { } } }); + put( + Word2VecSynonymProvider.class, + random -> { + final int numEntries = atLeast(10); + final int vectorDimension = random.nextInt(99) + 1; + Word2VecModel model = new Word2VecModel(numEntries, vectorDimension); + for (int j = 0; j < numEntries; j++) { + String s = TestUtil.randomSimpleString(random, 10, 20); + float[] vec = new float[vectorDimension]; + for (int i = 0; i < vectorDimension; i++) { + vec[i] = random.nextFloat(); + } + model.addTermAndVector(new TermAndVector(new BytesRef(s), vec)); + } + try { + return new Word2VecSynonymProvider(model); + } catch (IOException e) { + Rethrow.rethrow(e); + return null; // unreachable code + } + }); put( DateFormat.class, random -> { diff --git a/lucene/analysis/common/src/java/module-info.java b/lucene/analysis/common/src/java/module-info.java index 5679f0dde295..15ad5a2b1af0 100644 --- a/lucene/analysis/common/src/java/module-info.java +++ b/lucene/analysis/common/src/java/module-info.java @@ -78,6 +78,7 @@ exports org.apache.lucene.analysis.sr; exports org.apache.lucene.analysis.sv; exports org.apache.lucene.analysis.synonym; + exports org.apache.lucene.analysis.synonym.word2vec; exports org.apache.lucene.analysis.ta; exports org.apache.lucene.analysis.te; exports org.apache.lucene.analysis.th; @@ -256,6 +257,7 @@ org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory, org.apache.lucene.analysis.synonym.SynonymFilterFactory, org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory, + org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory, org.apache.lucene.analysis.core.FlattenGraphFilterFactory, org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory, org.apache.lucene.analysis.te.TeluguStemFilterFactory, diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Dictionary.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Dictionary.java index e94047b67db3..b7a4029a523b 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Dictionary.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Dictionary.java @@ -155,7 +155,7 @@ public class Dictionary { boolean checkCompoundCase, checkCompoundDup, checkCompoundRep; boolean checkCompoundTriple, simplifiedTriple; int compoundMin = 3, compoundMax = Integer.MAX_VALUE; - List compoundRules; // nullable + CompoundRule[] compoundRules; // nullable List checkCompoundPatterns = new ArrayList<>(); // ignored characters (dictionary, affix, inputs) @@ -601,11 +601,11 @@ private String[] splitBySpace(LineNumberReader reader, String line, int minParts return parts; } - private List parseCompoundRules(LineNumberReader reader, int num) + private CompoundRule[] parseCompoundRules(LineNumberReader reader, int num) throws IOException, ParseException { - List compoundRules = new ArrayList<>(); + CompoundRule[] compoundRules = new CompoundRule[num]; for (int i = 0; i < num; i++) { - compoundRules.add(new CompoundRule(singleArgument(reader, reader.readLine()), this)); + compoundRules[i] = new CompoundRule(singleArgument(reader, reader.readLine()), this); } return compoundRules; } @@ -992,7 +992,7 @@ private int mergeDictionaries( // if we haven't seen any custom morphological data, try to parse one if (!hasCustomMorphData) { int morphStart = line.indexOf(MORPH_SEPARATOR); - if (morphStart >= 0 && morphStart < line.length()) { + if (morphStart >= 0) { String data = line.substring(morphStart + 1); hasCustomMorphData = splitMorphData(data).stream().anyMatch(s -> !s.startsWith("ph:")); @@ -1321,14 +1321,22 @@ private List splitMorphData(String morphData) { if (morphData.isBlank()) { return Collections.emptyList(); } - return Arrays.stream(morphData.split("\\s+")) - .filter( - s -> - s.length() > 3 - && Character.isLetter(s.charAt(0)) - && Character.isLetter(s.charAt(1)) - && s.charAt(2) == ':') - .collect(Collectors.toList()); + + List result = null; + int start = 0; + for (int i = 0; i <= morphData.length(); i++) { + if (i == morphData.length() || Character.isWhitespace(morphData.charAt(i))) { + if (i - start > 3 + && Character.isLetter(morphData.charAt(start)) + && Character.isLetter(morphData.charAt(start + 1)) + && morphData.charAt(start + 2) == ':') { + if (result == null) result = new ArrayList<>(); + result.add(morphData.substring(start, i)); + } + start = i + 1; + } + } + return result == null ? List.of() : result; } boolean hasFlag(IntsRef forms, char flag) { diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Hunspell.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Hunspell.java index 1e2a1add13cd..3b58e0f4f980 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Hunspell.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/Hunspell.java @@ -450,7 +450,7 @@ private boolean checkCompoundRules( if (forms != null) { words.add(forms); - if (dictionary.compoundRules.stream().anyMatch(r -> r.mayMatch(words))) { + if (mayHaveCompoundRule(words)) { if (checkLastCompoundPart(wordChars, offset + breakPos, length - breakPos, words)) { return true; } @@ -467,6 +467,15 @@ private boolean checkCompoundRules( return false; } + private boolean mayHaveCompoundRule(List words) { + for (CompoundRule rule : dictionary.compoundRules) { + if (rule.mayMatch(words)) { + return true; + } + } + return false; + } + private boolean checkLastCompoundPart( char[] wordChars, int start, int length, List words) { IntsRef ref = new IntsRef(new int[1], 0, 1); @@ -475,7 +484,12 @@ private boolean checkLastCompoundPart( Stemmer.RootProcessor stopOnMatching = (stem, formID, morphDataId, outerPrefix, innerPrefix, outerSuffix, innerSuffix) -> { ref.ints[0] = formID; - return dictionary.compoundRules.stream().noneMatch(r -> r.fullyMatches(words)); + for (CompoundRule r : dictionary.compoundRules) { + if (r.fullyMatches(words)) { + return false; + } + } + return true; }; boolean found = !stemmer.doStem(wordChars, start, length, COMPOUND_RULE_END, stopOnMatching); words.remove(words.size() - 1); diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/TrigramAutomaton.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/TrigramAutomaton.java index dfe994ccf827..f4404e4bcf02 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/TrigramAutomaton.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/hunspell/TrigramAutomaton.java @@ -79,7 +79,7 @@ private int runAutomatonOnStringChars(String s) { } int ngramScore(CharsRef s2) { - countedSubstrings.clear(0, countedSubstrings.length()); + countedSubstrings.clear(); int score1 = 0, score2 = 0, score3 = 0; // scores for substrings of length 1, 2 and 3 diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java new file mode 100644 index 000000000000..f022dd8eca67 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Dl4jModelReader.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.BufferedInputStream; +import java.io.BufferedReader; +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Locale; +import java.util.zip.ZipEntry; +import java.util.zip.ZipInputStream; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TermAndVector; + +/** + * Dl4jModelReader reads the file generated by the library Deeplearning4j and provide a + * Word2VecModel with normalized vectors + * + *

Dl4j Word2Vec documentation: + * https://deeplearning4j.konduit.ai/v/en-1.0.0-beta7/language-processing/word2vec Example to + * generate a model using dl4j: + * https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/embeddingsfromcorpus/word2vec/Word2VecRawTextExample.java + * + * @lucene.experimental + */ +public class Dl4jModelReader implements Closeable { + + private static final String MODEL_FILE_NAME_PREFIX = "syn0"; + + private final ZipInputStream word2VecModelZipFile; + + public Dl4jModelReader(InputStream stream) { + this.word2VecModelZipFile = new ZipInputStream(new BufferedInputStream(stream)); + } + + public Word2VecModel read() throws IOException { + + ZipEntry entry; + while ((entry = word2VecModelZipFile.getNextEntry()) != null) { + String fileName = entry.getName(); + if (fileName.startsWith(MODEL_FILE_NAME_PREFIX)) { + BufferedReader reader = + new BufferedReader(new InputStreamReader(word2VecModelZipFile, StandardCharsets.UTF_8)); + + String header = reader.readLine(); + String[] headerValues = header.split(" "); + int dictionarySize = Integer.parseInt(headerValues[0]); + int vectorDimension = Integer.parseInt(headerValues[1]); + + Word2VecModel model = new Word2VecModel(dictionarySize, vectorDimension); + String line = reader.readLine(); + boolean isTermB64Encoded = false; + if (line != null) { + String[] tokens = line.split(" "); + isTermB64Encoded = + tokens[0].substring(0, 3).toLowerCase(Locale.ROOT).compareTo("b64") == 0; + model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded)); + } + while ((line = reader.readLine()) != null) { + String[] tokens = line.split(" "); + model.addTermAndVector(extractTermAndVector(tokens, vectorDimension, isTermB64Encoded)); + } + return model; + } + } + throw new IllegalArgumentException( + "Cannot read Dl4j word2vec model - '" + + MODEL_FILE_NAME_PREFIX + + "' file is missing in the zip. '" + + MODEL_FILE_NAME_PREFIX + + "' is a mandatory file containing the mapping between terms and vectors generated by the DL4j library."); + } + + private static TermAndVector extractTermAndVector( + String[] tokens, int vectorDimension, boolean isTermB64Encoded) { + BytesRef term = isTermB64Encoded ? decodeB64Term(tokens[0]) : new BytesRef((tokens[0])); + + float[] vector = new float[tokens.length - 1]; + + if (vectorDimension != vector.length) { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Word2Vec model file corrupted. " + + "Declared vectors of size %d but found vector of size %d for word %s (%s)", + vectorDimension, + vector.length, + tokens[0], + term.utf8ToString())); + } + + for (int i = 1; i < tokens.length; i++) { + vector[i - 1] = Float.parseFloat(tokens[i]); + } + return new TermAndVector(term, vector); + } + + static BytesRef decodeB64Term(String term) { + byte[] buffer = Base64.getDecoder().decode(term.substring(4)); + return new BytesRef(buffer, 0, buffer.length); + } + + @Override + public void close() throws IOException { + word2VecModelZipFile.close(); + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/TermAndBoost.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/TermAndBoost.java new file mode 100644 index 000000000000..03fdeecb0f20 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/TermAndBoost.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.analysis.synonym.word2vec; + +import org.apache.lucene.util.BytesRef; + +/** Wraps a term and boost */ +public class TermAndBoost { + /** the term */ + public final BytesRef term; + /** the boost */ + public final float boost; + + /** Creates a new TermAndBoost */ + public TermAndBoost(BytesRef term, float boost) { + this.term = BytesRef.deepCopyOf(term); + this.boost = boost; + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java new file mode 100644 index 000000000000..6719639b67d9 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefHash; +import org.apache.lucene.util.TermAndVector; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; + +/** + * Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each + * word in dictionary + * + * @lucene.experimental + */ +public class Word2VecModel implements RandomAccessVectorValues { + + private final int dictionarySize; + private final int vectorDimension; + private final TermAndVector[] termsAndVectors; + private final BytesRefHash word2Vec; + private int loadedCount = 0; + + public Word2VecModel(int dictionarySize, int vectorDimension) { + this.dictionarySize = dictionarySize; + this.vectorDimension = vectorDimension; + this.termsAndVectors = new TermAndVector[dictionarySize]; + this.word2Vec = new BytesRefHash(); + } + + private Word2VecModel( + int dictionarySize, + int vectorDimension, + TermAndVector[] termsAndVectors, + BytesRefHash word2Vec) { + this.dictionarySize = dictionarySize; + this.vectorDimension = vectorDimension; + this.termsAndVectors = termsAndVectors; + this.word2Vec = word2Vec; + } + + public void addTermAndVector(TermAndVector modelEntry) { + modelEntry.normalizeVector(); + this.termsAndVectors[loadedCount++] = modelEntry; + this.word2Vec.add(modelEntry.getTerm()); + } + + @Override + public float[] vectorValue(int targetOrd) { + return termsAndVectors[targetOrd].getVector(); + } + + public float[] vectorValue(BytesRef term) { + int termOrd = this.word2Vec.find(term); + if (termOrd < 0) return null; + TermAndVector entry = this.termsAndVectors[termOrd]; + return (entry == null) ? null : entry.getVector(); + } + + public BytesRef termValue(int targetOrd) { + return termsAndVectors[targetOrd].getTerm(); + } + + @Override + public int dimension() { + return vectorDimension; + } + + @Override + public int size() { + return dictionarySize; + } + + @Override + public RandomAccessVectorValues copy() throws IOException { + return new Word2VecModel( + this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec); + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java new file mode 100644 index 000000000000..a8db4c4c764a --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilter.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import org.apache.lucene.analysis.TokenFilter; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.synonym.SynonymGraphFilter; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; +import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute; +import org.apache.lucene.analysis.tokenattributes.TypeAttribute; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; + +/** + * Applies single-token synonyms from a Word2Vec trained network to an incoming {@link TokenStream}. + * + * @lucene.experimental + */ +public final class Word2VecSynonymFilter extends TokenFilter { + + private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); + private final PositionIncrementAttribute posIncrementAtt = + addAttribute(PositionIncrementAttribute.class); + private final PositionLengthAttribute posLenAtt = addAttribute(PositionLengthAttribute.class); + private final TypeAttribute typeAtt = addAttribute(TypeAttribute.class); + + private final Word2VecSynonymProvider synonymProvider; + private final int maxSynonymsPerTerm; + private final float minAcceptedSimilarity; + private final LinkedList synonymBuffer = new LinkedList<>(); + private State lastState; + + /** + * Apply previously built synonymProvider to incoming tokens. + * + * @param input input tokenstream + * @param synonymProvider synonym provider + * @param maxSynonymsPerTerm maximum number of result returned by the synonym search + * @param minAcceptedSimilarity minimal value of cosine similarity between the searched vector and + * the retrieved ones + */ + public Word2VecSynonymFilter( + TokenStream input, + Word2VecSynonymProvider synonymProvider, + int maxSynonymsPerTerm, + float minAcceptedSimilarity) { + super(input); + if (synonymProvider == null) { + throw new IllegalArgumentException("The SynonymProvider must be non-null"); + } + this.synonymProvider = synonymProvider; + this.maxSynonymsPerTerm = maxSynonymsPerTerm; + this.minAcceptedSimilarity = minAcceptedSimilarity; + } + + @Override + public boolean incrementToken() throws IOException { + + if (!synonymBuffer.isEmpty()) { + TermAndBoost synonym = synonymBuffer.pollFirst(); + clearAttributes(); + restoreState(this.lastState); + termAtt.setEmpty(); + termAtt.append(synonym.term.utf8ToString()); + typeAtt.setType(SynonymGraphFilter.TYPE_SYNONYM); + posLenAtt.setPositionLength(1); + posIncrementAtt.setPositionIncrement(0); + return true; + } + + if (input.incrementToken()) { + BytesRefBuilder bytesRefBuilder = new BytesRefBuilder(); + bytesRefBuilder.copyChars(termAtt.buffer(), 0, termAtt.length()); + BytesRef term = bytesRefBuilder.get(); + List synonyms = + this.synonymProvider.getSynonyms(term, maxSynonymsPerTerm, minAcceptedSimilarity); + if (synonyms.size() > 0) { + this.lastState = captureState(); + this.synonymBuffer.addAll(synonyms); + } + return true; + } + return false; + } + + @Override + public void reset() throws IOException { + super.reset(); + synonymBuffer.clear(); + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java new file mode 100644 index 000000000000..32b6288926fc --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymFilterFactory.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import org.apache.lucene.analysis.TokenFilterFactory; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymProviderFactory.Word2VecSupportedFormats; +import org.apache.lucene.util.ResourceLoader; +import org.apache.lucene.util.ResourceLoaderAware; + +/** + * Factory for {@link Word2VecSynonymFilter}. + * + * @lucene.experimental + * @lucene.spi {@value #NAME} + */ +public class Word2VecSynonymFilterFactory extends TokenFilterFactory + implements ResourceLoaderAware { + + /** SPI name */ + public static final String NAME = "Word2VecSynonym"; + + public static final int DEFAULT_MAX_SYNONYMS_PER_TERM = 5; + public static final float DEFAULT_MIN_ACCEPTED_SIMILARITY = 0.8f; + + private final int maxSynonymsPerTerm; + private final float minAcceptedSimilarity; + private final Word2VecSupportedFormats format; + private final String word2vecModelFileName; + + private Word2VecSynonymProvider synonymProvider; + + public Word2VecSynonymFilterFactory(Map args) { + super(args); + this.maxSynonymsPerTerm = getInt(args, "maxSynonymsPerTerm", DEFAULT_MAX_SYNONYMS_PER_TERM); + this.minAcceptedSimilarity = + getFloat(args, "minAcceptedSimilarity", DEFAULT_MIN_ACCEPTED_SIMILARITY); + this.word2vecModelFileName = require(args, "model"); + + String modelFormat = get(args, "format", "dl4j").toUpperCase(Locale.ROOT); + try { + this.format = Word2VecSupportedFormats.valueOf(modelFormat); + } catch (IllegalArgumentException exc) { + throw new IllegalArgumentException("Model format '" + modelFormat + "' not supported", exc); + } + + if (!args.isEmpty()) { + throw new IllegalArgumentException("Unknown parameters: " + args); + } + if (minAcceptedSimilarity <= 0 || minAcceptedSimilarity > 1) { + throw new IllegalArgumentException( + "minAcceptedSimilarity must be in the range (0, 1]. Found: " + minAcceptedSimilarity); + } + if (maxSynonymsPerTerm <= 0) { + throw new IllegalArgumentException( + "maxSynonymsPerTerm must be a positive integer greater than 0. Found: " + + maxSynonymsPerTerm); + } + } + + /** Default ctor for compatibility with SPI */ + public Word2VecSynonymFilterFactory() { + throw defaultCtorException(); + } + + Word2VecSynonymProvider getSynonymProvider() { + return this.synonymProvider; + } + + @Override + public TokenStream create(TokenStream input) { + return synonymProvider == null + ? input + : new Word2VecSynonymFilter( + input, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity); + } + + @Override + public void inform(ResourceLoader loader) throws IOException { + this.synonymProvider = + Word2VecSynonymProviderFactory.getSynonymProvider(loader, word2vecModelFileName, format); + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java new file mode 100644 index 000000000000..3089f1587a4e --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.util.hnsw.HnswGraphBuilder.DEFAULT_MAX_CONN; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.hnsw.HnswGraphBuilder; +import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.NeighborQueue; +import org.apache.lucene.util.hnsw.OnHeapHnswGraph; + +/** + * The Word2VecSynonymProvider generates the list of sysnonyms of a term. + * + * @lucene.experimental + */ +public class Word2VecSynonymProvider { + + private static final VectorSimilarityFunction SIMILARITY_FUNCTION = + VectorSimilarityFunction.DOT_PRODUCT; + private static final VectorEncoding VECTOR_ENCODING = VectorEncoding.FLOAT32; + private final Word2VecModel word2VecModel; + private final OnHeapHnswGraph hnswGraph; + + /** + * Word2VecSynonymProvider constructor + * + * @param model containing the set of TermAndVector entries + */ + public Word2VecSynonymProvider(Word2VecModel model) throws IOException { + word2VecModel = model; + + HnswGraphBuilder builder = + HnswGraphBuilder.create( + word2VecModel, + VECTOR_ENCODING, + SIMILARITY_FUNCTION, + DEFAULT_MAX_CONN, + DEFAULT_BEAM_WIDTH, + HnswGraphBuilder.randSeed); + this.hnswGraph = builder.build(word2VecModel.copy()); + } + + public List getSynonyms( + BytesRef term, int maxSynonymsPerTerm, float minAcceptedSimilarity) throws IOException { + + if (term == null) { + throw new IllegalArgumentException("Term must not be null"); + } + + LinkedList result = new LinkedList<>(); + float[] query = word2VecModel.vectorValue(term); + if (query != null) { + NeighborQueue synonyms = + HnswGraphSearcher.search( + query, + // The query vector is in the model. When looking for the top-k + // it's always the nearest neighbour of itself so, we look for the top-k+1 + maxSynonymsPerTerm + 1, + word2VecModel, + VECTOR_ENCODING, + SIMILARITY_FUNCTION, + hnswGraph, + null, + Integer.MAX_VALUE); + + int size = synonyms.size(); + for (int i = 0; i < size; i++) { + float similarity = synonyms.topScore(); + int id = synonyms.pop(); + + BytesRef synonym = word2VecModel.termValue(id); + // We remove the original query term + if (!synonym.equals(term) && similarity >= minAcceptedSimilarity) { + result.addFirst(new TermAndBoost(synonym, similarity)); + } + } + } + return result; + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java new file mode 100644 index 000000000000..ea849e653cd6 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProviderFactory.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.lucene.util.ResourceLoader; + +/** + * Supply Word2Vec Word2VecSynonymProvider cache avoiding that multiple instances of + * Word2VecSynonymFilterFactory will instantiate multiple instances of the same SynonymProvider. + * Assumes synonymProvider implementations are thread-safe. + */ +public class Word2VecSynonymProviderFactory { + + enum Word2VecSupportedFormats { + DL4J + } + + private static Map word2vecSynonymProviders = + new ConcurrentHashMap<>(); + + public static Word2VecSynonymProvider getSynonymProvider( + ResourceLoader loader, String modelFileName, Word2VecSupportedFormats format) + throws IOException { + Word2VecSynonymProvider synonymProvider = word2vecSynonymProviders.get(modelFileName); + if (synonymProvider == null) { + try (InputStream stream = loader.openResource(modelFileName)) { + try (Dl4jModelReader reader = getModelReader(format, stream)) { + synonymProvider = new Word2VecSynonymProvider(reader.read()); + } + } + word2vecSynonymProviders.put(modelFileName, synonymProvider); + } + return synonymProvider; + } + + private static Dl4jModelReader getModelReader( + Word2VecSupportedFormats format, InputStream stream) { + switch (format) { + case DL4J: + return new Dl4jModelReader(stream); + } + return null; + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java new file mode 100644 index 000000000000..e8d69ab3cf9b --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Analysis components for Synonyms using Word2Vec model. */ +package org.apache.lucene.analysis.synonym.word2vec; diff --git a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory index 19a34b7840a8..1e4e17eaeadf 100644 --- a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory +++ b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.TokenFilterFactory @@ -118,6 +118,7 @@ org.apache.lucene.analysis.sv.SwedishLightStemFilterFactory org.apache.lucene.analysis.sv.SwedishMinimalStemFilterFactory org.apache.lucene.analysis.synonym.SynonymFilterFactory org.apache.lucene.analysis.synonym.SynonymGraphFilterFactory +org.apache.lucene.analysis.synonym.word2vec.Word2VecSynonymFilterFactory org.apache.lucene.analysis.core.FlattenGraphFilterFactory org.apache.lucene.analysis.te.TeluguNormalizationFilterFactory org.apache.lucene.analysis.te.TeluguStemFilterFactory diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestFlattenGraphFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestFlattenGraphFilter.java index 7b35f56016cb..7fa901ac5bf4 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestFlattenGraphFilter.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/core/TestFlattenGraphFilter.java @@ -40,8 +40,8 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.CharsRef; import org.apache.lucene.util.CharsRefBuilder; +import org.apache.lucene.util.automaton.Automata; import org.apache.lucene.util.automaton.Automaton; -import org.apache.lucene.util.automaton.DaciukMihovAutomatonBuilder; import org.apache.lucene.util.automaton.Operations; import org.apache.lucene.util.automaton.Transition; @@ -780,7 +780,7 @@ public void testPathsNotLost() throws IOException { acceptStrings.sort(Comparator.naturalOrder()); acceptStrings = acceptStrings.stream().limit(wordCount).collect(Collectors.toList()); - Automaton nonFlattenedAutomaton = DaciukMihovAutomatonBuilder.build(acceptStrings); + Automaton nonFlattenedAutomaton = Automata.makeStringUnion(acceptStrings); TokenStream ts = AutomatonToTokenStream.toTokenStream(nonFlattenedAutomaton); TokenStream flattenedTokenStream = new FlattenGraphFilter(ts); diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/hunspell/TestHunspell.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/hunspell/TestHunspell.java index e160cd7851ce..7bee301837f8 100644 --- a/lucene/analysis/common/src/test/org/apache/lucene/analysis/hunspell/TestHunspell.java +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/hunspell/TestHunspell.java @@ -31,6 +31,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CancellationException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -74,7 +75,8 @@ public void testCustomCheckCanceledGivesPartialResult() throws Exception { }; Hunspell hunspell = new Hunspell(dictionary, RETURN_PARTIAL_RESULT, checkCanceled); - assertEquals(expected, hunspell.suggest("apac")); + // pass a long timeout so that slower CI servers are more predictable. + assertEquals(expected, hunspell.suggest("apac", TimeUnit.DAYS.toMillis(1))); counter.set(0); var e = diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java new file mode 100644 index 000000000000..213dcdaccd33 --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestDl4jModelReader.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.junit.Test; + +public class TestDl4jModelReader extends LuceneTestCase { + + private static final String MODEL_FILE = "word2vec-model.zip"; + private static final String MODEL_EMPTY_FILE = "word2vec-empty-model.zip"; + private static final String CORRUPTED_VECTOR_DIMENSION_MODEL_FILE = + "word2vec-corrupted-vector-dimension-model.zip"; + + InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_FILE); + Dl4jModelReader unit = new Dl4jModelReader(stream); + + @Test + public void read_zipFileWithMetadata_shouldReturnDictionarySize() throws Exception { + Word2VecModel model = unit.read(); + long expectedDictionarySize = 235; + assertEquals(expectedDictionarySize, model.size()); + } + + @Test + public void read_zipFileWithMetadata_shouldReturnVectorLength() throws Exception { + Word2VecModel model = unit.read(); + int expectedVectorDimension = 100; + assertEquals(expectedVectorDimension, model.dimension()); + } + + @Test + public void read_zipFile_shouldReturnDecodedTerm() throws Exception { + Word2VecModel model = unit.read(); + BytesRef expectedDecodedFirstTerm = new BytesRef("it"); + assertEquals(expectedDecodedFirstTerm, model.termValue(0)); + } + + @Test + public void decodeTerm_encodedTerm_shouldReturnDecodedTerm() throws Exception { + byte[] originalInput = "lucene".getBytes(StandardCharsets.UTF_8); + String B64encodedLuceneTerm = Base64.getEncoder().encodeToString(originalInput); + String word2vecEncodedLuceneTerm = "B64:" + B64encodedLuceneTerm; + assertEquals(new BytesRef("lucene"), Dl4jModelReader.decodeB64Term(word2vecEncodedLuceneTerm)); + } + + @Test + public void read_EmptyZipFile_shouldThrowException() throws Exception { + try (InputStream stream = TestDl4jModelReader.class.getResourceAsStream(MODEL_EMPTY_FILE)) { + Dl4jModelReader unit = new Dl4jModelReader(stream); + expectThrows(IllegalArgumentException.class, unit::read); + } + } + + @Test + public void read_corruptedVectorDimensionModelFile_shouldThrowException() throws Exception { + try (InputStream stream = + TestDl4jModelReader.class.getResourceAsStream(CORRUPTED_VECTOR_DIMENSION_MODEL_FILE)) { + Dl4jModelReader unit = new Dl4jModelReader(stream); + expectThrows(RuntimeException.class, unit::read); + } + } +} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java new file mode 100644 index 000000000000..3999931dd758 --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilter.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; +import org.apache.lucene.tests.analysis.MockTokenizer; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TermAndVector; +import org.junit.Test; + +public class TestWord2VecSynonymFilter extends BaseTokenStreamTestCase { + + @Test + public void synonymExpansion_oneCandidate_shouldBeExpandedWithinThreshold() throws Exception { + int maxSynonymPerTerm = 10; + float minAcceptedSimilarity = 0.9f; + Word2VecModel model = new Word2VecModel(6, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "d", "e", "c", "b", "post"}, // output + new int[] {0, 4, 4, 4, 4, 4, 6}, // start offset + new int[] {3, 5, 5, 5, 5, 5, 10}, // end offset + new String[] {"word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word"}, // types + new int[] {1, 1, 0, 0, 0, 0, 1}, // posIncrements + new int[] {1, 1, 1, 1, 1, 1, 1}); // posLenghts + a.close(); + } + + @Test + public void synonymExpansion_oneCandidate_shouldBeExpandedWithTopKSynonyms() throws Exception { + int maxSynonymPerTerm = 2; + float minAcceptedSimilarity = 0.9f; + Word2VecModel model = new Word2VecModel(5, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, maxSynonymPerTerm, minAcceptedSimilarity); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "d", "e", "post"}, // output + new int[] {0, 4, 4, 4, 6}, // start offset + new int[] {3, 5, 5, 5, 10}, // end offset + new String[] {"word", "word", "SYNONYM", "SYNONYM", "word"}, // types + new int[] {1, 1, 0, 0, 1}, // posIncrements + new int[] {1, 1, 1, 1, 1}); // posLenghts + a.close(); + } + + @Test + public void synonymExpansion_twoCandidates_shouldBothBeExpanded() throws Exception { + Word2VecModel model = new Word2VecModel(8, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {1, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("post"), new float[] {-10, -11})); + model.addTermAndVector(new TermAndVector(new BytesRef("after"), new float[] {-8, -10})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, 10, 0.9f); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "d", "e", "c", "b", "post", "after"}, // output + new int[] {0, 4, 4, 4, 4, 4, 6, 6}, // start offset + new int[] {3, 5, 5, 5, 5, 5, 10, 10}, // end offset + new String[] { // types + "word", "word", "SYNONYM", "SYNONYM", "SYNONYM", "SYNONYM", "word", "SYNONYM" + }, + new int[] {1, 1, 0, 0, 0, 0, 1, 0}, // posIncrements + new int[] {1, 1, 1, 1, 1, 1, 1, 1}); // posLengths + a.close(); + } + + @Test + public void synonymExpansion_forMinAcceptedSimilarity_shouldExpandToNoneSynonyms() + throws Exception { + Word2VecModel model = new Word2VecModel(4, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, -10})); + + Word2VecSynonymProvider synonymProvider = new Word2VecSynonymProvider(model); + + Analyzer a = getAnalyzer(synonymProvider, 10, 0.8f); + assertAnalyzesTo( + a, + "pre a post", // input + new String[] {"pre", "a", "post"}, // output + new int[] {0, 4, 6}, // start offset + new int[] {3, 5, 10}, // end offset + new String[] {"word", "word", "word"}, // types + new int[] {1, 1, 1}, // posIncrements + new int[] {1, 1, 1}); // posLengths + a.close(); + } + + private Analyzer getAnalyzer( + Word2VecSynonymProvider synonymProvider, + int maxSynonymsPerTerm, + float minAcceptedSimilarity) { + return new Analyzer() { + @Override + protected TokenStreamComponents createComponents(String fieldName) { + Tokenizer tokenizer = new MockTokenizer(MockTokenizer.WHITESPACE, false); + // Make a local variable so testRandomHuge doesn't share it across threads! + Word2VecSynonymFilter synFilter = + new Word2VecSynonymFilter( + tokenizer, synonymProvider, maxSynonymsPerTerm, minAcceptedSimilarity); + return new TokenStreamComponents(tokenizer, synFilter); + } + }; + } +} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java new file mode 100644 index 000000000000..007fedf4abed --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymFilterFactory.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import org.apache.lucene.tests.analysis.BaseTokenStreamFactoryTestCase; +import org.apache.lucene.util.ClasspathResourceLoader; +import org.apache.lucene.util.ResourceLoader; +import org.junit.Test; + +public class TestWord2VecSynonymFilterFactory extends BaseTokenStreamFactoryTestCase { + + public static final String FACTORY_NAME = "Word2VecSynonym"; + private static final String WORD2VEC_MODEL_FILE = "word2vec-model.zip"; + + @Test + public void testInform() throws Exception { + ResourceLoader loader = new ClasspathResourceLoader(getClass()); + assertTrue("loader is null and it shouldn't be", loader != null); + Word2VecSynonymFilterFactory factory = + (Word2VecSynonymFilterFactory) + tokenFilterFactory( + FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "minAcceptedSimilarity", "0.7"); + + Word2VecSynonymProvider synonymProvider = factory.getSynonymProvider(); + assertNotEquals(null, synonymProvider); + } + + @Test + public void missingRequiredArgument_shouldThrowException() throws Exception { + IllegalArgumentException expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "format", + "dl4j", + "minAcceptedSimilarity", + "0.7", + "maxSynonymsPerTerm", + "10"); + }); + assertTrue(expected.getMessage().contains("Configuration Error: missing parameter 'model'")); + } + + @Test + public void unsupportedModelFormat_shouldThrowException() throws Exception { + IllegalArgumentException expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "format", "bogusValue"); + }); + assertTrue(expected.getMessage().contains("Model format 'BOGUSVALUE' not supported")); + } + + @Test + public void bogusArgument_shouldThrowException() throws Exception { + IllegalArgumentException expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, "model", WORD2VEC_MODEL_FILE, "bogusArg", "bogusValue"); + }); + assertTrue(expected.getMessage().contains("Unknown parameters")); + } + + @Test + public void illegalArguments_shouldThrowException() throws Exception { + IllegalArgumentException expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "model", + WORD2VEC_MODEL_FILE, + "minAcceptedSimilarity", + "2", + "maxSynonymsPerTerm", + "10"); + }); + assertTrue( + expected + .getMessage() + .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 2")); + + expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "model", + WORD2VEC_MODEL_FILE, + "minAcceptedSimilarity", + "0", + "maxSynonymsPerTerm", + "10"); + }); + assertTrue( + expected + .getMessage() + .contains("minAcceptedSimilarity must be in the range (0, 1]. Found: 0")); + + expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "model", + WORD2VEC_MODEL_FILE, + "minAcceptedSimilarity", + "0.7", + "maxSynonymsPerTerm", + "-1"); + }); + assertTrue( + expected + .getMessage() + .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: -1")); + + expected = + expectThrows( + IllegalArgumentException.class, + () -> { + tokenFilterFactory( + FACTORY_NAME, + "model", + WORD2VEC_MODEL_FILE, + "minAcceptedSimilarity", + "0.7", + "maxSynonymsPerTerm", + "0"); + }); + assertTrue( + expected + .getMessage() + .contains("maxSynonymsPerTerm must be a positive integer greater than 0. Found: 0")); + } +} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java new file mode 100644 index 000000000000..3e7e6bce07a3 --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/TestWord2VecSynonymProvider.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.analysis.synonym.word2vec; + +import java.io.IOException; +import java.util.List; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TermAndVector; +import org.junit.Test; + +public class TestWord2VecSynonymProvider extends LuceneTestCase { + + private static final int MAX_SYNONYMS_PER_TERM = 10; + private static final float MIN_ACCEPTED_SIMILARITY = 0.85f; + + private final Word2VecSynonymProvider unit; + + public TestWord2VecSynonymProvider() throws IOException { + Word2VecModel model = new Word2VecModel(2, 3); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {0.24f, 0.78f, 0.28f})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {0.44f, 0.01f, 0.81f})); + unit = new Word2VecSynonymProvider(model); + } + + @Test + public void getSynonyms_nullToken_shouldThrowException() { + expectThrows( + IllegalArgumentException.class, + () -> unit.getSynonyms(null, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY)); + } + + @Test + public void getSynonyms_shouldReturnSynonymsBasedOnMinAcceptedSimilarity() throws Exception { + Word2VecModel model = new Word2VecModel(6, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("e"), new float[] {99, 101})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); + + Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); + + BytesRef inputTerm = new BytesRef("a"); + String[] expectedSynonyms = {"d", "e", "c", "b"}; + List actualSynonymsResults = + unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); + + assertEquals(4, actualSynonymsResults.size()); + for (int i = 0; i < expectedSynonyms.length; i++) { + assertEquals(new BytesRef(expectedSynonyms[i]), actualSynonymsResults.get(i).term); + } + } + + @Test + public void getSynonyms_shouldReturnSynonymsBoost() throws Exception { + Word2VecModel model = new Word2VecModel(3, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {1, 1})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {99, 101})); + + Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); + + BytesRef inputTerm = new BytesRef("a"); + List actualSynonymsResults = + unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); + + BytesRef expectedFirstSynonymTerm = new BytesRef("b"); + double expectedFirstSynonymBoost = 1.0; + assertEquals(expectedFirstSynonymTerm, actualSynonymsResults.get(0).term); + assertEquals(expectedFirstSynonymBoost, actualSynonymsResults.get(0).boost, 0.001f); + } + + @Test + public void noSynonymsWithinAcceptedSimilarity_shouldReturnNoSynonyms() throws Exception { + Word2VecModel model = new Word2VecModel(4, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {-10, -8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {-9, -10})); + model.addTermAndVector(new TermAndVector(new BytesRef("d"), new float[] {6, -6})); + + Word2VecSynonymProvider unit = new Word2VecSynonymProvider(model); + + BytesRef inputTerm = newBytesRef("a"); + List actualSynonymsResults = + unit.getSynonyms(inputTerm, MAX_SYNONYMS_PER_TERM, MIN_ACCEPTED_SIMILARITY); + assertEquals(0, actualSynonymsResults.size()); + } + + @Test + public void testModel_shouldReturnNormalizedVectors() { + Word2VecModel model = new Word2VecModel(4, 2); + model.addTermAndVector(new TermAndVector(new BytesRef("a"), new float[] {10, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("b"), new float[] {10, 8})); + model.addTermAndVector(new TermAndVector(new BytesRef("c"), new float[] {9, 10})); + model.addTermAndVector(new TermAndVector(new BytesRef("f"), new float[] {-1, 10})); + + float[] vectorIdA = model.vectorValue(new BytesRef("a")); + float[] vectorIdF = model.vectorValue(new BytesRef("f")); + assertArrayEquals(new float[] {0.70710f, 0.70710f}, vectorIdA, 0.001f); + assertArrayEquals(new float[] {-0.0995f, 0.99503f}, vectorIdF, 0.001f); + } + + @Test + public void normalizedVector_shouldReturnModule1() { + TermAndVector synonymTerm = new TermAndVector(new BytesRef("a"), new float[] {10, 10}); + synonymTerm.normalizeVector(); + float[] vector = synonymTerm.getVector(); + float len = 0; + for (int i = 0; i < vector.length; i++) { + len += vector[i] * vector[i]; + } + assertEquals(1, Math.sqrt(len), 0.0001f); + } +} diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip new file mode 100644 index 000000000000..e25693dd83cf Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-corrupted-vector-dimension-model.zip differ diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip new file mode 100644 index 000000000000..57d7832dd787 Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-empty-model.zip differ diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip new file mode 100644 index 000000000000..6d31b8d5a3fa Binary files /dev/null and b/lucene/analysis/common/src/test/org/apache/lucene/analysis/synonym/word2vec/word2vec-model.zip differ diff --git a/lucene/analysis/smartcn/src/resources/org/apache/lucene/analysis/cn/smart/stopwords.txt b/lucene/analysis/smartcn/src/resources/org/apache/lucene/analysis/cn/smart/stopwords.txt index fb0d71ad7d2c..65bcfd4e1b65 100644 --- a/lucene/analysis/smartcn/src/resources/org/apache/lucene/analysis/cn/smart/stopwords.txt +++ b/lucene/analysis/smartcn/src/resources/org/apache/lucene/analysis/cn/smart/stopwords.txt @@ -53,7 +53,5 @@ $ ● // the line below contains an IDEOGRAPHIC SPACE character (Used as a space in Chinese)   - //////////////// English Stop Words //////////////// - //////////////// Chinese Stop Words //////////////// diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java index fc82ce588860..639bdbd73339 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java @@ -219,7 +219,7 @@ static short writeBitSet(DocIdSetIterator it, IndexOutput out, byte denseRankPow // Flush block flush(prevBlock, buffer, blockCardinality, denseRankPower, out); // Reset for next block - buffer.clear(0, buffer.length()); + buffer.clear(); totalCardinality += blockCardinality; blockCardinality = 0; } @@ -233,7 +233,7 @@ static short writeBitSet(DocIdSetIterator it, IndexOutput out, byte denseRankPow jumps, out.getFilePointer() - origo, totalCardinality, jumpBlockIndex, prevBlock + 1); totalCardinality += blockCardinality; flush(prevBlock, buffer, blockCardinality, denseRankPower, out); - buffer.clear(0, buffer.length()); + buffer.clear(); prevBlock++; } final int lastBlock = diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index e7f16b4f3fc9..4b1f7068a5f2 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -183,10 +183,10 @@ private void addDiverseNeighbors(int node, NeighborQueue candidates) throws IOEx int size = neighbors.size(); for (int i = 0; i < size; i++) { int nbr = neighbors.node()[i]; - Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr); - nbrNbr.add(node, neighbors.score()[i]); - if (nbrNbr.size() > maxConn) { - diversityUpdate(nbrNbr); + Lucene90NeighborArray nbrsOfNbr = hnsw.getNeighbors(nbr); + nbrsOfNbr.add(node, neighbors.score()[i]); + if (nbrsOfNbr.size() > maxConn) { + diversityUpdate(nbrsOfNbr); } } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java index b0e9d160457b..c82920181cc4 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java @@ -204,10 +204,10 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) int size = neighbors.size(); for (int i = 0; i < size; i++) { int nbr = neighbors.node[i]; - Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr); - nbrNbr.add(node, neighbors.score[i]); - if (nbrNbr.size() > maxConn) { - diversityUpdate(nbrNbr); + Lucene91NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); + nbrsOfNbr.add(node, neighbors.score[i]); + if (nbrsOfNbr.size() > maxConn) { + diversityUpdate(nbrsOfNbr); } } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBackwardsCompatibility.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBackwardsCompatibility.java index d030777e3783..40541f505bea 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBackwardsCompatibility.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBackwardsCompatibility.java @@ -419,7 +419,9 @@ public void testCreateEmptyIndex() throws Exception { "9.4.2-cfs", "9.4.2-nocfs", "9.5.0-cfs", - "9.5.0-nocfs" + "9.5.0-nocfs", + "9.6.0-cfs", + "9.6.0-nocfs" }; public static String[] getOldNames() { @@ -459,7 +461,8 @@ public static String[] getOldNames() { "sorted.9.4.0", "sorted.9.4.1", "sorted.9.4.2", - "sorted.9.5.0" + "sorted.9.5.0", + "sorted.9.6.0" }; public static String[] getOldSortedNames() { diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-cfs.zip b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-cfs.zip new file mode 100644 index 000000000000..bc40e99b3d77 Binary files /dev/null and b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-cfs.zip differ diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-nocfs.zip b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-nocfs.zip new file mode 100644 index 000000000000..8e97a46c1485 Binary files /dev/null and b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/index.9.6.0-nocfs.zip differ diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/sorted.9.6.0.zip b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/sorted.9.6.0.zip new file mode 100644 index 000000000000..f3af2e4011ee Binary files /dev/null and b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/sorted.9.6.0.zip differ diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java index c020d7487cb3..12d11e59b9d7 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java @@ -35,6 +35,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.FuzzyTermsEnum; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TermQuery; @@ -214,7 +215,8 @@ private Query newTermQuery(IndexReader reader, Term term) throws IOException { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + IndexReader reader = indexSearcher.getIndexReader(); ScoreTermQueue q = new ScoreTermQueue(MAX_NUM_TERMS); // load up the list of possible terms for (FieldVals f : fieldVals) { diff --git a/lucene/core/src/generated/jdk/README.md b/lucene/core/src/generated/jdk/README.md index 371bbebf8518..48a014b992b8 100644 --- a/lucene/core/src/generated/jdk/README.md +++ b/lucene/core/src/generated/jdk/README.md @@ -40,4 +40,4 @@ to point the Lucene build system to missing JDK versions. The regeneration task a warning if a specific JDK is missing, leaving the already existing `.apijar` file untouched. -The extraction is done with the ASM library, see `ExtractForeignAPI.java` source code. +The extraction is done with the ASM library, see `ExtractJdkApis.java` source code. diff --git a/lucene/core/src/generated/jdk/jdk19.apijar b/lucene/core/src/generated/jdk/jdk19.apijar new file mode 100644 index 000000000000..4a04f1440e4a Binary files /dev/null and b/lucene/core/src/generated/jdk/jdk19.apijar differ diff --git a/lucene/core/src/generated/jdk/jdk20.apijar b/lucene/core/src/generated/jdk/jdk20.apijar new file mode 100644 index 000000000000..942ddef057b7 Binary files /dev/null and b/lucene/core/src/generated/jdk/jdk20.apijar differ diff --git a/lucene/core/src/generated/jdk/jdk21.apijar b/lucene/core/src/generated/jdk/jdk21.apijar new file mode 100644 index 000000000000..3ded0aaaed4e Binary files /dev/null and b/lucene/core/src/generated/jdk/jdk21.apijar differ diff --git a/lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar b/lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar deleted file mode 100644 index c9b73d9193bf..000000000000 Binary files a/lucene/core/src/generated/jdk/panama-foreign-jdk19.apijar and /dev/null differ diff --git a/lucene/core/src/generated/jdk/panama-foreign-jdk20.apijar b/lucene/core/src/generated/jdk/panama-foreign-jdk20.apijar deleted file mode 100644 index 03baf38a1938..000000000000 Binary files a/lucene/core/src/generated/jdk/panama-foreign-jdk20.apijar and /dev/null differ diff --git a/lucene/core/src/java/org/apache/lucene/analysis/WordlistLoader.java b/lucene/core/src/java/org/apache/lucene/analysis/WordlistLoader.java index 30ada92eb39b..8e18f4ad76d3 100644 --- a/lucene/core/src/java/org/apache/lucene/analysis/WordlistLoader.java +++ b/lucene/core/src/java/org/apache/lucene/analysis/WordlistLoader.java @@ -40,9 +40,9 @@ public class WordlistLoader { private WordlistLoader() {} /** - * Reads lines from a Reader and adds every line as an entry to a CharArraySet (omitting leading - * and trailing whitespace). Every line of the Reader should contain only one word. The words need - * to be in lowercase if you make use of an Analyzer which uses LowerCaseFilter (like + * Reads lines from a Reader and adds every non-blank line as an entry to a CharArraySet (omitting + * leading and trailing whitespace). Every line of the Reader should contain only one word. The + * words need to be in lowercase if you make use of an Analyzer which uses LowerCaseFilter (like * StandardAnalyzer). * * @param reader Reader containing the wordlist @@ -53,7 +53,10 @@ public static CharArraySet getWordSet(Reader reader, CharArraySet result) throws try (BufferedReader br = getBufferedReader(reader)) { String word = null; while ((word = br.readLine()) != null) { - result.add(word.trim()); + word = word.trim(); + // skip blank lines + if (word.isEmpty()) continue; + result.add(word); } } return result; @@ -101,10 +104,10 @@ public static CharArraySet getWordSet(InputStream stream, Charset charset) throw } /** - * Reads lines from a Reader and adds every non-comment line as an entry to a CharArraySet - * (omitting leading and trailing whitespace). Every line of the Reader should contain only one - * word. The words need to be in lowercase if you make use of an Analyzer which uses - * LowerCaseFilter (like StandardAnalyzer). + * Reads lines from a Reader and adds every non-blank non-comment line as an entry to a + * CharArraySet (omitting leading and trailing whitespace). Every line of the Reader should + * contain only one word. The words need to be in lowercase if you make use of an Analyzer which + * uses LowerCaseFilter (like StandardAnalyzer). * * @param reader Reader containing the wordlist * @param comment The string representing a comment. @@ -117,7 +120,10 @@ public static CharArraySet getWordSet(Reader reader, String comment, CharArraySe String word = null; while ((word = br.readLine()) != null) { if (word.startsWith(comment) == false) { - result.add(word.trim()); + word = word.trim(); + // skip blank lines + if (word.isEmpty()) continue; + result.add(word); } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java index 8da289e3ad35..512ab4b1e556 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java @@ -110,12 +110,12 @@ public final class IndexedDISI extends DocIdSetIterator { private static void flush( int block, FixedBitSet buffer, int cardinality, byte denseRankPower, IndexOutput out) throws IOException { - assert block >= 0 && block < 65536; + assert block >= 0 && block < BLOCK_SIZE; out.writeShort((short) block); - assert cardinality > 0 && cardinality <= 65536; + assert cardinality > 0 && cardinality <= BLOCK_SIZE; out.writeShort((short) (cardinality - 1)); if (cardinality > MAX_ARRAY_LENGTH) { - if (cardinality != 65536) { // all docs are set + if (cardinality != BLOCK_SIZE) { // all docs are set if (denseRankPower != -1) { final byte[] rank = createRank(buffer, denseRankPower); out.writeBytes(rank, rank.length); @@ -220,7 +220,7 @@ public static short writeBitSet(DocIdSetIterator it, IndexOutput out, byte dense // Flush block flush(prevBlock, buffer, blockCardinality, denseRankPower, out); // Reset for next block - buffer.clear(0, buffer.length()); + buffer.clear(); totalCardinality += blockCardinality; blockCardinality = 0; } @@ -234,7 +234,7 @@ public static short writeBitSet(DocIdSetIterator it, IndexOutput out, byte dense jumps, out.getFilePointer() - origo, totalCardinality, jumpBlockIndex, prevBlock + 1); totalCardinality += blockCardinality; flush(prevBlock, buffer, blockCardinality, denseRankPower, out); - buffer.clear(0, buffer.length()); + buffer.clear(); prevBlock++; } final int lastBlock = @@ -418,6 +418,7 @@ public static RandomAccessInput createJumpTable( // SPARSE variables boolean exists; + int nextExistDocInBlock = -1; // DENSE variables long word; @@ -495,7 +496,8 @@ private void readBlockHeader() throws IOException { if (numValues <= MAX_ARRAY_LENGTH) { method = Method.SPARSE; blockEnd = slice.getFilePointer() + (numValues << 1); - } else if (numValues == 65536) { + nextExistDocInBlock = -1; + } else if (numValues == BLOCK_SIZE) { method = Method.ALL; blockEnd = slice.getFilePointer(); gap = block - index - 1; @@ -550,6 +552,7 @@ boolean advanceWithinBlock(IndexedDISI disi, int target) throws IOException { if (doc >= targetInBlock) { disi.doc = disi.block | doc; disi.exists = true; + disi.nextExistDocInBlock = doc; return true; } } @@ -560,6 +563,10 @@ boolean advanceWithinBlock(IndexedDISI disi, int target) throws IOException { boolean advanceExactWithinBlock(IndexedDISI disi, int target) throws IOException { final int targetInBlock = target & 0xFFFF; // TODO: binary search + if (disi.nextExistDocInBlock > targetInBlock) { + assert !disi.exists; + return false; + } if (target == disi.doc) { return disi.exists; } @@ -567,6 +574,7 @@ boolean advanceExactWithinBlock(IndexedDISI disi, int target) throws IOException int doc = Short.toUnsignedInt(disi.slice.readShort()); disi.index++; if (doc >= targetInBlock) { + disi.nextExistDocInBlock = doc; if (doc != targetInBlock) { disi.index--; disi.slice.seek(disi.slice.getFilePointer() - Short.BYTES); diff --git a/lucene/core/src/java/org/apache/lucene/document/BinaryRangeFieldRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/BinaryRangeFieldRangeQuery.java index ec814a4b0a7e..b225f4d98fe2 100644 --- a/lucene/core/src/java/org/apache/lucene/document/BinaryRangeFieldRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/BinaryRangeFieldRangeQuery.java @@ -21,7 +21,6 @@ import java.util.Arrays; import java.util.Objects; import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.ConstantScoreScorer; @@ -90,8 +89,8 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - return super.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return super.rewrite(indexSearcher); } private BinaryRangeDocValues getValues(LeafReader reader, String field) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/document/DoubleRangeSlowRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/DoubleRangeSlowRangeQuery.java index 6406b1be81a2..f0f9040d08c9 100644 --- a/lucene/core/src/java/org/apache/lucene/document/DoubleRangeSlowRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/DoubleRangeSlowRangeQuery.java @@ -20,7 +20,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; -import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -79,8 +79,8 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - return super.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return super.rewrite(indexSearcher); } private static byte[] encodeRanges(double[] min, double[] max) { diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java index 27e689644594..85c3bad35199 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java @@ -31,6 +31,7 @@ import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.SortField; import org.apache.lucene.search.similarities.BM25Similarity; @@ -223,7 +224,7 @@ abstract static class FeatureFunction { abstract Explanation explain(String field, String feature, float w, int freq); - FeatureFunction rewrite(IndexReader reader) throws IOException { + FeatureFunction rewrite(IndexSearcher indexSearcher) throws IOException { return this; } } @@ -340,11 +341,11 @@ static final class SaturationFunction extends FeatureFunction { } @Override - public FeatureFunction rewrite(IndexReader reader) throws IOException { + public FeatureFunction rewrite(IndexSearcher indexSearcher) throws IOException { if (pivot != null) { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } - float newPivot = computePivotFeatureValue(reader, field, feature); + float newPivot = computePivotFeatureValue(indexSearcher.getIndexReader(), field, feature); return new SaturationFunction(field, feature, newPivot); } diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java index de899a96d4ea..8b1ea8c47258 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java @@ -20,7 +20,6 @@ import java.util.Objects; import org.apache.lucene.document.FeatureField.FeatureFunction; import org.apache.lucene.index.ImpactsEnum; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; @@ -52,12 +51,12 @@ final class FeatureQuery extends Query { } @Override - public Query rewrite(IndexReader reader) throws IOException { - FeatureFunction rewritten = function.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + FeatureFunction rewritten = function.rewrite(indexSearcher); if (function != rewritten) { return new FeatureQuery(fieldName, featureName, rewritten); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/document/FloatRangeSlowRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/FloatRangeSlowRangeQuery.java index 9c011621e066..62981f4870aa 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FloatRangeSlowRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/FloatRangeSlowRangeQuery.java @@ -20,7 +20,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; -import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -79,8 +79,8 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - return super.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return super.rewrite(indexSearcher); } private static byte[] encodeRanges(float[] min, float[] max) { diff --git a/lucene/core/src/java/org/apache/lucene/document/IntRangeSlowRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/IntRangeSlowRangeQuery.java index cd0714a52909..99d5fa212f67 100644 --- a/lucene/core/src/java/org/apache/lucene/document/IntRangeSlowRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/IntRangeSlowRangeQuery.java @@ -19,7 +19,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; -import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -77,8 +77,8 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - return super.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return super.rewrite(indexSearcher); } private static byte[] encodeRanges(int[] min, int[] max) { diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java index fabbc5259e3b..87cb6a9f056e 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnByteVectorField.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; +import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -100,7 +101,7 @@ public static FieldType createFieldType( public KnnByteVectorField( String name, byte[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = vector; + fieldsData = vector; // null-check done above } /** @@ -136,6 +137,11 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) { + " using byte[] but the field encoding is " + fieldType.vectorEncoding()); } + Objects.requireNonNull(vector, "vector value must not be null"); + if (vector.length != fieldType.vectorDimension()) { + throw new IllegalArgumentException( + "The number of vector dimensions does not match the field type"); + } fieldsData = vector; } diff --git a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java index d6673293c720..9d1cd02c013e 100644 --- a/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/KnnFloatVectorField.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; +import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -101,7 +102,7 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) { public KnnFloatVectorField( String name, float[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); - fieldsData = vector; + fieldsData = VectorUtil.checkFinite(vector); // null check done above } /** @@ -137,7 +138,12 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) { + " using float[] but the field encoding is " + fieldType.vectorEncoding()); } - fieldsData = vector; + Objects.requireNonNull(vector, "vector value must not be null"); + if (vector.length != fieldType.vectorDimension()) { + throw new IllegalArgumentException( + "The number of vector dimensions does not match the field type"); + } + fieldsData = VectorUtil.checkFinite(vector); } /** Return the vector value of this field */ diff --git a/lucene/core/src/java/org/apache/lucene/document/LongRangeSlowRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/LongRangeSlowRangeQuery.java index a4c164524fc5..b86d64bef62c 100644 --- a/lucene/core/src/java/org/apache/lucene/document/LongRangeSlowRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/LongRangeSlowRangeQuery.java @@ -20,7 +20,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; -import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -79,8 +79,8 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - return super.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return super.rewrite(indexSearcher); } private static byte[] encodeRanges(long[] min, long[] max) { diff --git a/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesRangeQuery.java index fb23056126bb..21264409136f 100644 --- a/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesRangeQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Objects; import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SortedNumericDocValues; @@ -84,11 +83,11 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (lowerValue == Long.MIN_VALUE && upperValue == Long.MAX_VALUE) { return new FieldExistsQuery(field); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesSetQuery.java b/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesSetQuery.java index 72d5d42b30a3..fe5e8c3afedd 100644 --- a/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/SortedNumericDocValuesSetQuery.java @@ -20,7 +20,6 @@ import java.util.Arrays; import java.util.Objects; import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SortedNumericDocValues; @@ -85,11 +84,11 @@ public long ramBytesUsed() { } @Override - public Query rewrite(IndexReader indexReader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (numbers.size() == 0) { return new MatchNoDocsQuery(); } - return super.rewrite(indexReader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/document/SortedSetDocValuesRangeQuery.java b/lucene/core/src/java/org/apache/lucene/document/SortedSetDocValuesRangeQuery.java index f7eab990d3d5..928257cbd1fa 100644 --- a/lucene/core/src/java/org/apache/lucene/document/SortedSetDocValuesRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/SortedSetDocValuesRangeQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Objects; import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedSetDocValues; @@ -98,11 +97,11 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (lowerValue == null && upperValue == null) { return new FieldExistsQuery(field); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/geo/Tessellator.java b/lucene/core/src/java/org/apache/lucene/geo/Tessellator.java index 5179037fa54f..ae0190f9d459 100644 --- a/lucene/core/src/java/org/apache/lucene/geo/Tessellator.java +++ b/lucene/core/src/java/org/apache/lucene/geo/Tessellator.java @@ -1090,17 +1090,22 @@ private static final Node splitPolygon(final Node a, final Node b, boolean edgeF * Determines whether a diagonal between two polygon nodes lies within a polygon interior. (This * determines the validity of the ray.) * */ - private static final boolean isValidDiagonal(final Node a, final Node b) { + private static boolean isValidDiagonal(final Node a, final Node b) { + if (a.next.idx == b.idx + || a.previous.idx == b.idx + // check next edges are locally visible + || isLocallyInside(a.previous, b) == false + || isLocallyInside(b.next, a) == false + // check polygons are CCW in both sides + || isCWPolygon(a, b) == false + || isCWPolygon(b, a) == false) { + return false; + } if (isVertexEquals(a, b)) { - // If points are equal then use it if they are valid polygons - return isCWPolygon(a, b); + return true; } - return a.next.idx != b.idx - && a.previous.idx != b.idx - && isLocallyInside(a, b) + return isLocallyInside(a, b) && isLocallyInside(b, a) - && isLocallyInside(a.previous, b) - && isLocallyInside(b.next, a) && middleInsert(a, a.getX(), a.getY(), b.getX(), b.getY()) // make sure we don't introduce collinear lines && area(a.previous.getX(), a.previous.getY(), a.getX(), a.getY(), b.getX(), b.getY()) != 0 @@ -1114,7 +1119,7 @@ && area(a.getX(), a.getY(), b.getX(), b.getY(), b.previous.getX(), b.previous.ge /** Determine whether the polygon defined between node start and node end is CW */ private static boolean isCWPolygon(final Node start, final Node end) { // The polygon must be CW - return (signedArea(start, end) < 0) ? true : false; + return signedArea(start, end) < 0; } /** Determine the signed area between node start and node end */ diff --git a/lucene/core/src/java/org/apache/lucene/index/CachingMergeContext.java b/lucene/core/src/java/org/apache/lucene/index/CachingMergeContext.java new file mode 100644 index 000000000000..a9974c64e0b1 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/index/CachingMergeContext.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Set; +import org.apache.lucene.util.InfoStream; + +/** + * a wrapper of IndexWriter MergeContext. Try to cache the {@link + * #numDeletesToMerge(SegmentCommitInfo)} result in merge phase, to avoid duplicate calculation + */ +class CachingMergeContext implements MergePolicy.MergeContext { + final MergePolicy.MergeContext mergeContext; + final HashMap cachedNumDeletesToMerge = new HashMap<>(); + + CachingMergeContext(MergePolicy.MergeContext mergeContext) { + this.mergeContext = mergeContext; + } + + @Override + public final int numDeletesToMerge(SegmentCommitInfo info) throws IOException { + Integer numDeletesToMerge = cachedNumDeletesToMerge.get(info); + if (numDeletesToMerge != null) { + return numDeletesToMerge; + } + numDeletesToMerge = mergeContext.numDeletesToMerge(info); + cachedNumDeletesToMerge.put(info, numDeletesToMerge); + return numDeletesToMerge; + } + + @Override + public final int numDeletedDocs(SegmentCommitInfo info) { + return mergeContext.numDeletedDocs(info); + } + + @Override + public final InfoStream getInfoStream() { + return mergeContext.getInfoStream(); + } + + @Override + public final Set getMergingSegments() { + return mergeContext.getMergingSegments(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/index/DocumentsWriterDeleteQueue.java b/lucene/core/src/java/org/apache/lucene/index/DocumentsWriterDeleteQueue.java index ed86416a2b19..4d8bfd054654 100644 --- a/lucene/core/src/java/org/apache/lucene/index/DocumentsWriterDeleteQueue.java +++ b/lucene/core/src/java/org/apache/lucene/index/DocumentsWriterDeleteQueue.java @@ -142,6 +142,10 @@ static Node newNode(Term term) { return new TermNode(term); } + static Node newNode(Query query) { + return new QueryNode(query); + } + static Node newNode(DocValuesUpdate... updates) { return new DocValuesUpdatesNode(updates); } @@ -437,6 +441,23 @@ public String toString() { } } + private static final class QueryNode extends Node { + + QueryNode(Query query) { + super(query); + } + + @Override + void apply(BufferedUpdates bufferedDeletes, int docIDUpto) { + bufferedDeletes.addQuery(item, docIDUpto); + } + + @Override + public String toString() { + return "del=" + item; + } + } + private static final class QueryArrayNode extends Node { QueryArrayNode(Query[] query) { super(query); diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index b180d3bdda3d..a572b2258af5 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -36,7 +36,7 @@ */ public class ExitableDirectoryReader extends FilterDirectoryReader { - private QueryTimeout queryTimeout; + private final QueryTimeout queryTimeout; /** Exception that is thrown to prematurely terminate a term enumeration. */ @SuppressWarnings("serial") @@ -50,7 +50,7 @@ public ExitingReaderException(String msg) { /** Wrapper class for a SubReaderWrapper that is used by the ExitableDirectoryReader. */ public static class ExitableSubReaderWrapper extends SubReaderWrapper { - private QueryTimeout queryTimeout; + private final QueryTimeout queryTimeout; /** Constructor * */ public ExitableSubReaderWrapper(QueryTimeout queryTimeout) { @@ -810,7 +810,7 @@ public static class ExitableTermsEnum extends FilterTermsEnum { // Create bit mask in the form of 0000 1111 for efficient checking private static final int NUM_CALLS_PER_TIMEOUT_CHECK = (1 << 4) - 1; // 15 private int calls; - private QueryTimeout queryTimeout; + private final QueryTimeout queryTimeout; /** Constructor * */ public ExitableTermsEnum(TermsEnum termsEnum, QueryTimeout queryTimeout) { diff --git a/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java b/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java index c48b851ae526..df2054c18846 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java +++ b/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java @@ -400,6 +400,150 @@ static void verifySameVectorOptions( } } + /* + This method will create a new instance of FieldInfo if any attribute changes (and it changes in a compatible way). + It is intended only to be used in indices where schema validation is not strict (legacy indices). It will return null + if no changes are done on this FieldInfo + */ + FieldInfo handleLegacySupportedUpdates(FieldInfo otherFi) { + IndexOptions newIndexOptions = this.indexOptions; + boolean newStoreTermVector = this.storeTermVector; + boolean newOmitNorms = this.omitNorms; + boolean newStorePayloads = this.storePayloads; + DocValuesType newDocValues = this.docValuesType; + int newPointDimensionCount = this.pointDimensionCount; + int newPointNumBytes = this.pointNumBytes; + int newPointIndexDimensionCount = this.pointIndexDimensionCount; + long newDvGen = this.dvGen; + + boolean fieldInfoChanges = false; + // System.out.println("FI.update field=" + name + " indexed=" + indexed + " omitNorms=" + + // omitNorms + " this.omitNorms=" + this.omitNorms); + if (this.indexOptions != otherFi.indexOptions) { + if (this.indexOptions == IndexOptions.NONE) { + newIndexOptions = otherFi.indexOptions; + fieldInfoChanges = true; + } else if (otherFi.indexOptions != IndexOptions.NONE) { + throw new IllegalArgumentException( + "cannot change field \"" + + name + + "\" from index options=" + + this.indexOptions + + " to inconsistent index options=" + + otherFi.indexOptions); + } + } + + if (this.pointDimensionCount != otherFi.pointDimensionCount + && otherFi.pointDimensionCount != 0) { + if (this.pointDimensionCount == 0) { + fieldInfoChanges = true; + newPointDimensionCount = otherFi.pointDimensionCount; + } else { + throw new IllegalArgumentException( + "cannot change field \"" + + name + + "\" from points dimensionCount=" + + this.pointDimensionCount + + " to inconsistent dimensionCount=" + + otherFi.pointDimensionCount); + } + } + + if (this.pointIndexDimensionCount != otherFi.pointIndexDimensionCount + && otherFi.pointIndexDimensionCount != 0) { + if (this.pointIndexDimensionCount == 0) { + fieldInfoChanges = true; + newPointIndexDimensionCount = otherFi.pointIndexDimensionCount; + } else { + throw new IllegalArgumentException( + "cannot change field \"" + + name + + "\" from points indexDimensionCount=" + + this.pointIndexDimensionCount + + " to inconsistent indexDimensionCount=" + + otherFi.pointIndexDimensionCount); + } + } + + if (this.pointNumBytes != otherFi.pointNumBytes && otherFi.pointNumBytes != 0) { + if (this.pointNumBytes == 0) { + fieldInfoChanges = true; + newPointNumBytes = otherFi.pointNumBytes; + } else { + throw new IllegalArgumentException( + "cannot change field \"" + + name + + "\" from points numBytes=" + + this.pointNumBytes + + " to inconsistent numBytes=" + + otherFi.pointNumBytes); + } + } + + if (newIndexOptions + != IndexOptions.NONE) { // if updated field data is not for indexing, leave the updates out + if (this.storeTermVector != otherFi.storeTermVector && this.storeTermVector == false) { + fieldInfoChanges = true; + newStoreTermVector = true; // once vector, always vector + } + if (this.storePayloads != otherFi.storePayloads + && this.storePayloads == false + && newIndexOptions.compareTo(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS) >= 0) { + fieldInfoChanges = true; + newStorePayloads = true; + } + + // Awkward: only drop norms if incoming update is indexed: + if (otherFi.indexOptions != IndexOptions.NONE + && this.omitNorms != otherFi.omitNorms + && this.omitNorms == false) { + fieldInfoChanges = true; + newOmitNorms = true; // if one require omitNorms at least once, it remains off for life + } + } + + if (otherFi.docValuesType != DocValuesType.NONE + && otherFi.docValuesType != this.docValuesType) { + if (this.docValuesType == DocValuesType.NONE) { + fieldInfoChanges = true; + newDocValues = otherFi.docValuesType; + newDvGen = otherFi.dvGen; + } else { + throw new IllegalArgumentException( + "cannot change DocValues type from " + + docValuesType + + " to " + + otherFi.docValuesType + + " for field \"" + + name + + "\""); + } + } + + if (!fieldInfoChanges) { + return null; + } + return new FieldInfo( + this.name, + this.number, + newStoreTermVector, + newOmitNorms, + newStorePayloads, + newIndexOptions, + newDocValues, + newDvGen, + this.attributes, // attributes don't need to be handled here because they are handled for + // the non-legacy case in FieldInfos + newPointDimensionCount, + newPointIndexDimensionCount, + newPointNumBytes, + this.vectorDimension, + this.vectorEncoding, + this.vectorSimilarityFunction, + this.softDeletesField); + } + /** * Record that this field is indexed with points, with the specified number of dimensions and * bytes per dimension. diff --git a/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java b/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java index 7029303992ed..85aac5439d3c 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java +++ b/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java @@ -628,6 +628,48 @@ FieldInfo constructFieldInfo(String fieldName, DocValuesType dvType, int newFiel isSoftDeletesField); } + synchronized void setDocValuesType(int number, String name, DocValuesType dvType) { + verifyConsistent(number, name, dvType); + docValuesType.put(name, dvType); + } + + synchronized void verifyConsistent(Integer number, String name, DocValuesType dvType) { + if (name.equals(numberToName.get(number)) == false) { + throw new IllegalArgumentException( + "field number " + + number + + " is already mapped to field name \"" + + numberToName.get(number) + + "\", not \"" + + name + + "\""); + } + if (number.equals(nameToNumber.get(name)) == false) { + throw new IllegalArgumentException( + "field name \"" + + name + + "\" is already mapped to field number \"" + + nameToNumber.get(name) + + "\", not \"" + + number + + "\""); + } + DocValuesType currentDVType = docValuesType.get(name); + if (dvType != DocValuesType.NONE + && currentDVType != null + && currentDVType != DocValuesType.NONE + && dvType != currentDVType) { + throw new IllegalArgumentException( + "cannot change DocValues type from " + + currentDVType + + " to " + + dvType + + " for field \"" + + name + + "\""); + } + } + synchronized Set getFieldNames() { return Set.copyOf(nameToNumber.keySet()); } @@ -708,6 +750,26 @@ FieldInfo add(FieldInfo fi, long dvGen) { final FieldInfo curFi = fieldInfo(fi.getName()); if (curFi != null) { curFi.verifySameSchema(fi, globalFieldNumbers.strictlyConsistent); + + if (!globalFieldNumbers.strictlyConsistent) { + // For the not strictly consistent case (legacy index), we may need to merge the + // FieldInfo instances + FieldInfo updatedFieldInfo = curFi.handleLegacySupportedUpdates(fi); + if (updatedFieldInfo != null) { + if (curFi.getDocValuesType() == DocValuesType.NONE + && updatedFieldInfo.getDocValuesType() != DocValuesType.NONE) { + // Must also update docValuesType map so it's + // aware of this field's DocValuesType. This will throw IllegalArgumentException if + // an illegal type change was attempted. + globalFieldNumbers.setDocValuesType( + updatedFieldInfo.number, + updatedFieldInfo.getName(), + updatedFieldInfo.getDocValuesType()); + } + // Since the FieldInfo changed, update in map + byName.put(fi.getName(), updatedFieldInfo); + } + } if (fi.attributes() != null) { fi.attributes().forEach((k, v) -> curFi.putAttribute(k, v)); } diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java b/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java index f488e1f83f62..07dca8945548 100644 --- a/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java +++ b/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java @@ -1522,6 +1522,19 @@ public long updateDocuments( delTerm == null ? null : DocumentsWriterDeleteQueue.newNode(delTerm), docs); } + /** + * Similar to {@link #updateDocuments(Term, Iterable)}, but take a query instead of a term to + * identify the documents to be updated + * + * @lucene.experimental + */ + public long updateDocuments( + Query delQuery, Iterable> docs) + throws IOException { + return updateDocuments( + delQuery == null ? null : DocumentsWriterDeleteQueue.newNode(delQuery), docs); + } + private long updateDocuments( final DocumentsWriterDeleteQueue.Node delNode, Iterable> docs) @@ -2202,10 +2215,11 @@ public void forceMergeDeletes(boolean doWait) throws IOException { } final MergePolicy mergePolicy = config.getMergePolicy(); + final CachingMergeContext cachingMergeContext = new CachingMergeContext(this); MergePolicy.MergeSpecification spec; boolean newMergesFound = false; synchronized (this) { - spec = mergePolicy.findForcedDeletesMerges(segmentInfos, this); + spec = mergePolicy.findForcedDeletesMerges(segmentInfos, cachingMergeContext); newMergesFound = spec != null; if (newMergesFound) { final int numMerges = spec.merges.size(); @@ -2315,6 +2329,7 @@ private synchronized MergePolicy.MergeSpecification updatePendingMerges( } final MergePolicy.MergeSpecification spec; + final CachingMergeContext cachingMergeContext = new CachingMergeContext(this); if (maxNumSegments != UNBOUNDED_MAX_MERGE_SEGMENTS) { assert trigger == MergeTrigger.EXPLICIT || trigger == MergeTrigger.MERGE_FINISHED : "Expected EXPLICT or MERGE_FINISHED as trigger even with maxNumSegments set but was: " @@ -2322,7 +2337,10 @@ private synchronized MergePolicy.MergeSpecification updatePendingMerges( spec = mergePolicy.findForcedMerges( - segmentInfos, maxNumSegments, Collections.unmodifiableMap(segmentsToMerge), this); + segmentInfos, + maxNumSegments, + Collections.unmodifiableMap(segmentsToMerge), + cachingMergeContext); if (spec != null) { final int numMerges = spec.merges.size(); for (int i = 0; i < numMerges; i++) { @@ -2334,7 +2352,7 @@ private synchronized MergePolicy.MergeSpecification updatePendingMerges( switch (trigger) { case GET_READER: case COMMIT: - spec = mergePolicy.findFullFlushMerges(trigger, segmentInfos, this); + spec = mergePolicy.findFullFlushMerges(trigger, segmentInfos, cachingMergeContext); break; case ADD_INDEXES: throw new IllegalStateException( @@ -2346,7 +2364,7 @@ private synchronized MergePolicy.MergeSpecification updatePendingMerges( case SEGMENT_FLUSH: case CLOSING: default: - spec = mergePolicy.findMerges(trigger, segmentInfos, this); + spec = mergePolicy.findMerges(trigger, segmentInfos, cachingMergeContext); } } if (spec != null) { diff --git a/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java b/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java index 0c64f4c2c9ac..f1e543423670 100644 --- a/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java +++ b/lucene/core/src/java/org/apache/lucene/index/QueryTimeout.java @@ -17,14 +17,17 @@ package org.apache.lucene.index; /** - * Base for query timeout implementations, which will provide a {@code shouldExit()} method, used - * with {@link ExitableDirectoryReader}. + * Query timeout abstraction that controls whether a query should continue or be stopped. Can be set + * to the searcher through {@link org.apache.lucene.search.IndexSearcher#setTimeout(QueryTimeout)}, + * in which case bulk scoring will be time-bound. Can also be used in combination with {@link + * ExitableDirectoryReader}. */ public interface QueryTimeout { /** - * Called from {@link ExitableDirectoryReader.ExitableTermsEnum#next()} to determine whether to - * stop processing a query. + * Called to determine whether to stop processing a query + * + * @return true if the query should stop, false otherwise */ - public abstract boolean shouldExit(); + boolean shouldExit(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java index 3646cf65584b..8a515cb79fc9 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java @@ -90,8 +90,8 @@ public float compare(byte[] v1, byte[] v2) { /** * Calculates a similarity score between the two vectors with a specified function. Higher - * similarity scores correspond to closer vectors. The offsets and lengths of the BytesRefs - * determine the vector data that is compared. Each (signed) byte represents a vector dimension. + * similarity scores correspond to closer vectors. Each (signed) byte represents a vector + * dimension. * * @param v1 a vector * @param v2 another vector, of the same dimension diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index eaa7cc1e8337..072846785723 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -19,9 +19,13 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.List; import java.util.Objects; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.FutureTask; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexReader; @@ -29,6 +33,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.ThreadInterruptedException; /** * Uses {@link KnnVectorsReader#search} to perform nearest neighbour search. @@ -51,7 +56,7 @@ abstract class AbstractKnnVectorQuery extends Query { private final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { - this.field = field; + this.field = Objects.requireNonNull(field, "field"); this.k = k; if (k < 1) { throw new IllegalArgumentException("k must be at least 1, got: " + k); @@ -60,12 +65,11 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()]; + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + IndexReader reader = indexSearcher.getIndexReader(); - Weight filterWeight = null; + final Weight filterWeight; if (filter != null) { - IndexSearcher indexSearcher = new IndexSearcher(reader); BooleanQuery booleanQuery = new BooleanQuery.Builder() .add(filter, BooleanClause.Occur.FILTER) @@ -73,17 +77,17 @@ public Query rewrite(IndexReader reader) throws IOException { .build(); Query rewritten = indexSearcher.rewrite(booleanQuery); filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); + } else { + filterWeight = null; } - for (LeafReaderContext ctx : reader.leaves()) { - TopDocs results = searchLeaf(ctx, filterWeight); - if (ctx.docBase > 0) { - for (ScoreDoc scoreDoc : results.scoreDocs) { - scoreDoc.doc += ctx.docBase; - } - } - perLeafResults[ctx.ord] = results; - } + SliceExecutor sliceExecutor = indexSearcher.getSliceExecutor(); + // in case of parallel execution, the leaf results are not ordered by leaf context's ordinal + TopDocs[] perLeafResults = + (sliceExecutor == null) + ? sequentialSearch(reader.leaves(), filterWeight) + : parallelSearch(indexSearcher.getSlices(), filterWeight, sliceExecutor); + // Merge sort the results TopDocs topK = TopDocs.merge(k, perLeafResults); if (topK.scoreDocs.length == 0) { @@ -92,7 +96,67 @@ public Query rewrite(IndexReader reader) throws IOException { return createRewrittenQuery(reader, topK); } + private TopDocs[] sequentialSearch( + List leafReaderContexts, Weight filterWeight) { + try { + TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()]; + for (LeafReaderContext ctx : leafReaderContexts) { + perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight); + } + return perLeafResults; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private TopDocs[] parallelSearch( + IndexSearcher.LeafSlice[] slices, Weight filterWeight, SliceExecutor sliceExecutor) { + + List> tasks = new ArrayList<>(slices.length); + int segmentsCount = 0; + for (IndexSearcher.LeafSlice slice : slices) { + segmentsCount += slice.leaves.length; + tasks.add( + new FutureTask<>( + () -> { + TopDocs[] results = new TopDocs[slice.leaves.length]; + int i = 0; + for (LeafReaderContext context : slice.leaves) { + results[i++] = searchLeaf(context, filterWeight); + } + return results; + })); + } + + sliceExecutor.invokeAll(tasks); + + TopDocs[] topDocs = new TopDocs[segmentsCount]; + int i = 0; + for (FutureTask task : tasks) { + try { + for (TopDocs docs : task.get()) { + topDocs[i++] = docs; + } + } catch (ExecutionException e) { + throw new RuntimeException(e.getCause()); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + } + return topDocs; + } + private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException { + TopDocs results = getLeafResults(ctx, filterWeight); + if (ctx.docBase > 0) { + for (ScoreDoc scoreDoc : results.scoreDocs) { + scoreDoc.doc += ctx.docBase; + } + } + return results; + } + + private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException { Bits liveDocs = ctx.reader().getLiveDocs(); int maxDoc = ctx.reader().maxDoc(); diff --git a/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java b/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java index 8b73eb5b27a9..2c7e41ac9719 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; @@ -268,11 +267,12 @@ public String toString(String field) { } @Override - public final Query rewrite(IndexReader reader) throws IOException { + public final Query rewrite(IndexSearcher indexSearcher) throws IOException { final TermStates[] contexts = ArrayUtil.copyOfSubArray(this.contexts, 0, this.contexts.length); for (int i = 0; i < contexts.length; ++i) { - if (contexts[i] == null || contexts[i].wasBuiltFor(reader.getContext()) == false) { - contexts[i] = TermStates.build(reader.getContext(), terms[i], true); + if (contexts[i] == null + || contexts[i].wasBuiltFor(indexSearcher.getTopReaderContext()) == false) { + contexts[i] = TermStates.build(indexSearcher.getTopReaderContext(), terms[i], true); } } @@ -287,7 +287,7 @@ public final Query rewrite(IndexReader reader) throws IOException { } for (int i = 0; i < contexts.length; ++i) { - contexts[i] = adjustFrequencies(reader.getContext(), contexts[i], df, ttf); + contexts[i] = adjustFrequencies(indexSearcher.getTopReaderContext(), contexts[i], df, ttf); } Query[] termQueries = new Query[terms.length]; diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java index f5c69621b15c..07823ae5c4d7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java @@ -30,7 +30,6 @@ import java.util.Objects; import java.util.Set; import java.util.function.Predicate; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause.Occur; /** @@ -247,7 +246,7 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (clauses.size() == 0) { return new MatchNoDocsQuery("empty BooleanQuery"); } @@ -286,12 +285,12 @@ public Query rewrite(IndexReader reader) throws IOException { Query rewritten; if (occur == Occur.FILTER || occur == Occur.MUST_NOT) { // Clauses that are not involved in scoring can get some extra simplifications - rewritten = new ConstantScoreQuery(query).rewrite(reader); + rewritten = new ConstantScoreQuery(query).rewrite(indexSearcher); if (rewritten instanceof ConstantScoreQuery) { rewritten = ((ConstantScoreQuery) rewritten).getQuery(); } } else { - rewritten = query.rewrite(reader); + rewritten = query.rewrite(indexSearcher); } if (rewritten != query || query.getClass() == MatchNoDocsQuery.class) { // rewrite clause @@ -566,7 +565,7 @@ public Query rewrite(IndexReader reader) throws IOException { } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java b/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java index 47c0f5f6f285..375269637000 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.IndexReader; /** * A {@link Query} wrapper that allows to give a boost to the wrapped query. Boost values that are @@ -73,8 +72,8 @@ public int hashCode() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query rewritten = query.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query rewritten = query.rewrite(indexSearcher); if (boost == 1f) { return rewritten; @@ -99,7 +98,7 @@ public Query rewrite(IndexReader reader) throws IOException { return new BoostQuery(rewritten, boost); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java b/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java index 2a7e04447481..89f14fff20bc 100644 --- a/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java +++ b/lucene/core/src/java/org/apache/lucene/search/CollectionTerminatedException.java @@ -31,4 +31,10 @@ public final class CollectionTerminatedException extends RuntimeException { public CollectionTerminatedException() { super(); } + + @Override + public Throwable fillInStackTrace() { + // never re-thrown so we can save the expensive stacktrace + return this; + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java b/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java index b70224f1ec50..03091b748b1c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConjunctionDISI.java @@ -20,7 +20,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.Comparator; import java.util.List; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BitSet; @@ -99,20 +98,22 @@ static DocIdSetIterator createConjunction( allIterators.size() > 0 ? allIterators.get(0).docID() : twoPhaseIterators.get(0).approximation.docID(); - boolean iteratorsOnTheSameDoc = allIterators.stream().allMatch(it -> it.docID() == curDoc); - iteratorsOnTheSameDoc = - iteratorsOnTheSameDoc - && twoPhaseIterators.stream().allMatch(it -> it.approximation().docID() == curDoc); - if (iteratorsOnTheSameDoc == false) { - throw new IllegalArgumentException( - "Sub-iterators of ConjunctionDISI are not on the same document!"); + long minCost = Long.MAX_VALUE; + for (DocIdSetIterator allIterator : allIterators) { + if (allIterator.docID() != curDoc) { + throwSubIteratorsNotOnSameDocument(); + } + minCost = Math.min(allIterator.cost(), minCost); + } + for (TwoPhaseIterator it : twoPhaseIterators) { + if (it.approximation().docID() != curDoc) { + throwSubIteratorsNotOnSameDocument(); + } } - - long minCost = allIterators.stream().mapToLong(DocIdSetIterator::cost).min().getAsLong(); List bitSetIterators = new ArrayList<>(); List iterators = new ArrayList<>(); for (DocIdSetIterator iterator : allIterators) { - if (iterator.cost() > minCost && iterator instanceof BitSetIterator) { + if (iterator instanceof BitSetIterator && iterator.cost() > minCost) { // we put all bitset iterators into bitSetIterators // except if they have the minimum cost, since we need // them to lead the iteration in that case @@ -142,6 +143,11 @@ static DocIdSetIterator createConjunction( return disi; } + private static void throwSubIteratorsNotOnSameDocument() { + throw new IllegalArgumentException( + "Sub-iterators of ConjunctionDISI are not on the same document!"); + } + final DocIdSetIterator lead1, lead2; final DocIdSetIterator[] others; @@ -150,14 +156,7 @@ private ConjunctionDISI(List iterators) { // Sort the array the first time to allow the least frequent DocsEnum to // lead the matching. - CollectionUtil.timSort( - iterators, - new Comparator() { - @Override - public int compare(DocIdSetIterator o1, DocIdSetIterator o2) { - return Long.compare(o1.cost(), o2.cost()); - } - }); + CollectionUtil.timSort(iterators, (o1, o2) -> Long.compare(o1.cost(), o2.cost())); lead1 = iterators.get(0); lead2 = iterators.get(1); others = iterators.subList(2, iterators.size()).toArray(new DocIdSetIterator[0]); @@ -326,13 +325,7 @@ private ConjunctionTwoPhaseIterator( assert twoPhaseIterators.size() > 0; CollectionUtil.timSort( - twoPhaseIterators, - new Comparator() { - @Override - public int compare(TwoPhaseIterator o1, TwoPhaseIterator o2) { - return Float.compare(o1.matchCost(), o2.matchCost()); - } - }); + twoPhaseIterators, (o1, o2) -> Float.compare(o1.matchCost(), o2.matchCost())); this.twoPhaseIterators = twoPhaseIterators.toArray(new TwoPhaseIterator[twoPhaseIterators.size()]); diff --git a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java index 2225cc109444..48f763e16460 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.Bits; @@ -40,8 +39,9 @@ public Query getQuery() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query rewritten = query.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + + Query rewritten = query.rewrite(indexSearcher); // Do some extra simplifications that are legal since scores are not needed on the wrapped // query. @@ -70,7 +70,7 @@ public Query rewrite(IndexReader reader) throws IOException { return new ConstantScoreQuery(((BoostQuery) rewritten).getQuery()); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java index 8c41c0a37cf0..1ab4f3b50818 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java @@ -23,7 +23,6 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; /** @@ -209,11 +208,10 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo /** * Optimize our representation and our subqueries representations * - * @param reader the IndexReader we query * @return an optimized copy of us (which may not be a copy if there is nothing to optimize) */ @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (disjuncts.isEmpty()) { return new MatchNoDocsQuery("empty DisjunctionMaxQuery"); } @@ -233,7 +231,7 @@ public Query rewrite(IndexReader reader) throws IOException { boolean actuallyRewritten = false; List rewrittenDisjuncts = new ArrayList<>(); for (Query sub : disjuncts) { - Query rewrittenSub = sub.rewrite(reader); + Query rewrittenSub = sub.rewrite(indexSearcher); actuallyRewritten |= rewrittenSub != sub; rewrittenDisjuncts.add(rewrittenSub); } @@ -242,7 +240,7 @@ public Query rewrite(IndexReader reader) throws IOException { return new DisjunctionMaxQuery(rewrittenDisjuncts, tieBreakerMultiplier); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java index 303dc516af8a..f27b791dd95a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java +++ b/lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java @@ -22,7 +22,6 @@ import java.util.function.DoubleToLongFunction; import java.util.function.LongToDoubleFunction; import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.search.comparators.DoubleComparator; @@ -85,7 +84,7 @@ public Explanation explain(LeafReaderContext ctx, int docId, Explanation scoreEx * *

Queries that use DoubleValuesSource objects should call rewrite() during {@link * Query#createWeight(IndexSearcher, ScoreMode, float)} rather than during {@link - * Query#rewrite(IndexReader)} to avoid IndexReader reference leakage. + * Query#rewrite(IndexSearcher)} to avoid IndexReader reference leakage. * *

For the same reason, implementations that cache references to the IndexSearcher should * return a new object from this method. diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java index 98c65b9ddd66..3d78df8e07c3 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java @@ -108,7 +108,8 @@ public int hashCode() { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + IndexReader reader = indexSearcher.getIndexReader(); boolean allReadersRewritable = true; for (LeafReaderContext context : reader.leaves()) { @@ -172,7 +173,7 @@ public Query rewrite(IndexReader reader) throws IOException { if (allReadersRewritable) { return new MatchAllDocsQuery(); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java b/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java index 9ba52eb674f4..599608f0d842 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import org.apache.lucene.document.LongPoint; import org.apache.lucene.document.SortedNumericDocValuesField; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; /** @@ -101,9 +100,9 @@ public int hashCode() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query indexRewrite = indexQuery.rewrite(reader); - Query dvRewrite = dvQuery.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query indexRewrite = indexQuery.rewrite(indexSearcher); + Query dvRewrite = dvQuery.rewrite(indexSearcher); if (indexRewrite.getClass() == MatchAllDocsQuery.class || dvRewrite.getClass() == MatchAllDocsQuery.class) { return new MatchAllDocsQuery(); diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java index 5798ea5d2cff..f167b7161123 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java @@ -796,9 +796,9 @@ protected void search(List leaves, Weight weight, Collector c */ public Query rewrite(Query original) throws IOException { Query query = original; - for (Query rewrittenQuery = query.rewrite(reader); + for (Query rewrittenQuery = query.rewrite(this); rewrittenQuery != query; - rewrittenQuery = query.rewrite(reader)) { + rewrittenQuery = query.rewrite(this)) { query = rewrittenQuery; } query.visit(getNumClausesCheckVisitor()); @@ -998,6 +998,10 @@ public Executor getExecutor() { return executor; } + SliceExecutor getSliceExecutor() { + return sliceExecutor; + } + /** * Thrown when an attempt is made to add more than {@link #getMaxClauseCount()} clauses. This * typically happens if a PrefixQuery, FuzzyQuery, WildcardQuery, or TermRangeQuery is expanded to diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexSortSortedNumericDocValuesRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/IndexSortSortedNumericDocValuesRangeQuery.java index eb2cf562994c..b678d34192ce 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexSortSortedNumericDocValuesRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexSortSortedNumericDocValuesRangeQuery.java @@ -23,7 +23,6 @@ import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; @@ -133,12 +132,12 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (lowerValue == Long.MIN_VALUE && upperValue == Long.MAX_VALUE) { return new FieldExistsQuery(field); } - Query rewrittenFallback = fallbackQuery.rewrite(reader); + Query rewrittenFallback = fallbackQuery.rewrite(indexSearcher); if (rewrittenFallback.getClass() == MatchAllDocsQuery.class) { return new MatchAllDocsQuery(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 4ec617c24470..10345cd7adf4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -71,7 +71,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k) { */ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { super(field, k, filter); - this.target = target; + this.target = Objects.requireNonNull(target, "target"); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 2b1b3a69582e..3036e7c45162 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Objects; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FieldInfo; @@ -25,6 +26,7 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; /** * Uses {@link KnnVectorsReader#search(String, float[], int, Bits, int)} to perform nearest @@ -70,7 +72,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) { */ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { super(field, k, filter); - this.target = target; + this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java index 1bf34569c395..27819235f644 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -185,7 +184,7 @@ public int[] getPositions() { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (termArrays.length == 0) { return new MatchNoDocsQuery("empty MultiPhraseQuery"); } else if (termArrays.length == 1) { // optimize one-term case @@ -196,7 +195,7 @@ public Query rewrite(IndexReader reader) throws IOException { } return builder.build(); } else { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java index 0dc6671dcccb..44bd4279aba2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQuery.java @@ -18,9 +18,9 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.FilteredTermsEnum; // javadocs +import org.apache.lucene.index.FilteredTermsEnum; import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.SingleTermsEnum; // javadocs +import org.apache.lucene.index.SingleTermsEnum; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; import org.apache.lucene.index.Terms; @@ -321,8 +321,8 @@ public long getTermsCount() throws IOException { * AttributeSource)}. For example, to rewrite to a single term, return a {@link SingleTermsEnum} */ @Override - public final Query rewrite(IndexReader reader) throws IOException { - return rewriteMethod.rewrite(reader, this); + public final Query rewrite(IndexSearcher indexSearcher) throws IOException { + return rewriteMethod.rewrite(indexSearcher.getIndexReader(), this); } public RewriteMethod getRewriteMethod() { diff --git a/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java index 41b0decba7c7..e5023bf032d6 100644 --- a/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java @@ -18,14 +18,13 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; /** * This is a {@link PhraseQuery} which is optimized for n-gram phrase query. For example, when you * query "ABCD" on a 2-gram field, you may want to use NGramPhraseQuery rather than {@link - * PhraseQuery}, because NGramPhraseQuery will {@link #rewrite(IndexReader)} the query to "AB/0 - * CD/2", while {@link PhraseQuery} will query "AB/0 BC/1 CD/2" (where term/position). + * PhraseQuery}, because NGramPhraseQuery will {@link Query#rewrite(IndexSearcher)} the query to + * "AB/0 CD/2", while {@link PhraseQuery} will query "AB/0 BC/1 CD/2" (where term/position). */ public class NGramPhraseQuery extends Query { @@ -44,7 +43,7 @@ public NGramPhraseQuery(int n, PhraseQuery query) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { final Term[] terms = phraseQuery.getTerms(); final int[] positions = phraseQuery.getPositions(); @@ -63,7 +62,7 @@ public Query rewrite(IndexReader reader) throws IOException { } if (isOptimizable == false) { - return phraseQuery.rewrite(reader); + return phraseQuery.rewrite(indexSearcher); } PhraseQuery.Builder builder = new PhraseQuery.Builder(); diff --git a/lucene/core/src/java/org/apache/lucene/search/NamedMatches.java b/lucene/core/src/java/org/apache/lucene/search/NamedMatches.java index 9a24f9433bae..d0ec5c3a2124 100644 --- a/lucene/core/src/java/org/apache/lucene/search/NamedMatches.java +++ b/lucene/core/src/java/org/apache/lucene/search/NamedMatches.java @@ -25,7 +25,6 @@ import java.util.LinkedList; import java.util.List; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; /** @@ -113,8 +112,8 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query rewritten = in.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query rewritten = in.rewrite(indexSearcher); if (rewritten != in) { return new NamedQuery(name, rewritten); } diff --git a/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java index 93a2ace64531..643861651367 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java @@ -24,7 +24,6 @@ import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat; import org.apache.lucene.codecs.lucene90.Lucene90PostingsReader; import org.apache.lucene.index.ImpactsEnum; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -284,7 +283,7 @@ public int[] getPositions() { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (terms.length == 0) { return new MatchNoDocsQuery("empty PhraseQuery"); } else if (terms.length == 1) { @@ -296,7 +295,7 @@ public Query rewrite(IndexReader reader) throws IOException { } return new PhraseQuery(slop, terms, newPositions); } else { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/Query.java b/lucene/core/src/java/org/apache/lucene/search/Query.java index 4f04728395d9..c872c076b76f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/Query.java +++ b/lucene/core/src/java/org/apache/lucene/search/Query.java @@ -17,7 +17,10 @@ package org.apache.lucene.search; import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.util.VirtualMethod; /** * The abstract base class for queries. @@ -41,10 +44,21 @@ * * *

See also additional queries available in the Queries module + * href="{@docRoot}/../queries/overview-summary.html">Queries module. */ public abstract class Query { + private static final VirtualMethod oldMethod = + new VirtualMethod<>(Query.class, "rewrite", IndexReader.class); + private static final VirtualMethod newMethod = + new VirtualMethod<>(Query.class, "rewrite", IndexSearcher.class); + private final boolean isDeprecatedRewriteMethodOverridden = + AccessController.doPrivileged( + (PrivilegedAction) + () -> + VirtualMethod.compareImplementationDistance(this.getClass(), oldMethod, newMethod) + > 0); + /** * Prints a query to a string, with field assumed to be the default field and * omitted. @@ -77,14 +91,36 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo *

Callers are expected to call rewrite multiple times if necessary, until the * rewritten query is the same as the original query. * + * @deprecated Use {@link Query#rewrite(IndexSearcher)} * @see IndexSearcher#rewrite(Query) */ + @Deprecated public Query rewrite(IndexReader reader) throws IOException { - return this; + return isDeprecatedRewriteMethodOverridden ? this : rewrite(new IndexSearcher(reader)); + } + + /** + * Expert: called to re-write queries into primitive queries. For example, a PrefixQuery will be + * rewritten into a BooleanQuery that consists of TermQuerys. + * + *

Callers are expected to call rewrite multiple times if necessary, until the + * rewritten query is the same as the original query. + * + *

The rewrite process may be able to make use of IndexSearcher's executor and be executed in + * parallel if the executor is provided. + * + *

However, if any of the intermediary queries do not satisfy the new API, parallel rewrite is + * not possible for any subsequent sub-queries. To take advantage of this API, the entire query + * tree must override this method. + * + * @see IndexSearcher#rewrite(Query) + */ + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return isDeprecatedRewriteMethodOverridden ? rewrite(indexSearcher.getIndexReader()) : this; } /** - * Recurse through the query tree, visiting any child queries + * Recurse through the query tree, visiting any child queries. * * @param visitor a QueryVisitor to be called by each query in the tree */ @@ -95,8 +131,8 @@ public Query rewrite(IndexReader reader) throws IOException { * that {@link QueryCache} works properly. * *

Typically a query will be equal to another only if it's an instance of the same class and - * its document-filtering properties are identical that other instance. Utility methods are - * provided for certain repetitive code. + * its document-filtering properties are identical to those of the other instance. Utility methods + * are provided for certain repetitive code. * * @see #sameClassAs(Object) * @see #classHash() @@ -119,7 +155,7 @@ public Query rewrite(IndexReader reader) throws IOException { * *

When this method is used in an implementation of {@link #equals(Object)}, consider using * {@link #classHash()} in the implementation of {@link #hashCode} to differentiate different - * class + * class. */ protected final boolean sameClassAs(Object other) { return other != null && getClass() == other.getClass(); diff --git a/lucene/core/src/java/org/apache/lucene/search/QueueSizeBasedExecutor.java b/lucene/core/src/java/org/apache/lucene/search/QueueSizeBasedExecutor.java index a76b81f5da19..65ba1ea5573d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/QueueSizeBasedExecutor.java +++ b/lucene/core/src/java/org/apache/lucene/search/QueueSizeBasedExecutor.java @@ -17,7 +17,6 @@ package org.apache.lucene.search; -import java.util.Collection; import java.util.concurrent.ThreadPoolExecutor; /** @@ -30,31 +29,15 @@ class QueueSizeBasedExecutor extends SliceExecutor { private final ThreadPoolExecutor threadPoolExecutor; - public QueueSizeBasedExecutor(ThreadPoolExecutor threadPoolExecutor) { + QueueSizeBasedExecutor(ThreadPoolExecutor threadPoolExecutor) { super(threadPoolExecutor); this.threadPoolExecutor = threadPoolExecutor; } @Override - public void invokeAll(Collection tasks) { - int i = 0; - - for (Runnable task : tasks) { - boolean shouldExecuteOnCallerThread = false; - - // Execute last task on caller thread - if (i == tasks.size() - 1) { - shouldExecuteOnCallerThread = true; - } - - if (threadPoolExecutor.getQueue().size() - >= (threadPoolExecutor.getMaximumPoolSize() * LIMITING_FACTOR)) { - shouldExecuteOnCallerThread = true; - } - - processTask(task, shouldExecuteOnCallerThread); - - ++i; - } + boolean shouldExecuteOnCallerThread(int index, int numTasks) { + return super.shouldExecuteOnCallerThread(index, numTasks) + || threadPoolExecutor.getQueue().size() + >= (threadPoolExecutor.getMaximumPoolSize() * LIMITING_FACTOR); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java b/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java index debc3efae154..5873b83075b6 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java +++ b/lucene/core/src/java/org/apache/lucene/search/ScoringRewrite.java @@ -101,8 +101,7 @@ public Query rewrite(IndexReader reader, MultiTermQuery query) throws IOExceptio protected abstract void checkMaxClauseCount(int count) throws IOException; @Override - public final Query rewrite(final IndexReader reader, final MultiTermQuery query) - throws IOException { + public final Query rewrite(IndexReader reader, final MultiTermQuery query) throws IOException { final B builder = getTopLevelBuilder(); final ParallelArraysTermCollector col = new ParallelArraysTermCollector(); collectTerms(reader, query, col); diff --git a/lucene/core/src/java/org/apache/lucene/search/SliceExecutor.java b/lucene/core/src/java/org/apache/lucene/search/SliceExecutor.java index c84beeb5fb78..0e593740914d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SliceExecutor.java +++ b/lucene/core/src/java/org/apache/lucene/search/SliceExecutor.java @@ -18,6 +18,7 @@ package org.apache.lucene.search; import java.util.Collection; +import java.util.Objects; import java.util.concurrent.Executor; import java.util.concurrent.RejectedExecutionException; @@ -28,54 +29,30 @@ class SliceExecutor { private final Executor executor; - public SliceExecutor(Executor executor) { - this.executor = executor; + SliceExecutor(Executor executor) { + this.executor = Objects.requireNonNull(executor, "Executor is null"); } - public void invokeAll(Collection tasks) { - - if (tasks == null) { - throw new IllegalArgumentException("Tasks is null"); - } - - if (executor == null) { - throw new IllegalArgumentException("Executor is null"); - } - + final void invokeAll(Collection tasks) { int i = 0; - for (Runnable task : tasks) { - boolean shouldExecuteOnCallerThread = false; - - // Execute last task on caller thread - if (i == tasks.size() - 1) { - shouldExecuteOnCallerThread = true; + if (shouldExecuteOnCallerThread(i, tasks.size())) { + task.run(); + } else { + try { + executor.execute(task); + } catch ( + @SuppressWarnings("unused") + RejectedExecutionException e) { + task.run(); + } } - - processTask(task, shouldExecuteOnCallerThread); ++i; } - ; } - // Helper method to execute a single task - protected void processTask(final Runnable task, final boolean shouldExecuteOnCallerThread) { - if (task == null) { - throw new IllegalArgumentException("Input is null"); - } - - if (!shouldExecuteOnCallerThread) { - try { - executor.execute(task); - - return; - } catch ( - @SuppressWarnings("unused") - RejectedExecutionException e) { - // Execute on caller thread - } - } - - task.run(); + boolean shouldExecuteOnCallerThread(int index, int numTasks) { + // Execute last task on caller thread + return index == numTasks - 1; } } diff --git a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java index 8131c879d3e9..7e314870f9c7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java @@ -29,7 +29,6 @@ import org.apache.lucene.index.Impacts; import org.apache.lucene.index.ImpactsEnum; import org.apache.lucene.index.ImpactsSource; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.SlowImpactsEnum; @@ -114,7 +113,7 @@ public SynonymQuery build() { */ private SynonymQuery(TermAndBoost[] terms, String field) { this.terms = Objects.requireNonNull(terms); - this.field = field; + this.field = Objects.requireNonNull(field); } public List getTerms() { @@ -147,11 +146,13 @@ public int hashCode() { @Override public boolean equals(Object other) { - return sameClassAs(other) && Arrays.equals(terms, ((SynonymQuery) other).terms); + return sameClassAs(other) + && field.equals(((SynonymQuery) other).field) + && Arrays.equals(terms, ((SynonymQuery) other).terms); } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { // optimize zero and non-boosted single term cases if (terms.length == 0) { return new BooleanQuery.Builder().build(); diff --git a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java index 2e0a297c7463..ed268751bba4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java @@ -17,11 +17,10 @@ package org.apache.lucene.search; import java.io.IOException; -import java.util.ArrayList; +import java.io.UncheckedIOException; import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.List; import java.util.SortedSet; import org.apache.lucene.index.FilteredTermsEnum; import org.apache.lucene.index.PrefixCodedTerms; @@ -38,7 +37,6 @@ import org.apache.lucene.util.automaton.Automata; import org.apache.lucene.util.automaton.Automaton; import org.apache.lucene.util.automaton.ByteRunAutomaton; -import org.apache.lucene.util.automaton.CompiledAutomaton; import org.apache.lucene.util.automaton.Operations; /** @@ -150,13 +148,17 @@ public void visit(QueryVisitor visitor) { } } + // TODO: This is pretty heavy-weight. If we have TermInSetQuery directly extend AutomatonQuery + // we won't have to do this (see GH#12176). private ByteRunAutomaton asByteRunAutomaton() { - TermIterator iterator = termData.iterator(); - List automata = new ArrayList<>(); - for (BytesRef term = iterator.next(); term != null; term = iterator.next()) { - automata.add(Automata.makeBinary(term)); + try { + Automaton a = Automata.makeBinaryStringUnion(termData.iterator()); + return new ByteRunAutomaton(a, true, Operations.DEFAULT_DETERMINIZE_WORK_LIMIT); + } catch (IOException e) { + // Shouldn't happen since termData.iterator() provides an interator implementation that + // never throws: + throw new UncheckedIOException(e); } - return new CompiledAutomaton(Operations.union(automata)).runAutomaton; } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingBulkScorer.java index 5e33884ea4ff..517f0a0e77b2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingBulkScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingBulkScorer.java @@ -41,6 +41,12 @@ static class TimeExceededException extends RuntimeException { private TimeExceededException() { super("TimeLimit Exceeded"); } + + @Override + public Throwable fillInStackTrace() { + // never re-thrown so we can save the expensive stacktrace + return this; + } } private final BulkScorer in; diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java index 4a208a3f0c5f..c50f3a372e97 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingCollector.java @@ -33,9 +33,9 @@ public class TimeLimitingCollector implements Collector { /** Thrown when elapsed search time exceeds allowed search time. */ @SuppressWarnings("serial") public static class TimeExceededException extends RuntimeException { - private long timeAllowed; - private long timeElapsed; - private int lastDocCollected; + private final long timeAllowed; + private final long timeElapsed; + private final int lastDocCollected; private TimeExceededException(long timeAllowed, long timeElapsed, int lastDocCollected) { super( diff --git a/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java b/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java index 067867ca7348..b5c52ba7cedf 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java +++ b/lucene/core/src/java/org/apache/lucene/search/TopTermsRewrite.java @@ -61,8 +61,7 @@ public int getSize() { protected abstract int getMaxSize(); @Override - public final Query rewrite(final IndexReader reader, final MultiTermQuery query) - throws IOException { + public final Query rewrite(IndexReader reader, final MultiTermQuery query) throws IOException { final int maxSize = Math.min(size, getMaxSize()); final PriorityQueue stQueue = new PriorityQueue<>(); collectTerms( diff --git a/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java b/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java index 0ac859d40f14..ea75530fb2b4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java +++ b/lucene/core/src/java/org/apache/lucene/search/comparators/NumericComparator.java @@ -91,8 +91,8 @@ public abstract class NumericLeafComparator implements LeafFieldComparator { // if skipping functionality should be enabled on this segment private final boolean enableSkipping; private final int maxDoc; - private final byte[] minValueAsBytes; - private final byte[] maxValueAsBytes; + private byte[] minValueAsBytes; + private byte[] maxValueAsBytes; private DocIdSetIterator competitiveIterator; private long iteratorCost = -1; @@ -128,16 +128,10 @@ public NumericLeafComparator(LeafReaderContext context) throws IOException { } this.enableSkipping = true; // skipping is enabled when points are available this.maxDoc = context.reader().maxDoc(); - this.maxValueAsBytes = - reverse == false ? new byte[bytesCount] : topValueSet ? new byte[bytesCount] : null; - this.minValueAsBytes = - reverse ? new byte[bytesCount] : topValueSet ? new byte[bytesCount] : null; this.competitiveIterator = DocIdSetIterator.all(maxDoc); } else { this.enableSkipping = false; this.maxDoc = 0; - this.maxValueAsBytes = null; - this.minValueAsBytes = null; } } @@ -191,7 +185,9 @@ public void setHitsThresholdReached() throws IOException { // update its iterator to include possibly only docs that are "stronger" than the current bottom // entry private void updateCompetitiveIterator() throws IOException { - if (enableSkipping == false || hitsThresholdReached == false || queueFull == false) return; + if (enableSkipping == false + || hitsThresholdReached == false + || (queueFull == false && topValueSet == false)) return; // if some documents have missing points, check that missing values prohibits optimization if ((pointValues.getDocCount() < maxDoc) && isMissingValueCompetitive()) { return; // we can't filter out documents, as documents with missing values are competitive @@ -204,13 +200,21 @@ private void updateCompetitiveIterator() throws IOException { return; } if (reverse == false) { - encodeBottom(maxValueAsBytes); + if (queueFull) { // bottom is avilable only when queue is full + maxValueAsBytes = maxValueAsBytes == null ? new byte[bytesCount] : maxValueAsBytes; + encodeBottom(maxValueAsBytes); + } if (topValueSet) { + minValueAsBytes = minValueAsBytes == null ? new byte[bytesCount] : minValueAsBytes; encodeTop(minValueAsBytes); } } else { - encodeBottom(minValueAsBytes); + if (queueFull) { // bottom is avilable only when queue is full + minValueAsBytes = minValueAsBytes == null ? new byte[bytesCount] : minValueAsBytes; + encodeBottom(minValueAsBytes); + } if (topValueSet) { + maxValueAsBytes = maxValueAsBytes == null ? new byte[bytesCount] : maxValueAsBytes; encodeTop(maxValueAsBytes); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/package-info.java b/lucene/core/src/java/org/apache/lucene/search/package-info.java index 4f606597d0a3..31a9d4018c8f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/package-info.java +++ b/lucene/core/src/java/org/apache/lucene/search/package-info.java @@ -357,11 +357,11 @@ * each Query implementation must provide an implementation of Weight. See the subsection on * The Weight Interface below for details on implementing the * Weight interface. - *

  • {@link org.apache.lucene.search.Query#rewrite(org.apache.lucene.index.IndexReader) - * rewrite(IndexReader reader)} — Rewrites queries into primitive queries. Primitive - * queries are: {@link org.apache.lucene.search.TermQuery TermQuery}, {@link - * org.apache.lucene.search.BooleanQuery BooleanQuery}, and other queries that - * implement {@link org.apache.lucene.search.Query#createWeight(IndexSearcher,ScoreMode,float) + *
  • {@link org.apache.lucene.search.Query#rewrite(IndexSearcher) rewrite(IndexSearcher + * searcher)} — Rewrites queries into primitive queries. Primitive queries are: {@link + * org.apache.lucene.search.TermQuery TermQuery}, {@link org.apache.lucene.search.BooleanQuery + * BooleanQuery}, and other queries that implement {@link + * org.apache.lucene.search.Query#createWeight(IndexSearcher,ScoreMode,float) * createWeight(IndexSearcher searcher,ScoreMode scoreMode, float boost)} * * diff --git a/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java b/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java index 974ad0c68747..442dbd7af422 100644 --- a/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java @@ -43,7 +43,7 @@ public abstract class BufferedIndexInput extends IndexInput implements RandomAcc /** A buffer size for merges set to {@value #MERGE_BUFFER_SIZE}. */ public static final int MERGE_BUFFER_SIZE = 4096; - private int bufferSize = BUFFER_SIZE; + private final int bufferSize; private ByteBuffer buffer = EMPTY_BYTEBUFFER; @@ -72,7 +72,7 @@ public BufferedIndexInput(String resourceDesc, int bufferSize) { this.bufferSize = bufferSize; } - /** Returns buffer size. @see #setBufferSize */ + /** Returns buffer size */ public final int getBufferSize() { return bufferSize; } @@ -220,55 +220,50 @@ public final long readVLong() throws IOException { } } - @Override - public final byte readByte(long pos) throws IOException { + private long resolvePositionInBuffer(long pos, int width) throws IOException { long index = pos - bufferStart; - if (index < 0 || index >= buffer.limit()) { + if (index >= 0 && index <= buffer.limit() - width) { + return index; + } + if (index < 0) { + // if we're moving backwards, then try and fill up the previous page rather than + // starting again at the current pos, to avoid successive backwards reads reloading + // the same data over and over again. We also check that we can read `width` + // bytes without going over the end of the buffer + bufferStart = Math.max(bufferStart - bufferSize, pos + width - bufferSize); + bufferStart = Math.max(bufferStart, 0); + bufferStart = Math.min(bufferStart, pos); + } else { + // we're moving forwards, reset the buffer to start at pos bufferStart = pos; - buffer.limit(0); // trigger refill() on read - seekInternal(pos); - refill(); - index = 0; } + buffer.limit(0); // trigger refill() on read + seekInternal(bufferStart); + refill(); + return pos - bufferStart; + } + + @Override + public final byte readByte(long pos) throws IOException { + long index = resolvePositionInBuffer(pos, Byte.BYTES); return buffer.get((int) index); } @Override public final short readShort(long pos) throws IOException { - long index = pos - bufferStart; - if (index < 0 || index >= buffer.limit() - 1) { - bufferStart = pos; - buffer.limit(0); // trigger refill() on read - seekInternal(pos); - refill(); - index = 0; - } + long index = resolvePositionInBuffer(pos, Short.BYTES); return buffer.getShort((int) index); } @Override public final int readInt(long pos) throws IOException { - long index = pos - bufferStart; - if (index < 0 || index >= buffer.limit() - 3) { - bufferStart = pos; - buffer.limit(0); // trigger refill() on read - seekInternal(pos); - refill(); - index = 0; - } + long index = resolvePositionInBuffer(pos, Integer.BYTES); return buffer.getInt((int) index); } @Override public final long readLong(long pos) throws IOException { - long index = pos - bufferStart; - if (index < 0 || index >= buffer.limit() - 7) { - bufferStart = pos; - buffer.limit(0); // trigger refill() on read - seekInternal(pos); - refill(); - index = 0; - } + long index = resolvePositionInBuffer(pos, Long.BYTES); return buffer.getLong((int) index); } diff --git a/lucene/core/src/java/org/apache/lucene/store/ByteBufferGuard.java b/lucene/core/src/java/org/apache/lucene/store/ByteBufferGuard.java index 2d75597f9deb..a9e65ffa06c2 100644 --- a/lucene/core/src/java/org/apache/lucene/store/ByteBufferGuard.java +++ b/lucene/core/src/java/org/apache/lucene/store/ByteBufferGuard.java @@ -17,11 +17,11 @@ package org.apache.lucene.store; import java.io.IOException; +import java.lang.invoke.VarHandle; import java.nio.ByteBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; -import java.util.concurrent.atomic.AtomicInteger; /** * A guard that is created for every {@link ByteBufferIndexInput} that tries on best effort to @@ -49,9 +49,6 @@ static interface BufferCleaner { /** Not volatile; see comments on visibility below! */ private boolean invalidated = false; - /** Used as a store-store barrier; see comments below! */ - private final AtomicInteger barrier = new AtomicInteger(); - /** * Creates an instance to be used for a single {@link ByteBufferIndexInput} which must be shared * by all of its clones. @@ -69,10 +66,9 @@ public void invalidateAndUnmap(ByteBuffer... bufs) throws IOException { // the "invalidated" field update visible to other threads. We specifically // don't make "invalidated" field volatile for performance reasons, hoping the // JVM won't optimize away reads of that field and hardware should ensure - // caches are in sync after this call. This isn't entirely "fool-proof" - // (see LUCENE-7409 discussion), but it has been shown to work in practice - // and we count on this behavior. - barrier.lazySet(0); + // caches are in sync after this call. + // For previous implementation (based on `AtomicInteger#lazySet(0)`) see LUCENE-7409. + VarHandle.fullFence(); // we give other threads a bit of time to finish reads on their ByteBuffer...: Thread.yield(); // finally unmap the ByteBuffers: diff --git a/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java b/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java index 5d23fb2f1ae0..30acb7023f03 100644 --- a/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java +++ b/lucene/core/src/java/org/apache/lucene/store/MMapDirectory.java @@ -76,9 +76,9 @@ *
  • {@code permission java.lang.RuntimePermission "accessClassInPackage.sun.misc";} * * - *

    On exactly Java 19 this class will use the modern {@code MemorySegment} API which - * allows to safely unmap (if you discover any problems with this preview API, you can disable it by - * using system property {@link #ENABLE_MEMORY_SEGMENTS_SYSPROP}). + *

    On exactly Java 19 / 20 / 21 this class will use the modern {@code MemorySegment} API + * which allows to safely unmap (if you discover any problems with this preview API, you can disable + * it by using system property {@link #ENABLE_MEMORY_SEGMENTS_SYSPROP}). * *

    NOTE: Accessing this class either directly or indirectly from a thread while it's * interrupted can close the underlying channel immediately if at the same time the thread is @@ -123,7 +123,7 @@ public class MMapDirectory extends FSDirectory { * Default max chunk size: * *

      - *
    • 16 GiBytes for 64 bit Java 19 JVMs + *
    • 16 GiBytes for 64 bit Java 19 / 20 / 21 JVMs *
    • 1 GiBytes for other 64 bit JVMs *
    • 256 MiBytes for 32 bit JVMs *
    @@ -220,9 +220,9 @@ public MMapDirectory(Path path, LockFactory lockFactory, int maxChunkSize) throw * files cannot be mapped. Using a lower chunk size makes the directory implementation a little * bit slower (as the correct chunk may be resolved on lots of seeks) but the chance is higher * that mmap does not fail. On 64 bit Java platforms, this parameter should always be large (like - * 1 GiBytes, or even larger with Java 19), as the address space is big enough. If it is larger, - * fragmentation of address space increases, but number of file handles and mappings is lower for - * huge installations with many open indexes. + * 1 GiBytes, or even larger with recent Java versions), as the address space is big enough. If it + * is larger, fragmentation of address space increases, but number of file handles and mappings is + * lower for huge installations with many open indexes. * *

    Please note: The chunk size is always rounded down to a power of 2. * @@ -417,7 +417,7 @@ private static MMapIndexInputProvider lookupProvider() { } final var lookup = MethodHandles.lookup(); final int runtimeVersion = Runtime.version().feature(); - if (runtimeVersion == 19 || runtimeVersion == 20) { + if (runtimeVersion >= 19 && runtimeVersion <= 21) { try { final var cls = lookup.findClass("org.apache.lucene.store.MemorySegmentIndexInputProvider"); // we use method handles, so we do not need to deal with setAccessible as we have private @@ -437,9 +437,9 @@ private static MMapIndexInputProvider lookupProvider() { throw new LinkageError( "MemorySegmentIndexInputProvider is missing in Lucene JAR file", cnfe); } - } else if (runtimeVersion >= 21) { + } else if (runtimeVersion >= 22) { LOG.warning( - "You are running with Java 21 or later. To make full use of MMapDirectory, please update Apache Lucene."); + "You are running with Java 22 or later. To make full use of MMapDirectory, please update Apache Lucene."); } return new MappedByteBufferIndexInputProvider(); } diff --git a/lucene/core/src/java/org/apache/lucene/util/BitSet.java b/lucene/core/src/java/org/apache/lucene/util/BitSet.java index f8b8ba65a591..c5d84833b28a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/BitSet.java +++ b/lucene/core/src/java/org/apache/lucene/util/BitSet.java @@ -43,6 +43,16 @@ public static BitSet of(DocIdSetIterator it, int maxDoc) throws IOException { return set; } + /** + * Clear all the bits of the set. + * + *

    Depending on the implementation, this may be significantly faster than clear(0, length). + */ + public void clear() { + // default implementation for compatibility + clear(0, length()); + } + /** Set the bit at i. */ public abstract void set(int i); diff --git a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java index 5566bd0c4833..ebf626a777da 100644 --- a/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java +++ b/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java @@ -147,6 +147,11 @@ public FixedBitSet(long[] storedBits, int numBits) { assert verifyGhostBitsClear(); } + @Override + public void clear() { + Arrays.fill(bits, 0L); + } + /** * Checks if the bits past numBits are clear. Some methods rely on this implicit assumption: * search for "Depends on the ghost bits being clear!" diff --git a/lucene/core/src/java/org/apache/lucene/util/SparseFixedBitSet.java b/lucene/core/src/java/org/apache/lucene/util/SparseFixedBitSet.java index 49d61614e86d..b4ebe3cfc59a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/SparseFixedBitSet.java +++ b/lucene/core/src/java/org/apache/lucene/util/SparseFixedBitSet.java @@ -17,6 +17,7 @@ package org.apache.lucene.util; import java.io.IOException; +import java.util.Arrays; import org.apache.lucene.search.DocIdSetIterator; /** @@ -73,6 +74,17 @@ public SparseFixedBitSet(int length) { + RamUsageEstimator.shallowSizeOf(bits); } + @Override + public void clear() { + Arrays.fill(bits, null); + Arrays.fill(indices, 0L); + nonZeroLongCount = 0; + ramBytesUsed = + BASE_RAM_BYTES_USED + + RamUsageEstimator.sizeOf(indices) + + RamUsageEstimator.shallowSizeOf(bits); + } + @Override public int length() { return length; diff --git a/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java new file mode 100644 index 000000000000..1ade19a19803 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/TermAndVector.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +import java.util.Locale; + +/** + * Word2Vec unit composed by a term with the associated vector + * + * @lucene.experimental + */ +public class TermAndVector { + + private final BytesRef term; + private final float[] vector; + + public TermAndVector(BytesRef term, float[] vector) { + this.term = term; + this.vector = vector; + } + + public BytesRef getTerm() { + return this.term; + } + + public float[] getVector() { + return this.vector; + } + + public int size() { + return vector.length; + } + + public void normalizeVector() { + float vectorLength = 0; + for (int i = 0; i < vector.length; i++) { + vectorLength += vector[i] * vector[i]; + } + vectorLength = (float) Math.sqrt(vectorLength); + for (int i = 0; i < vector.length; i++) { + vector[i] /= vectorLength; + } + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(this.term.utf8ToString()); + builder.append(" ["); + if (vector.length > 0) { + for (int i = 0; i < vector.length - 1; i++) { + builder.append(String.format(Locale.ROOT, "%.3f,", vector[i])); + } + builder.append(String.format(Locale.ROOT, "%.3f]", vector[vector.length - 1])); + } + return builder.toString(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/UnicodeUtil.java b/lucene/core/src/java/org/apache/lucene/util/UnicodeUtil.java index a9e9d815eb49..9cbd4910271b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/UnicodeUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/UnicodeUtil.java @@ -477,40 +477,67 @@ public static int UTF8toUTF32(final BytesRef utf8, final int[] ints) { int utf8Upto = utf8.offset; final byte[] bytes = utf8.bytes; final int utf8Limit = utf8.offset + utf8.length; + UTF8CodePoint reuse = null; while (utf8Upto < utf8Limit) { - final int numBytes = utf8CodeLength[bytes[utf8Upto] & 0xFF]; - int v = 0; - switch (numBytes) { - case 1: - ints[utf32Count++] = bytes[utf8Upto++]; - continue; - case 2: - // 5 useful bits - v = bytes[utf8Upto++] & 31; - break; - case 3: - // 4 useful bits - v = bytes[utf8Upto++] & 15; - break; - case 4: - // 3 useful bits - v = bytes[utf8Upto++] & 7; - break; - default: - throw new IllegalArgumentException("invalid utf8"); - } - - // TODO: this may read past utf8's limit. - final int limit = utf8Upto + numBytes - 1; - while (utf8Upto < limit) { - v = v << 6 | bytes[utf8Upto++] & 63; - } - ints[utf32Count++] = v; + reuse = codePointAt(bytes, utf8Upto, reuse); + ints[utf32Count++] = reuse.codePoint; + utf8Upto += reuse.numBytes; } return utf32Count; } + /** + * Computes the codepoint and codepoint length (in bytes) of the specified {@code offset} in the + * provided {@code utf8} byte array, assuming UTF8 encoding. As with other related methods in this + * class, this assumes valid UTF8 input and does not perform full UTF8 + * validation. Passing invalid UTF8 or a position that is not a valid header byte position may + * result in undefined behavior. This makes no attempt to synchronize or validate. + */ + public static UTF8CodePoint codePointAt(byte[] utf8, int pos, UTF8CodePoint reuse) { + if (reuse == null) { + reuse = new UTF8CodePoint(); + } + + int leadByte = utf8[pos] & 0xFF; + int numBytes = utf8CodeLength[leadByte]; + reuse.numBytes = numBytes; + int v; + switch (numBytes) { + case 1: + reuse.codePoint = leadByte; + return reuse; + case 2: + v = leadByte & 31; // 5 useful bits + break; + case 3: + v = leadByte & 15; // 4 useful bits + break; + case 4: + v = leadByte & 7; // 3 useful bits + break; + default: + throw new IllegalArgumentException( + "Invalid UTF8 header byte: 0x" + Integer.toHexString(leadByte)); + } + + // TODO: this may read past utf8's limit. + final int limit = pos + numBytes; + pos++; + while (pos < limit) { + v = v << 6 | utf8[pos++] & 63; + } + reuse.codePoint = v; + + return reuse; + } + + /** Holds a codepoint along with the number of bytes required to represent it in UTF8 */ + public static final class UTF8CodePoint { + public int codePoint; + public int numBytes; + } + /** Shift value for lead surrogate to form a supplementary character. */ private static final int LEAD_SURROGATE_SHIFT_ = 10; /** Mask to retrieve the significant value from a trail surrogate. */ diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 2a08436ec0b0..0921bb75a664 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -20,6 +20,8 @@ /** Utilities for computations with numeric arrays */ public final class VectorUtil { + private static final VectorUtilProvider PROVIDER = VectorUtilProvider.lookup(false); + private VectorUtil() {} /** @@ -31,68 +33,9 @@ public static float dotProduct(float[] a, float[] b) { if (a.length != b.length) { throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - float res = 0f; - /* - * If length of vector is larger than 8, we use unrolled dot product to accelerate the - * calculation. - */ - int i; - for (i = 0; i < a.length % 8; i++) { - res += b[i] * a[i]; - } - if (a.length < 8) { - return res; - } - for (; i + 31 < a.length; i += 32) { - res += - b[i + 0] * a[i + 0] - + b[i + 1] * a[i + 1] - + b[i + 2] * a[i + 2] - + b[i + 3] * a[i + 3] - + b[i + 4] * a[i + 4] - + b[i + 5] * a[i + 5] - + b[i + 6] * a[i + 6] - + b[i + 7] * a[i + 7]; - res += - b[i + 8] * a[i + 8] - + b[i + 9] * a[i + 9] - + b[i + 10] * a[i + 10] - + b[i + 11] * a[i + 11] - + b[i + 12] * a[i + 12] - + b[i + 13] * a[i + 13] - + b[i + 14] * a[i + 14] - + b[i + 15] * a[i + 15]; - res += - b[i + 16] * a[i + 16] - + b[i + 17] * a[i + 17] - + b[i + 18] * a[i + 18] - + b[i + 19] * a[i + 19] - + b[i + 20] * a[i + 20] - + b[i + 21] * a[i + 21] - + b[i + 22] * a[i + 22] - + b[i + 23] * a[i + 23]; - res += - b[i + 24] * a[i + 24] - + b[i + 25] * a[i + 25] - + b[i + 26] * a[i + 26] - + b[i + 27] * a[i + 27] - + b[i + 28] * a[i + 28] - + b[i + 29] * a[i + 29] - + b[i + 30] * a[i + 30] - + b[i + 31] * a[i + 31]; - } - for (; i + 7 < a.length; i += 8) { - res += - b[i + 0] * a[i + 0] - + b[i + 1] * a[i + 1] - + b[i + 2] * a[i + 2] - + b[i + 3] * a[i + 3] - + b[i + 4] * a[i + 4] - + b[i + 5] * a[i + 5] - + b[i + 6] * a[i + 6] - + b[i + 7] * a[i + 7]; - } - return res; + float r = PROVIDER.dotProduct(a, b); + assert Float.isFinite(r); + return r; } /** @@ -100,42 +43,21 @@ public static float dotProduct(float[] a, float[] b) { * * @throws IllegalArgumentException if the vectors' dimensions differ. */ - public static float cosine(float[] v1, float[] v2) { - if (v1.length != v2.length) { - throw new IllegalArgumentException( - "vector dimensions differ: " + v1.length + "!=" + v2.length); - } - - float sum = 0.0f; - float norm1 = 0.0f; - float norm2 = 0.0f; - int dim = v1.length; - - for (int i = 0; i < dim; i++) { - float elem1 = v1[i]; - float elem2 = v2[i]; - sum += elem1 * elem2; - norm1 += elem1 * elem1; - norm2 += elem2 * elem2; + public static float cosine(float[] a, float[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return (float) (sum / Math.sqrt(norm1 * norm2)); + float r = PROVIDER.cosine(a, b); + assert Float.isFinite(r); + return r; } /** Returns the cosine similarity between the two vectors. */ public static float cosine(byte[] a, byte[] b) { - // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. - int sum = 0; - int norm1 = 0; - int norm2 = 0; - - for (int i = 0; i < a.length; i++) { - byte elem1 = a[i]; - byte elem2 = b[i]; - sum += elem1 * elem2; - norm1 += elem1 * elem1; - norm2 += elem2 * elem2; + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); + return PROVIDER.cosine(a, b); } /** @@ -143,52 +65,21 @@ public static float cosine(byte[] a, byte[] b) { * * @throws IllegalArgumentException if the vectors' dimensions differ. */ - public static float squareDistance(float[] v1, float[] v2) { - if (v1.length != v2.length) { - throw new IllegalArgumentException( - "vector dimensions differ: " + v1.length + "!=" + v2.length); - } - float squareSum = 0.0f; - int dim = v1.length; - int i; - for (i = 0; i + 8 <= dim; i += 8) { - squareSum += squareDistanceUnrolled(v1, v2, i); - } - for (; i < dim; i++) { - float diff = v1[i] - v2[i]; - squareSum += diff * diff; + public static float squareDistance(float[] a, float[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return squareSum; - } - - private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { - float diff0 = v1[index + 0] - v2[index + 0]; - float diff1 = v1[index + 1] - v2[index + 1]; - float diff2 = v1[index + 2] - v2[index + 2]; - float diff3 = v1[index + 3] - v2[index + 3]; - float diff4 = v1[index + 4] - v2[index + 4]; - float diff5 = v1[index + 5] - v2[index + 5]; - float diff6 = v1[index + 6] - v2[index + 6]; - float diff7 = v1[index + 7] - v2[index + 7]; - return diff0 * diff0 - + diff1 * diff1 - + diff2 * diff2 - + diff3 * diff3 - + diff4 * diff4 - + diff5 * diff5 - + diff6 * diff6 - + diff7 * diff7; + float r = PROVIDER.squareDistance(a, b); + assert Float.isFinite(r); + return r; } /** Returns the sum of squared differences of the two vectors. */ public static int squareDistance(byte[] a, byte[] b) { - // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. - int squareSum = 0; - for (int i = 0; i < a.length; i++) { - int diff = a[i] - b[i]; - squareSum += diff * diff; + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return squareSum; + return PROVIDER.squareDistance(a, b); } /** @@ -250,12 +141,10 @@ public static void add(float[] u, float[] v) { * @return the value of the dot product of the two vectors */ public static int dotProduct(byte[] a, byte[] b) { - assert a.length == b.length; - int total = 0; - for (int i = 0; i < a.length; i++) { - total += a[i] * b[i]; + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); } - return total; + return PROVIDER.dotProduct(a, b); } /** @@ -270,4 +159,20 @@ public static float dotProductScore(byte[] a, byte[] b) { float denom = (float) (a.length * (1 << 15)); return 0.5f + dotProduct(a, b) / denom; } + + /** + * Checks if a float vector only has finite components. + * + * @param v bytes containing a vector + * @return the vector for call-chaining + * @throws IllegalArgumentException if any component of vector is not finite + */ + public static float[] checkFinite(float[] v) { + for (int i = 0; i < v.length; i++) { + if (!Float.isFinite(v[i])) { + throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]); + } + } + return v; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java new file mode 100644 index 000000000000..665181e86788 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilDefaultProvider.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +/** The default VectorUtil provider implementation. */ +final class VectorUtilDefaultProvider implements VectorUtilProvider { + + VectorUtilDefaultProvider() {} + + @Override + public float dotProduct(float[] a, float[] b) { + float res = 0f; + /* + * If length of vector is larger than 8, we use unrolled dot product to accelerate the + * calculation. + */ + int i; + for (i = 0; i < a.length % 8; i++) { + res += b[i] * a[i]; + } + if (a.length < 8) { + return res; + } + for (; i + 31 < a.length; i += 32) { + res += + b[i + 0] * a[i + 0] + + b[i + 1] * a[i + 1] + + b[i + 2] * a[i + 2] + + b[i + 3] * a[i + 3] + + b[i + 4] * a[i + 4] + + b[i + 5] * a[i + 5] + + b[i + 6] * a[i + 6] + + b[i + 7] * a[i + 7]; + res += + b[i + 8] * a[i + 8] + + b[i + 9] * a[i + 9] + + b[i + 10] * a[i + 10] + + b[i + 11] * a[i + 11] + + b[i + 12] * a[i + 12] + + b[i + 13] * a[i + 13] + + b[i + 14] * a[i + 14] + + b[i + 15] * a[i + 15]; + res += + b[i + 16] * a[i + 16] + + b[i + 17] * a[i + 17] + + b[i + 18] * a[i + 18] + + b[i + 19] * a[i + 19] + + b[i + 20] * a[i + 20] + + b[i + 21] * a[i + 21] + + b[i + 22] * a[i + 22] + + b[i + 23] * a[i + 23]; + res += + b[i + 24] * a[i + 24] + + b[i + 25] * a[i + 25] + + b[i + 26] * a[i + 26] + + b[i + 27] * a[i + 27] + + b[i + 28] * a[i + 28] + + b[i + 29] * a[i + 29] + + b[i + 30] * a[i + 30] + + b[i + 31] * a[i + 31]; + } + for (; i + 7 < a.length; i += 8) { + res += + b[i + 0] * a[i + 0] + + b[i + 1] * a[i + 1] + + b[i + 2] * a[i + 2] + + b[i + 3] * a[i + 3] + + b[i + 4] * a[i + 4] + + b[i + 5] * a[i + 5] + + b[i + 6] * a[i + 6] + + b[i + 7] * a[i + 7]; + } + return res; + } + + @Override + public float cosine(float[] a, float[] b) { + float sum = 0.0f; + float norm1 = 0.0f; + float norm2 = 0.0f; + int dim = a.length; + + for (int i = 0; i < dim; i++) { + float elem1 = a[i]; + float elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); + } + + @Override + public float squareDistance(float[] a, float[] b) { + float squareSum = 0.0f; + int dim = a.length; + int i; + for (i = 0; i + 8 <= dim; i += 8) { + squareSum += squareDistanceUnrolled(a, b, i); + } + for (; i < dim; i++) { + float diff = a[i] - b[i]; + squareSum += diff * diff; + } + return squareSum; + } + + private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { + float diff0 = v1[index + 0] - v2[index + 0]; + float diff1 = v1[index + 1] - v2[index + 1]; + float diff2 = v1[index + 2] - v2[index + 2]; + float diff3 = v1[index + 3] - v2[index + 3]; + float diff4 = v1[index + 4] - v2[index + 4]; + float diff5 = v1[index + 5] - v2[index + 5]; + float diff6 = v1[index + 6] - v2[index + 6]; + float diff7 = v1[index + 7] - v2[index + 7]; + return diff0 * diff0 + + diff1 * diff1 + + diff2 * diff2 + + diff3 * diff3 + + diff4 * diff4 + + diff5 * diff5 + + diff6 * diff6 + + diff7 * diff7; + } + + @Override + public int dotProduct(byte[] a, byte[] b) { + int total = 0; + for (int i = 0; i < a.length; i++) { + total += a[i] * b[i]; + } + return total; + } + + @Override + public float cosine(byte[] a, byte[] b) { + // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. + int sum = 0; + int norm1 = 0; + int norm2 = 0; + + for (int i = 0; i < a.length; i++) { + byte elem1 = a[i]; + byte elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); + } + + @Override + public int squareDistance(byte[] a, byte[] b) { + // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. + int squareSum = 0; + for (int i = 0; i < a.length; i++) { + int diff = a[i] - b[i]; + squareSum += diff * diff; + } + return squareSum; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java new file mode 100644 index 000000000000..3fd29c2bf349 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtilProvider.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util; + +import java.lang.Runtime.Version; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Locale; +import java.util.Objects; +import java.util.logging.Logger; + +/** A provider of VectorUtil implementations. */ +interface VectorUtilProvider { + + /** Calculates the dot product of the given float arrays. */ + float dotProduct(float[] a, float[] b); + + /** Returns the cosine similarity between the two vectors. */ + float cosine(float[] v1, float[] v2); + + /** Returns the sum of squared differences of the two vectors. */ + float squareDistance(float[] a, float[] b); + + /** Returns the dot product computed over signed bytes. */ + int dotProduct(byte[] a, byte[] b); + + /** Returns the cosine similarity between the two byte vectors. */ + float cosine(byte[] a, byte[] b); + + /** Returns the sum of squared differences of the two byte vectors. */ + int squareDistance(byte[] a, byte[] b); + + // -- provider lookup mechanism + + static final Logger LOG = Logger.getLogger(VectorUtilProvider.class.getName()); + + /** The minimal version of Java that has the bugfix for JDK-8301190. */ + static final Version VERSION_JDK8301190_FIXED = Version.parse("20.0.2"); + + static VectorUtilProvider lookup(boolean testMode) { + final int runtimeVersion = Runtime.version().feature(); + if (runtimeVersion >= 20 && runtimeVersion <= 21) { + // is locale sane (only buggy in Java 20) + if (isAffectedByJDK8301190()) { + LOG.warning( + "Java runtime is using a buggy default locale; Java vector incubator API can't be enabled: " + + Locale.getDefault()); + return new VectorUtilDefaultProvider(); + } + // is the incubator module present and readable (JVM providers may to exclude them or it is + // build with jlink) + if (!vectorModulePresentAndReadable()) { + LOG.warning( + "Java vector incubator module is not readable. For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API."); + return new VectorUtilDefaultProvider(); + } + if (!testMode && isClientVM()) { + LOG.warning("C2 compiler is disabled; Java vector incubator API can't be enabled"); + return new VectorUtilDefaultProvider(); + } + try { + // we use method handles with lookup, so we do not need to deal with setAccessible as we + // have private access through the lookup: + final var lookup = MethodHandles.lookup(); + final var cls = lookup.findClass("org.apache.lucene.util.VectorUtilPanamaProvider"); + final var constr = + lookup.findConstructor(cls, MethodType.methodType(void.class, boolean.class)); + try { + return (VectorUtilProvider) constr.invoke(testMode); + } catch (UnsupportedOperationException uoe) { + // not supported because preferred vector size too small or similar + LOG.warning("Java vector incubator API was not enabled. " + uoe.getMessage()); + return new VectorUtilDefaultProvider(); + } catch (RuntimeException | Error e) { + throw e; + } catch (Throwable th) { + throw new AssertionError(th); + } + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new LinkageError( + "VectorUtilPanamaProvider is missing correctly typed constructor", e); + } catch (ClassNotFoundException cnfe) { + throw new LinkageError("VectorUtilPanamaProvider is missing in Lucene JAR file", cnfe); + } + } else if (runtimeVersion >= 22) { + LOG.warning( + "You are running with Java 22 or later. To make full use of the Vector API, please update Apache Lucene."); + } + return new VectorUtilDefaultProvider(); + } + + private static boolean vectorModulePresentAndReadable() { + var opt = + ModuleLayer.boot().modules().stream() + .filter(m -> m.getName().equals("jdk.incubator.vector")) + .findFirst(); + if (opt.isPresent()) { + VectorUtilProvider.class.getModule().addReads(opt.get()); + return true; + } + return false; + } + + /** + * Check if runtime is affected by JDK-8301190 (avoids assertion when default language is say + * "tr"). + */ + private static boolean isAffectedByJDK8301190() { + return VERSION_JDK8301190_FIXED.compareToIgnoreOptional(Runtime.version()) > 0 + && !Objects.equals("I", "i".toUpperCase(Locale.getDefault())); + } + + @SuppressWarnings("removal") + @SuppressForbidden(reason = "security manager") + private static boolean isClientVM() { + try { + final PrivilegedAction action = + () -> System.getProperty("java.vm.info", "").contains("emulated-client"); + return AccessController.doPrivileged(action); + } catch ( + @SuppressWarnings("unused") + SecurityException e) { + LOG.warning( + "SecurityManager denies permission to 'java.vm.info' system property, so state of C2 compiler can't be detected. " + + "In case of performance issues allow access to this property."); + return false; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/Version.java b/lucene/core/src/java/org/apache/lucene/util/Version.java index 3bc4be34d5a1..bd37efbd8187 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Version.java +++ b/lucene/core/src/java/org/apache/lucene/util/Version.java @@ -259,11 +259,16 @@ public final class Version { @Deprecated public static final Version LUCENE_9_5_0 = new Version(9, 5, 0); /** - * Match settings and bugs in Lucene's 9.6.0 release. + * @deprecated (9.7.0) Use latest + */ + @Deprecated public static final Version LUCENE_9_6_0 = new Version(9, 6, 0); + + /** + * Match settings and bugs in Lucene's 9.7.0 release. * *

    Use this to get the latest & greatest settings, bug fixes, etc, for Lucene. */ - public static final Version LUCENE_9_6_0 = new Version(9, 6, 0); + public static final Version LUCENE_9_7_0 = new Version(9, 7, 0); // To add a new version: // * Only add above this comment @@ -279,7 +284,7 @@ public final class Version { * re-test your entire application to ensure it behaves as expected, as some defaults may * have changed and may break functionality in your application. */ - public static final Version LATEST = LUCENE_9_6_0; + public static final Version LATEST = LUCENE_9_7_0; /** * Constant for backwards compatibility. diff --git a/lucene/core/src/java/org/apache/lucene/util/VirtualMethod.java b/lucene/core/src/java/org/apache/lucene/util/VirtualMethod.java index 05eef2a86617..a7c2a71cc1c8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VirtualMethod.java +++ b/lucene/core/src/java/org/apache/lucene/util/VirtualMethod.java @@ -17,6 +17,8 @@ package org.apache.lucene.util; import java.lang.reflect.Method; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Collections; import java.util.HashSet; import java.util.Set; @@ -49,13 +51,20 @@ * *

      *  final boolean isDeprecatedMethodOverridden =
    - *   oldMethod.getImplementationDistance(this.getClass()) > newMethod.getImplementationDistance(this.getClass());
    + *   AccessController.doPrivileged((PrivilegedAction<Boolean>) () ->
    + *    (oldMethod.getImplementationDistance(this.getClass()) > newMethod.getImplementationDistance(this.getClass())));
      *
      *  // alternatively (more readable):
      *  final boolean isDeprecatedMethodOverridden =
    - *   VirtualMethod.compareImplementationDistance(this.getClass(), oldMethod, newMethod) > 0
    + *   AccessController.doPrivileged((PrivilegedAction<Boolean>) () ->
    + *    VirtualMethod.compareImplementationDistance(this.getClass(), oldMethod, newMethod) > 0);
      * 
    * + *

    It is important to use {@link AccessController#doPrivileged(PrivilegedAction)} for the actual + * call to get the implementation distance because the subclass may be in a different package. The + * static constructors do not need to use {@code AccessController} because it just initializes our + * own method reference. The caller should have access to all declared members in its own class. + * *

    {@link #getImplementationDistance} returns the distance of the subclass that overrides this * method. The one with the larger distance should be used preferable. This way also more * complicated method rename scenarios can be handled (think of 2.9 {@code TokenStream} diff --git a/lucene/core/src/java/org/apache/lucene/util/automaton/Automata.java b/lucene/core/src/java/org/apache/lucene/util/automaton/Automata.java index 9a642338f09b..c829429d3c9b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/automaton/Automata.java +++ b/lucene/core/src/java/org/apache/lucene/util/automaton/Automata.java @@ -29,9 +29,11 @@ package org.apache.lucene.util.automaton; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefIterator; import org.apache.lucene.util.StringHelper; /** @@ -40,6 +42,11 @@ * @lucene.experimental */ public final class Automata { + /** + * {@link #makeStringUnion(Collection)} limits terms of this max length to ensure the stack + * doesn't overflow while building, since our algorithm currently relies on recursion. + */ + public static final int MAX_STRING_UNION_TERM_LENGTH = 1000; private Automata() {} @@ -573,7 +580,49 @@ public static Automaton makeStringUnion(Collection utf8Strings) { if (utf8Strings.isEmpty()) { return makeEmpty(); } else { - return DaciukMihovAutomatonBuilder.build(utf8Strings); + return DaciukMihovAutomatonBuilder.build(utf8Strings, false); + } + } + + /** + * Returns a new (deterministic and minimal) automaton that accepts the union of the given + * collection of {@link BytesRef}s representing UTF-8 encoded strings. The resulting automaton + * will be built in a binary representation. + * + * @param utf8Strings The input strings, UTF-8 encoded. The collection must be in sorted order. + * @return An {@link Automaton} accepting all input strings. The resulting automaton is binary + * based (UTF-8 encoded byte transition labels). + */ + public static Automaton makeBinaryStringUnion(Collection utf8Strings) { + if (utf8Strings.isEmpty()) { + return makeEmpty(); + } else { + return DaciukMihovAutomatonBuilder.build(utf8Strings, true); } } + + /** + * Returns a new (deterministic and minimal) automaton that accepts the union of the given + * iterator of {@link BytesRef}s representing UTF-8 encoded strings. + * + * @param utf8Strings The input strings, UTF-8 encoded. The iterator must be in sorted order. + * @return An {@link Automaton} accepting all input strings. The resulting automaton is codepoint + * based (full unicode codepoints on transitions). + */ + public static Automaton makeStringUnion(BytesRefIterator utf8Strings) throws IOException { + return DaciukMihovAutomatonBuilder.build(utf8Strings, false); + } + + /** + * Returns a new (deterministic and minimal) automaton that accepts the union of the given + * iterator of {@link BytesRef}s representing UTF-8 encoded strings. The resulting automaton will + * be built in a binary representation. + * + * @param utf8Strings The input strings, UTF-8 encoded. The iterator must be in sorted order. + * @return An {@link Automaton} accepting all input strings. The resulting automaton is binary + * based (UTF-8 encoded byte transition labels). + */ + public static Automaton makeBinaryStringUnion(BytesRefIterator utf8Strings) throws IOException { + return DaciukMihovAutomatonBuilder.build(utf8Strings, true); + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/automaton/DaciukMihovAutomatonBuilder.java b/lucene/core/src/java/org/apache/lucene/util/automaton/DaciukMihovAutomatonBuilder.java index 94002b04a40b..7048d4538d15 100644 --- a/lucene/core/src/java/org/apache/lucene/util/automaton/DaciukMihovAutomatonBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/automaton/DaciukMihovAutomatonBuilder.java @@ -16,14 +16,15 @@ */ package org.apache.lucene.util.automaton; +import java.io.IOException; import java.util.Arrays; import java.util.Collection; -import java.util.Comparator; import java.util.HashMap; import java.util.IdentityHashMap; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.CharsRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.BytesRefIterator; import org.apache.lucene.util.UnicodeUtil; /** @@ -32,14 +33,22 @@ * * @see #build(Collection) * @see Automata#makeStringUnion(Collection) + * @see Automata#makeBinaryStringUnion(Collection) + * @see Automata#makeStringUnion(BytesRefIterator) + * @see Automata#makeBinaryStringUnion(BytesRefIterator) + * @deprecated Visibility of this class will be reduced in a future release. Users can access this + * functionality directly through {@link Automata#makeStringUnion(Collection)} */ +@Deprecated public final class DaciukMihovAutomatonBuilder { /** * This builder rejects terms that are more than 1k chars long since it then uses recursion based * on the length of the string, which might cause stack overflows. + * + * @deprecated See {@link Automata#MAX_STRING_UNION_TERM_LENGTH} */ - public static final int MAX_TERM_LENGTH = 1_000; + @Deprecated public static final int MAX_TERM_LENGTH = 1_000; /** The default constructor is private. Use static methods directly. */ private DaciukMihovAutomatonBuilder() { @@ -179,56 +188,18 @@ private static boolean referenceEquals(Object[] a1, Object[] a2) { private HashMap stateRegistry = new HashMap<>(); /** Root automaton state. */ - private State root = new State(); - - /** Previous sequence added to the automaton in {@link #add(CharsRef)}. */ - private CharsRef previous; + private final State root = new State(); - /** A comparator used for enforcing sorted UTF8 order, used in assertions only. */ - @SuppressWarnings("deprecation") - private static final Comparator comparator = CharsRef.getUTF16SortedAsUTF8Comparator(); + /** Used for input order checking (only through assertions right now) */ + private BytesRefBuilder previous; - /** - * Add another character sequence to this automaton. The sequence must be lexicographically larger - * or equal compared to any previous sequences added to this automaton (the input must be sorted). - */ - public void add(CharsRef current) { - if (current.length > MAX_TERM_LENGTH) { - throw new IllegalArgumentException( - "This builder doesn't allow terms that are larger than 1,000 characters, got " + current); - } - assert stateRegistry != null : "Automaton already built."; - assert previous == null || comparator.compare(previous, current) <= 0 - : "Input must be in sorted UTF-8 order: " + previous + " >= " + current; - assert setPrevious(current); - - // Descend in the automaton (find matching prefix). - int pos = 0, max = current.length(); - State next, state = root; - while (pos < max && (next = state.lastChild(Character.codePointAt(current, pos))) != null) { - state = next; - // todo, optimize me - pos += Character.charCount(Character.codePointAt(current, pos)); + /** Copy current into an internal buffer. */ + private boolean setPrevious(BytesRef current) { + if (previous == null) { + previous = new BytesRefBuilder(); } - - if (state.hasChildren()) replaceOrRegister(state); - - addSuffix(state, current, pos); - } - - /** - * Finalize the automaton and return the root state. No more strings can be added to the builder - * after this call. - * - * @return Root automaton state. - */ - public State complete() { - if (this.stateRegistry == null) throw new IllegalStateException(); - - if (root.hasChildren()) replaceOrRegister(root); - - stateRegistry = null; - return root; + previous.copyBytes(current); + return true; } /** Internal recursive traversal for conversion. */ @@ -253,35 +224,116 @@ private static int convert( return converted; } + /** + * Called after adding all terms. Performs final minimization and converts to a standard {@link + * Automaton} instance. + */ + private Automaton completeAndConvert() { + // Final minimization: + if (this.stateRegistry == null) throw new IllegalStateException(); + if (root.hasChildren()) replaceOrRegister(root); + stateRegistry = null; + + // Convert: + Automaton.Builder a = new Automaton.Builder(); + convert(a, root, new IdentityHashMap<>()); + return a.finish(); + } + /** * Build a minimal, deterministic automaton from a sorted list of {@link BytesRef} representing * strings in UTF-8. These strings must be binary-sorted. + * + * @deprecated Please see {@link Automata#makeStringUnion(Collection)} instead */ + @Deprecated public static Automaton build(Collection input) { + return build(input, false); + } + + /** + * Build a minimal, deterministic automaton from a sorted list of {@link BytesRef} representing + * strings in UTF-8. These strings must be binary-sorted. + */ + static Automaton build(Collection input, boolean asBinary) { final DaciukMihovAutomatonBuilder builder = new DaciukMihovAutomatonBuilder(); - char[] chars = new char[0]; - CharsRef ref = new CharsRef(); for (BytesRef b : input) { - chars = ArrayUtil.grow(chars, b.length); - final int len = UnicodeUtil.UTF8toUTF16(b, chars); - ref.chars = chars; - ref.length = len; - builder.add(ref); + builder.add(b, asBinary); } - Automaton.Builder a = new Automaton.Builder(); - convert(a, builder.complete(), new IdentityHashMap()); + return builder.completeAndConvert(); + } - return a.finish(); + /** + * Build a minimal, deterministic automaton from a sorted list of {@link BytesRef} representing + * strings in UTF-8. These strings must be binary-sorted. Creates an {@link Automaton} with either + * UTF-8 codepoints as transition labels or binary (compiled) transition labels based on {@code + * asBinary}. + */ + static Automaton build(BytesRefIterator input, boolean asBinary) throws IOException { + final DaciukMihovAutomatonBuilder builder = new DaciukMihovAutomatonBuilder(); + + for (BytesRef b = input.next(); b != null; b = input.next()) { + builder.add(b, asBinary); + } + + return builder.completeAndConvert(); } - /** Copy current into an internal buffer. */ - private boolean setPrevious(CharsRef current) { - // don't need to copy, once we fix https://issues.apache.org/jira/browse/LUCENE-3277 - // still, called only from assert - previous = CharsRef.deepCopyOf(current); - return true; + private void add(BytesRef current, boolean asBinary) { + if (current.length > Automata.MAX_STRING_UNION_TERM_LENGTH) { + throw new IllegalArgumentException( + "This builder doesn't allow terms that are larger than " + + Automata.MAX_STRING_UNION_TERM_LENGTH + + " characters, got " + + current); + } + assert stateRegistry != null : "Automaton already built."; + assert previous == null || previous.get().compareTo(current) <= 0 + : "Input must be in sorted UTF-8 order: " + previous.get() + " >= " + current; + assert setPrevious(current); + + // Reusable codepoint information if we're building a non-binary based automaton + UnicodeUtil.UTF8CodePoint codePoint = null; + + // Descend in the automaton (find matching prefix). + byte[] bytes = current.bytes; + int pos = current.offset, max = current.offset + current.length; + State next, state = root; + if (asBinary) { + while (pos < max && (next = state.lastChild(bytes[pos] & 0xff)) != null) { + state = next; + pos++; + } + } else { + while (pos < max) { + codePoint = UnicodeUtil.codePointAt(bytes, pos, codePoint); + next = state.lastChild(codePoint.codePoint); + if (next == null) { + break; + } + state = next; + pos += codePoint.numBytes; + } + } + + if (state.hasChildren()) replaceOrRegister(state); + + // Add suffix + if (asBinary) { + while (pos < max) { + state = state.newState(bytes[pos] & 0xff); + pos++; + } + } else { + while (pos < max) { + codePoint = UnicodeUtil.codePointAt(bytes, pos, codePoint); + state = state.newState(codePoint.codePoint); + pos += codePoint.numBytes; + } + } + state.is_final = true; } /** @@ -300,18 +352,4 @@ private void replaceOrRegister(State state) { stateRegistry.put(child, child); } } - - /** - * Add a suffix of current starting at fromIndex (inclusive) to state - * state. - */ - private void addSuffix(State state, CharSequence current, int fromIndex) { - final int len = current.length(); - while (fromIndex < len) { - int cp = Character.codePointAt(current, fromIndex); - state = state.newState(cp); - fromIndex += Character.charCount(cp); - } - state.is_final = true; - } } diff --git a/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java b/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java index 6db58e7119ea..9ebe5b998037 100644 --- a/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java +++ b/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java @@ -1135,7 +1135,7 @@ public static String getCommonPrefix(Automaton a) { FixedBitSet tmp = current; current = next; next = tmp; - next.clear(0, next.length()); + next.clear(); } return builder.toString(); } @@ -1311,9 +1311,22 @@ static Automaton totalize(Automaton a) { } /** - * Returns the topological sort of all states reachable from the initial state. Behavior is - * undefined if this automaton has cycles. CPU cost is O(numTransitions), and the implementation - * is recursive so an automaton matching long strings may exhaust the java stack. + * Returns the topological sort of all states reachable from the initial state. This method + * assumes that the automaton does not contain cycles, and will throw an IllegalArgumentException + * if a cycle is detected. The CPU cost is O(numTransitions), and the implementation is + * non-recursive, so it will not exhaust the java stack for automaton matching long strings. If + * there are dead states in the automaton, they will be removed from the returned array. + * + *

    Note: This method uses a deque to iterative the states, which could potentially consume a + * lot of heap space for some automatons. Specifically, automatons with a deep level of states + * (i.e., a large number of transitions from the initial state to the final state) may + * particularly contribute to high memory usage. The memory consumption of this method can be + * considered as O(N), where N is the depth of the automaton (the maximum number of transitions + * from the initial state to any state). However, as this method detects cycles, it will never + * attempt to use infinite RAM. + * + * @param a the Automaton to be sorted + * @return the topologically sorted array of state ids */ public static int[] topoSortStates(Automaton a) { if (a.getNumStates() == 0) { @@ -1321,8 +1334,7 @@ public static int[] topoSortStates(Automaton a) { } int numStates = a.getNumStates(); int[] states = new int[numStates]; - final BitSet visited = new BitSet(numStates); - int upto = topoSortStatesRecurse(a, visited, states, 0, 0, 0); + int upto = topoSortStates(a, states); if (upto < states.length) { // There were dead states @@ -1341,24 +1353,49 @@ public static int[] topoSortStates(Automaton a) { return states; } - // TODO: not great that this is recursive... in theory a - // large automata could exceed java's stack so the maximum level of recursion is bounded to 1000 - private static int topoSortStatesRecurse( - Automaton a, BitSet visited, int[] states, int upto, int state, int level) { - if (level > MAX_RECURSION_LEVEL) { - throw new IllegalArgumentException("input automaton is too large: " + level); - } + /** + * Performs a topological sort on the states of the given Automaton. + * + * @param a The automaton whose states are to be topologically sorted. + * @param states An int array which stores the states. + * @return the number of states in the final sorted list. + * @throws IllegalArgumentException if the input automaton has a cycle. + */ + private static int topoSortStates(Automaton a, int[] states) { + BitSet onStack = new BitSet(a.getNumStates()); + BitSet visited = new BitSet(a.getNumStates()); + var stack = new ArrayDeque(); + stack.push(0); // Assuming that the initial state is 0. + int upto = 0; Transition t = new Transition(); - int count = a.initTransition(state, t); - for (int i = 0; i < count; i++) { - a.getNextTransition(t); - if (!visited.get(t.dest)) { - visited.set(t.dest); - upto = topoSortStatesRecurse(a, visited, states, upto, t.dest, level + 1); + + while (!stack.isEmpty()) { + int state = stack.peek(); // Just peek, don't remove the state yet + + int count = a.initTransition(state, t); + boolean pushed = false; + for (int i = 0; i < count; i++) { + a.getNextTransition(t); + if (!visited.get(t.dest)) { + visited.set(t.dest); + stack.push(t.dest); // Push the next unvisited state onto the stack + onStack.set(state); + pushed = true; + break; // Exit the loop, we'll continue from here in the next iteration + } else if (onStack.get(t.dest)) { + // If the state is on the current recursion stack, we have detected a cycle + throw new IllegalArgumentException("Input automaton has a cycle."); + } + } + + // If we haven't pushed any new state onto the stack, we're done with this state + if (!pushed) { + onStack.clear(state); // remove the node from the current recursion stack + stack.pop(); + states[upto] = state; + upto++; } } - states[upto] = state; - upto++; return upto; } } diff --git a/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java b/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java index 6711dfb6230f..321c6ff133a0 100644 --- a/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java +++ b/lucene/core/src/java/org/apache/lucene/util/graph/GraphTokenStreamFiniteStrings.java @@ -45,6 +45,8 @@ * different paths of the {@link Automaton}. */ public final class GraphTokenStreamFiniteStrings { + /** Maximum level of recursion allowed in recursive operations. */ + private static final int MAX_RECURSION_LEVEL = 1000; private AttributeSource[] tokens = new AttributeSource[4]; private final Automaton det; @@ -271,7 +273,12 @@ private static void articulationPointsRecurse( a.getNextTransition(t); if (visited.get(t.dest) == false) { parent[t.dest] = state; - articulationPointsRecurse(a, t.dest, d + 1, depth, low, parent, visited, points); + if (d < MAX_RECURSION_LEVEL) { + articulationPointsRecurse(a, t.dest, d + 1, depth, low, parent, visited, points); + } else { + throw new IllegalArgumentException( + "Exceeded maximum recursion level during graph analysis"); + } childCount++; if (low[t.dest] >= depth[state]) { isArticulation = true; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 2c5e84be2859..5738cc8c75d8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -41,8 +41,17 @@ */ public final class HnswGraphBuilder { + /** Default number of maximum connections per node */ + public static final int DEFAULT_MAX_CONN = 16; + + /** + * Default number of the size of the queue maintained while searching during a graph construction. + */ + public static final int DEFAULT_BEAM_WIDTH = 100; + /** Default random seed for level generation * */ private static final long DEFAULT_RAND_SEED = 42; + /** A name for the HNSW component for the info-stream * */ public static final String HNSW_COMPONENT = "HNSW"; @@ -220,7 +229,9 @@ private void initializeFromGraph( binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor)); break; } - newNeighbors.insertSorted(newNeighbor, score); + // we are not sure whether the previous graph contains + // unchecked nodes, so we have to assume they're all unchecked + newNeighbors.addOutOfOrder(newNeighbor, score); } } } @@ -316,11 +327,11 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates) int size = neighbors.size(); for (int i = 0; i < size; i++) { int nbr = neighbors.node[i]; - NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr); - nbrNbr.insertSorted(node, neighbors.score[i]); - if (nbrNbr.size() > maxConnOnLevel) { - int indexToRemove = findWorstNonDiverse(nbrNbr); - nbrNbr.removeIndex(indexToRemove); + NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); + nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]); + if (nbrsOfNbr.size() > maxConnOnLevel) { + int indexToRemove = findWorstNonDiverse(nbrsOfNbr); + nbrsOfNbr.removeIndex(indexToRemove); } } } @@ -335,7 +346,7 @@ private void selectAndLinkDiverse( float cScore = candidates.score[i]; assert cNode < hnsw.size(); if (diversityCheck(cNode, cScore, neighbors)) { - neighbors.add(cNode, cScore); + neighbors.addInOrder(cNode, cScore); } } } @@ -347,7 +358,7 @@ private void popToScratch(NeighborQueue candidates) { // sorted from worst to best for (int i = 0; i < candidateCount; i++) { float maxSimilarity = candidates.topScore(); - scratch.add(candidates.pop(), maxSimilarity); + scratch.addInOrder(candidates.pop(), maxSimilarity); } } @@ -405,53 +416,119 @@ private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score * neighbours */ private int findWorstNonDiverse(NeighborArray neighbors) throws IOException { + int[] uncheckedIndexes = neighbors.sort(); + if (uncheckedIndexes == null) { + // all nodes are checked, we will directly return the most distant one + return neighbors.size() - 1; + } + int uncheckedCursor = uncheckedIndexes.length - 1; for (int i = neighbors.size() - 1; i > 0; i--) { - if (isWorstNonDiverse(i, neighbors)) { + if (uncheckedCursor < 0) { + // no unchecked node left + break; + } + if (isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) { return i; } + if (i == uncheckedIndexes[uncheckedCursor]) { + uncheckedCursor--; + } } return neighbors.size() - 1; } - private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors) + private boolean isWorstNonDiverse( + int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor) throws IOException { int candidateNode = neighbors.node[candidateIndex]; switch (vectorEncoding) { case BYTE: return isWorstNonDiverse( - candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors); + candidateIndex, + (byte[]) vectors.vectorValue(candidateNode), + neighbors, + uncheckedIndexes, + uncheckedCursor); default: case FLOAT32: return isWorstNonDiverse( - candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors); + candidateIndex, + (float[]) vectors.vectorValue(candidateNode), + neighbors, + uncheckedIndexes, + uncheckedCursor); } } private boolean isWorstNonDiverse( - int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException { + int candidateIndex, + float[] candidateVector, + NeighborArray neighbors, + int[] uncheckedIndexes, + int uncheckedCursor) + throws IOException { float minAcceptedSimilarity = neighbors.score[candidateIndex]; - for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = - similarityFunction.compare( - candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i])); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { - return true; + if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { + // the candidate itself is unchecked + for (int i = candidateIndex - 1; i >= 0; i--) { + float neighborSimilarity = + similarityFunction.compare( + candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } + } + } else { + // else we just need to make sure candidate does not violate diversity with the (newly + // inserted) unchecked nodes + assert candidateIndex > uncheckedIndexes[uncheckedCursor]; + for (int i = uncheckedCursor; i >= 0; i--) { + float neighborSimilarity = + similarityFunction.compare( + candidateVector, + (float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } } } return false; } private boolean isWorstNonDiverse( - int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException { + int candidateIndex, + byte[] candidateVector, + NeighborArray neighbors, + int[] uncheckedIndexes, + int uncheckedCursor) + throws IOException { float minAcceptedSimilarity = neighbors.score[candidateIndex]; - for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = - similarityFunction.compare( - candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i])); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { - return true; + if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { + // the candidate itself is unchecked + for (int i = candidateIndex - 1; i >= 0; i--) { + float neighborSimilarity = + similarityFunction.compare( + candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } + } + } else { + // else we just need to make sure candidate does not violate diversity with the (newly + // inserted) unchecked nodes + assert candidateIndex > uncheckedIndexes[uncheckedCursor]; + for (int i = uncheckedCursor; i >= 0; i--) { + float neighborSimilarity = + similarityFunction.compare( + candidateVector, + (byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]])); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } } } return false; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 4857d5b9d577..5bc718169466 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -100,28 +100,31 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); - NeighborQueue results; + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + } - int initialEp = graph.entryNode(); - if (initialEp == -1) { - return new NeighborQueue(1, true); - } - int[] eps = new int[] {initialEp}; - int numVisited = 0; - for (int level = graph.numLevels() - 1; level >= 1; level--) { - results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit); - numVisited += results.visitedCount(); - visitedLimit -= results.visitedCount(); - if (results.incomplete()) { - results.setVisitedCount(numVisited); - return results; - } - eps[0] = results.pop(); - } - results = - graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit); - results.setVisitedCount(results.visitedCount() + numVisited); - return results; + /** + * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to + * {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding, + * VectorSimilarityFunction, HnswGraph, Bits, int)} + */ + public static NeighborQueue search( + float[] query, + int topK, + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + OnHeapHnswGraph graph, + Bits acceptOrds, + int visitedLimit) + throws IOException { + OnHeapHnswGraphSearcher graphSearcher = + new OnHeapHnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(topK, true), + new SparseFixedBitSet(vectors.size())); + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); } /** @@ -161,6 +164,46 @@ public static NeighborQueue search( similarityFunction, new NeighborQueue(topK, true), new SparseFixedBitSet(vectors.size())); + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + } + + /** + * Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to + * {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction, + * HnswGraph, Bits, int)} + */ + public static NeighborQueue search( + byte[] query, + int topK, + RandomAccessVectorValues vectors, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + OnHeapHnswGraph graph, + Bits acceptOrds, + int visitedLimit) + throws IOException { + OnHeapHnswGraphSearcher graphSearcher = + new OnHeapHnswGraphSearcher<>( + vectorEncoding, + similarityFunction, + new NeighborQueue(topK, true), + new SparseFixedBitSet(vectors.size())); + return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit); + } + + private static NeighborQueue search( + T query, + int topK, + RandomAccessVectorValues vectors, + HnswGraph graph, + HnswGraphSearcher graphSearcher, + Bits acceptOrds, + int visitedLimit) + throws IOException { + int initialEp = graph.entryNode(); + if (initialEp == -1) { + return new NeighborQueue(1, true); + } NeighborQueue results; int[] eps = new int[] {graph.entryNode()}; int numVisited = 0; @@ -252,9 +295,9 @@ private NeighborQueue searchLevel( } int topCandidateNode = candidates.pop(); - graph.seek(level, topCandidateNode); + graphSeek(graph, level, topCandidateNode); int friendOrd; - while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) { + while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) { assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size; if (visited.getAndSet(friendOrd)) { continue; @@ -296,6 +339,62 @@ private void prepareScratchState(int capacity) { if (visited.length() < capacity) { visited = FixedBitSet.ensureCapacity((FixedBitSet) visited, capacity); } - visited.clear(0, visited.length()); + visited.clear(); + } + + /** + * Seek a specific node in the given graph. The default implementation will just call {@link + * HnswGraph#seek(int, int)} + * + * @throws IOException when seeking the graph + */ + void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException { + graph.seek(level, targetNode); + } + + /** + * Get the next neighbor from the graph, you must call {@link #graphSeek(HnswGraph, int, int)} + * before calling this method. The default implementation will just call {@link + * HnswGraph#nextNeighbor()} + * + * @return see {@link HnswGraph#nextNeighbor()} + * @throws IOException when advance neighbors + */ + int graphNextNeighbor(HnswGraph graph) throws IOException { + return graph.nextNeighbor(); + } + + /** + * This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner. + * + *

    Note the class itself is NOT thread safe, but since each search will create one new graph + * searcher the search method is thread safe. + */ + private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher { + + private NeighborArray cur; + private int upto; + + private OnHeapHnswGraphSearcher( + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + NeighborQueue candidates, + BitSet visited) { + super(vectorEncoding, similarityFunction, candidates, visited); + } + + @Override + void graphSeek(HnswGraph graph, int level, int targetNode) { + cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode); + upto = -1; + } + + @Override + int graphNextNeighbor(HnswGraph graph) { + if (++upto < cur.size()) { + return cur.node[upto]; + } + return NO_MORE_DOCS; + } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index ec1b5ec3e897..b44f7da8b8ad 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -34,6 +34,7 @@ public class NeighborArray { float[] score; int[] node; + private int sortedNodeSize; public NeighborArray(int maxSize, boolean descOrder) { node = new int[maxSize]; @@ -43,9 +44,10 @@ public NeighborArray(int maxSize, boolean descOrder) { /** * Add a new node to the NeighborArray. The new node must be worse than all previously stored - * nodes. + * nodes. This cannot be called after {@link #addOutOfOrder(int, float)} */ - public void add(int newNode, float newScore) { + public void addInOrder(int newNode, float newScore) { + assert size == sortedNodeSize : "cannot call addInOrder after addOutOfOrder"; if (size == node.length) { node = ArrayUtil.grow(node); score = ArrayUtil.growExact(score, node.length); @@ -54,28 +56,80 @@ public void add(int newNode, float newScore) { float previousScore = score[size - 1]; assert ((scoresDescOrder && (previousScore >= newScore)) || (scoresDescOrder == false && (previousScore <= newScore))) - : "Nodes are added in the incorrect order!"; + : "Nodes are added in the incorrect order! Comparing " + + newScore + + " to " + + Arrays.toString(ArrayUtil.copyOfSubArray(score, 0, size)); } node[size] = newNode; score[size] = newScore; ++size; + ++sortedNodeSize; } - /** Add a new node to the NeighborArray into a correct sort position according to its score. */ - public void insertSorted(int newNode, float newScore) { + /** Add node and score but do not insert as sorted */ + public void addOutOfOrder(int newNode, float newScore) { if (size == node.length) { node = ArrayUtil.grow(node); score = ArrayUtil.growExact(score, node.length); } + node[size] = newNode; + score[size] = newScore; + size++; + } + + /** + * Sort the array according to scores, and return the sorted indexes of previous unsorted nodes + * (unchecked nodes) + * + * @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is + * already fully sorted + */ + public int[] sort() { + if (size == sortedNodeSize) { + // all nodes checked and sorted + return null; + } + assert sortedNodeSize < size; + int[] uncheckedIndexes = new int[size - sortedNodeSize]; + int count = 0; + while (sortedNodeSize != size) { + uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside + for (int i = 0; i < count; i++) { + if (uncheckedIndexes[i] >= uncheckedIndexes[count]) { + // the previous inserted nodes has been shifted + uncheckedIndexes[i]++; + } + } + count++; + } + Arrays.sort(uncheckedIndexes); + return uncheckedIndexes; + } + + /** insert the first unsorted node into its sorted position */ + private int insertSortedInternal() { + assert sortedNodeSize < size : "Call this method only when there's unsorted node"; + int tmpNode = node[sortedNodeSize]; + float tmpScore = score[sortedNodeSize]; int insertionPoint = scoresDescOrder - ? descSortFindRightMostInsertionPoint(newScore) - : ascSortFindRightMostInsertionPoint(newScore); - System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint); - System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint); - node[insertionPoint] = newNode; - score[insertionPoint] = newScore; - ++size; + ? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize) + : ascSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize); + System.arraycopy( + node, insertionPoint, node, insertionPoint + 1, sortedNodeSize - insertionPoint); + System.arraycopy( + score, insertionPoint, score, insertionPoint + 1, sortedNodeSize - insertionPoint); + node[insertionPoint] = tmpNode; + score[insertionPoint] = tmpScore; + ++sortedNodeSize; + return insertionPoint; + } + + /** This method is for test only. */ + void insertSorted(int newNode, float newScore) { + addOutOfOrder(newNode, newScore); + insertSortedInternal(); } public int size() { @@ -97,15 +151,20 @@ public float[] score() { public void clear() { size = 0; + sortedNodeSize = 0; } public void removeLast() { size--; + sortedNodeSize = Math.min(sortedNodeSize, size); } public void removeIndex(int idx) { System.arraycopy(node, idx + 1, node, idx, size - idx - 1); System.arraycopy(score, idx + 1, score, idx, size - idx - 1); + if (idx < sortedNodeSize) { + sortedNodeSize--; + } size--; } @@ -114,11 +173,11 @@ public String toString() { return "NeighborArray[" + size + "]"; } - private int ascSortFindRightMostInsertionPoint(float newScore) { - int insertionPoint = Arrays.binarySearch(score, 0, size, newScore); + private int ascSortFindRightMostInsertionPoint(float newScore, int bound) { + int insertionPoint = Arrays.binarySearch(score, 0, bound, newScore); if (insertionPoint >= 0) { // find the right most position with the same score - while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) { + while ((insertionPoint < bound - 1) && (score[insertionPoint + 1] == score[insertionPoint])) { insertionPoint++; } insertionPoint++; @@ -128,9 +187,9 @@ private int ascSortFindRightMostInsertionPoint(float newScore) { return insertionPoint; } - private int descSortFindRightMostInsertionPoint(float newScore) { + private int descSortFindRightMostInsertionPoint(float newScore, int bound) { int start = 0; - int end = size - 1; + int end = bound - 1; while (start <= end) { int mid = (start + end) / 2; if (score[mid] < newScore) end = mid - 1; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java index 9862536de08c..75a77af056a2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OnHeapHnswGraph.java @@ -170,14 +170,14 @@ public NodesIterator getNodesOnLevel(int level) { public long ramBytesUsed() { long neighborArrayBytes0 = nsize0 * (Integer.BYTES + Float.BYTES) - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2 - + RamUsageEstimator.NUM_BYTES_OBJECT_REF - + Integer.BYTES * 2; + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 + + Integer.BYTES * 3; long neighborArrayBytes = nsize * (Integer.BYTES + Float.BYTES) - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER * 2 - + RamUsageEstimator.NUM_BYTES_OBJECT_REF - + Integer.BYTES * 2; + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + + RamUsageEstimator.NUM_BYTES_OBJECT_REF * 2 + + Integer.BYTES * 3; long total = 0; for (int l = 0; l < numLevels; l++) { if (l == 0) { diff --git a/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java new file mode 100644 index 000000000000..a1a5a404223f --- /dev/null +++ b/lucene/core/src/java20/org/apache/lucene/util/VectorUtilPanamaProvider.java @@ -0,0 +1,499 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.logging.Logger; +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorShape; +import jdk.incubator.vector.VectorSpecies; + +/** A VectorUtil provider implementation that leverages the Panama Vector API. */ +final class VectorUtilPanamaProvider implements VectorUtilProvider { + + private static final int INT_SPECIES_PREF_BIT_SIZE = IntVector.SPECIES_PREFERRED.vectorBitSize(); + + private static final VectorSpecies PREF_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; + private static final VectorSpecies PREF_BYTE_SPECIES; + private static final VectorSpecies PREF_SHORT_SPECIES; + + /** + * x86 and less than 256-bit vectors. + * + *

    it could be that it has only AVX1 and integer vectors are fast. it could also be that it has + * no AVX and integer vectors are extremely slow. don't use integer vectors to avoid landmines. + */ + private final boolean hasFastIntegerVectors; + + static { + if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + PREF_BYTE_SPECIES = + ByteVector.SPECIES_MAX.withShape( + VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 2)); + PREF_SHORT_SPECIES = + ShortVector.SPECIES_MAX.withShape( + VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 1)); + } else { + PREF_BYTE_SPECIES = null; + PREF_SHORT_SPECIES = null; + } + } + + // Extracted to a method to be able to apply the SuppressForbidden annotation + @SuppressWarnings("removal") + @SuppressForbidden(reason = "security manager") + private static T doPrivileged(PrivilegedAction action) { + return AccessController.doPrivileged(action); + } + + VectorUtilPanamaProvider(boolean testMode) { + if (!testMode && INT_SPECIES_PREF_BIT_SIZE < 128) { + throw new UnsupportedOperationException( + "Vector bit size is less than 128: " + INT_SPECIES_PREF_BIT_SIZE); + } + + // hack to work around for JDK-8309727: + try { + doPrivileged( + () -> + FloatVector.fromArray(PREF_FLOAT_SPECIES, new float[PREF_FLOAT_SPECIES.length()], 0)); + } catch (SecurityException se) { + throw new UnsupportedOperationException( + "We hit initialization failure described in JDK-8309727: " + se); + } + + // check if the system is x86 and less than 256-bit vectors: + var isAMD64withoutAVX2 = Constants.OS_ARCH.equals("amd64") && INT_SPECIES_PREF_BIT_SIZE < 256; + this.hasFastIntegerVectors = testMode || false == isAMD64withoutAVX2; + + var log = Logger.getLogger(getClass().getName()); + log.info( + "Java vector incubator API enabled" + + (testMode ? " (test mode)" : "") + + "; uses preferredBitSize=" + + INT_SPECIES_PREF_BIT_SIZE); + } + + @Override + public float dotProduct(float[] a, float[] b) { + int i = 0; + float res = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * PREF_FLOAT_SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector acc1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc4 = FloatVector.zero(PREF_FLOAT_SPECIES); + int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length()); + for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + acc1 = acc1.add(va.mul(vb)); + FloatVector vc = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length()); + FloatVector vd = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length()); + acc2 = acc2.add(vc.mul(vd)); + FloatVector ve = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length()); + FloatVector vf = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length()); + acc3 = acc3.add(ve.mul(vf)); + FloatVector vg = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length()); + FloatVector vh = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length()); + acc4 = acc4.add(vg.mul(vh)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = PREF_FLOAT_SPECIES.loopBound(a.length); + for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + acc1 = acc1.add(va.mul(vb)); + } + // reduce + FloatVector res1 = acc1.add(acc2); + FloatVector res2 = acc3.add(acc4); + res += res1.add(res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + @Override + public float cosine(float[] a, float[] b) { + int i = 0; + float sum = 0; + float norm1 = 0; + float norm2 = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * PREF_FLOAT_SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector sum1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector sum2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector sum3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector sum4 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm1_1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm1_2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm1_3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm1_4 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm2_1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm2_2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm2_3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector norm2_4 = FloatVector.zero(PREF_FLOAT_SPECIES); + int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length()); + for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + sum1 = sum1.add(va.mul(vb)); + norm1_1 = norm1_1.add(va.mul(va)); + norm2_1 = norm2_1.add(vb.mul(vb)); + FloatVector vc = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length()); + FloatVector vd = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length()); + sum2 = sum2.add(vc.mul(vd)); + norm1_2 = norm1_2.add(vc.mul(vc)); + norm2_2 = norm2_2.add(vd.mul(vd)); + FloatVector ve = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length()); + FloatVector vf = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length()); + sum3 = sum3.add(ve.mul(vf)); + norm1_3 = norm1_3.add(ve.mul(ve)); + norm2_3 = norm2_3.add(vf.mul(vf)); + FloatVector vg = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length()); + FloatVector vh = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length()); + sum4 = sum4.add(vg.mul(vh)); + norm1_4 = norm1_4.add(vg.mul(vg)); + norm2_4 = norm2_4.add(vh.mul(vh)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = PREF_FLOAT_SPECIES.loopBound(a.length); + for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + sum1 = sum1.add(va.mul(vb)); + norm1_1 = norm1_1.add(va.mul(va)); + norm2_1 = norm2_1.add(vb.mul(vb)); + } + // reduce + FloatVector sumres1 = sum1.add(sum2); + FloatVector sumres2 = sum3.add(sum4); + FloatVector norm1res1 = norm1_1.add(norm1_2); + FloatVector norm1res2 = norm1_3.add(norm1_4); + FloatVector norm2res1 = norm2_1.add(norm2_2); + FloatVector norm2res2 = norm2_3.add(norm2_4); + sum += sumres1.add(sumres2).reduceLanes(VectorOperators.ADD); + norm1 += norm1res1.add(norm1res2).reduceLanes(VectorOperators.ADD); + norm2 += norm2res1.add(norm2res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + float elem1 = a[i]; + float elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); + } + + @Override + public float squareDistance(float[] a, float[] b) { + int i = 0; + float res = 0; + // if the array size is large (> 2x platform vector size), its worth the overhead to vectorize + if (a.length > 2 * PREF_FLOAT_SPECIES.length()) { + // vector loop is unrolled 4x (4 accumulators in parallel) + FloatVector acc1 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc2 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc3 = FloatVector.zero(PREF_FLOAT_SPECIES); + FloatVector acc4 = FloatVector.zero(PREF_FLOAT_SPECIES); + int upperBound = PREF_FLOAT_SPECIES.loopBound(a.length - 3 * PREF_FLOAT_SPECIES.length()); + for (; i < upperBound; i += 4 * PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + FloatVector diff1 = va.sub(vb); + acc1 = acc1.add(diff1.mul(diff1)); + FloatVector vc = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + PREF_FLOAT_SPECIES.length()); + FloatVector vd = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + PREF_FLOAT_SPECIES.length()); + FloatVector diff2 = vc.sub(vd); + acc2 = acc2.add(diff2.mul(diff2)); + FloatVector ve = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 2 * PREF_FLOAT_SPECIES.length()); + FloatVector vf = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 2 * PREF_FLOAT_SPECIES.length()); + FloatVector diff3 = ve.sub(vf); + acc3 = acc3.add(diff3.mul(diff3)); + FloatVector vg = + FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i + 3 * PREF_FLOAT_SPECIES.length()); + FloatVector vh = + FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i + 3 * PREF_FLOAT_SPECIES.length()); + FloatVector diff4 = vg.sub(vh); + acc4 = acc4.add(diff4.mul(diff4)); + } + // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + upperBound = PREF_FLOAT_SPECIES.loopBound(a.length); + for (; i < upperBound; i += PREF_FLOAT_SPECIES.length()) { + FloatVector va = FloatVector.fromArray(PREF_FLOAT_SPECIES, a, i); + FloatVector vb = FloatVector.fromArray(PREF_FLOAT_SPECIES, b, i); + FloatVector diff = va.sub(vb); + acc1 = acc1.add(diff.mul(diff)); + } + // reduce + FloatVector res1 = acc1.add(acc2); + FloatVector res2 = acc3.add(acc4); + res += res1.add(res2).reduceLanes(VectorOperators.ADD); + } + + for (; i < a.length; i++) { + float diff = a[i] - b[i]; + res += diff * diff; + } + return res; + } + + // Binary functions, these all follow a general pattern like this: + // + // short intermediate = a * b; + // int accumulator = accumulator + intermediate; + // + // 256 or 512 bit vectors can process 64 or 128 bits at a time, respectively + // intermediate results use 128 or 256 bit vectors, respectively + // final accumulator uses 256 or 512 bit vectors, respectively + // + // We also support 128 bit vectors, using two 128 bit accumulators. + // This is slower but still faster than not vectorizing at all. + + @Override + public int dotProduct(byte[] a, byte[] b) { + int i = 0; + int res = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors (256-bit on intel to dodge performance landmines) + if (a.length >= 16 && hasFastIntegerVectors) { + // compute vectorized dot product consistent with VPDPBUSD instruction + if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + // optimized 256/512 bit implementation, processes 8/16 bytes at a time + int upperBound = PREF_BYTE_SPECIES.loopBound(a.length); + IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i); + Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector prod16 = va16.mul(vb16); + Vector prod32 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + acc = acc.add(prod32); + } + // reduce + res += acc.reduceLanes(VectorOperators.ADD); + } else { + // 128-bit implementation, which must "split up" vectors due to widening conversions + int upperBound = ByteVector.SPECIES_64.loopBound(a.length); + IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); + IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); + for (; i < upperBound; i += ByteVector.SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + // expand each byte vector into short vector and multiply + Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector prod16 = va16.mul(vb16); + // split each short vector into two int vectors and add + Vector prod32_1 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector prod32_2 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + acc1 = acc1.add(prod32_1); + acc2 = acc2.add(prod32_2); + } + // reduce + res += acc1.add(acc2).reduceLanes(VectorOperators.ADD); + } + } + + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + @Override + public float cosine(byte[] a, byte[] b) { + int i = 0; + int sum = 0; + int norm1 = 0; + int norm2 = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors (256-bit on intel to dodge performance landmines) + if (a.length >= 16 && hasFastIntegerVectors) { + if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + // optimized 256/512 bit implementation, processes 8/16 bytes at a time + int upperBound = PREF_BYTE_SPECIES.loopBound(a.length); + IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED); + IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED); + IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i); + Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector prod16 = va16.mul(vb16); + Vector norm1_16 = va16.mul(va16); + Vector norm2_16 = vb16.mul(vb16); + Vector prod32 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + Vector norm1_32 = + norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + Vector norm2_32 = + norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + accSum = accSum.add(prod32); + accNorm1 = accNorm1.add(norm1_32); + accNorm2 = accNorm2.add(norm2_32); + } + // reduce + sum += accSum.reduceLanes(VectorOperators.ADD); + norm1 += accNorm1.reduceLanes(VectorOperators.ADD); + norm2 += accNorm2.reduceLanes(VectorOperators.ADD); + } else { + // 128-bit implementation, which must "split up" vectors due to widening conversions + int upperBound = ByteVector.SPECIES_64.loopBound(a.length); + IntVector accSum1 = IntVector.zero(IntVector.SPECIES_128); + IntVector accSum2 = IntVector.zero(IntVector.SPECIES_128); + IntVector accNorm1_1 = IntVector.zero(IntVector.SPECIES_128); + IntVector accNorm1_2 = IntVector.zero(IntVector.SPECIES_128); + IntVector accNorm2_1 = IntVector.zero(IntVector.SPECIES_128); + IntVector accNorm2_2 = IntVector.zero(IntVector.SPECIES_128); + for (; i < upperBound; i += ByteVector.SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + // expand each byte vector into short vector and perform multiplications + Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector prod16 = va16.mul(vb16); + Vector norm1_16 = va16.mul(va16); + Vector norm2_16 = vb16.mul(vb16); + // split each short vector into two int vectors and add + Vector prod32_1 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector prod32_2 = + prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + Vector norm1_32_1 = + norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector norm1_32_2 = + norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + Vector norm2_32_1 = + norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector norm2_32_2 = + norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + accSum1 = accSum1.add(prod32_1); + accSum2 = accSum2.add(prod32_2); + accNorm1_1 = accNorm1_1.add(norm1_32_1); + accNorm1_2 = accNorm1_2.add(norm1_32_2); + accNorm2_1 = accNorm2_1.add(norm2_32_1); + accNorm2_2 = accNorm2_2.add(norm2_32_2); + } + // reduce + sum += accSum1.add(accSum2).reduceLanes(VectorOperators.ADD); + norm1 += accNorm1_1.add(accNorm1_2).reduceLanes(VectorOperators.ADD); + norm2 += accNorm2_1.add(accNorm2_2).reduceLanes(VectorOperators.ADD); + } + } + + for (; i < a.length; i++) { + byte elem1 = a[i]; + byte elem2 = b[i]; + sum += elem1 * elem2; + norm1 += elem1 * elem1; + norm2 += elem2 * elem2; + } + return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); + } + + @Override + public int squareDistance(byte[] a, byte[] b) { + int i = 0; + int res = 0; + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors (256-bit on intel to dodge performance landmines) + if (a.length >= 16 && hasFastIntegerVectors) { + if (INT_SPECIES_PREF_BIT_SIZE >= 256) { + // optimized 256/512 bit implementation, processes 8/16 bytes at a time + int upperBound = PREF_BYTE_SPECIES.loopBound(a.length); + IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED); + for (; i < upperBound; i += PREF_BYTE_SPECIES.length()) { + ByteVector va8 = ByteVector.fromArray(PREF_BYTE_SPECIES, a, i); + ByteVector vb8 = ByteVector.fromArray(PREF_BYTE_SPECIES, b, i); + Vector va16 = va8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, PREF_SHORT_SPECIES, 0); + Vector diff16 = va16.sub(vb16); + Vector diff32 = + diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_PREFERRED, 0); + acc = acc.add(diff32.mul(diff32)); + } + // reduce + res += acc.reduceLanes(VectorOperators.ADD); + } else { + // 128-bit implementation, which must "split up" vectors due to widening conversions + int upperBound = ByteVector.SPECIES_64.loopBound(a.length); + IntVector acc1 = IntVector.zero(IntVector.SPECIES_128); + IntVector acc2 = IntVector.zero(IntVector.SPECIES_128); + for (; i < upperBound; i += ByteVector.SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i); + // expand each byte vector into short vector and subtract + Vector va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0); + Vector diff16 = va16.sub(vb16); + // split each short vector into two int vectors, square, and add + Vector diff32_1 = + diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0); + Vector diff32_2 = + diff16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1); + acc1 = acc1.add(diff32_1.mul(diff32_1)); + acc2 = acc2.add(diff32_2.mul(diff32_2)); + } + // reduce + res += acc1.add(acc2).reduceLanes(VectorOperators.ADD); + } + } + + for (; i < a.length; i++) { + int diff = a[i] - b[i]; + res += diff * diff; + } + return res; + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java new file mode 100644 index 000000000000..7b2216add785 --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java @@ -0,0 +1,588 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.store; + +import java.io.EOFException; +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.Objects; +import org.apache.lucene.util.ArrayUtil; + +/** + * Base IndexInput implementation that uses an array of MemorySegments to represent a file. + * + *

    For efficiency, this class requires that the segment size are a power-of-two ( + * chunkSizePower). + */ +@SuppressWarnings("preview") +abstract class MemorySegmentIndexInput extends IndexInput implements RandomAccessInput { + static final ValueLayout.OfByte LAYOUT_BYTE = ValueLayout.JAVA_BYTE; + static final ValueLayout.OfShort LAYOUT_LE_SHORT = + ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + static final ValueLayout.OfInt LAYOUT_LE_INT = + ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + static final ValueLayout.OfLong LAYOUT_LE_LONG = + ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = + ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + final long length; + final long chunkSizeMask; + final int chunkSizePower; + final Arena arena; + final MemorySegment[] segments; + + int curSegmentIndex = -1; + MemorySegment + curSegment; // redundant for speed: segments[curSegmentIndex], also marker if closed! + long curPosition; // relative to curSegment, not globally + + public static MemorySegmentIndexInput newInstance( + String resourceDescription, + Arena arena, + MemorySegment[] segments, + long length, + int chunkSizePower) { + assert Arrays.stream(segments).map(MemorySegment::scope).allMatch(arena.scope()::equals); + if (segments.length == 1) { + return new SingleSegmentImpl(resourceDescription, arena, segments[0], length, chunkSizePower); + } else { + return new MultiSegmentImpl(resourceDescription, arena, segments, 0, length, chunkSizePower); + } + } + + private MemorySegmentIndexInput( + String resourceDescription, + Arena arena, + MemorySegment[] segments, + long length, + int chunkSizePower) { + super(resourceDescription); + this.arena = arena; + this.segments = segments; + this.length = length; + this.chunkSizePower = chunkSizePower; + this.chunkSizeMask = (1L << chunkSizePower) - 1L; + this.curSegment = segments[0]; + } + + void ensureOpen() { + if (curSegment == null) { + throw alreadyClosed(null); + } + } + + // the unused parameter is just to silence javac about unused variables + RuntimeException handlePositionalIOOBE(RuntimeException unused, String action, long pos) + throws IOException { + if (pos < 0L) { + return new IllegalArgumentException(action + " negative position (pos=" + pos + "): " + this); + } else { + throw new EOFException(action + " past EOF (pos=" + pos + "): " + this); + } + } + + // the unused parameter is just to silence javac about unused variables + AlreadyClosedException alreadyClosed(RuntimeException unused) { + return new AlreadyClosedException("Already closed: " + this); + } + + @Override + public final byte readByte() throws IOException { + try { + final byte v = curSegment.get(LAYOUT_BYTE, curPosition); + curPosition++; + return v; + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException e) { + do { + curSegmentIndex++; + if (curSegmentIndex >= segments.length) { + throw new EOFException("read past EOF: " + this); + } + curSegment = segments[curSegmentIndex]; + curPosition = 0L; + } while (curSegment.byteSize() == 0L); + final byte v = curSegment.get(LAYOUT_BYTE, curPosition); + curPosition++; + return v; + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public final void readBytes(byte[] b, int offset, int len) throws IOException { + try { + MemorySegment.copy(curSegment, LAYOUT_BYTE, curPosition, b, offset, len); + curPosition += len; + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException e) { + readBytesBoundary(b, offset, len); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + private void readBytesBoundary(byte[] b, int offset, int len) throws IOException { + try { + long curAvail = curSegment.byteSize() - curPosition; + while (len > curAvail) { + MemorySegment.copy(curSegment, LAYOUT_BYTE, curPosition, b, offset, (int) curAvail); + len -= curAvail; + offset += curAvail; + curSegmentIndex++; + if (curSegmentIndex >= segments.length) { + throw new EOFException("read past EOF: " + this); + } + curSegment = segments[curSegmentIndex]; + curPosition = 0L; + curAvail = curSegment.byteSize(); + } + MemorySegment.copy(curSegment, LAYOUT_BYTE, curPosition, b, offset, len); + curPosition += len; + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public void readInts(int[] dst, int offset, int length) throws IOException { + try { + MemorySegment.copy(curSegment, LAYOUT_LE_INT, curPosition, dst, offset, length); + curPosition += Integer.BYTES * (long) length; + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException iobe) { + super.readInts(dst, offset, length); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public void readLongs(long[] dst, int offset, int length) throws IOException { + try { + MemorySegment.copy(curSegment, LAYOUT_LE_LONG, curPosition, dst, offset, length); + curPosition += Long.BYTES * (long) length; + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException iobe) { + super.readLongs(dst, offset, length); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public void readFloats(float[] dst, int offset, int length) throws IOException { + try { + MemorySegment.copy(curSegment, LAYOUT_LE_FLOAT, curPosition, dst, offset, length); + curPosition += Float.BYTES * (long) length; + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException iobe) { + super.readFloats(dst, offset, length); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public final short readShort() throws IOException { + try { + final short v = curSegment.get(LAYOUT_LE_SHORT, curPosition); + curPosition += Short.BYTES; + return v; + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException e) { + return super.readShort(); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public final int readInt() throws IOException { + try { + final int v = curSegment.get(LAYOUT_LE_INT, curPosition); + curPosition += Integer.BYTES; + return v; + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException e) { + return super.readInt(); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public final long readLong() throws IOException { + try { + final long v = curSegment.get(LAYOUT_LE_LONG, curPosition); + curPosition += Long.BYTES; + return v; + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException e) { + return super.readLong(); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public long getFilePointer() { + ensureOpen(); + return (((long) curSegmentIndex) << chunkSizePower) + curPosition; + } + + @Override + public void seek(long pos) throws IOException { + ensureOpen(); + // we use >> here to preserve negative, so we will catch AIOOBE, + // in case pos + offset overflows. + final int si = (int) (pos >> chunkSizePower); + try { + if (si != curSegmentIndex) { + final MemorySegment seg = segments[si]; + // write values, on exception all is unchanged + this.curSegmentIndex = si; + this.curSegment = seg; + } + this.curPosition = Objects.checkIndex(pos & chunkSizeMask, curSegment.byteSize() + 1); + } catch (IndexOutOfBoundsException e) { + throw handlePositionalIOOBE(e, "seek", pos); + } + } + + @Override + public byte readByte(long pos) throws IOException { + try { + final int si = (int) (pos >> chunkSizePower); + return segments[si].get(LAYOUT_BYTE, pos & chunkSizeMask); + } catch (IndexOutOfBoundsException ioobe) { + throw handlePositionalIOOBE(ioobe, "read", pos); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + // used only by random access methods to handle reads across boundaries + private void setPos(long pos, int si) throws IOException { + try { + final MemorySegment seg = segments[si]; + // write values, on exception above all is unchanged + this.curPosition = pos & chunkSizeMask; + this.curSegmentIndex = si; + this.curSegment = seg; + } catch (IndexOutOfBoundsException ioobe) { + throw handlePositionalIOOBE(ioobe, "read", pos); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public short readShort(long pos) throws IOException { + final int si = (int) (pos >> chunkSizePower); + try { + return segments[si].get(LAYOUT_LE_SHORT, pos & chunkSizeMask); + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException ioobe) { + // either it's a boundary, or read past EOF, fall back: + setPos(pos, si); + return readShort(); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public int readInt(long pos) throws IOException { + final int si = (int) (pos >> chunkSizePower); + try { + return segments[si].get(LAYOUT_LE_INT, pos & chunkSizeMask); + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException ioobe) { + // either it's a boundary, or read past EOF, fall back: + setPos(pos, si); + return readInt(); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public long readLong(long pos) throws IOException { + final int si = (int) (pos >> chunkSizePower); + try { + return segments[si].get(LAYOUT_LE_LONG, pos & chunkSizeMask); + } catch ( + @SuppressWarnings("unused") + IndexOutOfBoundsException ioobe) { + // either it's a boundary, or read past EOF, fall back: + setPos(pos, si); + return readLong(); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public final long length() { + return length; + } + + @Override + public final MemorySegmentIndexInput clone() { + final MemorySegmentIndexInput clone = buildSlice((String) null, 0L, this.length); + try { + clone.seek(getFilePointer()); + } catch (IOException ioe) { + throw new AssertionError(ioe); + } + + return clone; + } + + /** + * Creates a slice of this index input, with the given description, offset, and length. The slice + * is seeked to the beginning. + */ + @Override + public final MemorySegmentIndexInput slice(String sliceDescription, long offset, long length) { + if (offset < 0 || length < 0 || offset + length > this.length) { + throw new IllegalArgumentException( + "slice() " + + sliceDescription + + " out of bounds: offset=" + + offset + + ",length=" + + length + + ",fileLength=" + + this.length + + ": " + + this); + } + + return buildSlice(sliceDescription, offset, length); + } + + /** Builds the actual sliced IndexInput (may apply extra offset in subclasses). * */ + MemorySegmentIndexInput buildSlice(String sliceDescription, long offset, long length) { + ensureOpen(); + + final long sliceEnd = offset + length; + final int startIndex = (int) (offset >>> chunkSizePower); + final int endIndex = (int) (sliceEnd >>> chunkSizePower); + + // we always allocate one more slice, the last one may be a 0 byte one after truncating with + // asSlice(): + final MemorySegment slices[] = ArrayUtil.copyOfSubArray(segments, startIndex, endIndex + 1); + + // set the last segment's limit for the sliced view. + slices[slices.length - 1] = slices[slices.length - 1].asSlice(0L, sliceEnd & chunkSizeMask); + + offset = offset & chunkSizeMask; + + final String newResourceDescription = getFullSliceDescription(sliceDescription); + if (slices.length == 1) { + return new SingleSegmentImpl( + newResourceDescription, + null, // clones don't have an Arena, as they can't close) + slices[0].asSlice(offset, length), + length, + chunkSizePower); + } else { + return new MultiSegmentImpl( + newResourceDescription, + null, // clones don't have an Arena, as they can't close) + slices, + offset, + length, + chunkSizePower); + } + } + + @Override + public final void close() throws IOException { + if (curSegment == null) { + return; + } + + // make sure all accesses to this IndexInput instance throw NPE: + curSegment = null; + Arrays.fill(segments, null); + + // the master IndexInput has an Arena and is able + // to release all resources (unmap segments) - a + // side effect is that other threads still using clones + // will throw IllegalStateException + if (arena != null) { + arena.close(); + } + } + + /** Optimization of MemorySegmentIndexInput for when there is only one segment. */ + static final class SingleSegmentImpl extends MemorySegmentIndexInput { + + SingleSegmentImpl( + String resourceDescription, + Arena arena, + MemorySegment segment, + long length, + int chunkSizePower) { + super(resourceDescription, arena, new MemorySegment[] {segment}, length, chunkSizePower); + this.curSegmentIndex = 0; + } + + @Override + public void seek(long pos) throws IOException { + ensureOpen(); + try { + curPosition = Objects.checkIndex(pos, length + 1); + } catch (IndexOutOfBoundsException e) { + throw handlePositionalIOOBE(e, "seek", pos); + } + } + + @Override + public long getFilePointer() { + ensureOpen(); + return curPosition; + } + + @Override + public byte readByte(long pos) throws IOException { + try { + return curSegment.get(LAYOUT_BYTE, pos); + } catch (IndexOutOfBoundsException e) { + throw handlePositionalIOOBE(e, "read", pos); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public short readShort(long pos) throws IOException { + try { + return curSegment.get(LAYOUT_LE_SHORT, pos); + } catch (IndexOutOfBoundsException e) { + throw handlePositionalIOOBE(e, "read", pos); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public int readInt(long pos) throws IOException { + try { + return curSegment.get(LAYOUT_LE_INT, pos); + } catch (IndexOutOfBoundsException e) { + throw handlePositionalIOOBE(e, "read", pos); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + + @Override + public long readLong(long pos) throws IOException { + try { + return curSegment.get(LAYOUT_LE_LONG, pos); + } catch (IndexOutOfBoundsException e) { + throw handlePositionalIOOBE(e, "read", pos); + } catch (NullPointerException | IllegalStateException e) { + throw alreadyClosed(e); + } + } + } + + /** This class adds offset support to MemorySegmentIndexInput, which is needed for slices. */ + static final class MultiSegmentImpl extends MemorySegmentIndexInput { + private final long offset; + + MultiSegmentImpl( + String resourceDescription, + Arena arena, + MemorySegment[] segments, + long offset, + long length, + int chunkSizePower) { + super(resourceDescription, arena, segments, length, chunkSizePower); + this.offset = offset; + try { + seek(0L); + } catch (IOException ioe) { + throw new AssertionError(ioe); + } + assert curSegment != null && curSegmentIndex >= 0; + } + + @Override + RuntimeException handlePositionalIOOBE(RuntimeException unused, String action, long pos) + throws IOException { + return super.handlePositionalIOOBE(unused, action, pos - offset); + } + + @Override + public void seek(long pos) throws IOException { + assert pos >= 0L : "negative position"; + super.seek(pos + offset); + } + + @Override + public long getFilePointer() { + return super.getFilePointer() - offset; + } + + @Override + public byte readByte(long pos) throws IOException { + return super.readByte(pos + offset); + } + + @Override + public short readShort(long pos) throws IOException { + return super.readShort(pos + offset); + } + + @Override + public int readInt(long pos) throws IOException { + return super.readInt(pos + offset); + } + + @Override + public long readLong(long pos) throws IOException { + return super.readLong(pos + offset); + } + + @Override + MemorySegmentIndexInput buildSlice(String sliceDescription, long ofs, long length) { + return super.buildSlice(sliceDescription, this.offset + ofs, length); + } + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java new file mode 100644 index 000000000000..e994c2dddfff --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInputProvider.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.store; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.logging.Logger; +import org.apache.lucene.util.Constants; +import org.apache.lucene.util.Unwrappable; + +@SuppressWarnings("preview") +final class MemorySegmentIndexInputProvider implements MMapDirectory.MMapIndexInputProvider { + + public MemorySegmentIndexInputProvider() { + var log = Logger.getLogger(getClass().getName()); + log.info( + "Using MemorySegmentIndexInput with Java 21; to disable start with -D" + + MMapDirectory.ENABLE_MEMORY_SEGMENTS_SYSPROP + + "=false"); + } + + @Override + public IndexInput openInput(Path path, IOContext context, int chunkSizePower, boolean preload) + throws IOException { + final String resourceDescription = "MemorySegmentIndexInput(path=\"" + path.toString() + "\")"; + + // Work around for JDK-8259028: we need to unwrap our test-only file system layers + path = Unwrappable.unwrapAll(path); + + boolean success = false; + final Arena arena = Arena.ofShared(); + try (var fc = FileChannel.open(path, StandardOpenOption.READ)) { + final long fileSize = fc.size(); + final IndexInput in = + MemorySegmentIndexInput.newInstance( + resourceDescription, + arena, + map(arena, resourceDescription, fc, chunkSizePower, preload, fileSize), + fileSize, + chunkSizePower); + success = true; + return in; + } finally { + if (success == false) { + arena.close(); + } + } + } + + @Override + public long getDefaultMaxChunkSize() { + return Constants.JRE_IS_64BIT ? (1L << 34) : (1L << 28); + } + + @Override + public boolean isUnmapSupported() { + return true; + } + + @Override + public String getUnmapNotSupportedReason() { + return null; + } + + private final MemorySegment[] map( + Arena arena, + String resourceDescription, + FileChannel fc, + int chunkSizePower, + boolean preload, + long length) + throws IOException { + if ((length >>> chunkSizePower) >= Integer.MAX_VALUE) + throw new IllegalArgumentException("File too big for chunk size: " + resourceDescription); + + final long chunkSize = 1L << chunkSizePower; + + // we always allocate one more segments, the last one may be a 0 byte one + final int nrSegments = (int) (length >>> chunkSizePower) + 1; + + final MemorySegment[] segments = new MemorySegment[nrSegments]; + + long startOffset = 0L; + for (int segNr = 0; segNr < nrSegments; segNr++) { + final long segSize = + (length > (startOffset + chunkSize)) ? chunkSize : (length - startOffset); + final MemorySegment segment; + try { + segment = fc.map(MapMode.READ_ONLY, startOffset, segSize, arena); + } catch (IOException ioe) { + throw convertMapFailedIOException(ioe, resourceDescription, segSize); + } + if (preload) { + segment.load(); + } + segments[segNr] = segment; + startOffset += segSize; + } + return segments; + } +} diff --git a/lucene/core/src/java21/org/apache/lucene/util/VectorUtilPanamaProvider.txt b/lucene/core/src/java21/org/apache/lucene/util/VectorUtilPanamaProvider.txt new file mode 100644 index 000000000000..db75951206ef --- /dev/null +++ b/lucene/core/src/java21/org/apache/lucene/util/VectorUtilPanamaProvider.txt @@ -0,0 +1,2 @@ +The version of VectorUtilPanamaProvider for Java 21 is identical to that of Java 20. +As such, there is no specific 21 version - the Java 20 version will be loaded from the MRJAR. \ No newline at end of file diff --git a/lucene/core/src/test/org/apache/lucene/analysis/TestAutomatonToTokenStream.java b/lucene/core/src/test/org/apache/lucene/analysis/TestAutomatonToTokenStream.java index 27d2c72ddb09..d558ad6d3697 100644 --- a/lucene/core/src/test/org/apache/lucene/analysis/TestAutomatonToTokenStream.java +++ b/lucene/core/src/test/org/apache/lucene/analysis/TestAutomatonToTokenStream.java @@ -22,8 +22,8 @@ import java.util.List; import org.apache.lucene.tests.analysis.BaseTokenStreamTestCase; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.automaton.Automata; import org.apache.lucene.util.automaton.Automaton; -import org.apache.lucene.util.automaton.DaciukMihovAutomatonBuilder; public class TestAutomatonToTokenStream extends BaseTokenStreamTestCase { @@ -31,7 +31,7 @@ public void testSinglePath() throws IOException { List acceptStrings = new ArrayList<>(); acceptStrings.add(new BytesRef("abc")); - Automaton flatPathAutomaton = DaciukMihovAutomatonBuilder.build(acceptStrings); + Automaton flatPathAutomaton = Automata.makeStringUnion(acceptStrings); TokenStream ts = AutomatonToTokenStream.toTokenStream(flatPathAutomaton); assertTokenStreamContents( ts, @@ -48,7 +48,7 @@ public void testParallelPaths() throws IOException { acceptStrings.add(new BytesRef("123")); acceptStrings.add(new BytesRef("abc")); - Automaton flatPathAutomaton = DaciukMihovAutomatonBuilder.build(acceptStrings); + Automaton flatPathAutomaton = Automata.makeStringUnion(acceptStrings); TokenStream ts = AutomatonToTokenStream.toTokenStream(flatPathAutomaton); assertTokenStreamContents( ts, @@ -65,7 +65,7 @@ public void testForkedPath() throws IOException { acceptStrings.add(new BytesRef("ab3")); acceptStrings.add(new BytesRef("abc")); - Automaton flatPathAutomaton = DaciukMihovAutomatonBuilder.build(acceptStrings); + Automaton flatPathAutomaton = Automata.makeStringUnion(acceptStrings); TokenStream ts = AutomatonToTokenStream.toTokenStream(flatPathAutomaton); assertTokenStreamContents( ts, diff --git a/lucene/core/src/test/org/apache/lucene/analysis/TestWordlistLoader.java b/lucene/core/src/test/org/apache/lucene/analysis/TestWordlistLoader.java index 7af64c0011eb..4747c86834ed 100644 --- a/lucene/core/src/test/org/apache/lucene/analysis/TestWordlistLoader.java +++ b/lucene/core/src/test/org/apache/lucene/analysis/TestWordlistLoader.java @@ -24,7 +24,7 @@ public class TestWordlistLoader extends LuceneTestCase { public void testWordlistLoading() throws IOException { - String s = "ONE\n two \nthree"; + String s = "ONE\n two \nthree\n\n"; CharArraySet wordSet1 = WordlistLoader.getWordSet(new StringReader(s)); checkSet(wordSet1); CharArraySet wordSet2 = WordlistLoader.getWordSet(new BufferedReader(new StringReader(s))); diff --git a/lucene/core/src/test/org/apache/lucene/geo/TestTessellator.java b/lucene/core/src/test/org/apache/lucene/geo/TestTessellator.java index 950772f32b75..44c17b48440c 100644 --- a/lucene/core/src/test/org/apache/lucene/geo/TestTessellator.java +++ b/lucene/core/src/test/org/apache/lucene/geo/TestTessellator.java @@ -915,6 +915,28 @@ public void testComplexPolygon54() throws Exception { } } + public void testComplexPolygon55() throws Exception { + String geoJson = GeoTestUtil.readShape("github-12352-1.geojson.gz"); + Polygon[] polygons = Polygon.fromGeoJSON(geoJson); + for (Polygon polygon : polygons) { + List tessellation = + Tessellator.tessellate(polygon, random().nextBoolean()); + assertEquals(area(polygon), area(tessellation), 0.0); + // don't check edges as it takes several minutes + } + } + + public void testComplexPolygon56() throws Exception { + String geoJson = GeoTestUtil.readShape("github-12352-2.geojson.gz"); + Polygon[] polygons = Polygon.fromGeoJSON(geoJson); + for (Polygon polygon : polygons) { + List tessellation = + Tessellator.tessellate(polygon, random().nextBoolean()); + assertEquals(area(polygon), area(tessellation), 0.0); + // don't check edges as it takes several minutes + } + } + private static class TestCountingMonitor implements Tessellator.Monitor { private int count = 0; private int splitsStarted = 0; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestCachingMergeContext.java b/lucene/core/src/test/org/apache/lucene/index/TestCachingMergeContext.java new file mode 100644 index 000000000000..3210d54540d7 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/index/TestCachingMergeContext.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import java.io.IOException; +import java.util.Set; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.InfoStream; + +public class TestCachingMergeContext extends LuceneTestCase { + public void testNumDeletesToMerge() throws IOException { + MockMergeContext mergeContext = new MockMergeContext(); + CachingMergeContext cachingMergeContext = new CachingMergeContext(mergeContext); + assertEquals(cachingMergeContext.numDeletesToMerge(null), 1); + assertEquals(cachingMergeContext.cachedNumDeletesToMerge.size(), 1); + assertEquals( + cachingMergeContext.cachedNumDeletesToMerge.getOrDefault(null, -1), Integer.valueOf(1)); + assertEquals(mergeContext.count, 1); + + // increase the mock count + mergeContext.numDeletesToMerge(null); + assertEquals(mergeContext.count, 2); + + // assert the cache result still one + assertEquals(cachingMergeContext.numDeletesToMerge(null), 1); + assertEquals(cachingMergeContext.cachedNumDeletesToMerge.size(), 1); + assertEquals( + cachingMergeContext.cachedNumDeletesToMerge.getOrDefault(null, -1), Integer.valueOf(1)); + } + + private static final class MockMergeContext implements MergePolicy.MergeContext { + int count = 0; + + @Override + public final int numDeletesToMerge(SegmentCommitInfo info) throws IOException { + this.count += 1; + return this.count; + } + + @Override + public int numDeletedDocs(SegmentCommitInfo info) { + return 0; + } + + @Override + public InfoStream getInfoStream() { + return null; + } + + @Override + public Set getMergingSegments() { + return null; + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 5dc11a52fb49..4a6365b53094 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -40,6 +40,7 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.TestVectorUtil; /** * Test that uses a default/lucene Implementation of {@link QueryTimeout} to exit out long running @@ -463,13 +464,21 @@ public void testVectorValues() throws IOException { ExitingReaderException.class, () -> leaf.searchNearestVectors( - "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE)); + "vector", + TestVectorUtil.randomVector(dimension), + 5, + leaf.getLiveDocs(), + Integer.MAX_VALUE)); } else { DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); scanAndRetrieve(leaf, iter); leaf.searchNearestVectors( - "vector", new float[dimension], 5, leaf.getLiveDocs(), Integer.MAX_VALUE); + "vector", + TestVectorUtil.randomVector(dimension), + 5, + leaf.getLiveDocs(), + Integer.MAX_VALUE); } reader.close(); diff --git a/lucene/core/src/test/org/apache/lucene/index/TestFieldInfo.java b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfo.java new file mode 100644 index 000000000000..42f2c55fbf88 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfo.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestFieldInfo extends LuceneTestCase { + + public void testHandleLegacySupportedUpdatesValidIndexInfoChange() { + + FieldInfo fi1 = new FieldInfoTestBuilder().setIndexOptions(IndexOptions.NONE).get(); + FieldInfo fi2 = new FieldInfoTestBuilder().setOmitNorms(true).setStoreTermVector(false).get(); + + FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2); + assertNotNull(updatedFi); + assertEquals(updatedFi.getIndexOptions(), IndexOptions.DOCS); + assertFalse( + updatedFi + .hasNorms()); // fi2 is set to omitNorms and fi1 was not indexed, it's OK that the final + // FieldInfo returns hasNorms == false + compareAttributes(fi1, fi2, Set.of("getIndexOptions", "hasNorms")); + compareAttributes(fi1, updatedFi, Set.of("getIndexOptions", "hasNorms")); + compareAttributes(fi2, updatedFi, Set.of()); + + // The reverse return null since fi2 wouldn't change + assertNull(fi2.handleLegacySupportedUpdates(fi1)); + } + + public void testHandleLegacySupportedUpdatesInvalidIndexInfoChange() { + FieldInfo fi1 = new FieldInfoTestBuilder().setIndexOptions(IndexOptions.DOCS).get(); + FieldInfo fi2 = + new FieldInfoTestBuilder().setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS).get(); + assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi2)); + assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi1)); + } + + public void testHandleLegacySupportedUpdatesValidPointDimensionCount() { + FieldInfo fi1 = new FieldInfoTestBuilder().get(); + FieldInfo fi2 = new FieldInfoTestBuilder().setPointDimensionCount(2).setPointNumBytes(2).get(); + FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2); + assertNotNull(updatedFi); + assertEquals(2, updatedFi.getPointDimensionCount()); + compareAttributes(fi1, fi2, Set.of("getPointDimensionCount", "getPointNumBytes")); + compareAttributes(fi1, updatedFi, Set.of("getPointDimensionCount", "getPointNumBytes")); + compareAttributes(fi2, updatedFi, Set.of()); + + // The reverse return null since fi2 wouldn't change + assertNull(fi2.handleLegacySupportedUpdates(fi1)); + } + + public void testHandleLegacySupportedUpdatesInvalidPointDimensionCount() { + FieldInfo fi1 = new FieldInfoTestBuilder().setPointDimensionCount(3).setPointNumBytes(2).get(); + FieldInfo fi2 = new FieldInfoTestBuilder().setPointDimensionCount(2).setPointNumBytes(2).get(); + FieldInfo fi3 = new FieldInfoTestBuilder().setPointDimensionCount(2).setPointNumBytes(3).get(); + assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi2)); + assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi1)); + + assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi3)); + assertThrows(IllegalArgumentException.class, () -> fi3.handleLegacySupportedUpdates(fi1)); + + assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi3)); + assertThrows(IllegalArgumentException.class, () -> fi3.handleLegacySupportedUpdates(fi2)); + } + + public void testHandleLegacySupportedUpdatesValidPointIndexDimensionCount() { + FieldInfo fi1 = new FieldInfoTestBuilder().get(); + FieldInfo fi2 = + new FieldInfoTestBuilder() + .setPointIndexDimensionCount(2) + .setPointDimensionCount(2) + .setPointNumBytes(2) + .get(); + FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2); + assertNotNull(updatedFi); + assertEquals(2, updatedFi.getPointDimensionCount()); + compareAttributes( + fi1, + fi2, + Set.of("getPointDimensionCount", "getPointNumBytes", "getPointIndexDimensionCount")); + compareAttributes( + fi1, + updatedFi, + Set.of("getPointDimensionCount", "getPointNumBytes", "getPointIndexDimensionCount")); + compareAttributes(fi2, updatedFi, Set.of()); + + // The reverse return null since fi2 wouldn't change + assertNull(fi2.handleLegacySupportedUpdates(fi1)); + } + + public void testHandleLegacySupportedUpdatesInvalidPointIndexDimensionCount() { + FieldInfo fi1 = + new FieldInfoTestBuilder() + .setPointDimensionCount(2) + .setPointIndexDimensionCount(2) + .setPointNumBytes(2) + .get(); + FieldInfo fi2 = + new FieldInfoTestBuilder() + .setPointDimensionCount(2) + .setPointIndexDimensionCount(3) + .setPointNumBytes(2) + .get(); + assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi2)); + assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi1)); + } + + public void testHandleLegacySupportedUpdatesValidStoreTermVectors() { + FieldInfo fi1 = new FieldInfoTestBuilder().setStoreTermVector(false).get(); + FieldInfo fi2 = new FieldInfoTestBuilder().setStoreTermVector(true).get(); + FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2); + assertNotNull(updatedFi); + assertTrue(updatedFi.hasVectors()); + compareAttributes(fi1, fi2, Set.of("hasVectors")); + compareAttributes(fi1, updatedFi, Set.of("hasVectors")); + compareAttributes(fi2, updatedFi, Set.of()); + + // The reverse return null since fi2 wouldn't change + assertNull(fi2.handleLegacySupportedUpdates(fi1)); + } + + public void testHandleLegacySupportedUpdatesValidStorePayloads() { + FieldInfo fi1 = + new FieldInfoTestBuilder() + .setStorePayloads(false) + .setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS) + .get(); + FieldInfo fi2 = + new FieldInfoTestBuilder() + .setStorePayloads(true) + .setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS) + .get(); + FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2); + assertNotNull(updatedFi); + assertTrue(updatedFi.hasPayloads()); + compareAttributes(fi1, fi2, Set.of("hasPayloads")); + compareAttributes(fi1, updatedFi, Set.of("hasPayloads")); + compareAttributes(fi2, updatedFi, Set.of()); + + // The reverse return null since fi2 wouldn't change + assertNull(fi2.handleLegacySupportedUpdates(fi1)); + } + + public void testHandleLegacySupportedUpdatesValidOmitNorms() { + FieldInfo fi1 = + new FieldInfoTestBuilder().setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS).get(); + FieldInfo fi2 = + new FieldInfoTestBuilder() + .setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS) + .setOmitNorms(true) + .get(); + FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2); + assertNotNull(updatedFi); + assertFalse(updatedFi.hasNorms()); // Once norms are omitted, they are always omitted + compareAttributes(fi1, fi2, Set.of("hasNorms")); + compareAttributes(fi1, updatedFi, Set.of("hasNorms")); + compareAttributes(fi2, updatedFi, Set.of()); + + // The reverse return null since fi2 wouldn't change + assertNull(fi2.handleLegacySupportedUpdates(fi1)); + } + + public void testHandleLegacySupportedUpdatesValidDocValuesType() { + + FieldInfo fi1 = new FieldInfoTestBuilder().get(); + FieldInfo fi2 = new FieldInfoTestBuilder().setDocValues(DocValuesType.SORTED).setDvGen(1).get(); + + FieldInfo updatedFi = fi1.handleLegacySupportedUpdates(fi2); + assertNotNull(updatedFi); + assertEquals(DocValuesType.SORTED, updatedFi.getDocValuesType()); + assertEquals(1, updatedFi.getDocValuesGen()); + compareAttributes(fi1, fi2, Set.of("getDocValuesType", "getDocValuesGen")); + compareAttributes(fi1, updatedFi, Set.of("getDocValuesType", "getDocValuesGen")); + compareAttributes(fi2, updatedFi, Set.of()); + + // The reverse return null since fi2 wouldn't change + assertNull(fi2.handleLegacySupportedUpdates(fi1)); + + FieldInfo fi3 = new FieldInfoTestBuilder().setDocValues(DocValuesType.SORTED).setDvGen(2).get(); + // Changes in DocValues Generation only are ignored + assertNull(fi2.handleLegacySupportedUpdates(fi3)); + assertNull(fi3.handleLegacySupportedUpdates(fi2)); + } + + public void testHandleLegacySupportedUpdatesInvalidDocValuesTypeChange() { + FieldInfo fi1 = new FieldInfoTestBuilder().setDocValues(DocValuesType.SORTED).get(); + FieldInfo fi2 = new FieldInfoTestBuilder().setDocValues(DocValuesType.SORTED_SET).get(); + assertThrows(IllegalArgumentException.class, () -> fi1.handleLegacySupportedUpdates(fi2)); + assertThrows(IllegalArgumentException.class, () -> fi2.handleLegacySupportedUpdates(fi1)); + } + + private static class FieldInfoTestBuilder { + String name = "f1"; + int number = 1; + boolean storeTermVector = true; + boolean omitNorms = false; + boolean storePayloads = false; + IndexOptions indexOptions = IndexOptions.DOCS; + DocValuesType docValues = DocValuesType.NONE; + long dvGen = -1; + Map attributes = new HashMap<>(); + int pointDimensionCount = 0; + int pointIndexDimensionCount = 0; + int pointNumBytes = 0; + int vectorDimension = 10; + VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; + VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN; + boolean softDeletesField = false; + + FieldInfo get() { + return new FieldInfo( + name, + number, + storeTermVector, + omitNorms, + storePayloads, + indexOptions, + docValues, + dvGen, + attributes, + pointDimensionCount, + pointIndexDimensionCount, + pointNumBytes, + vectorDimension, + vectorEncoding, + vectorSimilarityFunction, + softDeletesField); + } + + public FieldInfoTestBuilder setStoreTermVector(boolean storeTermVector) { + this.storeTermVector = storeTermVector; + return this; + } + + public FieldInfoTestBuilder setOmitNorms(boolean omitNorms) { + this.omitNorms = omitNorms; + return this; + } + + public FieldInfoTestBuilder setStorePayloads(boolean storePayloads) { + this.storePayloads = storePayloads; + return this; + } + + public FieldInfoTestBuilder setIndexOptions(IndexOptions indexOptions) { + this.indexOptions = indexOptions; + return this; + } + + public FieldInfoTestBuilder setDocValues(DocValuesType docValues) { + this.docValues = docValues; + return this; + } + + public FieldInfoTestBuilder setDvGen(long dvGen) { + this.dvGen = dvGen; + return this; + } + + public FieldInfoTestBuilder setPointDimensionCount(int pointDimensionCount) { + this.pointDimensionCount = pointDimensionCount; + return this; + } + + public FieldInfoTestBuilder setPointIndexDimensionCount(int pointIndexDimensionCount) { + this.pointIndexDimensionCount = pointIndexDimensionCount; + return this; + } + + public FieldInfoTestBuilder setPointNumBytes(int pointNumBytes) { + this.pointNumBytes = pointNumBytes; + return this; + } + } + + private void compareAttributes(FieldInfo fi1, FieldInfo fi2, Set exclude) { + assertNotNull(fi1); + assertNotNull(fi2); + Arrays.stream(FieldInfo.class.getMethods()) + .filter( + m -> + (m.getName().startsWith("get") || m.getName().startsWith("has")) + && !m.getName().equals("hashCode") + && !exclude.contains(m.getName()) + && m.getParameterCount() == 0) + .forEach( + m -> { + try { + assertEquals( + "Unexpected difference in FieldInfo for method: " + m.getName(), + m.invoke(fi1), + m.invoke(fi2)); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + }); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java index b5a9ac5034b7..a4272cf56334 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java @@ -361,4 +361,43 @@ public void testRelaxConsistencyCheckForOldIndices() throws IOException { } } } + + public void testFieldInfosMergeBehaviorOnOldIndices() throws IOException { + try (Directory dir = newDirectory()) { + IndexWriterConfig config = + new IndexWriterConfig() + .setIndexCreatedVersionMajor(8) + .setMergeScheduler(new SerialMergeScheduler()) + .setOpenMode(IndexWriterConfig.OpenMode.CREATE); + FieldType ft1 = new FieldType(); + ft1.setIndexOptions(IndexOptions.NONE); + ft1.setStored(true); + FieldType ft2 = new FieldType(); + ft2.setIndexOptions(IndexOptions.DOCS); + ft2.setStored(true); + + try (IndexWriter writer = new IndexWriter(dir, config)) { + Document d1 = new Document(); + // Document 1 has my_field with IndexOptions.NONE + d1.add(new Field("my_field", "first", ft1)); + writer.addDocument(d1); + for (int i = 0; i < 100; i++) { + // Add some more docs to make sure segment 0 is the biggest one + Document d = new Document(); + d.add(new Field("foo", "bar" + i, ft2)); + writer.addDocument(d); + } + writer.flush(); + + Document d2 = new Document(); + // Document 2 has my_field with IndexOptions.DOCS + d2.add(new Field("my_field", "first", ft2)); + writer.addDocument(d2); + writer.flush(); + + writer.commit(); + writer.forceMerge(1); + } + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestIndexWriter.java b/lucene/core/src/test/org/apache/lucene/index/TestIndexWriter.java index e9a2fd9141dd..9c7b7c8a5290 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestIndexWriter.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestIndexWriter.java @@ -3473,7 +3473,12 @@ public int numDeletesToMerge( Document doc = new Document(); doc.add(new StringField("id", id, Field.Store.YES)); if (mixDeletes && random().nextBoolean()) { - writer.updateDocuments(new Term("id", id), Arrays.asList(doc, doc)); + if (random().nextBoolean()) { + writer.updateDocuments(new Term("id", id), Arrays.asList(doc, doc)); + } else { + writer.updateDocuments( + new TermQuery(new Term("id", id)), Arrays.asList(doc, doc)); + } } else { writer.softUpdateDocuments( new Term("id", id), diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 08f089430ba5..372384df5572 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -54,6 +54,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; +import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.junit.After; import org.junit.Before; @@ -62,7 +63,7 @@ public class TestKnnGraph extends LuceneTestCase { private static final String KNN_GRAPH_FIELD = "vector"; - private static int M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN; + private static int M = HnswGraphBuilder.DEFAULT_MAX_CONN; private Codec codec; private Codec float32Codec; @@ -80,7 +81,7 @@ public void setup() { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); } }; @@ -92,7 +93,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); } }; @@ -103,7 +104,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { new Lucene95Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene95HnswVectorsFormat(M, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); } }; } @@ -115,7 +116,7 @@ private VectorEncoding randomVectorEncoding() { @After public void cleanup() { - M = Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN; + M = HnswGraphBuilder.DEFAULT_MAX_CONN; } /** Basic test of creating documents in a graph */ diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index ce1fa6d7abbe..24de1a463c73 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException { IndexSearcher searcher = newSearcher(reader); AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10); IllegalArgumentException e = - expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10)); + expectThrows( + RuntimeException.class, + IllegalArgumentException.class, + () -> searcher.search(kvq, 10)); assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage()); } } @@ -236,7 +239,7 @@ public void testDifferentReader() throws IOException { getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); IndexReader reader = DirectoryReader.open(indexStore)) { AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3); - Query dasq = query.rewrite(reader); + Query dasq = query.rewrite(newSearcher(reader)); IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader()); expectThrows( IllegalStateException.class, @@ -256,7 +259,7 @@ public void testAdvanceShallow() throws IOException { try (IndexReader reader = DirectoryReader.open(d)) { IndexSearcher searcher = new IndexSearcher(reader); AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3); - Query dasq = query.rewrite(reader); + Query dasq = query.rewrite(searcher); Scorer scorer = dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0)); // before advancing the iterator @@ -283,7 +286,7 @@ public void testScoreEuclidean() throws IOException { IndexReader reader = DirectoryReader.open(d)) { IndexSearcher searcher = new IndexSearcher(reader); AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3); - Query rewritten = query.rewrite(reader); + Query rewritten = query.rewrite(searcher); Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1); Scorer scorer = weight.scorer(reader.leaves().get(0)); @@ -322,7 +325,7 @@ public void testScoreCosine() throws IOException { assertEquals(1, reader.leaves().size()); IndexSearcher searcher = new IndexSearcher(reader); AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {2, 3}, 3); - Query rewritten = query.rewrite(reader); + Query rewritten = query.rewrite(searcher); Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1); Scorer scorer = weight.scorer(reader.leaves().get(0)); @@ -529,6 +532,7 @@ public void testRandomWithFilter() throws IOException { assertEquals(9, results.totalHits.value); assertEquals(results.totalHits.value, results.scoreDocs.length); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -543,6 +547,7 @@ public void testRandomWithFilter() throws IOException { assertEquals(5, results.totalHits.value); assertEquals(results.totalHits.value, results.scoreDocs.length); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -570,6 +575,7 @@ public void testRandomWithFilter() throws IOException { // Test a filter that exhausts visitedLimit in upper levels, and switches to exact search Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( @@ -742,6 +748,7 @@ public void testBitSetQuery() throws IOException { Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs)); expectThrows( + RuntimeException.class, UnsupportedOperationException.class, () -> searcher.search( diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java index 74badd25e777..87b8068d2f1f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java @@ -83,7 +83,7 @@ public void testSingleFilterClause() throws IOException { query1.add(new TermQuery(new Term("field", "a")), Occur.FILTER); // Single clauses rewrite to a term query - final Query rewritten1 = query1.build().rewrite(reader); + final Query rewritten1 = query1.build().rewrite(searcher); assertTrue(rewritten1 instanceof BoostQuery); assertEquals(0f, ((BoostQuery) rewritten1).getBoost(), 0f); @@ -946,7 +946,7 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader indexReader) { + public Query rewrite(IndexSearcher indexSearcher) { numRewrites++; return this; } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestFieldExistsQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestFieldExistsQuery.java index 307d6800aa57..209ad510889b 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestFieldExistsQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestFieldExistsQuery.java @@ -63,7 +63,8 @@ public void testDocValuesRewriteWithTermsPresent() throws IOException { final IndexReader reader = iw.getReader(); iw.close(); - assertTrue((new FieldExistsQuery("f")).rewrite(reader) instanceof MatchAllDocsQuery); + assertTrue( + (new FieldExistsQuery("f")).rewrite(newSearcher(reader)) instanceof MatchAllDocsQuery); reader.close(); dir.close(); } @@ -82,7 +83,8 @@ public void testDocValuesRewriteWithPointValuesPresent() throws IOException { final IndexReader reader = iw.getReader(); iw.close(); - assertTrue(new FieldExistsQuery("dim").rewrite(reader) instanceof MatchAllDocsQuery); + assertTrue( + new FieldExistsQuery("dim").rewrite(newSearcher(reader)) instanceof MatchAllDocsQuery); reader.close(); dir.close(); } @@ -106,9 +108,10 @@ public void testDocValuesNoRewrite() throws IOException { iw.commit(); final IndexReader reader = iw.getReader(); iw.close(); + final IndexSearcher searcher = newSearcher(reader); - assertFalse((new FieldExistsQuery("dim")).rewrite(reader) instanceof MatchAllDocsQuery); - assertFalse((new FieldExistsQuery("f")).rewrite(reader) instanceof MatchAllDocsQuery); + assertFalse((new FieldExistsQuery("dim")).rewrite(searcher) instanceof MatchAllDocsQuery); + assertFalse((new FieldExistsQuery("f")).rewrite(searcher) instanceof MatchAllDocsQuery); reader.close(); dir.close(); } @@ -127,10 +130,11 @@ public void testDocValuesNoRewriteWithDocValues() throws IOException { iw.commit(); final IndexReader reader = iw.getReader(); iw.close(); + final IndexSearcher searcher = newSearcher(reader); - assertFalse((new FieldExistsQuery("dv1")).rewrite(reader) instanceof MatchAllDocsQuery); - assertFalse((new FieldExistsQuery("dv2")).rewrite(reader) instanceof MatchAllDocsQuery); - assertFalse((new FieldExistsQuery("dv3")).rewrite(reader) instanceof MatchAllDocsQuery); + assertFalse((new FieldExistsQuery("dv1")).rewrite(searcher) instanceof MatchAllDocsQuery); + assertFalse((new FieldExistsQuery("dv2")).rewrite(searcher) instanceof MatchAllDocsQuery); + assertFalse((new FieldExistsQuery("dv3")).rewrite(searcher) instanceof MatchAllDocsQuery); reader.close(); dir.close(); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java b/lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java index c3bcda9d47a6..7e8693bd8b31 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestIndexSearcher.java @@ -453,20 +453,15 @@ private void runSliceExecutorTest(ThreadPoolExecutor service, boolean useRandomS } } - private class RandomBlockingSliceExecutor extends SliceExecutor { + private static class RandomBlockingSliceExecutor extends SliceExecutor { - public RandomBlockingSliceExecutor(Executor executor) { + RandomBlockingSliceExecutor(Executor executor) { super(executor); } @Override - public void invokeAll(Collection tasks) { - - for (Runnable task : tasks) { - boolean shouldExecuteOnCallerThread = random().nextBoolean(); - - processTask(task, shouldExecuteOnCallerThread); - } + boolean shouldExecuteOnCallerThread(int index, int numTasks) { + return random().nextBoolean(); } } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestIndexSortSortedNumericDocValuesRangeQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestIndexSortSortedNumericDocValuesRangeQuery.java index 8df8dcb1917e..cc5d8800e8a2 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestIndexSortSortedNumericDocValuesRangeQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestIndexSortSortedNumericDocValuesRangeQuery.java @@ -339,7 +339,7 @@ public void testRewriteExhaustiveRange() throws IOException { IndexReader reader = writer.getReader(); Query query = createQuery("field", Long.MIN_VALUE, Long.MAX_VALUE); - Query rewrittenQuery = query.rewrite(reader); + Query rewrittenQuery = query.rewrite(newSearcher(reader)); assertEquals(new FieldExistsQuery("field"), rewrittenQuery); writer.close(); @@ -357,7 +357,7 @@ public void testRewriteFallbackQuery() throws IOException { Query fallbackQuery = new BooleanQuery.Builder().build(); Query query = new IndexSortSortedNumericDocValuesRangeQuery("field", 1, 42, fallbackQuery); - Query rewrittenQuery = query.rewrite(reader); + Query rewrittenQuery = query.rewrite(newSearcher(reader)); assertNotEquals(query, rewrittenQuery); assertThat(rewrittenQuery, instanceOf(IndexSortSortedNumericDocValuesRangeQuery.class)); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 684e260f12b7..8467ac506ae4 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -95,7 +95,7 @@ public void testScoreNegativeDotProduct() throws IOException { assertEquals(1, reader.leaves().size()); IndexSearcher searcher = new IndexSearcher(reader); AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {1, 0}, 2); - Query rewritten = query.rewrite(reader); + Query rewritten = query.rewrite(searcher); Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1); Scorer scorer = weight.scorer(reader.leaves().get(0)); @@ -126,7 +126,7 @@ public void testScoreDotProduct() throws IOException { IndexSearcher searcher = new IndexSearcher(reader); AbstractKnnVectorQuery query = getKnnVectorQuery("field", VectorUtil.l2normalize(new float[] {2, 3}), 3); - Query rewritten = query.rewrite(reader); + Query rewritten = query.rewrite(searcher); Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1); Scorer scorer = weight.scorer(reader.leaves().get(0)); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMatchNoDocsQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestMatchNoDocsQuery.java index e1facc81e61a..43db42cc1d34 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestMatchNoDocsQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestMatchNoDocsQuery.java @@ -44,7 +44,7 @@ public void testSimple() throws Exception { assertEquals(query.toString(), "MatchNoDocsQuery(\"\")"); query = new MatchNoDocsQuery("field 'title' not found"); assertEquals(query.toString(), "MatchNoDocsQuery(\"field 'title' not found\")"); - Query rewrite = query.rewrite(null); + Query rewrite = query.rewrite((IndexSearcher) null); assertTrue(rewrite instanceof MatchNoDocsQuery); assertEquals(rewrite.toString(), "MatchNoDocsQuery(\"field 'title' not found\")"); } @@ -87,7 +87,7 @@ public void testQuery() throws Exception { assertEquals(query.toString(), "key:one MatchNoDocsQuery(\"field not found\")"); assertEquals(searcher.count(query), 1); hits = searcher.search(query, 1000).scoreDocs; - Query rewrite = query.rewrite(ir); + Query rewrite = query.rewrite(searcher); assertEquals(1, hits.length); assertEquals(rewrite.toString(), "key:one"); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMatchesIterator.java b/lucene/core/src/test/org/apache/lucene/search/TestMatchesIterator.java index b62d1ec69f15..33acb57c8a5d 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestMatchesIterator.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestMatchesIterator.java @@ -499,7 +499,7 @@ public void testMinimalSeekingWithWildcards() throws IOException { SeekCountingLeafReader reader = new SeekCountingLeafReader(getOnlyLeafReader(this.reader)); this.searcher = new IndexSearcher(reader); Query query = new PrefixQuery(new Term(FIELD_WITH_OFFSETS, "w")); - Weight w = searcher.createWeight(query.rewrite(reader), ScoreMode.COMPLETE, 1); + Weight w = searcher.createWeight(query.rewrite(searcher), ScoreMode.COMPLETE, 1); // docs 0-3 match several different terms here, but we only seek to the first term and // then short-cut return; other terms are ignored until we try and iterate over matches diff --git a/lucene/core/src/test/org/apache/lucene/search/TestNGramPhraseQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestNGramPhraseQuery.java index f42a576ce54f..905b43b93ed1 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestNGramPhraseQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestNGramPhraseQuery.java @@ -29,6 +29,7 @@ public class TestNGramPhraseQuery extends LuceneTestCase { private static IndexReader reader; private static Directory directory; + private static IndexSearcher searcher; @BeforeClass public static void beforeClass() throws Exception { @@ -36,6 +37,7 @@ public static void beforeClass() throws Exception { RandomIndexWriter writer = new RandomIndexWriter(random(), directory); writer.close(); reader = DirectoryReader.open(directory); + searcher = new IndexSearcher(reader); } @AfterClass @@ -50,8 +52,8 @@ public void testRewrite() throws Exception { // bi-gram test ABC => AB/BC => AB/BC NGramPhraseQuery pq1 = new NGramPhraseQuery(2, new PhraseQuery("f", "AB", "BC")); - Query q = pq1.rewrite(reader); - assertSame(q.rewrite(reader), q); + Query q = pq1.rewrite(searcher); + assertSame(q.rewrite(searcher), q); PhraseQuery rewritten1 = (PhraseQuery) q; assertArrayEquals(new Term[] {new Term("f", "AB"), new Term("f", "BC")}, rewritten1.getTerms()); assertArrayEquals(new int[] {0, 1}, rewritten1.getPositions()); @@ -59,7 +61,7 @@ public void testRewrite() throws Exception { // bi-gram test ABCD => AB/BC/CD => AB//CD NGramPhraseQuery pq2 = new NGramPhraseQuery(2, new PhraseQuery("f", "AB", "BC", "CD")); - q = pq2.rewrite(reader); + q = pq2.rewrite(searcher); assertTrue(q instanceof PhraseQuery); assertNotSame(pq2, q); PhraseQuery rewritten2 = (PhraseQuery) q; @@ -70,7 +72,7 @@ public void testRewrite() throws Exception { NGramPhraseQuery pq3 = new NGramPhraseQuery(3, new PhraseQuery("f", "ABC", "BCD", "CDE", "DEF", "EFG", "FGH")); - q = pq3.rewrite(reader); + q = pq3.rewrite(searcher); assertTrue(q instanceof PhraseQuery); assertNotSame(pq3, q); PhraseQuery rewritten3 = (PhraseQuery) q; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java b/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java index aa2f3c892122..03b585710868 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestNeedsScores.java @@ -139,10 +139,10 @@ public Scorer scorer(LeafReaderContext context) throws IOException { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query in2 = in.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query in2 = in.rewrite(indexSearcher); if (in2 == in) { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } else { return new AssertNeedsScores(in2, value); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPhraseQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPhraseQuery.java index da93b3a19510..f721f5e4fe21 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestPhraseQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestPhraseQuery.java @@ -567,7 +567,7 @@ public void testEmptyPhraseQuery() throws Throwable { /* test that a single term is rewritten to a term query */ public void testRewrite() throws IOException { PhraseQuery pq = new PhraseQuery("foo", "bar"); - Query rewritten = pq.rewrite(searcher.getIndexReader()); + Query rewritten = pq.rewrite(searcher); assertTrue(rewritten instanceof TermQuery); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestQueryRewriteBackwardsCompatibility.java b/lucene/core/src/test/org/apache/lucene/search/TestQueryRewriteBackwardsCompatibility.java new file mode 100644 index 000000000000..a9534b7f191e --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestQueryRewriteBackwardsCompatibility.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestQueryRewriteBackwardsCompatibility extends LuceneTestCase { + + public void testQueryRewriteNoOverrides() throws IOException { + Directory directory = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), directory, newIndexWriterConfig()); + IndexReader reader = w.getReader(); + w.close(); + IndexSearcher searcher = newSearcher(reader); + Query queryNoOverrides = new TestQueryNoOverrides(); + assertSame(queryNoOverrides, searcher.rewrite(queryNoOverrides)); + assertSame(queryNoOverrides, queryNoOverrides.rewrite(searcher)); + assertSame(queryNoOverrides, queryNoOverrides.rewrite(reader)); + reader.close(); + directory.close(); + } + + public void testSingleQueryRewrite() throws IOException { + Directory directory = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), directory, newIndexWriterConfig()); + IndexReader reader = w.getReader(); + w.close(); + IndexSearcher searcher = newSearcher(reader); + + RewriteCountingQuery oldQuery = new OldQuery(null); + RewriteCountingQuery newQuery = new NewQuery(null); + + oldQuery.rewrite(searcher); + oldQuery.rewrite(reader); + + newQuery.rewrite(searcher); + newQuery.rewrite(reader); + + assertEquals(2, oldQuery.rewriteCount); + assertEquals(2, newQuery.rewriteCount); + reader.close(); + directory.close(); + } + + public void testNestedQueryRewrite() throws IOException { + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig()); + IndexReader reader = w.getReader(); + w.close(); + IndexSearcher searcher = newSearcher(reader); + + RewriteCountingQuery query = random().nextBoolean() ? new NewQuery(null) : new OldQuery(null); + + for (int i = 0; i < 5 + random().nextInt(5); i++) { + query = random().nextBoolean() ? new NewQuery(query) : new OldQuery(query); + } + + query.rewrite(searcher); + query.rewrite(reader); + + RewriteCountingQuery innerQuery = query; + while (innerQuery != null) { + assertEquals(2, innerQuery.rewriteCount); + innerQuery = innerQuery.getInnerQuery(); + } + reader.close(); + dir.close(); + } + + public void testRewriteQueryInheritance() throws IOException { + Directory directory = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), directory, newIndexWriterConfig()); + IndexReader reader = w.getReader(); + w.close(); + IndexSearcher searcher = newSearcher(reader); + NewRewritableCallingSuper oneRewrite = new NewRewritableCallingSuper(); + NewRewritableCallingSuper twoRewrites = new OldNewRewritableCallingSuper(); + NewRewritableCallingSuper threeRewrites = new OldOldNewRewritableCallingSuper(); + + searcher.rewrite(oneRewrite); + searcher.rewrite(twoRewrites); + searcher.rewrite(threeRewrites); + assertEquals(1, oneRewrite.rewriteCount); + assertEquals(2, twoRewrites.rewriteCount); + assertEquals(3, threeRewrites.rewriteCount); + + reader.close(); + directory.close(); + } + + private static class NewRewritableCallingSuper extends RewriteCountingQuery { + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + rewriteCount++; + return super.rewrite(searcher); + } + + @Override + public String toString(String field) { + return "NewRewritableCallingSuper"; + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public boolean equals(Object obj) { + return obj instanceof NewRewritableCallingSuper; + } + + @Override + public int hashCode() { + return 1; + } + + @Override + RewriteCountingQuery getInnerQuery() { + return null; + } + } + + private static class OldNewRewritableCallingSuper extends NewRewritableCallingSuper { + @Override + public Query rewrite(IndexReader reader) throws IOException { + rewriteCount++; + return super.rewrite(reader); + } + } + + private static class OldOldNewRewritableCallingSuper extends OldNewRewritableCallingSuper { + @Override + public Query rewrite(IndexReader reader) throws IOException { + rewriteCount++; + return super.rewrite(reader); + } + } + + private abstract static class RewriteCountingQuery extends Query { + int rewriteCount = 0; + + abstract RewriteCountingQuery getInnerQuery(); + } + + private static class OldQuery extends RewriteCountingQuery { + private final RewriteCountingQuery optionalSubQuery; + + private OldQuery(RewriteCountingQuery optionalSubQuery) { + this.optionalSubQuery = optionalSubQuery; + } + + @Override + public Query rewrite(IndexReader reader) throws IOException { + if (this.optionalSubQuery != null) { + this.optionalSubQuery.rewrite(reader); + } + rewriteCount++; + return this; + } + + @Override + public String toString(String field) { + return "OldQuery"; + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public boolean equals(Object obj) { + return obj instanceof OldQuery + && Objects.equals(((OldQuery) obj).optionalSubQuery, optionalSubQuery); + } + + @Override + public int hashCode() { + return 42 ^ Objects.hash(optionalSubQuery); + } + + @Override + RewriteCountingQuery getInnerQuery() { + return optionalSubQuery; + } + } + + private static class NewQuery extends RewriteCountingQuery { + private final RewriteCountingQuery optionalSubQuery; + + private NewQuery(RewriteCountingQuery optionalSubQuery) { + this.optionalSubQuery = optionalSubQuery; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + if (this.optionalSubQuery != null) { + this.optionalSubQuery.rewrite(searcher); + } + rewriteCount++; + return this; + } + + @Override + public String toString(String field) { + return "NewQuery"; + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public boolean equals(Object obj) { + return obj instanceof NewQuery + && Objects.equals(((NewQuery) obj).optionalSubQuery, optionalSubQuery); + } + + @Override + public int hashCode() { + return 73 ^ Objects.hash(optionalSubQuery); + } + + @Override + RewriteCountingQuery getInnerQuery() { + return optionalSubQuery; + } + } + + private static class TestQueryNoOverrides extends Query { + + private final int randomHash = random().nextInt(); + + @Override + public String toString(String field) { + return "TestQueryNoOverrides"; + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public boolean equals(Object obj) { + return obj instanceof TestQueryNoOverrides + && randomHash == ((TestQueryNoOverrides) obj).randomHash; + } + + @Override + public int hashCode() { + return randomHash; + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java index 123307e75f8e..ae026a87367f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSynonymQuery.java @@ -73,6 +73,18 @@ public void testEquals() { .addTerm(new Term("field", "c"), 0.2f) .addTerm(new Term("field", "d")) .build()); + + QueryUtils.checkUnequal( + new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.4f).build(), + new SynonymQuery.Builder("field").addTerm(new Term("field", "b"), 0.4f).build()); + + QueryUtils.checkUnequal( + new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.2f).build(), + new SynonymQuery.Builder("field").addTerm(new Term("field", "a"), 0.4f).build()); + + QueryUtils.checkUnequal( + new SynonymQuery.Builder("field1").addTerm(new Term("field1", "b"), 0.4f).build(), + new SynonymQuery.Builder("field2").addTerm(new Term("field2", "b"), 0.4f).build()); } public void testBogusParams() { @@ -127,6 +139,12 @@ public void testBogusParams() { () -> { new SynonymQuery.Builder("field1").addTerm(new Term("field1", "a"), -0f); }); + + expectThrows( + NullPointerException.class, + () -> new SynonymQuery.Builder(null).addTerm(new Term("field1", "a"), -0f)); + + expectThrows(NullPointerException.class, () -> new SynonymQuery.Builder(null).build()); } public void testToString() { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java index 936ae8f5834a..abc6d7138f36 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java @@ -950,12 +950,12 @@ public int hashCode() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query rewritten = query.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query rewritten = query.rewrite(indexSearcher); if (rewritten != query) { return new MaxScoreWrapperQuery(rewritten, maxRange, maxScore); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/store/TestBufferedIndexInput.java b/lucene/core/src/test/org/apache/lucene/store/TestBufferedIndexInput.java index 546a22493975..d19b2058e07b 100644 --- a/lucene/core/src/test/org/apache/lucene/store/TestBufferedIndexInput.java +++ b/lucene/core/src/test/org/apache/lucene/store/TestBufferedIndexInput.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Random; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; @@ -142,6 +143,72 @@ public void testEOF() throws Exception { }); } + // Test that when reading backwards, we page backwards rather than refilling + // on every call + public void testBackwardsByteReads() throws IOException { + MyBufferedIndexInput input = new MyBufferedIndexInput(1024 * 8); + for (int i = 2048; i > 0; i -= random().nextInt(16)) { + assertEquals(byten(i), input.readByte(i)); + } + assertEquals(3, input.readCount); + } + + public void testBackwardsShortReads() throws IOException { + MyBufferedIndexInput input = new MyBufferedIndexInput(1024 * 8); + ByteBuffer bb = ByteBuffer.allocate(2); + bb.order(ByteOrder.LITTLE_ENDIAN); + for (int i = 2048; i > 0; i -= (random().nextInt(16) + 1)) { + bb.clear(); + bb.put(byten(i)); + bb.put(byten(i + 1)); + assertEquals(bb.getShort(0), input.readShort(i)); + } + // readCount can be three or four, depending on whether or not we had to adjust the bufferStart + // to include a whole short + assertTrue( + "Expected 4 or 3, got " + input.readCount, input.readCount == 4 || input.readCount == 3); + } + + public void testBackwardsIntReads() throws IOException { + MyBufferedIndexInput input = new MyBufferedIndexInput(1024 * 8); + ByteBuffer bb = ByteBuffer.allocate(4); + bb.order(ByteOrder.LITTLE_ENDIAN); + for (int i = 2048; i > 0; i -= (random().nextInt(16) + 3)) { + bb.clear(); + bb.put(byten(i)); + bb.put(byten(i + 1)); + bb.put(byten(i + 2)); + bb.put(byten(i + 3)); + assertEquals(bb.getInt(0), input.readInt(i)); + } + // readCount can be three or four, depending on whether or not we had to adjust the bufferStart + // to include a whole int + assertTrue( + "Expected 4 or 3, got " + input.readCount, input.readCount == 4 || input.readCount == 3); + } + + public void testBackwardsLongReads() throws IOException { + MyBufferedIndexInput input = new MyBufferedIndexInput(1024 * 8); + ByteBuffer bb = ByteBuffer.allocate(8); + bb.order(ByteOrder.LITTLE_ENDIAN); + for (int i = 2048; i > 0; i -= (random().nextInt(16) + 7)) { + bb.clear(); + bb.put(byten(i)); + bb.put(byten(i + 1)); + bb.put(byten(i + 2)); + bb.put(byten(i + 3)); + bb.put(byten(i + 4)); + bb.put(byten(i + 5)); + bb.put(byten(i + 6)); + bb.put(byten(i + 7)); + assertEquals(bb.getLong(0), input.readLong(i)); + } + // readCount can be three or four, depending on whether or not we had to adjust the bufferStart + // to include a whole long + assertTrue( + "Expected 4 or 3, got " + input.readCount, input.readCount == 4 || input.readCount == 3); + } + // byten emulates a file - byten(n) returns the n'th byte in that file. // MyBufferedIndexInput reads this "file". private static byte byten(long n) { @@ -150,7 +217,8 @@ private static byte byten(long n) { private static class MyBufferedIndexInput extends BufferedIndexInput { private long pos; - private long len; + private final long len; + private long readCount = 0; public MyBufferedIndexInput(long len) { super("MyBufferedIndexInput(len=" + len + ")", BufferedIndexInput.BUFFER_SIZE); @@ -164,14 +232,15 @@ public MyBufferedIndexInput() { } @Override - protected void readInternal(ByteBuffer b) throws IOException { + protected void readInternal(ByteBuffer b) { + readCount++; while (b.hasRemaining()) { b.put(byten(pos++)); } } @Override - protected void seekInternal(long pos) throws IOException { + protected void seekInternal(long pos) { this.pos = pos; } diff --git a/lucene/core/src/test/org/apache/lucene/store/TestMmapDirectory.java b/lucene/core/src/test/org/apache/lucene/store/TestMmapDirectory.java index dc638f6a529e..5597161a3a98 100644 --- a/lucene/core/src/test/org/apache/lucene/store/TestMmapDirectory.java +++ b/lucene/core/src/test/org/apache/lucene/store/TestMmapDirectory.java @@ -48,9 +48,9 @@ private static boolean isMemorySegmentImpl() { public void testCorrectImplementation() { final int runtimeVersion = Runtime.version().feature(); - if (runtimeVersion == 19 || runtimeVersion == 20) { + if (runtimeVersion >= 19 && runtimeVersion <= 21) { assertTrue( - "on Java 19 and Java 20 we should use MemorySegmentIndexInputProvider to create mmap IndexInputs", + "on Java 19, 20, and 21 we should use MemorySegmentIndexInputProvider to create mmap IndexInputs", isMemorySegmentImpl()); } else { assertSame(MappedByteBufferIndexInputProvider.class, MMapDirectory.PROVIDER.getClass()); diff --git a/lucene/core/src/test/org/apache/lucene/util/TestUnicodeUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestUnicodeUtil.java index 8e0348b706c6..2f309e945b74 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestUnicodeUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestUnicodeUtil.java @@ -166,6 +166,28 @@ public void testUTF8toUTF32() { } } + public void testUTF8CodePointAt() { + int num = atLeast(50000); + UnicodeUtil.UTF8CodePoint reuse = null; + for (int i = 0; i < num; i++) { + final String s = TestUtil.randomUnicodeString(random()); + final byte[] utf8 = new byte[UnicodeUtil.maxUTF8Length(s.length())]; + final int utf8Len = UnicodeUtil.UTF16toUTF8(s, 0, s.length(), utf8); + + int[] expected = s.codePoints().toArray(); + int pos = 0; + int expectedUpto = 0; + while (pos < utf8Len) { + reuse = UnicodeUtil.codePointAt(utf8, pos, reuse); + assertEquals(expected[expectedUpto], reuse.codePoint); + expectedUpto++; + pos += reuse.numBytes; + } + assertEquals(utf8Len, pos); + assertEquals(expected.length, expectedUpto); + } + } + public void testNewString() { final int[] codePoints = { Character.toCodePoint(Character.MIN_HIGH_SURROGATE, Character.MAX_LOW_SURROGATE), diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java new file mode 100644 index 000000000000..be1205890bd8 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtilProviders.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import java.util.function.ToDoubleFunction; +import java.util.function.ToIntFunction; +import java.util.stream.IntStream; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.BeforeClass; + +public class TestVectorUtilProviders extends LuceneTestCase { + + private static final double DELTA = 1e-3; + private static final VectorUtilProvider LUCENE_PROVIDER = new VectorUtilDefaultProvider(); + private static final VectorUtilProvider JDK_PROVIDER = VectorUtilProvider.lookup(true); + + private static final int[] VECTOR_SIZES = { + 1, 4, 6, 8, 13, 16, 25, 32, 64, 100, 128, 207, 256, 300, 512, 702, 1024 + }; + + private final int size; + + public TestVectorUtilProviders(int size) { + this.size = size; + } + + @ParametersFactory + public static Iterable parametersFactory() { + return () -> IntStream.of(VECTOR_SIZES).boxed().map(i -> new Object[] {i}).iterator(); + } + + @BeforeClass + public static void beforeClass() throws Exception { + assumeFalse( + "Test only works when JDK's vector incubator module is enabled.", + JDK_PROVIDER instanceof VectorUtilDefaultProvider); + } + + public void testFloatVectors() { + var a = new float[size]; + var b = new float[size]; + for (int i = 0; i < size; ++i) { + a[i] = random().nextFloat(); + b[i] = random().nextFloat(); + } + assertFloatReturningProviders(p -> p.dotProduct(a, b)); + assertFloatReturningProviders(p -> p.squareDistance(a, b)); + assertFloatReturningProviders(p -> p.cosine(a, b)); + } + + public void testBinaryVectors() { + var a = new byte[size]; + var b = new byte[size]; + random().nextBytes(a); + random().nextBytes(b); + assertIntReturningProviders(p -> p.dotProduct(a, b)); + assertIntReturningProviders(p -> p.squareDistance(a, b)); + assertFloatReturningProviders(p -> p.cosine(a, b)); + } + + private void assertFloatReturningProviders(ToDoubleFunction func) { + assertEquals(func.applyAsDouble(LUCENE_PROVIDER), func.applyAsDouble(JDK_PROVIDER), DELTA); + } + + private void assertIntReturningProviders(ToIntFunction func) { + assertEquals(func.applyAsInt(LUCENE_PROVIDER), func.applyAsInt(JDK_PROVIDER)); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/automaton/TestCompiledAutomaton.java b/lucene/core/src/test/org/apache/lucene/util/automaton/TestCompiledAutomaton.java index 5d63b8051cd0..4640bea2dc2c 100644 --- a/lucene/core/src/test/org/apache/lucene/util/automaton/TestCompiledAutomaton.java +++ b/lucene/core/src/test/org/apache/lucene/util/automaton/TestCompiledAutomaton.java @@ -35,7 +35,7 @@ private CompiledAutomaton build(int determinizeWorkLimit, String... strings) { terms.add(new BytesRef(s)); } Collections.sort(terms); - final Automaton a = DaciukMihovAutomatonBuilder.build(terms); + final Automaton a = Automata.makeStringUnion(terms); return new CompiledAutomaton(a, true, false, determinizeWorkLimit, false); } diff --git a/lucene/core/src/test/org/apache/lucene/util/automaton/TestDaciukMihovAutomatonBuilder.java b/lucene/core/src/test/org/apache/lucene/util/automaton/TestDaciukMihovAutomatonBuilder.java index 4b31eda98c6d..d72e08f18e98 100644 --- a/lucene/core/src/test/org/apache/lucene/util/automaton/TestDaciukMihovAutomatonBuilder.java +++ b/lucene/core/src/test/org/apache/lucene/util/automaton/TestDaciukMihovAutomatonBuilder.java @@ -16,26 +16,185 @@ */ package org.apache.lucene.util.automaton; +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.generators.RandomNumbers; +import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.BytesRefIterator; +import org.apache.lucene.util.IntsRef; +import org.apache.lucene.util.fst.Util; public class TestDaciukMihovAutomatonBuilder extends LuceneTestCase { - public void testLargeTerms() { + public void testBasic() throws Exception { + List terms = basicTerms(); + Collections.sort(terms); + + Automaton a = build(terms, false); + checkAutomaton(terms, a, false); + checkMinimized(a); + } + + public void testBasicBinary() throws Exception { + List terms = basicTerms(); + Collections.sort(terms); + + Automaton a = build(terms, true); + checkAutomaton(terms, a, true); + checkMinimized(a); + } + + public void testRandomMinimized() throws Exception { + int iters = RandomizedTest.isNightly() ? 20 : 5; + for (int i = 0; i < iters; i++) { + boolean buildBinary = random().nextBoolean(); + int size = RandomNumbers.randomIntBetween(random(), 2, 50); + Set terms = new HashSet<>(); + List automatonList = new ArrayList<>(size); + for (int j = 0; j < size; j++) { + if (buildBinary) { + BytesRef t = TestUtil.randomBinaryTerm(random(), 8); + terms.add(t); + automatonList.add(Automata.makeBinary(t)); + } else { + String s = TestUtil.randomRealisticUnicodeString(random(), 8); + terms.add(newBytesRef(s)); + automatonList.add(Automata.makeString(s)); + } + } + List sortedTerms = terms.stream().sorted().collect(Collectors.toList()); + + Automaton expected = + MinimizationOperations.minimize( + Operations.union(automatonList), Operations.DEFAULT_DETERMINIZE_WORK_LIMIT); + Automaton actual = build(sortedTerms, buildBinary); + assertSameAutomaton(expected, actual); + } + } + + public void testRandomUnicodeOnly() throws Exception { + testRandom(false); + } + + public void testRandomBinary() throws Exception { + testRandom(true); + } + + public void testLargeTerms() throws Exception { byte[] b10k = new byte[10_000]; Arrays.fill(b10k, (byte) 'a'); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> DaciukMihovAutomatonBuilder.build(Collections.singleton(new BytesRef(b10k)))); + () -> build(Collections.singleton(new BytesRef(b10k)), false)); assertTrue( e.getMessage() - .startsWith("This builder doesn't allow terms that are larger than 1,000 characters")); + .startsWith( + "This builder doesn't allow terms that are larger than " + + Automata.MAX_STRING_UNION_TERM_LENGTH + + " characters")); byte[] b1k = ArrayUtil.copyOfSubArray(b10k, 0, 1000); - DaciukMihovAutomatonBuilder.build(Collections.singleton(new BytesRef(b1k))); // no exception + build(Collections.singleton(new BytesRef(b1k)), false); // no exception + } + + private void testRandom(boolean allowBinary) throws Exception { + int iters = RandomizedTest.isNightly() ? 50 : 10; + for (int i = 0; i < iters; i++) { + int size = RandomNumbers.randomIntBetween(random(), 500, 2_000); + Set terms = new HashSet<>(size); + for (int j = 0; j < size; j++) { + if (allowBinary && random().nextInt(10) < 2) { + // Sometimes random bytes term that isn't necessarily valid unicode + terms.add(newBytesRef(TestUtil.randomBinaryTerm(random()))); + } else { + terms.add(newBytesRef(TestUtil.randomRealisticUnicodeString(random()))); + } + } + + List sorted = terms.stream().sorted().collect(Collectors.toList()); + Automaton a = build(sorted, allowBinary); + checkAutomaton(sorted, a, allowBinary); + } + } + + private void checkAutomaton(List expected, Automaton a, boolean isBinary) { + CompiledAutomaton c = + new CompiledAutomaton(a, true, false, Operations.DEFAULT_DETERMINIZE_WORK_LIMIT, isBinary); + ByteRunAutomaton runAutomaton = c.runAutomaton; + + // Make sure every expected term is accepted + for (BytesRef t : expected) { + String readable = isBinary ? t.toString() : t.utf8ToString(); + assertTrue( + readable + " should be found but wasn't", runAutomaton.run(t.bytes, t.offset, t.length)); + } + + // Make sure every term produced by the automaton is expected + BytesRefBuilder scratch = new BytesRefBuilder(); + FiniteStringsIterator it = new FiniteStringsIterator(c.automaton); + for (IntsRef r = it.next(); r != null; r = it.next()) { + BytesRef t = Util.toBytesRef(r, scratch); + assertTrue(expected.contains(t)); + } + } + + private void checkMinimized(Automaton a) { + Automaton minimized = + MinimizationOperations.minimize(a, Operations.DEFAULT_DETERMINIZE_WORK_LIMIT); + assertSameAutomaton(minimized, a); + } + + private static void assertSameAutomaton(Automaton a, Automaton b) { + assertEquals(a.getNumStates(), b.getNumStates()); + assertEquals(a.getNumTransitions(), b.getNumTransitions()); + assertTrue(Operations.sameLanguage(a, b)); + } + + private List basicTerms() { + List terms = new ArrayList<>(); + terms.add(newBytesRef("dog")); + terms.add(newBytesRef("day")); + terms.add(newBytesRef("dad")); + terms.add(newBytesRef("cats")); + terms.add(newBytesRef("cat")); + return terms; + } + + private Automaton build(Collection terms, boolean asBinary) throws IOException { + if (random().nextBoolean()) { + return DaciukMihovAutomatonBuilder.build(terms, asBinary); + } else { + return DaciukMihovAutomatonBuilder.build(new TermIterator(terms), asBinary); + } + } + + private static final class TermIterator implements BytesRefIterator { + private final Iterator it; + + TermIterator(Collection terms) { + this.it = terms.iterator(); + } + + @Override + public BytesRef next() throws IOException { + if (it.hasNext() == false) { + return null; + } + return it.next(); + } } } diff --git a/lucene/core/src/test/org/apache/lucene/util/automaton/TestOperations.java b/lucene/core/src/test/org/apache/lucene/util/automaton/TestOperations.java index 62913c166598..b8ef5645438b 100644 --- a/lucene/core/src/test/org/apache/lucene/util/automaton/TestOperations.java +++ b/lucene/core/src/test/org/apache/lucene/util/automaton/TestOperations.java @@ -17,6 +17,7 @@ package org.apache.lucene.util.automaton; import static org.apache.lucene.util.automaton.Operations.DEFAULT_DETERMINIZE_WORK_LIMIT; +import static org.apache.lucene.util.automaton.Operations.topoSortStates; import com.carrotsearch.randomizedtesting.generators.RandomNumbers; import java.util.ArrayList; @@ -69,6 +70,42 @@ public void testEmptyLanguageConcatenate() { assertTrue(Operations.isEmpty(concat)); } + /** + * Test case for the topoSortStates method when the input Automaton contains a cycle. This test + * case constructs an Automaton with two disjoint sets of states—one without a cycle and one with + * a cycle. The topoSortStates method should detect the presence of a cycle and throw an + * IllegalArgumentException. + */ + public void testCycledAutomaton() { + Automaton a = generateRandomAutomaton(true); + IllegalArgumentException exc = + expectThrows(IllegalArgumentException.class, () -> topoSortStates(a)); + assertTrue(exc.getMessage().contains("Input automaton has a cycle")); + } + + public void testTopoSortStates() { + Automaton a = generateRandomAutomaton(false); + + int[] sorted = topoSortStates(a); + int[] stateMap = new int[a.getNumStates()]; + Arrays.fill(stateMap, -1); + int order = 0; + for (int state : sorted) { + assertEquals(-1, stateMap[state]); + stateMap[state] = (order++); + } + + Transition transition = new Transition(); + for (int state : sorted) { + int count = a.initTransition(state, transition); + for (int i = 0; i < count; i++) { + a.getNextTransition(transition); + // ensure dest's order is higher than current state + assertTrue(stateMap[transition.dest] > stateMap[state]); + } + } + } + /** Test optimization to concatenate() with empty String to an NFA */ public void testEmptySingletonNFAConcatenate() { Automaton singleton = Automata.makeString(""); @@ -144,19 +181,6 @@ public void testIsFiniteEatsStack() { assertTrue(exc.getMessage().contains("input automaton is too large")); } - public void testTopoSortEatsStack() { - char[] chars = new char[50000]; - TestUtil.randomFixedLengthUnicodeString(random(), chars, 0, chars.length); - String bigString1 = new String(chars); - TestUtil.randomFixedLengthUnicodeString(random(), chars, 0, chars.length); - String bigString2 = new String(chars); - Automaton a = - Operations.union(Automata.makeString(bigString1), Automata.makeString(bigString2)); - IllegalArgumentException exc = - expectThrows(IllegalArgumentException.class, () -> Operations.topoSortStates(a)); - assertTrue(exc.getMessage().contains("input automaton is too large")); - } - /** * Returns the set of all accepted strings. * @@ -190,4 +214,52 @@ private static Set getFiniteStrings(FiniteStringsIterator iterator) { return result; } + + /** + * This method creates a random Automaton by generating states at multiple levels. At each level, + * a random number of states are created, and transitions are added between the states of the + * current and the previous level randomly, If the 'hasCycle' parameter is true, a transition is + * added from the first state of the last level back to the initial state to create a cycle in the + * Automaton.. + * + * @param hasCycle if true, the generated Automaton will have a cycle; if false, it won't have a + * cycle. + * @return a randomly generated Automaton instance. + */ + private Automaton generateRandomAutomaton(boolean hasCycle) { + Automaton a = new Automaton(); + List lastLevelStates = new ArrayList<>(); + int initialState = a.createState(); + int maxLevel = TestUtil.nextInt(random(), 4, 9); + lastLevelStates.add(initialState); + + for (int level = 1; level < maxLevel; level++) { + int numStates = TestUtil.nextInt(random(), 3, 9); + List nextLevelStates = new ArrayList<>(); + + for (int i = 0; i < numStates; i++) { + int nextState = a.createState(); + nextLevelStates.add(nextState); + } + + for (int lastState : lastLevelStates) { + for (int nextState : nextLevelStates) { + // if hasCycle is enabled, we will always add a transition, so we could make sure the + // generated Automaton has a cycle. + if (hasCycle || random().nextInt(7) >= 1) { + a.addTransition(lastState, nextState, random().nextInt(10)); + } + } + } + lastLevelStates = nextLevelStates; + } + + if (hasCycle) { + int lastState = lastLevelStates.get(0); + a.addTransition(lastState, initialState, random().nextInt(10)); + } + + a.finishState(); + return a; + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/graph/TestGraphTokenStreamFiniteStrings.java b/lucene/core/src/test/org/apache/lucene/util/graph/TestGraphTokenStreamFiniteStrings.java index e68b892d01ce..4df3a0eb029a 100644 --- a/lucene/core/src/test/org/apache/lucene/util/graph/TestGraphTokenStreamFiniteStrings.java +++ b/lucene/core/src/test/org/apache/lucene/util/graph/TestGraphTokenStreamFiniteStrings.java @@ -16,6 +16,7 @@ */ package org.apache.lucene.util.graph; +import java.util.ArrayList; import java.util.Iterator; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; @@ -660,4 +661,27 @@ public void testMultipleSidePathsWithGaps() throws Exception { it.next(), new String[] {"king", "alfred", "saxons", "ruled"}, new int[] {1, 1, 3, 1}); assertFalse(it.hasNext()); } + + public void testLongTokenStreamStackOverflowError() throws Exception { + + ArrayList tokens = + new ArrayList() { + { + add(token("fast", 1, 1)); + add(token("wi", 1, 1)); + add(token("wifi", 0, 2)); + add(token("fi", 1, 1)); + } + }; + + // Add in too many tokens to get a high depth graph + for (int i = 0; i < 1024 + 1; i++) { + tokens.add(token("network", 1, 1)); + } + + TokenStream ts = new CannedTokenStream(tokens.toArray(new Token[0])); + GraphTokenStreamFiniteStrings graph = new GraphTokenStreamFiniteStrings(ts); + + assertThrows(IllegalArgumentException.class, graph::articulationPoints); + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 63aebc40dd75..a6ddb84f2da3 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -32,6 +32,12 @@ import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene95.Lucene95Codec; @@ -66,6 +72,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.NamedThreadFactory; import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; @@ -990,6 +997,115 @@ public void testRandom() throws IOException { assertTrue("overlap=" + overlap, overlap > 0.9); } + /* test thread-safety of searching OnHeapHnswGraph */ + @SuppressWarnings("unchecked") + public void testOnHeapHnswGraphSearch() + throws IOException, ExecutionException, InterruptedException, TimeoutException { + int size = atLeast(100); + int dim = atLeast(10); + AbstractMockVectorValues vectors = vectorValues(size, dim); + int topK = 5; + HnswGraphBuilder builder = + HnswGraphBuilder.create( + vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong()); + OnHeapHnswGraph hnsw = builder.build(vectors.copy()); + Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); + + List queries = new ArrayList<>(); + List expects = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + NeighborQueue expect = null; + T query = randomVector(dim); + queries.add(query); + switch (getVectorEncoding()) { + case BYTE: + expect = + HnswGraphSearcher.search( + (byte[]) query, + 100, + (RandomAccessVectorValues) vectors, + getVectorEncoding(), + similarityFunction, + hnsw, + acceptOrds, + Integer.MAX_VALUE); + break; + case FLOAT32: + expect = + HnswGraphSearcher.search( + (float[]) query, + 100, + (RandomAccessVectorValues) vectors, + getVectorEncoding(), + similarityFunction, + hnsw, + acceptOrds, + Integer.MAX_VALUE); + } + ; + while (expect.size() > topK) { + expect.pop(); + } + expects.add(expect); + } + + ExecutorService exec = + Executors.newFixedThreadPool(4, new NamedThreadFactory("onHeapHnswSearch")); + List> futures = new ArrayList<>(); + for (T query : queries) { + futures.add( + exec.submit( + () -> { + NeighborQueue actual = null; + try { + + switch (getVectorEncoding()) { + case BYTE: + actual = + HnswGraphSearcher.search( + (byte[]) query, + 100, + (RandomAccessVectorValues) vectors, + getVectorEncoding(), + similarityFunction, + hnsw, + acceptOrds, + Integer.MAX_VALUE); + break; + case FLOAT32: + actual = + HnswGraphSearcher.search( + (float[]) query, + 100, + (RandomAccessVectorValues) vectors, + getVectorEncoding(), + similarityFunction, + hnsw, + acceptOrds, + Integer.MAX_VALUE); + } + ; + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + while (actual.size() > topK) { + actual.pop(); + } + return actual; + })); + } + List actuals = new ArrayList<>(); + for (Future future : futures) { + actuals.add(future.get(10, TimeUnit.SECONDS)); + } + exec.shutdownNow(); + for (int i = 0; i < expects.size(); i++) { + NeighborQueue expect = expects.get(i); + NeighborQueue actual = actuals.get(i); + assertArrayEquals(expect.nodes(), actual.nodes()); + } + } + private int computeOverlap(int[] a, int[] b) { Arrays.sort(a); Arrays.sort(b); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java index b8ae24f62009..c81077aa6daf 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java @@ -23,100 +23,160 @@ public class TestNeighborArray extends LuceneTestCase { public void testScoresDescOrder() { NeighborArray neighbors = new NeighborArray(10, true); - neighbors.add(0, 1); - neighbors.add(1, 0.8f); + neighbors.addInOrder(0, 1); + neighbors.addInOrder(1, 0.8f); - AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.9f)); - assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.9f)); + assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 3, 1}, neighbors); neighbors.insertSorted(4, 1f); assertScoresEqual(new float[] {1, 1, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 4, 3, 1}, neighbors); neighbors.insertSorted(5, 1.1f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 3, 1}, neighbors); neighbors.insertSorted(6, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6}, neighbors); neighbors.insertSorted(7, 0.8f); assertScoresEqual(new float[] {1.1f, 1, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 3, 1, 6, 7}, neighbors); neighbors.removeIndex(2); assertScoresEqual(new float[] {1.1f, 1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {5, 0, 3, 1, 6, 7}, neighbors); neighbors.removeIndex(0); assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors); + assertNodesEqual(new int[] {0, 3, 1, 6, 7}, neighbors); neighbors.removeIndex(4); assertScoresEqual(new float[] {1, 0.9f, 0.8f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 1, 6}, neighbors); + assertNodesEqual(new int[] {0, 3, 1, 6}, neighbors); neighbors.removeLast(); assertScoresEqual(new float[] {1, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 1}, neighbors); + assertNodesEqual(new int[] {0, 3, 1}, neighbors); neighbors.insertSorted(8, 0.9f); assertScoresEqual(new float[] {1, 0.9f, 0.9f, 0.8f}, neighbors); - asserNodesEqual(new int[] {0, 3, 8, 1}, neighbors); + assertNodesEqual(new int[] {0, 3, 8, 1}, neighbors); } public void testScoresAscOrder() { NeighborArray neighbors = new NeighborArray(10, false); - neighbors.add(0, 0.1f); - neighbors.add(1, 0.3f); + neighbors.addInOrder(0, 0.1f); + neighbors.addInOrder(1, 0.3f); - AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.add(2, 0.15f)); - assertEquals("Nodes are added in the incorrect order!", ex.getMessage()); + AssertionError ex = expectThrows(AssertionError.class, () -> neighbors.addInOrder(2, 0.15f)); + assert ex.getMessage().startsWith("Nodes are added in the incorrect order!") : ex.getMessage(); neighbors.insertSorted(3, 0.3f); assertScoresEqual(new float[] {0.1f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {0, 1, 3}, neighbors); + assertNodesEqual(new int[] {0, 1, 3}, neighbors); neighbors.insertSorted(4, 0.2f); assertScoresEqual(new float[] {0.1f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {0, 4, 1, 3}, neighbors); + assertNodesEqual(new int[] {0, 4, 1, 3}, neighbors); neighbors.insertSorted(5, 0.05f); assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 1, 3}, neighbors); neighbors.insertSorted(6, 0.2f); assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 6, 1, 3}, neighbors); neighbors.insertSorted(7, 0.2f); assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors); + assertNodesEqual(new int[] {5, 0, 4, 6, 7, 1, 3}, neighbors); neighbors.removeIndex(2); assertScoresEqual(new float[] {0.05f, 0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors); + assertNodesEqual(new int[] {5, 0, 6, 7, 1, 3}, neighbors); neighbors.removeIndex(0); assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f, 0.3f}, neighbors); - asserNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors); + assertNodesEqual(new int[] {0, 6, 7, 1, 3}, neighbors); neighbors.removeIndex(4); assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f, 0.3f}, neighbors); - asserNodesEqual(new int[] {0, 6, 7, 1}, neighbors); + assertNodesEqual(new int[] {0, 6, 7, 1}, neighbors); neighbors.removeLast(); assertScoresEqual(new float[] {0.1f, 0.2f, 0.2f}, neighbors); - asserNodesEqual(new int[] {0, 6, 7}, neighbors); + assertNodesEqual(new int[] {0, 6, 7}, neighbors); neighbors.insertSorted(8, 0.01f); assertScoresEqual(new float[] {0.01f, 0.1f, 0.2f, 0.2f}, neighbors); - asserNodesEqual(new int[] {8, 0, 6, 7}, neighbors); + assertNodesEqual(new int[] {8, 0, 6, 7}, neighbors); + } + + public void testSortAsc() { + NeighborArray neighbors = new NeighborArray(10, false); + neighbors.addOutOfOrder(1, 2); + // we disallow calling addInOrder after addOutOfOrder even if they're actual in order + expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2)); + neighbors.addOutOfOrder(2, 3); + neighbors.addOutOfOrder(5, 6); + neighbors.addOutOfOrder(3, 4); + neighbors.addOutOfOrder(7, 8); + neighbors.addOutOfOrder(6, 7); + neighbors.addOutOfOrder(4, 5); + int[] unchecked = neighbors.sort(); + assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); + assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); + assertScoresEqual(new float[] {2, 3, 4, 5, 6, 7, 8}, neighbors); + + NeighborArray neighbors2 = new NeighborArray(10, false); + neighbors2.addInOrder(0, 1); + neighbors2.addInOrder(1, 2); + neighbors2.addInOrder(4, 5); + neighbors2.addOutOfOrder(2, 3); + neighbors2.addOutOfOrder(6, 7); + neighbors2.addOutOfOrder(5, 6); + neighbors2.addOutOfOrder(3, 4); + unchecked = neighbors2.sort(); + assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked); + assertNodesEqual(new int[] {0, 1, 2, 3, 4, 5, 6}, neighbors2); + assertScoresEqual(new float[] {1, 2, 3, 4, 5, 6, 7}, neighbors2); + } + + public void testSortDesc() { + NeighborArray neighbors = new NeighborArray(10, true); + neighbors.addOutOfOrder(1, 7); + // we disallow calling addInOrder after addOutOfOrder even if they're actual in order + expectThrows(AssertionError.class, () -> neighbors.addInOrder(1, 2)); + neighbors.addOutOfOrder(2, 6); + neighbors.addOutOfOrder(5, 3); + neighbors.addOutOfOrder(3, 5); + neighbors.addOutOfOrder(7, 1); + neighbors.addOutOfOrder(6, 2); + neighbors.addOutOfOrder(4, 4); + int[] unchecked = neighbors.sort(); + assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); + assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); + assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); + + NeighborArray neighbors2 = new NeighborArray(10, true); + neighbors2.addInOrder(1, 7); + neighbors2.addInOrder(2, 6); + neighbors2.addInOrder(5, 3); + neighbors2.addOutOfOrder(3, 5); + neighbors2.addOutOfOrder(7, 1); + neighbors2.addOutOfOrder(6, 2); + neighbors2.addOutOfOrder(4, 4); + unchecked = neighbors2.sort(); + assertArrayEquals(new int[] {2, 3, 5, 6}, unchecked); + assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors2); + assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors2); } private void assertScoresEqual(float[] scores, NeighborArray neighbors) { @@ -125,7 +185,7 @@ private void assertScoresEqual(float[] scores, NeighborArray neighbors) { } } - private void asserNodesEqual(int[] nodes, NeighborArray neighbors) { + private void assertNodesEqual(int[] nodes, NeighborArray neighbors) { for (int i = 0; i < nodes.length; i++) { assertEquals(nodes[i], neighbors.node[i]); } diff --git a/lucene/distribution.tests/src/test/org/apache/lucene/distribution/TestModularLayer.java b/lucene/distribution.tests/src/test/org/apache/lucene/distribution/TestModularLayer.java index b3e55f277706..a5b8f6409e93 100644 --- a/lucene/distribution.tests/src/test/org/apache/lucene/distribution/TestModularLayer.java +++ b/lucene/distribution.tests/src/test/org/apache/lucene/distribution/TestModularLayer.java @@ -206,7 +206,7 @@ public void testMultiReleaseJar() { ClassLoader loader = layer.findLoader(coreModuleId); - final Set jarVersions = Set.of(19, 20); + final Set jarVersions = Set.of(19, 20, 21); for (var v : jarVersions) { Assertions.assertThat( loader.getResource( diff --git a/lucene/facet/src/java/org/apache/lucene/facet/DrillDownQuery.java b/lucene/facet/src/java/org/apache/lucene/facet/DrillDownQuery.java index a4d47aa8c129..9fe4b15b9072 100644 --- a/lucene/facet/src/java/org/apache/lucene/facet/DrillDownQuery.java +++ b/lucene/facet/src/java/org/apache/lucene/facet/DrillDownQuery.java @@ -24,11 +24,11 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -100,8 +100,8 @@ public DrillDownQuery(FacetsConfig config) { /** * Creates a new {@code DrillDownQuery} over the given base query. Can be {@code null}, in which - * case the result {@link Query} from {@link #rewrite(IndexReader)} will be a pure browsing query, - * filtering on the added categories only. + * case the result {@link Query} from {@link Query#rewrite(IndexSearcher)} will be a pure browsing + * query, filtering on the added categories only. */ public DrillDownQuery(FacetsConfig config, Query baseQuery) { this.baseQuery = baseQuery; @@ -156,7 +156,7 @@ private boolean equalsTo(DrillDownQuery other) { } @Override - public Query rewrite(IndexReader r) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { BooleanQuery rewritten = getBooleanQuery(); if (rewritten.clauses().isEmpty()) { return new MatchAllDocsQuery(); diff --git a/lucene/facet/src/java/org/apache/lucene/facet/DrillSidewaysQuery.java b/lucene/facet/src/java/org/apache/lucene/facet/DrillSidewaysQuery.java index 81643019ad93..88850df057ca 100644 --- a/lucene/facet/src/java/org/apache/lucene/facet/DrillSidewaysQuery.java +++ b/lucene/facet/src/java/org/apache/lucene/facet/DrillSidewaysQuery.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Objects; import org.apache.lucene.facet.DrillSidewaysScorer.DocsAndCost; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.ConstantScoreScorer; @@ -80,8 +79,8 @@ class DrillSidewaysQuery extends Query { } /** - * Needed for {@link #rewrite(IndexReader)}. Ensures the same "managed" lists get used since - * {@link DrillSideways} accesses references to these through the original {@code + * Needed for {@link Query#rewrite(IndexSearcher)}. Ensures the same "managed" lists get used + * since {@link DrillSideways} accesses references to these through the original {@code * DrillSidewaysQuery}. */ private DrillSidewaysQuery( @@ -107,17 +106,17 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { Query newQuery = baseQuery; while (true) { - Query rewrittenQuery = newQuery.rewrite(reader); + Query rewrittenQuery = newQuery.rewrite(indexSearcher); if (rewrittenQuery == newQuery) { break; } newQuery = rewrittenQuery; } if (newQuery == baseQuery) { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } else { return new DrillSidewaysQuery( newQuery, diff --git a/lucene/facet/src/java/org/apache/lucene/facet/range/DoubleRange.java b/lucene/facet/src/java/org/apache/lucene/facet/range/DoubleRange.java index bf188a9d80f1..3824baca4d34 100644 --- a/lucene/facet/src/java/org/apache/lucene/facet/range/DoubleRange.java +++ b/lucene/facet/src/java/org/apache/lucene/facet/range/DoubleRange.java @@ -20,7 +20,6 @@ import java.util.Objects; import org.apache.lucene.facet.MultiDoubleValues; import org.apache.lucene.facet.MultiDoubleValuesSource; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; @@ -154,14 +153,14 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (fastMatchQuery != null) { - final Query fastMatchRewritten = fastMatchQuery.rewrite(reader); + final Query fastMatchRewritten = fastMatchQuery.rewrite(indexSearcher); if (fastMatchRewritten != fastMatchQuery) { return new ValueSourceQuery(range, fastMatchRewritten, valueSource); } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override @@ -252,14 +251,14 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (fastMatchQuery != null) { - final Query fastMatchRewritten = fastMatchQuery.rewrite(reader); + final Query fastMatchRewritten = fastMatchQuery.rewrite(indexSearcher); if (fastMatchRewritten != fastMatchQuery) { return new MultiValueSourceQuery(range, fastMatchRewritten, valueSource); } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/facet/src/java/org/apache/lucene/facet/range/LongRange.java b/lucene/facet/src/java/org/apache/lucene/facet/range/LongRange.java index 63e991904378..6796780d010d 100644 --- a/lucene/facet/src/java/org/apache/lucene/facet/range/LongRange.java +++ b/lucene/facet/src/java/org/apache/lucene/facet/range/LongRange.java @@ -20,7 +20,6 @@ import java.util.Objects; import org.apache.lucene.facet.MultiLongValues; import org.apache.lucene.facet.MultiLongValuesSource; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; @@ -141,14 +140,14 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (fastMatchQuery != null) { - final Query fastMatchRewritten = fastMatchQuery.rewrite(reader); + final Query fastMatchRewritten = fastMatchQuery.rewrite(indexSearcher); if (fastMatchRewritten != fastMatchQuery) { return new ValueSourceQuery(range, fastMatchRewritten, valueSource); } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override @@ -239,14 +238,14 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (fastMatchQuery != null) { - final Query fastMatchRewritten = fastMatchQuery.rewrite(reader); + final Query fastMatchRewritten = fastMatchQuery.rewrite(indexSearcher); if (fastMatchRewritten != fastMatchQuery) { return new MultiValueSourceQuery(range, fastMatchRewritten, valuesSource); } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/facet/src/java/org/apache/lucene/facet/range/OverlappingLongRangeCounter.java b/lucene/facet/src/java/org/apache/lucene/facet/range/OverlappingLongRangeCounter.java index 3c9ae7ebd7cb..b478e3994e11 100644 --- a/lucene/facet/src/java/org/apache/lucene/facet/range/OverlappingLongRangeCounter.java +++ b/lucene/facet/src/java/org/apache/lucene/facet/range/OverlappingLongRangeCounter.java @@ -84,7 +84,7 @@ void startMultiValuedDoc() { if (multiValuedDocElementaryIntervalHits == null) { multiValuedDocElementaryIntervalHits = new FixedBitSet(boundaries.length); } else { - multiValuedDocElementaryIntervalHits.clear(0, multiValuedDocElementaryIntervalHits.length()); + multiValuedDocElementaryIntervalHits.clear(); } } @@ -103,7 +103,7 @@ boolean endMultiValuedDoc() { if (multiValuedDocRangeHits == null) { multiValuedDocRangeHits = new FixedBitSet(rangeCount()); } else { - multiValuedDocRangeHits.clear(0, multiValuedDocRangeHits.length()); + multiValuedDocRangeHits.clear(); } elementaryIntervalUpto = 0; rollupMultiValued(root); diff --git a/lucene/facet/src/test/org/apache/lucene/facet/TestDrillDownQuery.java b/lucene/facet/src/test/org/apache/lucene/facet/TestDrillDownQuery.java index fdac9935eeef..d606009d79f0 100644 --- a/lucene/facet/src/test/org/apache/lucene/facet/TestDrillDownQuery.java +++ b/lucene/facet/src/test/org/apache/lucene/facet/TestDrillDownQuery.java @@ -255,7 +255,8 @@ public void testClone() throws Exception { public void testNoDrillDown() throws Exception { Query base = new MatchAllDocsQuery(); DrillDownQuery q = new DrillDownQuery(config, base); - Query rewrite = q.rewrite(reader).rewrite(reader); + IndexSearcher searcher = newSearcher(reader); + Query rewrite = q.rewrite(searcher).rewrite(searcher); assertEquals(base, rewrite); } diff --git a/lucene/facet/src/test/org/apache/lucene/facet/TestDrillSideways.java b/lucene/facet/src/test/org/apache/lucene/facet/TestDrillSideways.java index eb12c2b42eed..1280e2f75114 100644 --- a/lucene/facet/src/test/org/apache/lucene/facet/TestDrillSideways.java +++ b/lucene/facet/src/test/org/apache/lucene/facet/TestDrillSideways.java @@ -740,7 +740,7 @@ protected FacetsCollectorManager createDrillDownFacetsCollectorManager() { Query baseQuery = new TermQuery(new Term("content", "foo")) { @Override - public Query rewrite(IndexReader reader) { + public Query rewrite(IndexSearcher indexSearcher) { // return a new instance, forcing the DrillDownQuery to also rewrite itself, exposing // the bug in LUCENE-9988: return new TermQuery(getTerm()); diff --git a/lucene/facet/src/test/org/apache/lucene/facet/range/TestRangeFacetCounts.java b/lucene/facet/src/test/org/apache/lucene/facet/range/TestRangeFacetCounts.java index 97cf11c8e1ac..ca86e83e41b1 100644 --- a/lucene/facet/src/test/org/apache/lucene/facet/range/TestRangeFacetCounts.java +++ b/lucene/facet/src/test/org/apache/lucene/facet/range/TestRangeFacetCounts.java @@ -1450,12 +1450,12 @@ public int hashCode() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query inRewritten = in.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query inRewritten = in.rewrite(indexSearcher); if (in != inRewritten) { return new UsedQuery(inRewritten, used); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java index da77a97ac9fb..7fbe6fbd080d 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java @@ -250,7 +250,7 @@ protected void extract(Query query, float boost, Map t if (query instanceof MultiTermQuery) { rewritten = MultiTermQuery.SCORING_BOOLEAN_REWRITE.rewrite(reader, (MultiTermQuery) query); } else { - rewritten = origQuery.rewrite(reader); + rewritten = origQuery.rewrite(new IndexSearcher(reader)); } if (rewritten != origQuery) { // only rewrite once and then flatten again - the rewritten query could have a special diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/uhighlight/MemoryIndexOffsetStrategy.java b/lucene/highlighter/src/java/org/apache/lucene/search/uhighlight/MemoryIndexOffsetStrategy.java index aad68190c644..851c4639df80 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/uhighlight/MemoryIndexOffsetStrategy.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/uhighlight/MemoryIndexOffsetStrategy.java @@ -30,7 +30,7 @@ import org.apache.lucene.index.memory.MemoryIndex; import org.apache.lucene.queries.spans.SpanQuery; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.automaton.DaciukMihovAutomatonBuilder; +import org.apache.lucene.util.automaton.Automata; /** * Uses an {@link Analyzer} on content to get offsets and then populates a {@link MemoryIndex}. @@ -67,7 +67,7 @@ private static CharArrayMatcher buildCombinedAutomaton(UHComponents components) // to build an automaton on them List filteredTerms = Arrays.stream(components.getTerms()) - .filter(b -> b.length < DaciukMihovAutomatonBuilder.MAX_TERM_LENGTH) + .filter(b -> b.length < Automata.MAX_STRING_UNION_TERM_LENGTH) .collect(Collectors.toList()); allAutomata.add(CharArrayMatcher.fromTerms(filteredTerms)); } diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/uhighlight/UnifiedHighlighter.java b/lucene/highlighter/src/java/org/apache/lucene/search/uhighlight/UnifiedHighlighter.java index 0dee795ebc28..615db5ecd5ff 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/uhighlight/UnifiedHighlighter.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/uhighlight/UnifiedHighlighter.java @@ -1252,11 +1252,11 @@ protected FieldOffsetStrategy getOffsetStrategy( /** * When highlighting phrases accurately, we need to know which {@link SpanQuery}'s need to have - * {@link Query#rewrite(IndexReader)} called on them. It helps performance to avoid it if it's not - * needed. This method will be invoked on all SpanQuery instances recursively. If you have custom - * SpanQuery queries then override this to check instanceof and provide a definitive answer. If - * the query isn't your custom one, simply return null to have the default rules apply, which - * govern the ones included in Lucene. + * {@link Query#rewrite(IndexSearcher)} called on them. It helps performance to avoid it if it's + * not needed. This method will be invoked on all SpanQuery instances recursively. If you have + * custom SpanQuery queries then override this to check instanceof and provide a definitive + * answer. If the query isn't your custom one, simply return null to have the default rules apply, + * which govern the ones included in Lucene. */ protected Boolean requiresRewrite(SpanQuery spanQuery) { return null; diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/vectorhighlight/FieldQuery.java b/lucene/highlighter/src/java/org/apache/lucene/search/vectorhighlight/FieldQuery.java index ff5a11c47d66..851197e42d5e 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/vectorhighlight/FieldQuery.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/vectorhighlight/FieldQuery.java @@ -33,6 +33,7 @@ import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.DisjunctionMaxQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MultiTermQuery; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; @@ -65,8 +66,14 @@ public FieldQuery(Query query, IndexReader reader, boolean phraseHighlight, bool throws IOException { this.fieldMatch = fieldMatch; Set flatQueries = new LinkedHashSet<>(); - flatten(query, reader, flatQueries, 1f); - saveTerms(flatQueries, reader); + IndexSearcher searcher; + if (reader == null) { + searcher = null; + } else { + searcher = new IndexSearcher(reader); + } + flatten(query, searcher, flatQueries, 1f); + saveTerms(flatQueries, searcher); Collection expandQueries = expand(flatQueries); for (Query flatQuery : expandQueries) { @@ -96,7 +103,7 @@ public FieldQuery(Query query, IndexReader reader, boolean phraseHighlight, bool } protected void flatten( - Query sourceQuery, IndexReader reader, Collection flatQueries, float boost) + Query sourceQuery, IndexSearcher searcher, Collection flatQueries, float boost) throws IOException { while (sourceQuery instanceof BoostQuery) { BoostQuery bq = (BoostQuery) sourceQuery; @@ -107,13 +114,13 @@ protected void flatten( BooleanQuery bq = (BooleanQuery) sourceQuery; for (BooleanClause clause : bq) { if (!clause.isProhibited()) { - flatten(clause.getQuery(), reader, flatQueries, boost); + flatten(clause.getQuery(), searcher, flatQueries, boost); } } } else if (sourceQuery instanceof DisjunctionMaxQuery) { DisjunctionMaxQuery dmq = (DisjunctionMaxQuery) sourceQuery; for (Query query : dmq) { - flatten(query, reader, flatQueries, boost); + flatten(query, searcher, flatQueries, boost); } } else if (sourceQuery instanceof TermQuery) { if (boost != 1f) { @@ -123,7 +130,7 @@ protected void flatten( } else if (sourceQuery instanceof SynonymQuery) { SynonymQuery synQuery = (SynonymQuery) sourceQuery; for (Term term : synQuery.getTerms()) { - flatten(new TermQuery(term), reader, flatQueries, boost); + flatten(new TermQuery(term), searcher, flatQueries, boost); } } else if (sourceQuery instanceof PhraseQuery) { PhraseQuery pq = (PhraseQuery) sourceQuery; @@ -135,28 +142,28 @@ protected void flatten( } else if (sourceQuery instanceof ConstantScoreQuery) { final Query q = ((ConstantScoreQuery) sourceQuery).getQuery(); if (q != null) { - flatten(q, reader, flatQueries, boost); + flatten(q, searcher, flatQueries, boost); } } else if (sourceQuery instanceof FunctionScoreQuery) { final Query q = ((FunctionScoreQuery) sourceQuery).getWrappedQuery(); if (q != null) { - flatten(q, reader, flatQueries, boost); + flatten(q, searcher, flatQueries, boost); } - } else if (reader != null) { + } else if (searcher != null) { Query query = sourceQuery; Query rewritten; if (sourceQuery instanceof MultiTermQuery) { rewritten = new MultiTermQuery.TopTermsScoringBooleanQueryRewrite(MAX_MTQ_TERMS) - .rewrite(reader, (MultiTermQuery) query); + .rewrite(searcher.getIndexReader(), (MultiTermQuery) query); } else { - rewritten = query.rewrite(reader); + rewritten = query.rewrite(searcher); } if (rewritten != query) { // only rewrite once and then flatten again - the rewritten query could have a speacial // treatment // if this method is overwritten in a subclass. - flatten(rewritten, reader, flatQueries, boost); + flatten(rewritten, searcher, flatQueries, boost); } // if the query is already rewritten we discard it } @@ -311,7 +318,7 @@ else if (query instanceof PhraseQuery) { * - fieldMatch==false * termSetMap=Map> */ - void saveTerms(Collection flatQueries, IndexReader reader) throws IOException { + void saveTerms(Collection flatQueries, IndexSearcher searcher) throws IOException { for (Query query : flatQueries) { while (query instanceof BoostQuery) { query = ((BoostQuery) query).getQuery(); @@ -320,8 +327,8 @@ void saveTerms(Collection flatQueries, IndexReader reader) throws IOExcep if (query instanceof TermQuery) termSet.add(((TermQuery) query).getTerm().text()); else if (query instanceof PhraseQuery) { for (Term term : ((PhraseQuery) query).getTerms()) termSet.add(term.text()); - } else if (query instanceof MultiTermQuery && reader != null) { - BooleanQuery mtqTerms = (BooleanQuery) query.rewrite(reader); + } else if (query instanceof MultiTermQuery && searcher != null) { + BooleanQuery mtqTerms = (BooleanQuery) query.rewrite(searcher); for (BooleanClause clause : mtqTerms) { termSet.add(((TermQuery) clause.getQuery()).getTerm().text()); } diff --git a/lucene/highlighter/src/test/org/apache/lucene/search/highlight/TestHighlighter.java b/lucene/highlighter/src/test/org/apache/lucene/search/highlight/TestHighlighter.java index cb1e640f3551..1664eb4be67d 100644 --- a/lucene/highlighter/src/test/org/apache/lucene/search/highlight/TestHighlighter.java +++ b/lucene/highlighter/src/test/org/apache/lucene/search/highlight/TestHighlighter.java @@ -261,7 +261,7 @@ public void testHighlightUnknownQueryAfterRewrite() new Query() { @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { CommonTermsQuery query = new CommonTermsQuery(Occur.MUST, Occur.SHOULD, 3); query.add(new Term(FIELD_NAME, "this")); // stop-word query.add(new Term(FIELD_NAME, "long")); @@ -2223,7 +2223,7 @@ public void doSearching(Query unReWrittenQuery) throws Exception { searcher = newSearcher(reader); // for any multi-term queries to work (prefix, wildcard, range,fuzzy etc) // you must use a rewritten query! - query = unReWrittenQuery.rewrite(reader); + query = unReWrittenQuery.rewrite(searcher); if (VERBOSE) System.out.println("Searching for: " + query.toString(FIELD_NAME)); hits = searcher.search(query, 1000); } diff --git a/lucene/highlighter/src/test/org/apache/lucene/search/highlight/custom/TestHighlightCustomQuery.java b/lucene/highlighter/src/test/org/apache/lucene/search/highlight/custom/TestHighlightCustomQuery.java index 97e0baf6e912..37235bdb6110 100644 --- a/lucene/highlighter/src/test/org/apache/lucene/search/highlight/custom/TestHighlightCustomQuery.java +++ b/lucene/highlighter/src/test/org/apache/lucene/search/highlight/custom/TestHighlightCustomQuery.java @@ -21,9 +21,9 @@ import java.util.Map; import java.util.Objects; import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TermQuery; @@ -167,7 +167,7 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { return new TermQuery(term); } diff --git a/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighter.java b/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighter.java index f831fbf77a7f..54792a569f76 100644 --- a/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighter.java +++ b/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighter.java @@ -58,7 +58,7 @@ import org.apache.lucene.tests.analysis.MockTokenizer; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.automaton.DaciukMihovAutomatonBuilder; +import org.apache.lucene.util.automaton.Automata; import org.junit.After; import org.junit.Before; @@ -1619,7 +1619,7 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) { + public Query rewrite(IndexSearcher indexSearcher) { return this; } @@ -1671,12 +1671,11 @@ public void testQueryWithLongTerm() throws IOException { Query query = new BooleanQuery.Builder() .add( - new TermQuery( - new Term("title", "a".repeat(DaciukMihovAutomatonBuilder.MAX_TERM_LENGTH))), + new TermQuery(new Term("title", "a".repeat(Automata.MAX_STRING_UNION_TERM_LENGTH))), BooleanClause.Occur.SHOULD) .add( new TermQuery( - new Term("title", "a".repeat(DaciukMihovAutomatonBuilder.MAX_TERM_LENGTH + 1))), + new Term("title", "a".repeat(Automata.MAX_STRING_UNION_TERM_LENGTH + 1))), BooleanClause.Occur.SHOULD) .add(new TermQuery(new Term("title", "title")), BooleanClause.Occur.SHOULD) .build(); diff --git a/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighterMTQ.java b/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighterMTQ.java index 4f7bd056ca19..4aabf26d45bc 100644 --- a/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighterMTQ.java +++ b/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighterMTQ.java @@ -1111,8 +1111,8 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newOriginalQuery = originalQuery.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query newOriginalQuery = originalQuery.rewrite(indexSearcher); if (newOriginalQuery != originalQuery) { return new MyWrapperSpanQuery((SpanQuery) newOriginalQuery); } diff --git a/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighterStrictPhrases.java b/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighterStrictPhrases.java index 8116ece1c1c0..9f189a215b88 100644 --- a/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighterStrictPhrases.java +++ b/lucene/highlighter/src/test/org/apache/lucene/search/uhighlight/TestUnifiedHighlighterStrictPhrases.java @@ -658,8 +658,8 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newWrapped = wrapped.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query newWrapped = wrapped.rewrite(indexSearcher); if (newWrapped != wrapped) { return new MyQuery(newWrapped); } diff --git a/lucene/highlighter/src/test/org/apache/lucene/search/vectorhighlight/AbstractTestCase.java b/lucene/highlighter/src/test/org/apache/lucene/search/vectorhighlight/AbstractTestCase.java index 8542a45aab7b..9a2f8a9e4faa 100644 --- a/lucene/highlighter/src/test/org/apache/lucene/search/vectorhighlight/AbstractTestCase.java +++ b/lucene/highlighter/src/test/org/apache/lucene/search/vectorhighlight/AbstractTestCase.java @@ -21,7 +21,9 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; -import org.apache.lucene.analysis.*; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.Tokenizer; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute; @@ -37,6 +39,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.DisjunctionMaxQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; @@ -56,6 +59,7 @@ public abstract class AbstractTestCase extends LuceneTestCase { protected Analyzer analyzerB; protected Analyzer analyzerK; protected IndexReader reader; + protected IndexSearcher searcher; protected static final String[] shortMVValues = { "", "", "a b c", "", // empty data in multi valued field @@ -343,6 +347,7 @@ protected void make1dmfIndex(Analyzer analyzer, String... values) throws Excepti writer.close(); if (reader != null) reader.close(); reader = DirectoryReader.open(dir); + searcher = newSearcher(reader); } // make 1 doc with multi valued & not analyzed field @@ -363,6 +368,7 @@ protected void make1dmfIndexNA(String... values) throws Exception { writer.close(); if (reader != null) reader.close(); reader = DirectoryReader.open(dir); + searcher = newSearcher(reader); } protected void makeIndexShortMV() throws Exception { diff --git a/lucene/highlighter/src/test/org/apache/lucene/search/vectorhighlight/TestFieldQuery.java b/lucene/highlighter/src/test/org/apache/lucene/search/vectorhighlight/TestFieldQuery.java index 42fec6dba55f..6538c0d2985f 100644 --- a/lucene/highlighter/src/test/org/apache/lucene/search/vectorhighlight/TestFieldQuery.java +++ b/lucene/highlighter/src/test/org/apache/lucene/search/vectorhighlight/TestFieldQuery.java @@ -63,7 +63,7 @@ public void testFlattenBoolean() throws Exception { FieldQuery fq = new FieldQuery(booleanQuery, true, true); Set flatQueries = new HashSet<>(); - fq.flatten(booleanQuery, reader, flatQueries, 1f); + fq.flatten(booleanQuery, searcher, flatQueries, 1f); assertCollectionQueries(flatQueries, tq(boost, "A"), tq(boost, "B"), tq(boost, "C")); } @@ -73,7 +73,7 @@ public void testFlattenDisjunctionMaxQuery() throws Exception { query = new BoostQuery(query, boost); FieldQuery fq = new FieldQuery(query, true, true); Set flatQueries = new HashSet<>(); - fq.flatten(query, reader, flatQueries, 1f); + fq.flatten(query, searcher, flatQueries, 1f); assertCollectionQueries(flatQueries, tq(boost, "A"), tq(boost, "B"), pqF(boost, "C", "D")); } @@ -87,7 +87,7 @@ public void testFlattenTermAndPhrase() throws Exception { FieldQuery fq = new FieldQuery(booleanQuery, true, true); Set flatQueries = new HashSet<>(); - fq.flatten(booleanQuery, reader, flatQueries, 1f); + fq.flatten(booleanQuery, searcher, flatQueries, 1f); assertCollectionQueries(flatQueries, tq(boost, "A"), pqF(boost, "B", "C")); } @@ -99,7 +99,7 @@ public void testFlattenTermAndPhrase2gram() throws Exception { FieldQuery fq = new FieldQuery(query.build(), true, true); Set flatQueries = new HashSet<>(); - fq.flatten(query.build(), reader, flatQueries, 1f); + fq.flatten(query.build(), searcher, flatQueries, 1f); assertCollectionQueries(flatQueries, tq("AA"), pqF("BC", "CD"), pqF("EF", "FG", "GH")); } @@ -107,7 +107,7 @@ public void testFlatten1TermPhrase() throws Exception { Query query = pqF("A"); FieldQuery fq = new FieldQuery(query, true, true); Set flatQueries = new HashSet<>(); - fq.flatten(query, reader, flatQueries, 1f); + fq.flatten(query, searcher, flatQueries, 1f); assertCollectionQueries(flatQueries, tq("A")); } @@ -950,7 +950,7 @@ public void testFlattenConstantScoreQuery() throws Exception { query = new BoostQuery(query, boost); FieldQuery fq = new FieldQuery(query, true, true); Set flatQueries = new HashSet<>(); - fq.flatten(query, reader, flatQueries, 1f); + fq.flatten(query, searcher, flatQueries, 1f); assertCollectionQueries(flatQueries, tq(boost, "A")); } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ParentChildrenBlockJoinQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/ParentChildrenBlockJoinQuery.java index 9ba6daee284a..64f5e5ba1169 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/ParentChildrenBlockJoinQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/ParentChildrenBlockJoinQuery.java @@ -18,7 +18,6 @@ package org.apache.lucene.search.join; import java.io.IOException; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.search.DocIdSetIterator; @@ -88,12 +87,12 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query childRewrite = childQuery.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query childRewrite = childQuery.rewrite(indexSearcher); if (childRewrite != childQuery) { return new ParentChildrenBlockJoinQuery(parentFilter, childRewrite, parentDocId); } else { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ToChildBlockJoinQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/ToChildBlockJoinQuery.java index 42c41efcefa2..3f2a9e76c54b 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/ToChildBlockJoinQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/ToChildBlockJoinQuery.java @@ -20,7 +20,6 @@ import java.util.Collection; import java.util.Collections; import java.util.Locale; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -305,12 +304,12 @@ int getParentDoc() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query parentRewrite = parentQuery.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query parentRewrite = parentQuery.rewrite(indexSearcher); if (parentRewrite != parentQuery) { return new ToChildBlockJoinQuery(parentRewrite, parentsFilter); } else { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java index 5640c9e4a7cd..6134cd0e1dde 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/ToParentBlockJoinQuery.java @@ -22,7 +22,6 @@ import java.util.Collection; import java.util.Collections; import java.util.Locale; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.ConstantScoreQuery; @@ -173,7 +172,7 @@ public long cost() { public Explanation explain(LeafReaderContext context, int doc) throws IOException { BlockJoinScorer scorer = (BlockJoinScorer) scorer(context); if (scorer != null && scorer.iterator().advance(doc) == doc) { - return scorer.explain(context, in); + return scorer.explain(context, in, scoreMode); } return Explanation.noMatch("Not a match"); } @@ -392,45 +391,61 @@ private void setScoreAndFreq() throws IOException { } this.score = (float) score; } - - public Explanation explain(LeafReaderContext context, Weight childWeight) throws IOException { + /* + * This instance of Explanation requires three parameters, context, childWeight, and scoreMode. + * The scoreMode parameter considers Avg, Total, Min, Max, and None. + * */ + public Explanation explain(LeafReaderContext context, Weight childWeight, ScoreMode scoreMode) + throws IOException { int prevParentDoc = parentBits.prevSetBit(parentApproximation.docID() - 1); int start = context.docBase + prevParentDoc + 1; // +1 b/c prevParentDoc is previous parent doc int end = context.docBase + parentApproximation.docID() - 1; // -1 b/c parentDoc is parent doc Explanation bestChild = null; + Explanation worstChild = null; + int matches = 0; for (int childDoc = start; childDoc <= end; childDoc++) { Explanation child = childWeight.explain(context, childDoc - context.docBase); if (child.isMatch()) { matches++; if (bestChild == null - || child.getValue().floatValue() > bestChild.getValue().floatValue()) { + || child.getValue().doubleValue() > bestChild.getValue().doubleValue()) { bestChild = child; } + if (worstChild == null + || child.getValue().doubleValue() < worstChild.getValue().doubleValue()) { + worstChild = child; + } } } - + assert matches > 0 : "No matches should be handled before."; + Explanation subExplain = scoreMode == ScoreMode.Min ? worstChild : bestChild; return Explanation.match( - score(), - String.format( - Locale.ROOT, - "Score based on %d child docs in range from %d to %d, best match:", - matches, - start, - end), - bestChild); + this.score(), + formatScoreExplanation(matches, start, end, scoreMode), + subExplain == null ? Collections.emptyList() : Collections.singleton(subExplain)); + } + + private String formatScoreExplanation(int matches, int start, int end, ScoreMode scoreMode) { + return String.format( + Locale.ROOT, + "Score based on %d child docs in range from %d to %d, using score mode %s", + matches, + start, + end, + scoreMode); } } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query childRewrite = childQuery.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query childRewrite = childQuery.rewrite(indexSearcher); if (childRewrite != childQuery) { return new ToParentBlockJoinQuery(childRewrite, parentsFilter, scoreMode); } else { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } } diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java b/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java index b77e634aac20..ca4246196fce 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestBlockJoin.java @@ -277,7 +277,6 @@ public void testSimple() throws Exception { CheckHits.checkHitCollector(random(), fullQuery.build(), "country", s, new int[] {2}); TopDocs topDocs = s.search(fullQuery.build(), 1); - // assertEquals(1, results.totalHitCount); assertEquals(1, topDocs.totalHits.value); Document parentDoc = s.storedFields().document(topDocs.scoreDocs[0].doc); @@ -890,13 +889,11 @@ public void testRandom() throws Exception { Explanation explanation = joinS.explain(childJoinQuery, hit.doc); Document document = joinS.storedFields().document(hit.doc - 1); int childId = Integer.parseInt(document.get("childID")); - // System.out.println(" hit docID=" + hit.doc + " childId=" + childId + " parentId=" + - // document.get("parentID")); assertTrue(explanation.isMatch()); assertEquals(hit.score, explanation.getValue().doubleValue(), 0.0f); Matcher m = Pattern.compile( - "Score based on ([0-9]+) child docs in range from ([0-9]+) to ([0-9]+), best match:") + "Score based on ([0-9]+) child docs in range from ([0-9]+) to ([0-9]+), using score mode (None|Avg|Min|Max|Total)") .matcher(explanation.getDescription()); assertTrue("Block Join description not matches", m.matches()); assertTrue("Matched children not positive", Integer.parseInt(m.group(1)) > 0); diff --git a/lucene/luke/src/java/org/apache/lucene/luke/models/search/SearchImpl.java b/lucene/luke/src/java/org/apache/lucene/luke/models/search/SearchImpl.java index 9936d6924150..fa552e65d3f0 100644 --- a/lucene/luke/src/java/org/apache/lucene/luke/models/search/SearchImpl.java +++ b/lucene/luke/src/java/org/apache/lucene/luke/models/search/SearchImpl.java @@ -152,7 +152,7 @@ public Query parseQuery( if (rewrite) { try { - query = query.rewrite(reader); + query = query.rewrite(searcher); } catch (IOException e) { throw new LukeException( String.format(Locale.ENGLISH, "Failed to rewrite query: %s", query.toString()), e); diff --git a/lucene/misc/src/test/org/apache/lucene/misc/search/TestDiversifiedTopDocsCollector.java b/lucene/misc/src/test/org/apache/lucene/misc/search/TestDiversifiedTopDocsCollector.java index e561346f8cf1..ca63678fd2b3 100644 --- a/lucene/misc/src/test/org/apache/lucene/misc/search/TestDiversifiedTopDocsCollector.java +++ b/lucene/misc/src/test/org/apache/lucene/misc/search/TestDiversifiedTopDocsCollector.java @@ -499,12 +499,12 @@ public int hashCode() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query rewritten = query.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query rewritten = query.rewrite(indexSearcher); if (rewritten != query) { return new DocValueScoreQuery(rewritten, scoreField); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/monitor/src/java/org/apache/lucene/monitor/ForceNoBulkScoringQuery.java b/lucene/monitor/src/java/org/apache/lucene/monitor/ForceNoBulkScoringQuery.java index 2af23b4a5616..a97d1054c4c9 100644 --- a/lucene/monitor/src/java/org/apache/lucene/monitor/ForceNoBulkScoringQuery.java +++ b/lucene/monitor/src/java/org/apache/lucene/monitor/ForceNoBulkScoringQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.*; import org.apache.lucene.search.Matches; @@ -34,10 +33,10 @@ public ForceNoBulkScoringQuery(Query inner) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query rewritten = inner.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query rewritten = inner.rewrite(indexSearcher); if (rewritten != inner) return new ForceNoBulkScoringQuery(rewritten); - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/monitor/src/test/org/apache/lucene/monitor/MonitorTestBase.java b/lucene/monitor/src/test/org/apache/lucene/monitor/MonitorTestBase.java index 65017e732470..548c417aafea 100644 --- a/lucene/monitor/src/test/org/apache/lucene/monitor/MonitorTestBase.java +++ b/lucene/monitor/src/test/org/apache/lucene/monitor/MonitorTestBase.java @@ -22,9 +22,9 @@ import java.util.Map; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.standard.StandardAnalyzer; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.queryparser.classic.ParseException; import org.apache.lucene.queryparser.classic.QueryParser; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.tests.util.LuceneTestCase; @@ -65,7 +65,7 @@ protected Monitor newMonitor(Analyzer analyzer) throws IOException { public static class ThrowOnRewriteQuery extends Query { @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { throw new IOException("Error rewriting"); } diff --git a/lucene/monitor/src/test/org/apache/lucene/monitor/TestForceNoBulkScoringQuery.java b/lucene/monitor/src/test/org/apache/lucene/monitor/TestForceNoBulkScoringQuery.java index 28f96c760b33..4f5011ba17e7 100644 --- a/lucene/monitor/src/test/org/apache/lucene/monitor/TestForceNoBulkScoringQuery.java +++ b/lucene/monitor/src/test/org/apache/lucene/monitor/TestForceNoBulkScoringQuery.java @@ -67,7 +67,7 @@ public void testRewrite() throws IOException { assertEquals(q.getWrappedQuery(), pq); - Query rewritten = q.rewrite(reader); + Query rewritten = q.rewrite(newSearcher(reader)); assertTrue(rewritten instanceof ForceNoBulkScoringQuery); Query inner = ((ForceNoBulkScoringQuery) rewritten).getWrappedQuery(); diff --git a/lucene/queries/src/java/org/apache/lucene/queries/CommonTermsQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/CommonTermsQuery.java index 46a7f45069ad..b6ff82dce38d 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/CommonTermsQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/CommonTermsQuery.java @@ -30,6 +30,7 @@ import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -101,7 +102,8 @@ public void add(Term term) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + IndexReader reader = indexSearcher.getIndexReader(); if (this.terms.isEmpty()) { return new MatchNoDocsQuery("CommonTermsQuery with no terms"); } else if (this.terms.size() == 1) { diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java index f5f03d202919..449188091bd6 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.DoubleValues; @@ -122,8 +121,8 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query rewritten = in.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query rewritten = in.rewrite(indexSearcher); if (rewritten == in) { return this; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java index 0d0a99919006..6b2a8a1f546c 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionValues.java @@ -70,6 +70,14 @@ public boolean boolVal(int doc) throws IOException { return intVal(doc) != 0; } + public float[] floatVectorVal(int doc) throws IOException { + throw new UnsupportedOperationException(); + } + + public byte[] byteVectorVal(int doc) throws IOException { + throw new UnsupportedOperationException(); + } + /** * returns the bytes representation of the string val - TODO: should this return the indexed raw * bytes not? diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java new file mode 100644 index 000000000000..c8a4a93a2dfc --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +/** + * An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields. + */ +public class ByteKnnVectorFieldSource extends ValueSource { + private final String fieldName; + + public ByteKnnVectorFieldSource(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + + final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); + + if (vectorValues == null) { + throw new IllegalArgumentException( + "no byte vector value is indexed for field '" + fieldName + "'"); + } + + return new VectorFieldFunction(this) { + + @Override + public byte[] byteVectorVal(int doc) throws IOException { + if (exists(doc)) { + return vectorValues.vectorValue(); + } else { + return null; + } + } + + @Override + protected DocIdSetIterator getVectorIterator() { + return vectorValues; + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ByteKnnVectorFieldSource other = (ByteKnnVectorFieldSource) o; + return Objects.equals(fieldName, other.fieldName); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().hashCode(), fieldName); + } + + @Override + public String description() { + return "ByteKnnVectorFieldSource(" + fieldName + ")"; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java new file mode 100644 index 000000000000..fb6ec68ee9e5 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +/** + * ByteVectorSimilarityFunction returns a similarity function between two knn vectors + * with byte elements. + */ +public class ByteVectorSimilarityFunction extends VectorSimilarityFunction { + public ByteVectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { + super(similarityFunction, vector1, vector2); + } + + @Override + protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + + var v1 = f1.byteVectorVal(doc); + var v2 = f2.byteVectorVal(doc); + + if (v1 == null || v2 == null) { + return 0.f; + } + + assert v1.length == v2.length : "Vectors must have the same length"; + + return similarityFunction.compare(v1, v2); + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java new file mode 100644 index 000000000000..4996e026abee --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +/** Function that returns a constant byte vector value for every document. */ +public class ConstKnnByteVectorValueSource extends ValueSource { + private final byte[] vector; + + public ConstKnnByteVectorValueSource(byte[] constVector) { + this.vector = Objects.requireNonNull(constVector, "constVector"); + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + return new FunctionValues() { + @Override + public byte[] byteVectorVal(int doc) { + return vector; + } + + @Override + public String strVal(int doc) { + return Arrays.toString(vector); + } + + @Override + public String toString(int doc) throws IOException { + return description() + '=' + strVal(doc); + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConstKnnByteVectorValueSource other = (ConstKnnByteVectorValueSource) o; + return Arrays.equals(vector, other.vector); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector)); + } + + @Override + public String description() { + return "ConstKnnByteVectorValueSource(" + Arrays.toString(vector) + ')'; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java new file mode 100644 index 000000000000..57c016eb793e --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.util.VectorUtil; + +/** Function that returns a constant float vector value for every document. */ +public class ConstKnnFloatValueSource extends ValueSource { + private final float[] vector; + + public ConstKnnFloatValueSource(float[] constVector) { + this.vector = VectorUtil.checkFinite(Objects.requireNonNull(constVector, "constVector")); + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + return new FunctionValues() { + @Override + public float[] floatVectorVal(int doc) { + return vector; + } + + @Override + public String strVal(int doc) { + return Arrays.toString(vector); + } + + @Override + public String toString(int doc) throws IOException { + return description() + '=' + strVal(doc); + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConstKnnFloatValueSource other = (ConstKnnFloatValueSource) o; + return Arrays.equals(vector, other.vector); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector)); + } + + @Override + public String description() { + return "ConstKnnFloatValueSource(" + Arrays.toString(vector) + ')'; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java new file mode 100644 index 000000000000..9a1f27a7c79d --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +/** + * An implementation for retrieving {@link FunctionValues} instances for float knn vectors fields. + */ +public class FloatKnnVectorFieldSource extends ValueSource { + private final String fieldName; + + public FloatKnnVectorFieldSource(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + + final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); + + if (vectorValues == null) { + throw new IllegalArgumentException( + "no float vector value is indexed for field '" + fieldName + "'"); + } + return new VectorFieldFunction(this) { + + @Override + public float[] floatVectorVal(int doc) throws IOException { + if (exists(doc)) { + return vectorValues.vectorValue(); + } else { + return null; + } + } + + @Override + protected DocIdSetIterator getVectorIterator() { + return vectorValues; + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FloatKnnVectorFieldSource other = (FloatKnnVectorFieldSource) o; + return Objects.equals(fieldName, other.fieldName); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().hashCode(), fieldName); + } + + @Override + public String description() { + return "FloatKnnVectorFieldSource(" + fieldName + ")"; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java new file mode 100644 index 000000000000..296775388856 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatVectorSimilarityFunction.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +/** + * FloatVectorSimilarityFunction returns a similarity function between two knn vectors + * with float elements. + */ +public class FloatVectorSimilarityFunction extends VectorSimilarityFunction { + public FloatVectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { + super(similarityFunction, vector1, vector2); + } + + @Override + protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { + + var v1 = f1.floatVectorVal(doc); + var v2 = f2.floatVectorVal(doc); + + if (v1 == null || v2 == null) { + return 0.f; + } + + assert v1.length == v2.length : "Vectors must have the same length"; + return similarityFunction.compare(v1, v2); + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java new file mode 100644 index 000000000000..de64984249fe --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorFieldFunction.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.DocIdSetIterator; + +/** An implementation for retrieving {@link FunctionValues} instances for knn vectors fields. */ +public abstract class VectorFieldFunction extends FunctionValues { + + protected final ValueSource valueSource; + int lastDocID; + + protected VectorFieldFunction(ValueSource valueSource) { + this.valueSource = valueSource; + } + + protected abstract DocIdSetIterator getVectorIterator(); + + @Override + public String toString(int doc) throws IOException { + return valueSource.description() + strVal(doc); + } + + @Override + public boolean exists(int doc) throws IOException { + if (doc < lastDocID) { + throw new IllegalArgumentException( + "docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + doc); + } + + lastDocID = doc; + + int curDocID = getVectorIterator().docID(); + if (doc > curDocID) { + curDocID = getVectorIterator().advance(doc); + } + return doc == curDocID; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java new file mode 100644 index 000000000000..9ba2d359a568 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/VectorSimilarityFunction.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function.valuesource; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.queries.function.FunctionValues; +import org.apache.lucene.queries.function.ValueSource; + +/** VectorSimilarityFunction returns a similarity function between two knn vectors. */ +public abstract class VectorSimilarityFunction extends ValueSource { + + protected final org.apache.lucene.index.VectorSimilarityFunction similarityFunction; + protected final ValueSource vector1; + protected final ValueSource vector2; + + public VectorSimilarityFunction( + org.apache.lucene.index.VectorSimilarityFunction similarityFunction, + ValueSource vector1, + ValueSource vector2) { + + this.similarityFunction = similarityFunction; + this.vector1 = vector1; + this.vector2 = vector2; + } + + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + + final FunctionValues vector1Vals = vector1.getValues(context, readerContext); + final FunctionValues vector2Vals = vector2.getValues(context, readerContext); + return new FunctionValues() { + @Override + public float floatVal(int doc) throws IOException { + return func(doc, vector1Vals, vector2Vals); + } + + @Override + public String strVal(int doc) throws IOException { + return Float.toString(floatVal(doc)); + } + + @Override + public boolean exists(int doc) throws IOException { + return MultiFunction.allExists(doc, vector1Vals, vector2Vals); + } + + @Override + public String toString(int doc) throws IOException { + return description() + " = " + strVal(doc); + } + }; + } + + protected abstract float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException; + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return Objects.equals(vector1, ((VectorSimilarityFunction) o).vector1) + && Objects.equals(vector2, ((VectorSimilarityFunction) o).vector2); + } + + @Override + public int hashCode() { + return Objects.hash(similarityFunction, vector1, vector2); + } + + @Override + public String description() { + return similarityFunction.name() + + "(" + + vector1.description() + + ", " + + vector2.description() + + ")"; + } +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThisQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThisQuery.java index 7ba1bb00b5ee..159d30c8e547 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThisQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThisQuery.java @@ -22,9 +22,9 @@ import java.util.Objects; import java.util.Set; import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -57,8 +57,8 @@ public MoreLikeThisQuery( } @Override - public Query rewrite(IndexReader reader) throws IOException { - MoreLikeThis mlt = new MoreLikeThis(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + MoreLikeThis mlt = new MoreLikeThis(indexSearcher.getIndexReader()); mlt.setFieldNames(moreLikeFields); mlt.setAnalyzer(analyzer); diff --git a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java index a782901f9328..d4dd52b46205 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Map; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; @@ -83,12 +82,12 @@ public String getField() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query matchRewritten = wrappedQuery.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query matchRewritten = wrappedQuery.rewrite(indexSearcher); if (wrappedQuery != matchRewritten && matchRewritten instanceof SpanQuery) { return new PayloadScoreQuery((SpanQuery) matchRewritten, function, decoder, includeSpanScore); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/payloads/SpanPayloadCheckQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/payloads/SpanPayloadCheckQuery.java index 1d3c13ed9cd5..ef04a8dccd08 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/payloads/SpanPayloadCheckQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/payloads/SpanPayloadCheckQuery.java @@ -20,7 +20,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; @@ -116,13 +115,13 @@ public SpanWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, floa } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query matchRewritten = match.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query matchRewritten = match.rewrite(indexSearcher); if (match != matchRewritten && matchRewritten instanceof SpanQuery) { return new SpanPayloadCheckQuery( (SpanQuery) matchRewritten, payloadToMatch, payloadType, operation); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/FieldMaskingSpanQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/FieldMaskingSpanQuery.java index 500f9aad456f..038a2e3742de 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/FieldMaskingSpanQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/FieldMaskingSpanQuery.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; @@ -93,13 +92,13 @@ public SpanWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, floa } @Override - public Query rewrite(IndexReader reader) throws IOException { - SpanQuery rewritten = (SpanQuery) maskedQuery.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + SpanQuery rewritten = (SpanQuery) maskedQuery.rewrite(indexSearcher); if (rewritten != maskedQuery) { return new FieldMaskingSpanQuery(rewritten, field); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanContainQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanContainQuery.java index 680ceea0f789..99412f964065 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanContainQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanContainQuery.java @@ -20,7 +20,6 @@ import java.util.ArrayList; import java.util.Map; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; @@ -109,9 +108,9 @@ String toString(String field, String name) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - SpanQuery rewrittenBig = (SpanQuery) big.rewrite(reader); - SpanQuery rewrittenLittle = (SpanQuery) little.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + SpanQuery rewrittenBig = (SpanQuery) big.rewrite(indexSearcher); + SpanQuery rewrittenLittle = (SpanQuery) little.rewrite(indexSearcher); if (big != rewrittenBig || little != rewrittenLittle) { try { SpanContainQuery clone = (SpanContainQuery) super.clone(); @@ -122,7 +121,7 @@ public Query rewrite(IndexReader reader) throws IOException { throw new AssertionError(e); } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanMultiTermQueryWrapper.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanMultiTermQueryWrapper.java index aa36368270ed..882e6c96fb3d 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanMultiTermQueryWrapper.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanMultiTermQueryWrapper.java @@ -117,8 +117,8 @@ public String toString(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - return rewriteMethod.rewrite(reader, query); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return rewriteMethod.rewrite(indexSearcher.getIndexReader(), query); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanNearQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanNearQuery.java index b3f9f8e357bc..00318e4bae64 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanNearQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanNearQuery.java @@ -23,7 +23,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; @@ -242,12 +241,12 @@ public boolean isCacheable(LeafReaderContext ctx) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { boolean actuallyRewritten = false; List rewrittenClauses = new ArrayList<>(); for (int i = 0; i < clauses.size(); i++) { SpanQuery c = clauses.get(i); - SpanQuery query = (SpanQuery) c.rewrite(reader); + SpanQuery query = (SpanQuery) c.rewrite(indexSearcher); actuallyRewritten |= query != c; rewrittenClauses.add(query); } @@ -260,7 +259,7 @@ public Query rewrite(IndexReader reader) throws IOException { throw new AssertionError(e); } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanNotQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanNotQuery.java index fc2f334a4bae..aded21f5f7cf 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanNotQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanNotQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Map; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; @@ -224,13 +223,13 @@ public boolean isCacheable(LeafReaderContext ctx) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - SpanQuery rewrittenInclude = (SpanQuery) include.rewrite(reader); - SpanQuery rewrittenExclude = (SpanQuery) exclude.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + SpanQuery rewrittenInclude = (SpanQuery) include.rewrite(indexSearcher); + SpanQuery rewrittenExclude = (SpanQuery) exclude.rewrite(indexSearcher); if (rewrittenInclude != include || rewrittenExclude != exclude) { return new SpanNotQuery(rewrittenInclude, rewrittenExclude, pre, post); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanOrQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanOrQuery.java index 2b8e1856774a..46b14aef92dd 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanOrQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanOrQuery.java @@ -21,7 +21,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; @@ -67,19 +66,19 @@ public String getField() { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { SpanOrQuery rewritten = new SpanOrQuery(); boolean actuallyRewritten = false; for (int i = 0; i < clauses.size(); i++) { SpanQuery c = clauses.get(i); - SpanQuery query = (SpanQuery) c.rewrite(reader); + SpanQuery query = (SpanQuery) c.rewrite(indexSearcher); actuallyRewritten |= query != c; rewritten.addClause(query); } if (actuallyRewritten) { return rewritten; } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanPositionCheckQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanPositionCheckQuery.java index 0227a13f3ab4..a83969f3fb5e 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanPositionCheckQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanPositionCheckQuery.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Map; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; @@ -113,8 +112,8 @@ protected AcceptStatus accept(Spans candidate) throws IOException { } @Override - public Query rewrite(IndexReader reader) throws IOException { - SpanQuery rewritten = (SpanQuery) match.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + SpanQuery rewritten = (SpanQuery) match.rewrite(indexSearcher); if (rewritten != match) { try { SpanPositionCheckQuery clone = (SpanPositionCheckQuery) this.clone(); @@ -125,7 +124,7 @@ public Query rewrite(IndexReader reader) throws IOException { } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java new file mode 100644 index 000000000000..12144b252ba0 --- /dev/null +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestKnnVectorSimilarityFunctions.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.function; + +import java.util.List; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction; +import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource; +import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource; +import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; +import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.CheckHits; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestKnnVectorSimilarityFunctions extends LuceneTestCase { + static Directory dir; + static Analyzer analyzer; + static IndexReader reader; + static IndexSearcher searcher; + static final List documents = List.of("1", "2"); + + @BeforeClass + public static void beforeClass() throws Exception { + dir = newDirectory(); + analyzer = new MockAnalyzer(random()); + IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer); + iwConfig.setMergePolicy(newLogMergePolicy()); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig); + + Document document = new Document(); + document.add(new StringField("id", "1", Field.Store.NO)); + document.add(new SortedDocValuesField("id", new BytesRef("1"))); + document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document.add(new KnnFloatVectorField("knnFloatField2", new float[] {5.2f, 3.2f, 3.1f})); + + // add only to the first document + document.add(new KnnFloatVectorField("knnFloatField3", new float[] {1.0f, 1.0f, 1.0f})); + document.add(new KnnByteVectorField("knnByteField3", new byte[] {1, 1, 1})); + + document.add(new KnnByteVectorField("knnByteField1", new byte[] {1, 2, 3})); + document.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); + iw.addDocument(document); + + Document document2 = new Document(); + document2.add(new StringField("id", "2", Field.Store.NO)); + document2.add(new SortedDocValuesField("id", new BytesRef("2"))); + document2.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f})); + document2.add(new KnnFloatVectorField("knnFloatField2", new float[] {5.2f, 3.2f, 3.1f})); + + document2.add(new KnnByteVectorField("knnByteField1", new byte[] {1, 2, 3})); + document2.add(new KnnByteVectorField("knnByteField2", new byte[] {4, 2, 3})); + iw.addDocument(document2); + + reader = iw.getReader(); + searcher = newSearcher(reader); + iw.close(); + } + + @AfterClass + public static void afterClass() throws Exception { + searcher = null; + reader.close(); + reader = null; + dir.close(); + dir = null; + analyzer.close(); + analyzer = null; + } + + @Test + public void vectorSimilarity_floatConstantVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ConstKnnFloatValueSource(new float[] {1, 2, 3}); + var v2 = new ConstKnnFloatValueSource(new float[] {5, 4, 1}); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.04f, 0.04f}); + } + + @Test + public void vectorSimilarity_byteConstantVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); + var v2 = new ConstKnnByteVectorValueSource(new byte[] {2, 5, 6}); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.05f, 0.05f}); + } + + @Test + public void vectorSimilarity_floatFieldVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new FloatKnnVectorFieldSource("knnFloatField1"); + var v2 = new FloatKnnVectorFieldSource("knnFloatField2"); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.049776014f, 0.049776014f}); + } + + @Test + public void vectorSimilarity_byteFieldVectors_shouldReturnFloatSimilarity() throws Exception { + var v1 = new ByteKnnVectorFieldSource("knnByteField1"); + var v2 = new ByteKnnVectorFieldSource("knnByteField2"); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.1f, 0.1f}); + } + + @Test + public void vectorSimilarity_FloatConstAndFloatFieldVectors_shouldReturnFloatSimilarity() + throws Exception { + var v1 = new ConstKnnFloatValueSource(new float[] {1, 2, 4}); + var v2 = new FloatKnnVectorFieldSource("knnFloatField1"); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.5f}); + } + + @Test + public void vectorSimilarity_ByteConstAndByteFieldVectors_shouldReturnFloatSimilarity() + throws Exception { + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 4}); + var v2 = new ByteKnnVectorFieldSource("knnByteField1"); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.5f}); + } + + @Test + public void vectorSimilarity_missingFloatVectorField_shouldReturnZero() throws Exception { + var v1 = new ConstKnnFloatValueSource(new float[] {2.f, 1.f, 1.f}); + var v2 = new FloatKnnVectorFieldSource("knnFloatField3"); + assertHits( + new FunctionQuery( + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.f}); + } + + @Test + public void vectorSimilarity_missingByteVectorField_shouldReturnZero() throws Exception { + var v1 = new ConstKnnByteVectorValueSource(new byte[] {2, 1, 1}); + var v2 = new ByteKnnVectorFieldSource("knnByteField3"); + assertHits( + new FunctionQuery( + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2)), + new float[] {0.5f, 0.f}); + } + + @Test + public void vectorSimilarity_twoVectorsWithDifferentDimensions_shouldRaiseException() { + ValueSource v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3, 4}); + ValueSource v2 = new ByteKnnVectorFieldSource("knnByteField1"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + AssertionError.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + + v1 = new ConstKnnFloatValueSource(new float[] {1.f, 2.f}); + v2 = new FloatKnnVectorFieldSource("knnFloatField1"); + FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + AssertionError.class, + () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); + } + + @Test + public void vectorSimilarity_byteAndFloatVectors_shouldRaiseException() { + var v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); + ValueSource v2 = new ByteKnnVectorFieldSource("knnByteField1"); + FloatVectorSimilarityFunction floatDenseVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + UnsupportedOperationException.class, + () -> searcher.search(new FunctionQuery(floatDenseVectorSimilarityFunction), 10)); + + v1 = new ConstKnnByteVectorValueSource(new byte[] {1, 2, 3}); + v2 = new FloatKnnVectorFieldSource("knnFloatField1"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + assertThrows( + UnsupportedOperationException.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + } + + @Test + public void vectorSimilarity_wrongFieldType_shouldRaiseException() { + ValueSource v1 = new ByteKnnVectorFieldSource("knnByteField1"); + ValueSource v2 = new ByteKnnVectorFieldSource("knnFloatField2"); + ByteVectorSimilarityFunction byteDenseVectorSimilarityFunction = + new ByteVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + + assertThrows( + IllegalArgumentException.class, + () -> searcher.search(new FunctionQuery(byteDenseVectorSimilarityFunction), 10)); + + v1 = new FloatKnnVectorFieldSource("knnByteField1"); + v2 = new FloatKnnVectorFieldSource("knnFloatField2"); + FloatVectorSimilarityFunction floatVectorSimilarityFunction = + new FloatVectorSimilarityFunction(VectorSimilarityFunction.EUCLIDEAN, v1, v2); + + assertThrows( + IllegalArgumentException.class, + () -> searcher.search(new FunctionQuery(floatVectorSimilarityFunction), 10)); + } + + private static void assertHits(Query q, float[] scores) throws Exception { + ScoreDoc[] expected = new ScoreDoc[scores.length]; + int[] expectedDocs = new int[scores.length]; + for (int i = 0; i < expected.length; i++) { + expectedDocs[i] = i; + expected[i] = new ScoreDoc(i, scores[i]); + } + TopDocs docs = + searcher.search( + q, documents.size(), new Sort(new SortField("id", SortField.Type.STRING)), true); + CheckHits.checkHitsQuery(q, expected, docs.scoreDocs, expectedDocs); + } +} diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestValueSources.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestValueSources.java index 93bed6fe9f5b..48a84c105f22 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestValueSources.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestValueSources.java @@ -781,8 +781,8 @@ public float getMaxScore(int upTo) throws IOException { } @Override - public Query rewrite(IndexReader reader) throws IOException { - var rewrite = in.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + var rewrite = in.rewrite(indexSearcher); return rewrite == in ? this : new AssertScoreComputedOnceQuery(rewrite); } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/spans/AssertingSpanQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/spans/AssertingSpanQuery.java index f22a70a03023..8eab12aa7d11 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/spans/AssertingSpanQuery.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/spans/AssertingSpanQuery.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -50,10 +49,10 @@ public SpanWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, floa } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query q = in.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query q = in.rewrite(indexSearcher); if (q == in) { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } else if (q instanceof SpanQuery) { return new AssertingSpanQuery((SpanQuery) q); } else { diff --git a/lucene/queries/src/test/org/apache/lucene/queries/spans/TestFieldMaskingSpanQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/spans/TestFieldMaskingSpanQuery.java index 567bfdf711ee..9c98a0025829 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/spans/TestFieldMaskingSpanQuery.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/spans/TestFieldMaskingSpanQuery.java @@ -164,7 +164,7 @@ public void testRewrite1() throws Exception { new FieldMaskingSpanQuery( new SpanTermQuery(new Term("last", "sally")) { @Override - public Query rewrite(IndexReader reader) { + public Query rewrite(IndexSearcher indexSearcher) { return new SpanOrQuery( new SpanTermQuery(new Term("first", "sally")), new SpanTermQuery(new Term("first", "james"))); diff --git a/lucene/queryparser/src/java/org/apache/lucene/queryparser/complexPhrase/ComplexPhraseQueryParser.java b/lucene/queryparser/src/java/org/apache/lucene/queryparser/complexPhrase/ComplexPhraseQueryParser.java index 4dad705b876a..339e60a14bd6 100644 --- a/lucene/queryparser/src/java/org/apache/lucene/queryparser/complexPhrase/ComplexPhraseQueryParser.java +++ b/lucene/queryparser/src/java/org/apache/lucene/queryparser/complexPhrase/ComplexPhraseQueryParser.java @@ -22,7 +22,6 @@ import java.util.List; import java.util.Objects; import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.queries.spans.SpanNearQuery; import org.apache.lucene.queries.spans.SpanNotQuery; @@ -256,7 +255,7 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { final Query contents = this.contents[0]; // ArrayList spanClauses = new ArrayList(); if (contents instanceof TermQuery @@ -284,7 +283,7 @@ public Query rewrite(IndexReader reader) throws IOException { // HashSet bclauseterms=new HashSet(); Query qc = clause.getQuery(); // Rewrite this clause e.g one* becomes (one OR onerous) - qc = new IndexSearcher(reader).rewrite(qc); + qc = indexSearcher.rewrite(qc); if (clause.getOccur().equals(BooleanClause.Occur.MUST_NOT)) { numNegatives++; } diff --git a/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/DistanceRewriteQuery.java b/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/DistanceRewriteQuery.java index d739bc022e67..7186aa3cf244 100644 --- a/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/DistanceRewriteQuery.java +++ b/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/DistanceRewriteQuery.java @@ -17,7 +17,7 @@ package org.apache.lucene.queryparser.surround.query; import java.io.IOException; -import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -28,8 +28,8 @@ class DistanceRewriteQuery extends RewriteQuery { } @Override - public Query rewrite(IndexReader reader) throws IOException { - return srndQuery.getSpanNearQuery(reader, fieldName, qf); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return srndQuery.getSpanNearQuery(indexSearcher.getIndexReader(), fieldName, qf); } @Override diff --git a/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/RewriteQuery.java b/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/RewriteQuery.java index 69005e6dd00c..533cf72ce023 100644 --- a/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/RewriteQuery.java +++ b/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/RewriteQuery.java @@ -18,7 +18,7 @@ import java.io.IOException; import java.util.Objects; -import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; abstract class RewriteQuery extends Query { @@ -33,7 +33,7 @@ abstract class RewriteQuery extends Query { } @Override - public abstract Query rewrite(IndexReader reader) throws IOException; + public abstract Query rewrite(IndexSearcher indexSearcher) throws IOException; @Override public String toString(String field) { diff --git a/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/SimpleTermRewriteQuery.java b/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/SimpleTermRewriteQuery.java index 22f1118beb48..4371e5f224fd 100644 --- a/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/SimpleTermRewriteQuery.java +++ b/lucene/queryparser/src/java/org/apache/lucene/queryparser/surround/query/SimpleTermRewriteQuery.java @@ -19,9 +19,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -33,10 +33,10 @@ class SimpleTermRewriteQuery extends RewriteQuery { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { final List luceneSubQueries = new ArrayList<>(); srndQuery.visitMatchingTerms( - reader, + indexSearcher.getIndexReader(), fieldName, new SimpleTerm.MatchingTermVisitor() { @Override diff --git a/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java b/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java index ae306da9afc3..102360fae87a 100644 --- a/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java +++ b/lucene/queryparser/src/test/org/apache/lucene/queryparser/xml/TestCoreParser.java @@ -316,7 +316,7 @@ private Query implParse(String xmlFileName, boolean span) throws ParserException } protected Query rewrite(Query q) throws IOException { - return q.rewrite(reader()); + return q.rewrite(searcher()); } protected void dumpResults(String qType, Query q, int numDocs) throws IOException { diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java index 6193d27a99ff..ab4cbb05fc99 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java @@ -38,6 +38,7 @@ import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.FuzzyTermsEnum; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TermQuery; @@ -282,7 +283,8 @@ public void visit(QueryVisitor visitor) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + IndexReader reader = indexSearcher.getIndexReader(); ScoreTermQueue q = new ScoreTermQueue(maxNumTerms); // load up the list of possible terms for (FieldVals f : fieldVals) { diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java index 3196f9a6d997..08bb24a846cf 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java @@ -253,7 +253,7 @@ public long ramBytesUsed() { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (terms.length == 0 || fieldAndWeights.isEmpty()) { return new BooleanQuery.Builder().build(); } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CoveringQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CoveringQuery.java index 69e5bd69f834..5a493a00c1a5 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CoveringQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CoveringQuery.java @@ -22,7 +22,6 @@ import java.util.List; import java.util.Objects; import java.util.stream.Collectors; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -125,7 +124,7 @@ public long ramBytesUsed() { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (minimumNumberMatch instanceof LongValuesSource.ConstantLongValuesSource) { final long constantMin = ((LongValuesSource.ConstantLongValuesSource) minimumNumberMatch).getValue(); @@ -136,7 +135,7 @@ public Query rewrite(IndexReader reader) throws IOException { BooleanQuery.Builder builder = new BooleanQuery.Builder().setMinimumNumberShouldMatch((int) Math.max(constantMin, 1)); for (Query query : queries) { - Query r = query.rewrite(reader); + Query r = query.rewrite(indexSearcher); builder.add(r, BooleanClause.Occur.SHOULD); } return builder.build(); @@ -144,14 +143,14 @@ public Query rewrite(IndexReader reader) throws IOException { Multiset rewritten = new Multiset<>(); boolean actuallyRewritten = false; for (Query query : queries) { - Query r = query.rewrite(reader); + Query r = query.rewrite(indexSearcher); rewritten.add(r); actuallyRewritten |= query != r; } if (actuallyRewritten) { return new CoveringQuery(rewritten, minimumNumberMatch); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java index 19b885dfed5a..837de213d227 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java @@ -23,7 +23,6 @@ import java.util.Comparator; import java.util.List; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; @@ -169,7 +168,7 @@ public void visit(QueryVisitor visitor) { * #mergeOverlappingRanges} */ @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (numDims != 1) { return this; } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/PhraseWildcardQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/PhraseWildcardQuery.java index c4fc7189fe17..bbe8a4970eb2 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/PhraseWildcardQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/PhraseWildcardQuery.java @@ -113,14 +113,14 @@ public String getField() { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (phraseTerms.isEmpty()) { return NO_MATCH_QUERY; } if (phraseTerms.size() == 1) { return phraseTerms.get(0).getQuery(); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/TermAutomatonQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/TermAutomatonQuery.java index d06c99523bb6..7fae86711d87 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/TermAutomatonQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/TermAutomatonQuery.java @@ -23,7 +23,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; @@ -485,7 +484,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (Operations.isEmpty(det)) { return new MatchNoDocsQuery(); } diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/queries/TestFuzzyLikeThisQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/queries/TestFuzzyLikeThisQuery.java index 89506bfa7bfb..2b8fca789cb9 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/queries/TestFuzzyLikeThisQuery.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/queries/TestFuzzyLikeThisQuery.java @@ -81,7 +81,7 @@ private void addDoc(RandomIndexWriter writer, String name, String id) throws IOE public void testClosestEditDistanceMatchComesFirst() throws Throwable { FuzzyLikeThisQuery flt = new FuzzyLikeThisQuery(10, analyzer); flt.addTerms("smith", "name", 2, 1); - Query q = flt.rewrite(searcher.getIndexReader()); + Query q = flt.rewrite(searcher); HashSet queryTerms = new HashSet<>(); q.visit(QueryVisitor.termCollector(queryTerms)); assertTrue("Should have variant smythe", queryTerms.contains(new Term("name", "smythe"))); @@ -98,7 +98,7 @@ public void testClosestEditDistanceMatchComesFirst() throws Throwable { public void testMultiWord() throws Throwable { FuzzyLikeThisQuery flt = new FuzzyLikeThisQuery(10, analyzer); flt.addTerms("jonathin smoth", "name", 2, 1); - Query q = flt.rewrite(searcher.getIndexReader()); + Query q = flt.rewrite(searcher); HashSet queryTerms = new HashSet<>(); q.visit(QueryVisitor.termCollector(queryTerms)); assertTrue("Should have variant jonathan", queryTerms.contains(new Term("name", "jonathan"))); @@ -116,7 +116,7 @@ public void testNonExistingField() throws Throwable { flt.addTerms("jonathin smoth", "name", 2, 1); flt.addTerms("jonathin smoth", "this field does not exist", 2, 1); // don't fail here just because the field doesn't exits - Query q = flt.rewrite(searcher.getIndexReader()); + Query q = flt.rewrite(searcher); HashSet queryTerms = new HashSet<>(); q.visit(QueryVisitor.termCollector(queryTerms)); assertTrue("Should have variant jonathan", queryTerms.contains(new Term("name", "jonathan"))); @@ -132,7 +132,7 @@ public void testNonExistingField() throws Throwable { public void testNoMatchFirstWordBug() throws Throwable { FuzzyLikeThisQuery flt = new FuzzyLikeThisQuery(10, analyzer); flt.addTerms("fernando smith", "name", 2, 1); - Query q = flt.rewrite(searcher.getIndexReader()); + Query q = flt.rewrite(searcher); HashSet queryTerms = new HashSet<>(); q.visit(QueryVisitor.termCollector(queryTerms)); assertTrue("Should have variant smith", queryTerms.contains(new Term("name", "smith"))); diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestCoveringQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestCoveringQuery.java index 98d174f806c5..3edc18b16a0a 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestCoveringQuery.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestCoveringQuery.java @@ -80,7 +80,8 @@ public void testRewrite() throws IOException { LongValuesSource vs = LongValuesSource.fromIntField("field"); assertEquals( new CoveringQuery(Collections.singleton(tq), vs), - new CoveringQuery(Collections.singleton(pq), vs).rewrite(new MultiReader())); + new CoveringQuery(Collections.singleton(pq), vs) + .rewrite(new IndexSearcher(new MultiReader()))); } public void testToString() { diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java index b4d1e49445f9..0162fe2d3b58 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestMultiRangeQueries.java @@ -806,7 +806,7 @@ public void testRandomRewrite() throws IOException { builder2.add(LongPoint.newRangeQuery("point", lower, upper), BooleanClause.Occur.SHOULD); } - MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader); + MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(searcher); BooleanQuery booleanQuery = builder2.build(); int count = searcher.search(multiRangeQuery, DummyTotalHitCountCollector.createManager()); int booleanCount = searcher.search(booleanQuery, DummyTotalHitCountCollector.createManager()); @@ -839,7 +839,7 @@ public void testOneDimensionCount() throws IOException { builder2.add(LongPoint.newRangeQuery("point", lower, upper), BooleanClause.Occur.SHOULD); } - MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(reader); + MultiRangeQuery multiRangeQuery = (MultiRangeQuery) builder1.build().rewrite(searcher); BooleanQuery booleanQuery = builder2.build(); int count = multiRangeQuery diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestTermAutomatonQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestTermAutomatonQuery.java index 8598d1efe9b2..3d5539d2e371 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestTermAutomatonQuery.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestTermAutomatonQuery.java @@ -788,7 +788,7 @@ public void testRewriteNoMatch() throws Exception { w.addDocument(doc); IndexReader r = w.getReader(); - assertTrue(q.rewrite(r) instanceof MatchNoDocsQuery); + assertTrue(q.rewrite(newSearcher(r)) instanceof MatchNoDocsQuery); IOUtils.close(w, r, dir); } @@ -807,7 +807,7 @@ public void testRewriteTerm() throws Exception { w.addDocument(doc); IndexReader r = w.getReader(); - Query rewrite = q.rewrite(r); + Query rewrite = q.rewrite(newSearcher(r)); assertTrue(rewrite instanceof TermQuery); assertEquals(new Term("field", "foo"), ((TermQuery) rewrite).getTerm()); IOUtils.close(w, r, dir); @@ -830,7 +830,7 @@ public void testRewriteSimplePhrase() throws Exception { w.addDocument(doc); IndexReader r = w.getReader(); - Query rewrite = q.rewrite(r); + Query rewrite = q.rewrite(newSearcher(r)); assertTrue(rewrite instanceof PhraseQuery); Term[] terms = ((PhraseQuery) rewrite).getTerms(); assertEquals(new Term("field", "foo"), terms[0]); @@ -855,7 +855,7 @@ public CustomTermAutomatonQuery(String field) { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher searcher) { return this; } } @@ -876,7 +876,7 @@ public void testExplainNoMatchingDocument() throws Exception { IndexReader r = w.getReader(); IndexSearcher searcher = newSearcher(r); - Query rewrittenQuery = q.rewrite(r); + Query rewrittenQuery = q.rewrite(searcher); assertTrue(rewrittenQuery instanceof TermAutomatonQuery); TopDocs topDocs = searcher.search(rewrittenQuery, 10); @@ -918,7 +918,7 @@ public void testExplainMatchingDocuments() throws Exception { IndexReader r = w.getReader(); IndexSearcher searcher = newSearcher(r); - Query rewrittenQuery = q.rewrite(r); + Query rewrittenQuery = q.rewrite(searcher); assertTrue( "Rewritten query should be an instance of TermAutomatonQuery", rewrittenQuery instanceof TermAutomatonQuery); @@ -953,7 +953,7 @@ public void testRewritePhraseWithAny() throws Exception { w.addDocument(doc); IndexReader r = w.getReader(); - Query rewrite = q.rewrite(r); + Query rewrite = q.rewrite(newSearcher(r)); assertTrue(rewrite instanceof PhraseQuery); Term[] terms = ((PhraseQuery) rewrite).getTerms(); assertEquals(new Term("field", "foo"), terms[0]); @@ -982,7 +982,7 @@ public void testRewriteSimpleMultiPhrase() throws Exception { w.addDocument(doc); IndexReader r = w.getReader(); - Query rewrite = q.rewrite(r); + Query rewrite = q.rewrite(newSearcher(r)); assertTrue(rewrite instanceof MultiPhraseQuery); Term[][] terms = ((MultiPhraseQuery) rewrite).getTermArrays(); assertEquals(1, terms.length); @@ -1017,7 +1017,7 @@ public void testRewriteMultiPhraseWithAny() throws Exception { w.addDocument(doc); IndexReader r = w.getReader(); - Query rewrite = q.rewrite(r); + Query rewrite = q.rewrite(newSearcher(r)); assertTrue(rewrite instanceof MultiPhraseQuery); Term[][] terms = ((MultiPhraseQuery) rewrite).getTermArrays(); assertEquals(2, terms.length); diff --git a/lucene/spatial-extras/src/java/org/apache/lucene/spatial/composite/CompositeVerifyQuery.java b/lucene/spatial-extras/src/java/org/apache/lucene/spatial/composite/CompositeVerifyQuery.java index 354a0196229d..6483e4253940 100644 --- a/lucene/spatial-extras/src/java/org/apache/lucene/spatial/composite/CompositeVerifyQuery.java +++ b/lucene/spatial-extras/src/java/org/apache/lucene/spatial/composite/CompositeVerifyQuery.java @@ -17,7 +17,6 @@ package org.apache.lucene.spatial.composite; import java.io.IOException; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; @@ -47,12 +46,12 @@ public CompositeVerifyQuery(Query indexQuery, ShapeValuesPredicate predicateValu } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query rewritten = indexQuery.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query rewritten = indexQuery.rewrite(indexSearcher); if (rewritten != indexQuery) { return new CompositeVerifyQuery(rewritten, predicateValueSource); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/spatial-extras/src/java/org/apache/lucene/spatial/vector/PointVectorStrategy.java b/lucene/spatial-extras/src/java/org/apache/lucene/spatial/vector/PointVectorStrategy.java index 77055d15ae27..0442b1c7ec02 100644 --- a/lucene/spatial-extras/src/java/org/apache/lucene/spatial/vector/PointVectorStrategy.java +++ b/lucene/spatial-extras/src/java/org/apache/lucene/spatial/vector/PointVectorStrategy.java @@ -24,7 +24,6 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -253,8 +252,8 @@ private DistanceRangeQuery(Query inner, DoubleValuesSource distanceSource, doubl } @Override - public Query rewrite(IndexReader reader) throws IOException { - Query rewritten = inner.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + Query rewritten = inner.rewrite(indexSearcher); if (rewritten == inner) return this; return new DistanceRangeQuery(rewritten, distanceSource, limit); } diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/CompletionQuery.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/CompletionQuery.java index 577c1a886569..8db16c194a71 100644 --- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/CompletionQuery.java +++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/CompletionQuery.java @@ -20,11 +20,11 @@ import static org.apache.lucene.search.suggest.document.CompletionAnalyzer.HOLE_CHARACTER; import java.io.IOException; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.suggest.BitsProducer; @@ -83,11 +83,11 @@ public Term getTerm() { } @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { byte type = 0; boolean first = true; Terms terms; - for (LeafReaderContext context : reader.leaves()) { + for (LeafReaderContext context : indexSearcher.getLeafContexts()) { LeafReader leafReader = context.reader(); try { if ((terms = leafReader.terms(getField())) == null) { @@ -124,7 +124,7 @@ public Query rewrite(IndexReader reader) throws IOException { } } } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestIndexSearcher.java b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestIndexSearcher.java index b4b97dbde29e..e46e73bb1aa9 100644 --- a/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestIndexSearcher.java +++ b/lucene/suggest/src/java/org/apache/lucene/search/suggest/document/SuggestIndexSearcher.java @@ -62,7 +62,7 @@ public TopSuggestDocs suggest(CompletionQuery query, int n, boolean skipDuplicat public void suggest(CompletionQuery query, TopSuggestDocsCollector collector) throws IOException { // TODO use IndexSearcher.rewrite instead // have to implement equals() and hashCode() in CompletionQuerys and co - query = (CompletionQuery) query.rewrite(getIndexReader()); + query = (CompletionQuery) query.rewrite(this); Weight weight = query.createWeight(this, collector.scoreMode(), 1f); for (LeafReaderContext context : getIndexReader().leaves()) { BulkScorer scorer = weight.bulkScorer(context); diff --git a/lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestAnalyzingSuggester.java b/lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestAnalyzingSuggester.java index b5e49c5ec79b..cb14d983716f 100644 --- a/lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestAnalyzingSuggester.java +++ b/lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestAnalyzingSuggester.java @@ -1325,22 +1325,6 @@ static final Iterable shuffle(Input... values) { return asList; } - // TODO: we need BaseSuggesterTestCase? - public void testTooLongSuggestion() throws Exception { - Analyzer a = new MockAnalyzer(random()); - Directory tempDir = getDirectory(); - AnalyzingSuggester suggester = new AnalyzingSuggester(tempDir, "suggest", a); - String bigString = TestUtil.randomSimpleString(random(), 30000, 30000); - IllegalArgumentException ex = - expectThrows( - IllegalArgumentException.class, - () -> { - suggester.build(new InputArrayIterator(new Input[] {new Input(bigString, 7)})); - }); - assertTrue(ex.getMessage().contains("input automaton is too large")); - IOUtils.close(a, tempDir); - } - private Directory getDirectory() { return newDirectory(); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java index b2ce16d80776..f9cd607ccacc 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/analysis/BaseTokenStreamTestCase.java @@ -55,6 +55,7 @@ import org.apache.lucene.document.TextField; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableFieldType; +import org.apache.lucene.search.BoostAttribute; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; @@ -154,7 +155,8 @@ public static void assertTokenStreamContents( boolean[] keywordAtts, boolean graphOffsetsAreCorrect, byte[][] payloads, - int[] flags) + int[] flags, + float[] boost) throws IOException { assertNotNull(output); CheckClearAttributesAttribute checkClearAtt = @@ -221,6 +223,12 @@ public static void assertTokenStreamContents( flagsAtt = ts.getAttribute(FlagsAttribute.class); } + BoostAttribute boostAtt = null; + if (boost != null) { + assertTrue("has no BoostAttribute", ts.hasAttribute(BoostAttribute.class)); + boostAtt = ts.getAttribute(BoostAttribute.class); + } + // Maps position to the start/end offset: final Map posToStartOffset = new HashMap<>(); final Map posToEndOffset = new HashMap<>(); @@ -243,6 +251,7 @@ public static void assertTokenStreamContents( if (payloadAtt != null) payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24})); if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's + if (boostAtt != null) boostAtt.setBoost(-1f); checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before assertTrue("token " + i + " does not exist", ts.incrementToken()); @@ -278,6 +287,9 @@ public static void assertTokenStreamContents( if (flagsAtt != null) { assertEquals("flagsAtt " + i + " term=" + termAtt, flags[i], flagsAtt.getFlags()); } + if (boostAtt != null) { + assertEquals("boostAtt " + i + " term=" + termAtt, boost[i], boostAtt.getBoost(), 0.001); + } if (payloads != null) { if (payloads[i] != null) { assertEquals("payloads " + i, new BytesRef(payloads[i]), payloadAtt.getPayload()); @@ -405,6 +417,7 @@ public static void assertTokenStreamContents( if (payloadAtt != null) payloadAtt.setPayload(new BytesRef(new byte[] {0x00, -0x21, 0x12, -0x43, 0x24})); if (flagsAtt != null) flagsAtt.setFlags(~0); // all 1's + if (boostAtt != null) boostAtt.setBoost(-1); checkClearAtt.getAndResetClearCalled(); // reset it, because we called clearAttribute() before @@ -426,6 +439,38 @@ public static void assertTokenStreamContents( ts.close(); } + public static void assertTokenStreamContents( + TokenStream ts, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + Integer finalOffset, + Integer finalPosInc, + boolean[] keywordAtts, + boolean graphOffsetsAreCorrect, + byte[][] payloads, + int[] flags) + throws IOException { + assertTokenStreamContents( + ts, + output, + startOffsets, + endOffsets, + types, + posIncrements, + posLengths, + finalOffset, + finalPosInc, + keywordAtts, + graphOffsetsAreCorrect, + payloads, + flags, + null); + } + public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -438,6 +483,33 @@ public static void assertTokenStreamContents( boolean[] keywordAtts, boolean graphOffsetsAreCorrect) throws IOException { + assertTokenStreamContents( + ts, + output, + startOffsets, + endOffsets, + types, + posIncrements, + posLengths, + finalOffset, + keywordAtts, + graphOffsetsAreCorrect, + null); + } + + public static void assertTokenStreamContents( + TokenStream ts, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + Integer finalOffset, + boolean[] keywordAtts, + boolean graphOffsetsAreCorrect, + float[] boost) + throws IOException { assertTokenStreamContents( ts, output, @@ -451,7 +523,8 @@ public static void assertTokenStreamContents( keywordAtts, graphOffsetsAreCorrect, null, - null); + null, + boost); } public static void assertTokenStreamContents( @@ -481,9 +554,36 @@ public static void assertTokenStreamContents( keywordAtts, graphOffsetsAreCorrect, payloads, + null, null); } + public static void assertTokenStreamContents( + TokenStream ts, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + Integer finalOffset, + boolean graphOffsetsAreCorrect, + float[] boost) + throws IOException { + assertTokenStreamContents( + ts, + output, + startOffsets, + endOffsets, + types, + posIncrements, + posLengths, + finalOffset, + null, + graphOffsetsAreCorrect, + boost); + } + public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -505,7 +605,8 @@ public static void assertTokenStreamContents( posLengths, finalOffset, null, - graphOffsetsAreCorrect); + graphOffsetsAreCorrect, + null); } public static void assertTokenStreamContents( @@ -522,6 +623,30 @@ public static void assertTokenStreamContents( ts, output, startOffsets, endOffsets, types, posIncrements, posLengths, finalOffset, true); } + public static void assertTokenStreamContents( + TokenStream ts, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + Integer finalOffset, + float[] boost) + throws IOException { + assertTokenStreamContents( + ts, + output, + startOffsets, + endOffsets, + types, + posIncrements, + posLengths, + finalOffset, + true, + boost); + } + public static void assertTokenStreamContents( TokenStream ts, String[] output, @@ -649,6 +774,21 @@ public static void assertAnalyzesTo( int[] posIncrements, int[] posLengths) throws IOException { + assertAnalyzesTo( + a, input, output, startOffsets, endOffsets, types, posIncrements, posLengths, null); + } + + public static void assertAnalyzesTo( + Analyzer a, + String input, + String[] output, + int[] startOffsets, + int[] endOffsets, + String[] types, + int[] posIncrements, + int[] posLengths, + float[] boost) + throws IOException { assertTokenStreamContents( a.tokenStream("dummy", input), output, @@ -657,7 +797,8 @@ public static void assertAnalyzesTo( types, posIncrements, posLengths, - input.length()); + input.length(), + boost); checkResetException(a, input); checkAnalysisConsistency(random(), a, true, input); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/RandomIndexWriter.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/RandomIndexWriter.java index 11cd73a2aaf5..2988957ba295 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/RandomIndexWriter.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/RandomIndexWriter.java @@ -39,6 +39,7 @@ import org.apache.lucene.internal.tests.IndexWriterAccess; import org.apache.lucene.internal.tests.TestSecrets; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.util.LuceneTestCase; @@ -283,7 +284,12 @@ public long updateDocuments( w.softUpdateDocuments( delTerm, docs, new NumericDocValuesField(config.getSoftDeletesField(), 1)); } else { - seqNo = w.updateDocuments(delTerm, docs); + if (r.nextInt(10) < 3) { + // 30% chance + seqNo = w.updateDocuments(new TermQuery(delTerm), docs); + } else { + seqNo = w.updateDocuments(delTerm, docs); + } } maybeFlushOrCommit(); return seqNo; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingQuery.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingQuery.java index 0947ff908bb4..d4fcb653aba1 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingQuery.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingQuery.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Random; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -74,10 +73,10 @@ public Query getIn() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query rewritten = in.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query rewritten = in.rewrite(indexSearcher); if (rewritten == in) { - return super.rewrite(reader); + return super.rewrite(indexSearcher); } else { return wrap(new Random(random.nextLong()), rewritten); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/BlockScoreQueryWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/BlockScoreQueryWrapper.java index ebb5c4bd1322..5745aef757be 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/BlockScoreQueryWrapper.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/BlockScoreQueryWrapper.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -69,12 +68,12 @@ public int hashCode() { } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query rewritten = query.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query rewritten = query.rewrite(indexSearcher); if (rewritten != query) { return new BlockScoreQueryWrapper(rewritten, blockLength); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/RandomApproximationQuery.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/RandomApproximationQuery.java index 78dc60732fb1..d58baae69361 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/RandomApproximationQuery.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/RandomApproximationQuery.java @@ -19,7 +19,6 @@ import com.carrotsearch.randomizedtesting.generators.RandomNumbers; import java.io.IOException; import java.util.Random; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FilterWeight; @@ -43,12 +42,12 @@ public RandomApproximationQuery(Query query, Random random) { } @Override - public Query rewrite(IndexReader reader) throws IOException { - final Query rewritten = query.rewrite(reader); + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + final Query rewritten = query.rewrite(indexSearcher); if (rewritten != query) { return new RandomApproximationQuery(rewritten, random); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/BaseBitSetTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/BaseBitSetTestCase.java index 1bb5e500a44c..cbce97d87ac7 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/BaseBitSetTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/BaseBitSetTestCase.java @@ -170,6 +170,22 @@ public void testClearRange() throws IOException { } } + /** Test the {@link BitSet#clear()} method. */ + public void testClearAll() throws IOException { + Random random = random(); + final int numBits = 1 + random.nextInt(100000); + for (float percentSet : new float[] {0, 0.01f, 0.1f, 0.5f, 0.9f, 0.99f, 1f}) { + BitSet set1 = new JavaUtilBitSet(randomSet(numBits, percentSet), numBits); + T set2 = copyOf(set1, numBits); + final int iters = atLeast(random, 10); + for (int i = 0; i < iters; ++i) { + set1.clear(); + set2.clear(); + assertEquals(set1, set2, numBits); + } + } + } + private DocIdSet randomCopy(BitSet set, int numBits) throws IOException { switch (random().nextInt(5)) { case 0: @@ -241,6 +257,11 @@ private static class JavaUtilBitSet extends BitSet { this.numBits = numBits; } + @Override + public void clear() { + bitSet.clear(); + } + @Override public void clear(int index) { bitSet.clear(index); diff --git a/lucene/test-framework/src/resources/org/apache/lucene/tests/geo/github-12352-1.geojson.gz b/lucene/test-framework/src/resources/org/apache/lucene/tests/geo/github-12352-1.geojson.gz new file mode 100644 index 000000000000..17a9a17393e7 Binary files /dev/null and b/lucene/test-framework/src/resources/org/apache/lucene/tests/geo/github-12352-1.geojson.gz differ diff --git a/lucene/test-framework/src/resources/org/apache/lucene/tests/geo/github-12352-2.geojson.gz b/lucene/test-framework/src/resources/org/apache/lucene/tests/geo/github-12352-2.geojson.gz new file mode 100644 index 000000000000..eec620e2b8c9 Binary files /dev/null and b/lucene/test-framework/src/resources/org/apache/lucene/tests/geo/github-12352-2.geojson.gz differ