Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved SpecialTokens assignment after the modification to avoid "Collection Modified" error #7328

Merged
merged 7 commits into from
Dec 5, 2024
45 changes: 29 additions & 16 deletions src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand Down Expand Up @@ -762,15 +762,16 @@ private static BertTokenizer Create(

options.Normalizer ??= options.ApplyBasicTokenization ? new BertNormalizer(options.LowerCaseBeforeTokenization, options.IndividuallyTokenizeCjk, options.RemoveNonSpacingMarks) : null;

IReadOnlyDictionary<string, int>? specialTokensDict = options.SpecialTokens;
if (options.SplitOnSpecialTokens)
{
bool lowerCase = options.ApplyBasicTokenization && options.LowerCaseBeforeTokenization;
if (options.SpecialTokens is not null)
{
if (lowerCase)
{
Dictionary<string, int> dic = options.SpecialTokens.ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
options.SpecialTokens = dic;
Dictionary<string, int> tempSpecialTokens = [];
specialTokensDict = tempSpecialTokens;

foreach (var kvp in options.SpecialTokens)
{
Expand All @@ -779,37 +780,49 @@ private static BertTokenizer Create(
throw new ArgumentException($"The special token '{kvp.Key}' is not in the vocabulary or assigned id value {id} different than the value {kvp.Value} in the special tokens.");
}

// Ensure that the special tokens are lowercased.
dic[kvp.Key.ToLowerInvariant()] = kvp.Value;
// Add the special token into our dictionary, normalizing it, and adding it into the
// main vocab, if needed.
AddSpecialToken(vocab, tempSpecialTokens, kvp.Key, true);
}
}
}
else
{
// Create a dictionary with the special tokens.
Dictionary<string, int> specialTokens = new Dictionary<string, int>();
options.SpecialTokens = specialTokens;

AddSpecialToken(vocab, specialTokens, options.UnknownToken, lowerCase);
AddSpecialToken(vocab, specialTokens, options.SeparatorToken, lowerCase);
AddSpecialToken(vocab, specialTokens, options.PaddingToken, lowerCase);
AddSpecialToken(vocab, specialTokens, options.ClassificationToken, lowerCase);
AddSpecialToken(vocab, specialTokens, options.MaskingToken, lowerCase);
// Create a dictionary with the special tokens - store the un-normalized forms in the options as
// that field is exposed to the public. In addition, store the normalized form for creating the
// pre-tokenizer.
Dictionary<string, int> tempSpecialTokens = [];
Dictionary<string, int> notNormalizedSpecialTokens = [];
AddSpecialToken(vocab, tempSpecialTokens, options.UnknownToken, lowerCase, notNormalizedSpecialTokens);
AddSpecialToken(vocab, tempSpecialTokens, options.SeparatorToken, lowerCase, notNormalizedSpecialTokens);
AddSpecialToken(vocab, tempSpecialTokens, options.PaddingToken, lowerCase, notNormalizedSpecialTokens);
AddSpecialToken(vocab, tempSpecialTokens, options.ClassificationToken, lowerCase, notNormalizedSpecialTokens);
AddSpecialToken(vocab, tempSpecialTokens, options.MaskingToken, lowerCase, notNormalizedSpecialTokens);

options.SpecialTokens = notNormalizedSpecialTokens;
specialTokensDict = tempSpecialTokens;
}
}

options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? options.SpecialTokens : null) : PreTokenizer.CreateWhiteSpace();
// We set the PreTokenizer here using the normalized special tokens dict (if relevant), and therefore we can
// keep the not-normalized special tokens dict in the options passed to the WordPieceTokenizer.
options.PreTokenizer ??= options.ApplyBasicTokenization ? PreTokenizer.CreateWordOrPunctuation(options.SplitOnSpecialTokens ? specialTokensDict : null) : PreTokenizer.CreateWhiteSpace();

return new BertTokenizer(vocab, vocabReverse, options);
}

