Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨🚨 Limit backtracking in Nougat regexp #35264

Merged
merged 4 commits into from
Dec 17, 2024

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Dec 13, 2024

What does this PR do?

Limit the number of repetitions in a regular expression pattern to prevent the method from hanging.

Here is a code sample to test previous and updated regex matches, along with a performance test.

The only different example is # 1., which I believe should be parsed. Therefore, it can also be considered as a fix, not just a breaking change.

import re
import time


# Original Regex (Vulnerable to ReDoS)
def original_post_process(generation):
    return re.sub(r"^#+ (?:\.?(?:\d|[ixv])+)*\s*(?:$|\n\s*)", "", generation, flags=re.M)

# Updated Regex (Avoids ReDoS by simplifying backtracking risks)
def updated_post_process(generation):
    return re.sub(r"^#+ (?:[\d+\.]+|[ixv\.]+)?\s*(?:$|\n\s*)", "", generation, flags=re.M)


# Test cases to validate equivalence and performance
def test_post_process_equivalence():
    test_cases = [
        
        # Simple headings
        "# Section",
        "# 1.2.3",
        "# .1.2.3",
        "# .i.v.x",

        # Standard headings
        "# Heading 1\n## 1. Subheading\n### 1.1 Sub-subheading\n#### IV. Roman numeral heading\nRegular text starts here.",
        
        # Trailing spaces
        "#    \n## 1.    \n### Roman numeral heading with spaces    \nRegular text here.",
        
        # Roman numerals
        "# i\n# iv Roman numeral heading\n# x Section\nText with valid content.",
        
        # Mixed content
        "# Heading 1\n## Subheading 1\nRegular text.\n### Subheading with text\nSome more regular text.",
        
        # Non-heading patterns
        "# This is a valid heading\nSome text that shouldn't be removed.\n# Heading with text afterward\nText with valid content.",
        
        # Completely empty or irregular inputs
        "",
        "   \n   \n   ",
        
        # Inputs with special characters
        "# # Special heading\nRegular text.",
        
        # Escaped markdown syntax
        "\\# Escaped heading\n## Valid heading\nText content.",
        
        # Multiline text blocks
        "# Valid heading\n\nText under heading.\n\n## Another valid heading\n\nMore text here.",
        
        # Random non-heading input
        "This is just random text with no headings.\nAnother line of text.",

        # Long problematic input for ReDoS
        "# " + "0" * 25 + ":\n",  # Long problematic input for ReDoS

        # Large multiline input
        "# Heading 1\n## 1. Subheading\n### 1.1 Sub-subheading\n#### IV. Roman numeral heading\nRegular text starts here.\n" * 100,

    ]

    for i, input_str in enumerate(test_cases):
        original_output = original_post_process(input_str)
        updated_output = updated_post_process(input_str)

        if original_output != updated_output:
            print("\nInput:\n", input_str)
            print("\nOriginal:\n", original_output)
            print("\nUpdated:\n", updated_output)
            print("\n" * 3)
            # raise ValueError(f"Test {i + 1}: Outputs do not match!")

        print(f"Test {i + 1}: Outputs match!")


# Performance comparison
def performance_test():
    long_input = "# " + "0" * 25 + ":\n"  # Long problematic input for ReDoS

    # Test original method
    start_time = time.time()
    original_post_process(long_input)
    print(f"Original method execution time: {time.time() - start_time:.6f} seconds")

    # Test updated method
    start_time = time.time()
    updated_post_process(long_input)
    print(f"Updated method execution time: {time.time() - start_time:.6f} seconds")


if __name__ == "__main__":
    # Run tests
    print("Running equivalence tests...")
    test_post_process_equivalence()

    print("\nRunning performance tests...")
    performance_test()

Output:

Running equivalence tests...
Test 1: Outputs match!
Test 2: Outputs match!
Test 3: Outputs match!
Test 4: Outputs match!
Test 5: Outputs match!

Input:
 #    
## 1.    
### Roman numeral heading with spaces    
Regular text here.

Original:
 
## 1.    
### Roman numeral heading with spaces    
Regular text here.

Updated:
 

### Roman numeral heading with spaces    
Regular text here.

Test 7: Outputs match!
Test 8: Outputs match!
Test 9: Outputs match!
Test 10: Outputs match!
Test 11: Outputs match!
Test 12: Outputs match!
Test 13: Outputs match!
Test 14: Outputs match!
Test 15: Outputs match!
Test 16: Outputs match!
Test 17: Outputs match!

Running performance tests...
Updated method execution time: 0.000503 seconds

Even for 100000 zeros the time now does not exceed a second.

@qubvel qubvel changed the title Limit backtracking in Nougat regexp 🚨🚨🚨 Limit backtracking in Nougat regexp Dec 13, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qubvel qubvel requested a review from NielsRogge December 17, 2024 11:31
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 🤗

@qubvel qubvel merged commit deac971 into huggingface:main Dec 17, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants