Skip to content

Commit

Permalink
fix char offset
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan committed Nov 5, 2022
1 parent 2f4879c commit 67cd1cf
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
8 changes: 4 additions & 4 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_

let input_sequence = tk::InputSequence::from(sequence);
let encoded_input = EncodeInput::Single(input_sequence);
let encoding = tokenizer.encode(encoded_input, add_special_tokens == JNI_TRUE);
let encoding = tokenizer.encode_char_offsets(encoded_input, add_special_tokens == JNI_TRUE);

match encoding {
Ok(output) => to_handle(output),
Expand Down Expand Up @@ -127,7 +127,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
let input_sequence1 = tk::InputSequence::from(sequence1);
let input_sequence2 = tk::InputSequence::from(sequence2);
let encoded_input = EncodeInput::Dual(input_sequence1, input_sequence2);
let encoding = tokenizer.encode(encoded_input, add_special_tokens == JNI_TRUE);
let encoding = tokenizer.encode_char_offsets(encoded_input, add_special_tokens == JNI_TRUE);

match encoding {
Ok(output) => to_handle(output),
Expand Down Expand Up @@ -160,7 +160,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_

let input_sequence = tk::InputSequence::from(array);
let encoded_input = EncodeInput::from(input_sequence);
let encoding = tokenizer.encode(encoded_input, add_special_tokens == JNI_TRUE);
let encoding = tokenizer.encode_char_offsets(encoded_input, add_special_tokens == JNI_TRUE);

match encoding {
Ok(output) => to_handle(output),
Expand Down Expand Up @@ -192,7 +192,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
}

let encodings = tokenizer
.encode_batch(array, add_special_tokens == JNI_TRUE)
.encode_batch_char_offsets(array, add_special_tokens == JNI_TRUE)
.unwrap();
let handles = encodings
.into_iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ public void testTokenizer() throws IOException {
new CharSpan(14, 17),
new CharSpan(18, 21),
new CharSpan(22, 25),
new CharSpan(26, 30),
new CharSpan(31, 32),
new CharSpan(26, 27),
new CharSpan(28, 29),
null
};
int expectedLength = charSpansExpected.length;
Expand Down Expand Up @@ -389,4 +389,29 @@ public void testTruncationAndPaddingForPairInputs() throws IOException {
Assert.assertEquals(encoding.getIds().length, 8);
}
}

@Test
public void testSpecialTokenHandling() throws IOException {
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerName("distilbert-base-uncased")
.build()) {
String someText = "¥$9";
Encoding encodedText = tokenizer.encode(someText);

CharSpan[] expected =
new CharSpan[] {
new CharSpan(-1, -1),
new CharSpan(0, 1),
new CharSpan(1, 2),
new CharSpan(2, 3),
new CharSpan(-1, -1)
};
CharSpan[] charSpans = encodedText.getCharTokenSpans();
for (int i = 1; i < charSpans.length - 1; i++) {
Assert.assertEquals(expected[i].getStart(), charSpans[i].getStart());
Assert.assertEquals(expected[i].getEnd(), charSpans[i].getEnd());
}
}
}
}

0 comments on commit 67cd1cf

Please sign in to comment.