Skip to content

Commit

Permalink
[tokenizer] Return if exceed max token length (#2957)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Jan 22, 2024
1 parent 9c8cc60 commit 5ece342
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class Encoding {
private long[] specialTokenMask;
private CharSpan[] charTokenSpans;
private Encoding[] overflowing;
private boolean exceedMaxLength;

protected Encoding(
long[] ids,
Expand All @@ -36,6 +37,7 @@ protected Encoding(
long[] attentionMask,
long[] specialTokenMask,
CharSpan[] charTokenSpans,
boolean exceedMaxLength,
Encoding[] overflowing) {
this.ids = ids;
this.typeIds = typeIds;
Expand All @@ -44,6 +46,7 @@ protected Encoding(
this.attentionMask = attentionMask;
this.specialTokenMask = specialTokenMask;
this.charTokenSpans = charTokenSpans;
this.exceedMaxLength = exceedMaxLength;
this.overflowing = overflowing;
}

Expand Down Expand Up @@ -127,6 +130,15 @@ public CharSpan[] getCharTokenSpans() {
return charTokenSpans;
}

/**
* Returns if tokens exceed max length.
*
* @return {@code true} if tokens exceed max length
*/
public boolean exceedMaxLength() {
return exceedMaxLength;
}

/**
* Returns an array of overflowing encodings.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,10 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);

long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);
boolean exceedMaxLength = overflowingHandles.length > 0;
Encoding[] overflowing;
if (withOverflowingTokens) {
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);

overflowing = new Encoding[overflowingHandles.length];
for (int i = 0; i < overflowingHandles.length; ++i) {
overflowing[i] = toEncoding(overflowingHandles[i], true);
Expand All @@ -551,6 +551,7 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
attentionMask,
specialTokenMask,
charSpans,
exceedMaxLength,
overflowing);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ public void testTruncationStride() throws IOException {
.build()) {
String text = "Hello there my friend I am happy to see you";
String textPair = "How are you my friend";
Encoding[] overflowing = tokenizer.encode(text, textPair).getOverflowing();
Encoding encoding = tokenizer.encode(text, textPair);
Assert.assertTrue(encoding.exceedMaxLength());
Encoding[] overflowing = encoding.getOverflowing();

int expectedNumberOfOverflowEncodings = 7;
Assert.assertEquals(overflowing.length, expectedNumberOfOverflowEncodings);
Expand Down

0 comments on commit 5ece342

Please sign in to comment.