diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 7de2179817d0..7876d4bfbf56 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -126,6 +126,8 @@ Optimizations * GITHUB#12270 Don't generate stacktrace in CollectionTerminatedException. (Armin Braun) +* GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai) + Bug Fixes --------------------- (No changes) diff --git a/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java b/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java index 268744203154..1384e1e992b4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java +++ b/lucene/core/src/java/org/apache/lucene/util/automaton/Operations.java @@ -1273,9 +1273,14 @@ static Automaton totalize(Automaton a) { } /** - * Returns the topological sort of all states reachable from the initial state. Behavior is - * undefined if this automaton has cycles. CPU cost is O(numTransitions), and the implementation - * is recursive so an automaton matching long strings may exhaust the java stack. + * Returns the topological sort of all states reachable from the initial state. This method + * assumes that the automaton does not contain cycles, and will throw an IllegalArgumentException + * if a cycle is detected. The CPU cost is O(numTransitions), and the implementation is + * non-recursive, so it will not exhaust the java stack for automaton matching long strings. If + * there are dead states in the automaton, they will be removed from the returned array. + * + * @param a the Automaton to be sorted + * @return the topologically sorted array of state ids */ public static int[] topoSortStates(Automaton a) { if (a.getNumStates() == 0) { @@ -1283,8 +1288,7 @@ public static int[] topoSortStates(Automaton a) { } int numStates = a.getNumStates(); int[] states = new int[numStates]; - final BitSet visited = new BitSet(numStates); - int upto = topoSortStatesRecurse(a, visited, states, 0, 0, 0); + int upto = topoSortStates(a, states); if (upto < states.length) { // There were dead states @@ -1303,24 +1307,49 @@ public static int[] topoSortStates(Automaton a) { return states; } - // TODO: not great that this is recursive... in theory a - // large automata could exceed java's stack so the maximum level of recursion is bounded to 1000 - private static int topoSortStatesRecurse( - Automaton a, BitSet visited, int[] states, int upto, int state, int level) { - if (level > MAX_RECURSION_LEVEL) { - throw new IllegalArgumentException("input automaton is too large: " + level); - } + /** + * Performs a topological sort on the states of the given Automaton. + * + * @param a The automaton whose states are to be topologically sorted. + * @param states An int array which stores the states. + * @return the number of states in the final sorted list. + * @throws IllegalArgumentException if the input automaton has a cycle. + */ + private static int topoSortStates(Automaton a, int[] states) { + BitSet onStack = new BitSet(a.getNumStates()); + BitSet visited = new BitSet(a.getNumStates()); + var stack = new ArrayDeque(); + stack.push(0); // Assuming that the initial state is 0. + int upto = 0; Transition t = new Transition(); - int count = a.initTransition(state, t); - for (int i = 0; i < count; i++) { - a.getNextTransition(t); - if (!visited.get(t.dest)) { - visited.set(t.dest); - upto = topoSortStatesRecurse(a, visited, states, upto, t.dest, level + 1); + + while (!stack.isEmpty()) { + int state = stack.peek(); // Just peek, don't remove the state yet + + int count = a.initTransition(state, t); + boolean pushed = false; + for (int i = 0; i < count; i++) { + a.getNextTransition(t); + if (!visited.get(t.dest)) { + visited.set(t.dest); + stack.push(t.dest); // Push the next unvisited state onto the stack + onStack.set(state); + pushed = true; + break; // Exit the loop, we'll continue from here in the next iteration + } else if (onStack.get(t.dest)) { + // If the state is on the current recursion stack, we have detected a cycle + throw new IllegalArgumentException("Input automaton has a cycle."); + } + } + + // If we haven't pushed any new state onto the stack, we're done with this state + if (!pushed) { + onStack.clear(state); // remove the node from the current recursion stack + stack.pop(); + states[upto] = state; + upto++; } } - states[upto] = state; - upto++; return upto; } } diff --git a/lucene/core/src/test/org/apache/lucene/util/automaton/TestOperations.java b/lucene/core/src/test/org/apache/lucene/util/automaton/TestOperations.java index 8875bf8c7718..ec38eafe0ced 100644 --- a/lucene/core/src/test/org/apache/lucene/util/automaton/TestOperations.java +++ b/lucene/core/src/test/org/apache/lucene/util/automaton/TestOperations.java @@ -17,6 +17,7 @@ package org.apache.lucene.util.automaton; import static org.apache.lucene.util.automaton.Operations.DEFAULT_DETERMINIZE_WORK_LIMIT; +import static org.apache.lucene.util.automaton.Operations.topoSortStates; import com.carrotsearch.randomizedtesting.generators.RandomNumbers; import java.util.ArrayList; @@ -69,6 +70,42 @@ public void testEmptyLanguageConcatenate() { assertTrue(Operations.isEmpty(concat)); } + /** + * Test case for the topoSortStates method when the input Automaton contains a cycle. This test + * case constructs an Automaton with two disjoint sets of states—one without a cycle and one with + * a cycle. The topoSortStates method should detect the presence of a cycle and throw an + * IllegalArgumentException. + */ + public void testCycledAutomaton() { + Automaton a = generateRandomAutomaton(true); + IllegalArgumentException exc = + expectThrows(IllegalArgumentException.class, () -> topoSortStates(a)); + assertTrue(exc.getMessage().contains("Input automaton has a cycle")); + } + + public void testTopoSortStates() { + Automaton a = generateRandomAutomaton(false); + + int[] sorted = topoSortStates(a); + int[] stateMap = new int[a.getNumStates()]; + Arrays.fill(stateMap, -1); + int order = 0; + for (int state : sorted) { + assertEquals(-1, stateMap[state]); + stateMap[state] = (order++); + } + + Transition transition = new Transition(); + for (int state : sorted) { + int count = a.initTransition(state, transition); + for (int i = 0; i < count; i++) { + a.getNextTransition(transition); + // ensure dest's order is higher than current state + assertTrue(stateMap[transition.dest] > stateMap[state]); + } + } + } + /** Test optimization to concatenate() with empty String to an NFA */ public void testEmptySingletonNFAConcatenate() { Automaton singleton = Automata.makeString(""); @@ -136,19 +173,6 @@ public void testIsFiniteEatsStack() { assertTrue(exc.getMessage().contains("input automaton is too large")); } - public void testTopoSortEatsStack() { - char[] chars = new char[50000]; - TestUtil.randomFixedLengthUnicodeString(random(), chars, 0, chars.length); - String bigString1 = new String(chars); - TestUtil.randomFixedLengthUnicodeString(random(), chars, 0, chars.length); - String bigString2 = new String(chars); - Automaton a = - Operations.union(Automata.makeString(bigString1), Automata.makeString(bigString2)); - IllegalArgumentException exc = - expectThrows(IllegalArgumentException.class, () -> Operations.topoSortStates(a)); - assertTrue(exc.getMessage().contains("input automaton is too large")); - } - /** * Returns the set of all accepted strings. * @@ -182,4 +206,52 @@ private static Set getFiniteStrings(FiniteStringsIterator iterator) { return result; } + + /** + * This method creates a random Automaton by generating states at multiple levels. At each level, + * a random number of states are created, and transitions are added between the states of the + * current and the previous level randomly, If the 'hasCycle' parameter is true, a transition is + * added from the first state of the last level back to the initial state to create a cycle in the + * Automaton.. + * + * @param hasCycle if true, the generated Automaton will have a cycle; if false, it won't have a + * cycle. + * @return a randomly generated Automaton instance. + */ + private Automaton generateRandomAutomaton(boolean hasCycle) { + Automaton a = new Automaton(); + List lastLevelStates = new ArrayList<>(); + int initialState = a.createState(); + int maxLevel = random().nextInt(4, 10); + lastLevelStates.add(initialState); + + for (int level = 1; level < maxLevel; level++) { + int numStates = random().nextInt(3, 10); + List nextLevelStates = new ArrayList<>(); + + for (int i = 0; i < numStates; i++) { + int nextState = a.createState(); + nextLevelStates.add(nextState); + } + + for (int lastState : lastLevelStates) { + for (int nextState : nextLevelStates) { + // if hasCycle is enabled, we will always add a transition, so we could make sure the + // generated Automaton has a cycle. + if (hasCycle || random().nextInt(7) >= 1) { + a.addTransition(lastState, nextState, random().nextInt(10)); + } + } + } + lastLevelStates = nextLevelStates; + } + + if (hasCycle) { + int lastState = lastLevelStates.get(0); + a.addTransition(lastState, initialState, random().nextInt(10)); + } + + a.finishState(); + return a; + } } diff --git a/lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestAnalyzingSuggester.java b/lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestAnalyzingSuggester.java index b5e49c5ec79b..cb14d983716f 100644 --- a/lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestAnalyzingSuggester.java +++ b/lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestAnalyzingSuggester.java @@ -1325,22 +1325,6 @@ static final Iterable shuffle(Input... values) { return asList; } - // TODO: we need BaseSuggesterTestCase? - public void testTooLongSuggestion() throws Exception { - Analyzer a = new MockAnalyzer(random()); - Directory tempDir = getDirectory(); - AnalyzingSuggester suggester = new AnalyzingSuggester(tempDir, "suggest", a); - String bigString = TestUtil.randomSimpleString(random(), 30000, 30000); - IllegalArgumentException ex = - expectThrows( - IllegalArgumentException.class, - () -> { - suggester.build(new InputArrayIterator(new Input[] {new Input(bigString, 7)})); - }); - assertTrue(ex.getMessage().contains("input automaton is too large")); - IOUtils.close(a, tempDir); - } - private Directory getDirectory() { return newDirectory(); }