Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LUCENE-9825: Hunspell: reverse the "words" trie for faster word lookup/suggestions (lucene repo) #2

Merged
merged 8 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
*/
package org.apache.lucene.analysis.hunspell;

import static org.apache.lucene.analysis.hunspell.AffixKind.*;
import static org.apache.lucene.analysis.hunspell.AffixKind.PREFIX;
import static org.apache.lucene.analysis.hunspell.AffixKind.SUFFIX;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
Expand Down Expand Up @@ -53,8 +54,6 @@
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.IntsRefBuilder;
Expand Down Expand Up @@ -91,14 +90,8 @@ public class Dictionary {
*/
ArrayList<AffixCondition> patterns = new ArrayList<>();

/**
* The entries in the .dic file, mapping to their set of flags. the fst output is the ordinal list
* for flagLookup.
*/
FST<IntsRef> words;

/** A Bloom filter over {@link #words} to avoid unnecessary expensive FST traversals */
FixedBitSet wordHashes;
/** The entries in the .dic file, mapping to their set of flags */
WordStorage words;

/**
* The list of unique flagsets (wordforms). theoretically huge, but practically small (for Polish
Expand Down Expand Up @@ -257,9 +250,8 @@ public void close() {
// read dictionary entries
IndexOutput unsorted = tempDir.createTempOutput(tempFileNamePrefix, "dat", IOContext.DEFAULT);
int wordCount = mergeDictionaries(dictionaries, decoder, unsorted);
wordHashes = new FixedBitSet(Integer.highestOneBit(wordCount * 10));
String sortedFile = sortWordsOffline(tempDir, tempFileNamePrefix, unsorted);
words = readSortedDictionaries(tempDir, sortedFile, flagEnumerator);
words = readSortedDictionaries(tempDir, sortedFile, flagEnumerator, wordCount);
flagLookup = flagEnumerator.finish();
aliases = null; // no longer needed
morphAliases = null; // no longer needed
Expand All @@ -272,36 +264,27 @@ int formStep() {

/** Looks up Hunspell word forms from the dictionary */
IntsRef lookupWord(char[] word, int offset, int length) {
int hash = CharsRef.stringHashCode(word, offset, length);
if (!wordHashes.get(Math.abs(hash) % wordHashes.length())) {
return null;
}

return lookup(words, word, offset, length);
return words.lookupWord(word, offset, length);
}

// only for testing
IntsRef lookupPrefix(char[] word) {
return lookup(prefixes, word, 0, word.length);
return lookup(prefixes, word);
}

// only for testing
IntsRef lookupSuffix(char[] word) {
return lookup(suffixes, word, 0, word.length);
return lookup(suffixes, word);
}

IntsRef lookup(FST<IntsRef> fst, char[] word, int offset, int length) {
if (fst == null) {
return null;
}
private IntsRef lookup(FST<IntsRef> fst, char[] word) {
final FST.BytesReader bytesReader = fst.getBytesReader();
final FST.Arc<IntsRef> arc = fst.getFirstArc(new FST.Arc<>());
// Accumulate output as we go
IntsRef output = fst.outputs.getNoOutput();

int l = offset + length;
for (int i = offset, cp; i < l; i += Character.charCount(cp)) {
cp = Character.codePointAt(word, i, l);
for (int i = 0, cp; i < word.length; i += Character.charCount(cp)) {
cp = Character.codePointAt(word, i, word.length);
output = nextArc(fst, arc, bytesReader, output, cp);
if (output == null) {
return null;
Expand Down Expand Up @@ -1134,13 +1117,13 @@ public int compare(BytesRef o1, BytesRef o2) {
return sorted;
}

private FST<IntsRef> readSortedDictionaries(
Directory tempDir, String sorted, FlagEnumerator flags) throws IOException {
private WordStorage readSortedDictionaries(
Directory tempDir, String sorted, FlagEnumerator flags, int wordCount) throws IOException {
boolean success = false;

Map<String, Integer> morphIndices = new HashMap<>();

EntryGrouper grouper = new EntryGrouper(flags);
WordStorage.Builder builder = new WordStorage.Builder(wordCount, hasCustomMorphData, flags);

try (ByteSequencesReader reader =
new ByteSequencesReader(tempDir.openChecksumInput(sorted, IOContext.READONCE), sorted)) {
Expand Down Expand Up @@ -1180,6 +1163,8 @@ private FST<IntsRef> readSortedDictionaries(
entry = line.substring(0, flagSep);
}

if (entry.isEmpty()) continue;

int morphDataID = 0;
if (end + 1 < line.length()) {
List<String> morphFields = readMorphFields(entry, line.substring(end + 1));
Expand All @@ -1189,14 +1174,12 @@ private FST<IntsRef> readSortedDictionaries(
}
}

wordHashes.set(Math.abs(entry.hashCode()) % wordHashes.length());
grouper.add(entry, wordForm, morphDataID);
builder.add(entry, wordForm, morphDataID);
}

// finalize last entry
grouper.flushGroup();
success = true;
return grouper.words.compile();
return builder.build();
} finally {
if (success) {
tempDir.deleteFile(sorted);
Expand Down Expand Up @@ -1275,76 +1258,6 @@ boolean isDotICaseChangeDisallowed(char[] word) {
return word[0] == 'İ' && !alternateCasing;
}

private class EntryGrouper {
final FSTCompiler<IntsRef> words =
new FSTCompiler<>(FST.INPUT_TYPE.BYTE4, IntSequenceOutputs.getSingleton());
private final List<char[]> group = new ArrayList<>();
private final List<Integer> morphDataIDs = new ArrayList<>();
private final IntsRefBuilder scratchInts = new IntsRefBuilder();
private String currentEntry = null;
private final FlagEnumerator flagEnumerator;

EntryGrouper(FlagEnumerator flagEnumerator) {
this.flagEnumerator = flagEnumerator;
}

void add(String entry, char[] flags, int morphDataID) throws IOException {
if (!entry.equals(currentEntry)) {
if (currentEntry != null) {
if (entry.compareTo(currentEntry) < 0) {
throw new IllegalArgumentException("out of order: " + entry + " < " + currentEntry);
}
flushGroup();
}
currentEntry = entry;
}

group.add(flags);
if (hasCustomMorphData) {
morphDataIDs.add(morphDataID);
}
}

void flushGroup() throws IOException {
IntsRefBuilder currentOrds = new IntsRefBuilder();

boolean hasNonHidden = false;
for (char[] flags : group) {
if (!hasHiddenFlag(flags)) {
hasNonHidden = true;
break;
}
}

for (int i = 0; i < group.size(); i++) {
char[] flags = group.get(i);
if (hasNonHidden && hasHiddenFlag(flags)) {
continue;
}

currentOrds.append(flagEnumerator.add(flags));
if (hasCustomMorphData) {
currentOrds.append(morphDataIDs.get(i));
}
}

Util.toUTF32(currentEntry, scratchInts);
words.add(scratchInts.get(), currentOrds.get());

group.clear();
morphDataIDs.clear();
}
}

private static boolean hasHiddenFlag(char[] flags) {
for (char flag : flags) {
if (flag == HIDDEN_FLAG) {
return true;
}
}
return false;
}

private void parseAlias(String line) {
String[] ruleArgs = line.split("\\s+");
if (aliases == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static org.apache.lucene.analysis.hunspell.Dictionary.AFFIX_FLAG;
import static org.apache.lucene.analysis.hunspell.Dictionary.AFFIX_STRIP_ORD;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashSet;
Expand All @@ -30,11 +29,8 @@
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.IntsRefFSTEnum;
import org.apache.lucene.util.fst.IntsRefFSTEnum.InputOutput;

/**
* A class that traverses the entire dictionary and applies affix rules to check if those yield
Expand Down Expand Up @@ -68,66 +64,42 @@ private List<Weighted<Root<String>>> findSimilarDictionaryEntries(
boolean ignoreTitleCaseRoots = originalCase == WordCase.LOWER && !dictionary.hasLanguage("de");
TrigramAutomaton automaton = new TrigramAutomaton(word);

IntsRefFSTEnum<IntsRef> fstEnum = new IntsRefFSTEnum<>(dictionary.words);
InputOutput<IntsRef> mapping;
while ((mapping = nextKey(fstEnum, word.length() + 4)) != null) {
speller.checkCanceled.run();
dictionary.words.processAllWords(
word.length() + 4,
(rootChars, forms) -> {
speller.checkCanceled.run();

IntsRef key = mapping.input;
assert key.length > 0;
if (Math.abs(key.length - word.length()) > MAX_ROOT_LENGTH_DIFF) {
assert key.length < word.length(); // nextKey takes care of longer keys
continue;
}

String root = toString(key);
filterSuitableEntries(root, mapping.output, entries);
if (entries.isEmpty()) continue;
assert rootChars.length > 0;
if (Math.abs(rootChars.length - word.length()) > MAX_ROOT_LENGTH_DIFF) {
assert rootChars.length < word.length(); // processAllWords takes care of longer keys
return;
}

if (ignoreTitleCaseRoots && WordCase.caseOf(root) == WordCase.TITLE) {
continue;
}
String root = rootChars.toString();
filterSuitableEntries(root, forms, entries);
if (entries.isEmpty()) return;

String lower = dictionary.toLowerCase(root);
int sc =
automaton.ngramScore(lower) - longerWorsePenalty(word, lower) + commonPrefix(word, root);
if (ignoreTitleCaseRoots && WordCase.caseOf(rootChars) == WordCase.TITLE) {
return;
}

if (roots.size() == MAX_ROOTS && sc < roots.peek().score) {
continue;
}
String lower = dictionary.toLowerCase(root);
int sc =
automaton.ngramScore(lower)
- longerWorsePenalty(word, lower)
+ commonPrefix(word, root);

entries.forEach(e -> roots.add(new Weighted<>(e, sc)));
while (roots.size() > MAX_ROOTS) {
roots.poll();
}
}
return roots.stream().sorted().collect(Collectors.toList());
}
if (roots.size() == MAX_ROOTS && sc < roots.peek().score) {
return;
}

private static InputOutput<IntsRef> nextKey(IntsRefFSTEnum<IntsRef> fstEnum, int maxLen) {
try {
InputOutput<IntsRef> next = fstEnum.next();
while (next != null && next.input.length > maxLen) {
int offset = next.input.offset;
int[] ints = ArrayUtil.copyOfSubArray(next.input.ints, offset, offset + maxLen);
if (ints[ints.length - 1] == Integer.MAX_VALUE) {
throw new AssertionError("Too large char");
}
ints[ints.length - 1]++;
next = fstEnum.seekCeil(new IntsRef(ints, 0, ints.length));
}
return next;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
entries.forEach(e -> roots.add(new Weighted<>(e, sc)));
while (roots.size() > MAX_ROOTS) {
roots.poll();
}
});

private static String toString(IntsRef key) {
char[] chars = new char[key.length];
for (int i = 0; i < key.length; i++) {
chars[i] = (char) key.ints[i + key.offset];
}
return new String(chars);
return roots.stream().sorted().collect(Collectors.toList());
}

private void filterSuitableEntries(String word, IntsRef forms, List<Root<String>> result) {
Expand Down Expand Up @@ -363,7 +335,7 @@ private List<String> getMostRelevantSuggestions(
return result;
}

private static int commonPrefix(String s1, String s2) {
static int commonPrefix(String s1, String s2) {
int i = 0;
int limit = Math.min(s1.length(), s2.length());
while (i < limit && s1.charAt(i) == s2.charAt(i)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ private void tryLongSwap(String word) {
}

private void tryRemovingChar(String word) {
if (word.length() == 1) return;

for (int i = 0; i < word.length(); i++) {
trySuggestion(word.substring(0, i) + word.substring(i + 1));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ public List<CharsRef> stem(char[] word, int length) {
}

List<CharsRef> list = new ArrayList<>();
if (length == 0) {
return list;
}

RootProcessor processor =
(stem, formID, stemException) -> {
list.add(newStem(stem, stemException));
Expand Down Expand Up @@ -484,6 +488,8 @@ private char[] stripAffix(
int stripEnd = dictionary.stripOffsets[stripOrd + 1];
int stripLen = stripEnd - stripStart;

if (stripLen + deAffixedLen == 0) return null;

char[] stripData = dictionary.stripData;
int condition = dictionary.getAffixCondition(affix);
if (condition != 0) {
Expand Down
Loading