diff --git a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs index 6c08fae5b5..8d23442e89 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -762,6 +762,7 @@ private static BertTokenizer Create( options.Normalizer ??= options.ApplyBasicTokenization ? new BertNormalizer(options.LowerCaseBeforeTokenization, options.IndividuallyTokenizeCjk, options.RemoveNonSpacingMarks) : null; + IReadOnlyDictionary? specialTokensDict = options.SpecialTokens; if (options.SplitOnSpecialTokens) { bool lowerCase = options.ApplyBasicTokenization && options.LowerCaseBeforeTokenization; @@ -769,8 +770,8 @@ private static BertTokenizer Create( { if (lowerCase) { - Dictionary dic = options.SpecialTokens.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); - options.SpecialTokens = dic; + Dictionary tempSpecialTokens = []; + specialTokensDict = tempSpecialTokens; foreach (var kvp in options.SpecialTokens) { @@ -779,37 +780,49 @@ private static BertTokenizer Create( throw new ArgumentException($"The special token '{kvp.Key}' is not in the vocabulary or assigned id value {id} different than the value {kvp.Value} in the special tokens."); } - // Ensure that the special tokens are lowercased. - dic[kvp.Key.ToLowerInvariant()] = kvp.Value; + // Add the special token into our dictionary, normalizing it, and adding it into the + // main vocab, if needed. + AddSpecialToken(vocab, tempSpecialTokens, kvp.Key, true); } } } else { - // Create a dictionary with the special tokens. - Dictionary specialTokens = new Dictionary(); - options.SpecialTokens = specialTokens; - - AddSpecialToken(vocab, specialTokens, options.UnknownToken, lowerCase); - AddSpecialToken(vocab, specialTokens, options.SeparatorToken, lowerCase); - AddSpecialToken(vocab, specialTokens, options.PaddingToken, lowerCase); - AddSpecialToken(vocab, specialTokens, options.ClassificationToken, lowerCase); - AddSpecialToken(vocab, specialTokens, options.MaskingToken, lowerCase); + // Create a dictionary with the special tokens - store the un-normalized forms in the options as + // that field is exposed to the public. In addition, store the normalized form for creating the + // pre-tokenizer. + Dictionary tempSpecialTokens = []; + Dictionary notNormalizedSpecialTokens = []; + AddSpecialToken(vocab, tempSpecialTokens, options.UnknownToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.PaddingToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.ClassificationToken, lowerCase, notNormalizedSpecialTokens); + AddSpecialToken(vocab, tempSpecialTokens, options.MaskingToken, lowerCase, notNormalizedSpecialTokens); + + options.SpecialTokens = notNormalizedSpecialTokens; + specialTokensDict = tempSpecialTokens; } } - options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? options.SpecialTokens : null) : PreTokenizer.CreateWhiteSpace(); + // We set the PreTokenizer here using the normalized special tokens dict (if relevant), and therefore we can + // keep the not-normalized special tokens dict in the options passed to the WordPieceTokenizer. + options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? specialTokensDict : null) : PreTokenizer.CreateWhiteSpace(); return new BertTokenizer(vocab, vocabReverse, options); } - private static void AddSpecialToken(Dictionary vocab, Dictionary specialTokens, string token, bool lowerCase) + private static void AddSpecialToken(Dictionary vocab, Dictionary specialTokens, string token, bool lowerCase, Dictionary? notNormalizedSpecialTokens = null) { if (token is null || !vocab.TryGetValue(new StringSpanOrdinalKey(token), out int id)) { throw new ArgumentException($"The special token '{token}' is not in the vocabulary."); } + if (notNormalizedSpecialTokens is not null) + { + notNormalizedSpecialTokens[token] = id; + } + string normalizedToken = token; if (lowerCase) { diff --git a/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs index e362da9b93..6dc0051346 100644 --- a/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -42,7 +42,7 @@ internal WordPieceTokenizer( options ??= new(); SpecialTokens = options.SpecialTokens; - SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key) : null; + SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.GroupBy(kvp => kvp.Value).ToDictionary(g => g.Key, g => g.First().Key) : null; if (options.UnknownToken is null) { @@ -800,4 +800,4 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool return OperationStatus.Done; } } -} \ No newline at end of file +} diff --git a/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs index fb1c3850ba..8a5042f645 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -14,6 +14,91 @@ namespace Microsoft.ML.Tokenizers.Tests { public class BertTokenizerTests { + [Fact] + public void TestWithLowerCasingExplicitSpecialTokens() + { + // Add [SPECIAL] token at end (to keep indices as is) + // Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13 + string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"]; + + string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens); + + Dictionary specialTokens = new() { + { "[PAD]", 0 }, + { "[UNK]", 1 }, + { "[CLS]", 2 }, + { "[SEP]", 3 }, + { "[MASK]", 4 }, + { "[SPECIAL]", 13 }, + }; + var bertOptions = new BertOptions() + { + SpecialTokens = specialTokens + }; + + try + { + using Stream vocabStream = File.OpenRead(vocabFile); + BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile, bertOptions), BertTokenizer.Create(vocabStream, bertOptions)]; + + foreach (var tokenizer in bertTokenizers) + { + Assert.NotNull(tokenizer.PreTokenizer); + Assert.Equal("[UNK]", tokenizer.UnknownToken); + Assert.Equal(1, tokenizer.UnknownTokenId); + Assert.NotNull(tokenizer.Normalizer); + Assert.NotNull(tokenizer.PreTokenizer); + + Assert.True(tokenizer.SpecialTokens!.ContainsKey("[SPECIAL]")); + + string text = "Hello, How are you [SPECIAL]?"; + var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText); + Assert.Equal("hello, how are you [special]?", normalizedText); + + Assert.Equal( + [ + new EncodedToken(8, "hello", new Range(0, 5)), + new EncodedToken(6, ",", new Range(5, 6)), + new EncodedToken(10, "how", new Range(7, 10)), + new EncodedToken(11, "are", new Range(11, 14)), + new EncodedToken(12, "you", new Range(15, 18)), + new EncodedToken(13, "[SPECIAL]", new Range(19, 28)), + new EncodedToken(7, "?", new Range(28, 29)) + ], + tokens); + + var ids = tokenizer.EncodeToIds(text); + Assert.Equal([tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId], ids); + + Assert.Equal("[CLS] hello, how are you [SPECIAL]? [SEP]", tokenizer.Decode(ids)); + Assert.Equal("hello, how are you?", tokenizer.Decode(ids, skipSpecialTokens: true)); + + tokens = tokenizer.EncodeToTokens(tokenizer.Decode(ids), out normalizedText); + Assert.Equal("[cls] hello, how are you [special]? [sep]", normalizedText); + Assert.Equal( + [ + new EncodedToken(2, "[CLS]", new Range(0, 5)), + new EncodedToken(8, "hello", new Range(6, 11)), + new EncodedToken(6, ",", new Range(11, 12)), + new EncodedToken(10, "how", new Range(13, 16)), + new EncodedToken(11, "are", new Range(17, 20)), + new EncodedToken(12, "you", new Range(21, 24)), + new EncodedToken(13, "[SPECIAL]", new Range(25, 34)), + new EncodedToken(7, "?", new Range(34, 35)), + new EncodedToken(3, "[SEP]", new Range(36, 41)) + ], + tokens); + + ids = tokenizer.EncodeToIds(normalizedText!); + Assert.Equal([tokenizer.ClassificationTokenId, tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId, tokenizer.SeparatorTokenId], ids); + } + } + finally + { + File.Delete(vocabFile); + } + } + [Fact] public void TestWithLowerCasing() { @@ -35,6 +120,10 @@ public void TestWithLowerCasing() Assert.NotNull(tokenizer.Normalizer); Assert.NotNull(tokenizer.PreTokenizer); + // Make sure the SpecialTokens dictionary contains the not-normalized tokens + Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.UnknownToken)); + Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.ClassificationToken)); + string text = "Hello, How are you?"; var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText); Assert.Equal("hello, how are you?", normalizedText); @@ -511,4 +600,4 @@ public void TestCreateTokenTypeIdsFromSequences() } } } -} \ No newline at end of file +}