Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize reduced_vocabulary #1070

Closed

Conversation

lapp0
Copy link
Contributor

@lapp0 lapp0 commented Jul 29, 2024

Maybe? fixes #768 (discuss)

Original Issue

The suggestion to execute reduced_vocabulary and model loading in parallel presents challenges:

  • For models.llamacpp, the model must be loaded before accessing the tokenizer, preventing any benefits.
  • Implementing this for models.transformers_vision would require complex changes.
  • Each model loader needs a distinct implementation to parallelize.

So I looked at reduced_vocabulary and saw a performance issue similar to one I've seen previously. I applied a similar fix from before (described below) and it works well.

Problem

In main, reduced_vocabulary constructs a numba List in pure-python mode. There is a serious performance issue for numba.typed.List.append() calls in pure-python mode.

reduced_vocabulary is run once per model load. It is more annoying now that models are starting to have 100,000 - 200,000 token vocabularies.

Solution

  • In pure-python construct a dict mapping token_ids to normalized token strings.
  • In @njit, convert the dictionary to a list of tuples, where each tuple contains (normalized_token: unicode_type, token_ids: int64[:]).

Benchmarks

New benchmarks

| Change   | Before [d78041e8]    | After [c87534d7]    |   Ratio | Benchmark (Parameter)                                                                                     |
|----------|----------------------|---------------------|---------|-----------------------------------------------------------------------------------------------------------|
| -        | 376±3ms              | 17.5±0.5ms          | 0.05    | bench_regex_fsm.RegexReducedVocabularyBenchmark.time_reduced_vocabulary(10000)                            |
| -        | 3.51±0.03s           | 188±2ms             | 0.05    | bench_regex_fsm.RegexReducedVocabularyBenchmark.time_reduced_vocabulary(100000)                           |
|          | failed               | 1.94±0.04s          | n/a     | bench_regex_fsm.RegexReducedVocabularyBenchmark.time_reduced_vocabulary(1000000)                          |

Old benchmarks

Benchmarks that have improved:
| Change   | Before [26e29344]    | After [58df8ee1]    | Ratio   | Benchmark (Parameter)                                                                                     |
|----------|----------------------|---------------------|---------|-----------------------------------------------------------------------------------------------------------|
| -        | 5.25±0.05s           | 3.22±0.1s           | 0.61    | bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm('complex_schema')                           |
| -        | 3.71±0.06s           | 1.62±0.03s          | 0.44    | bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm('simple_schema')                            |
| -        | 5.82±0.05s           | 3.86±0.04s          | 0.66    | bench_numba_compile.NumbaCompileBenchmark.time_compile_numba                                              |
| -        | 2.74±0.03s           | 675±7ms             | 0.25    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('complex_phone')                                |
| -        | 6.37±0.01s           | 4.35±0.06s          | 0.68    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('complex_span_constrained_relation_extraction') |
| -        | 2.57±0.03s           | 514±3ms             | 0.20    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('date')                                         |
| -        | 2.54±0.02s           | 479±4ms             | 0.19    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('email')                                        |
| -        | 2.50±0.01s           | 434±4ms             | 0.17    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('ip')                                           |
| -        | 2.46±0.02s           | 400±2ms             | 0.16    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('simple_phone')                                 |
| -        | 2.38±0s              | 350±5ms             | 0.15    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('ssn')                                          |
| -        | 2.43±0.02s           | 350±3ms             | 0.14    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('time')                                         |
| -        | 2.63±0.01s           | 582±3ms             | 0.22    | bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('url')                                          |
Benchmarks that have stayed the same:
| Change   | Before [26e29344]    | After [58df8ee1]    |   Ratio | Benchmark (Parameter)                                                                                              |
|----------|----------------------|---------------------|---------|--------------------------------------------------------------------------------------------------------------------|
|          | 89.8±2μs             | 90.3±0.6μs          |    1.01 | bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_regex('complex_schema')                                  |
|          | 50.1±0.3μs           | 50.3±0.2μs          |    1    | bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_regex('simple_schema')                                   |
|          | 187±4μs              | 188±3μs             |    1.01 | bench_processors.LogitsProcessorPassthroughBenchmark.time_passthrough('numpy')                                     |
|          | 192±5μs              | 175±3μs             |    0.91 | bench_processors.LogitsProcessorPassthroughBenchmark.time_passthrough('torch')                                     |
|          | 247±3μs              | 247±3μs             |    1    | bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation('numpy', 'Z*')                      |
|          | 1.09±0.01ms          | 1.12±0.02ms         |    1.03 | bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation('numpy', '[^Z]*')                   |
|          | 235±4μs              | 234±4μs             |    1    | bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation('torch', 'Z*')                      |
|          | 1.11±0.01ms          | 1.08±0.02ms         |    0.97 | bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation('torch', '[^Z]*')                   |
|          | 594M                 | 590M                |    0.99 | bench_regex_guide.MemoryRegexGuideBenchmark.peakmem_regex_to_guide('complex_span_constrained_relation_extraction') |
|          | 496M                 | 489M                |    0.99 | bench_regex_guide.MemoryRegexGuideBenchmark.peakmem_regex_to_guide('simple_phone')                                 |

Open Questions

Why is the benchmark for compiling numba faster? This raises an eyebrow.

@lapp0 lapp0 added optimization Related to performance optimizations run-benchmarks regex labels Jul 29, 2024
@lapp0 lapp0 force-pushed the parallel-model-tokenizer-index-load branch from 1bf8c30 to 4058ce0 Compare July 29, 2024 01:29

from .common import ensure_numba_compiled

tokenizer_uris = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create synthetic datasets instead of pulling external data?

Copy link
Contributor Author

@lapp0 lapp0 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

The benchmarks show a 95% reduction in runtime, but this is misleading. With the real tokenizers it was ~90% reduction. The difference is that convert_token_to_string is much simpler in this new benchmark AND if "\ufffd" in token_str and not re_replacement_seq.match(token): is never True

@lapp0 lapp0 force-pushed the parallel-model-tokenizer-index-load branch from 4058ce0 to 144b97a Compare July 29, 2024 20:35
@lapp0 lapp0 force-pushed the parallel-model-tokenizer-index-load branch from 144b97a to f2abefe Compare July 29, 2024 20:38
@lapp0 lapp0 closed this Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
optimization Related to performance optimizations regex run-benchmarks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Speed up index construction by converting vocabulary types while loading the model
2 participants