private static void AddSpecialToken(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<string, int> specialTokens, string token, bool lowerCase)
private static void AddSpecialToken(Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<string, int> specialTokens, string token, bool lowerCase, Dictionary<string, int>? notNormalizedSpecialTokens = null)
{
if (token is null || !vocab.TryGetValue(new StringSpanOrdinalKey(token), out int id))
{
throw new ArgumentException($"The special token '{token}' is not in the vocabulary.");
}

if (notNormalizedSpecialTokens is not null)
{
notNormalizedSpecialTokens[token] = id;
}

string normalizedToken = token;
if (lowerCase)
{
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand Down Expand Up @@ -42,7 +42,7 @@ internal WordPieceTokenizer(
options ??= new();

SpecialTokens = options.SpecialTokens;
SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.ToDictionary(kvp => kvp.Value, kvp => kvp.Key) : null;
SpecialTokensReverse = options.SpecialTokens is not null ? options.SpecialTokens.GroupBy(kvp => kvp.Value).ToDictionary(g => g.Key, g => g.First().Key) : null;

if (options.UnknownToken is null)
{
Expand Down Expand Up @@ -800,4 +800,4 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
return OperationStatus.Done;
}
}
}
}
93 changes: 91 additions & 2 deletions test/Microsoft.ML.Tokenizers.Tests/BertTokenizerTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand All @@ -14,6 +14,91 @@ namespace Microsoft.ML.Tokenizers.Tests
{
public class BertTokenizerTests
{
[Fact]
public void TestWithLowerCasingExplicitSpecialTokens()
{
// Add [SPECIAL] token at end (to keep indices as is)
// Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13
string[] vocabTokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "!", ",", "?", "hello", "world", "how", "are", "you", "[SPECIAL]"];

string vocabFile = WordPieceTests.CreateVocabFile(vocabTokens);

Dictionary<string, int> specialTokens = new() {
{ "[PAD]", 0 },
{ "[UNK]", 1 },
{ "[CLS]", 2 },
{ "[SEP]", 3 },
{ "[MASK]", 4 },
{ "[SPECIAL]", 13 },
};
var bertOptions = new BertOptions()
{
SpecialTokens = specialTokens
};

try
{
using Stream vocabStream = File.OpenRead(vocabFile);
BertTokenizer[] bertTokenizers = [BertTokenizer.Create(vocabFile, bertOptions), BertTokenizer.Create(vocabStream, bertOptions)];

foreach (var tokenizer in bertTokenizers)
{
Assert.NotNull(tokenizer.PreTokenizer);
Assert.Equal("[UNK]", tokenizer.UnknownToken);
Assert.Equal(1, tokenizer.UnknownTokenId);
Assert.NotNull(tokenizer.Normalizer);
Assert.NotNull(tokenizer.PreTokenizer);

Assert.True(tokenizer.SpecialTokens!.ContainsKey("[SPECIAL]"));

string text = "Hello, How are you [SPECIAL]?";
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
Assert.Equal("hello, how are you [special]?", normalizedText);

Assert.Equal(
[
new EncodedToken(8, "hello", new Range(0, 5)),
new EncodedToken(6, ",", new Range(5, 6)),
new EncodedToken(10, "how", new Range(7, 10)),
new EncodedToken(11, "are", new Range(11, 14)),
new EncodedToken(12, "you", new Range(15, 18)),
new EncodedToken(13, "[SPECIAL]", new Range(19, 28)),
new EncodedToken(7, "?", new Range(28, 29))
],
tokens);

var ids = tokenizer.EncodeToIds(text);
Assert.Equal([tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId], ids);

Assert.Equal("[CLS] hello, how are you [SPECIAL]? [SEP]", tokenizer.Decode(ids));
Assert.Equal("hello, how are you?", tokenizer.Decode(ids, skipSpecialTokens: true));

tokens = tokenizer.EncodeToTokens(tokenizer.Decode(ids), out normalizedText);
Assert.Equal("[cls] hello, how are you [special]? [sep]", normalizedText);
Assert.Equal(
[
new EncodedToken(2, "[CLS]", new Range(0, 5)),
new EncodedToken(8, "hello", new Range(6, 11)),
new EncodedToken(6, ",", new Range(11, 12)),
new EncodedToken(10, "how", new Range(13, 16)),
new EncodedToken(11, "are", new Range(17, 20)),
new EncodedToken(12, "you", new Range(21, 24)),
new EncodedToken(13, "[SPECIAL]", new Range(25, 34)),
new EncodedToken(7, "?", new Range(34, 35)),
new EncodedToken(3, "[SEP]", new Range(36, 41))
],
tokens);

ids = tokenizer.EncodeToIds(normalizedText!);
Assert.Equal([tokenizer.ClassificationTokenId, tokenizer.ClassificationTokenId, 8, 6, 10, 11, 12, 13, 7, tokenizer.SeparatorTokenId, tokenizer.SeparatorTokenId], ids);
}
}
finally
{
File.Delete(vocabFile);
}
}

[Fact]
public void TestWithLowerCasing()
{
Expand All @@ -35,6 +120,10 @@ public void TestWithLowerCasing()
Assert.NotNull(tokenizer.Normalizer);
Assert.NotNull(tokenizer.PreTokenizer);

// Make sure the SpecialTokens dictionary contains the not-normalized tokens
Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.UnknownToken));
Assert.True(tokenizer.SpecialTokens!.ContainsKey(tokenizer.ClassificationToken));

string text = "Hello, How are you?";
var tokens = tokenizer.EncodeToTokens(text, out string? normalizedText);
Assert.Equal("hello, how are you?", normalizedText);
Expand Down Expand Up @@ -511,4 +600,4 @@ public void TestCreateTokenTypeIdsFromSequences()
}
}
}
}
}