Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizers] Add int32 option to encoding #3571

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.huggingface.tokenizers;

import ai.djl.huggingface.tokenizers.jni.CharSpan;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;

Expand Down Expand Up @@ -55,23 +56,85 @@ protected Encoding(
this.overflowing = overflowing;
}

/**
* Returns the {@link NDList} representation of the encodings.
*
* @param encodings the {@code Encoding} batch
* @param manager the {@link NDManager} to create the NDList
* @param withTokenType true to include the token type id
* @param int32 true to use int32 datatype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding, why do we have a boolean for int32? Should these methods accept a dtype argument for more flexibility? Is there an issue with Rust/Candle that prevents us from doing so?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the token ids are always integer, it actually should always be int64, but some of the rust implementation currently only accept int32, that's a workaround for candle.

* @return the {@link NDList}
*/
public static NDList toNDList(
Encoding[] encodings, NDManager manager, boolean withTokenType, boolean int32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we leverage the instance method toNdList here? It seems like we should be able to loop over all the encondings and call encoding.toNdList(...)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are different, the static one is batched (Encoding[])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but per encoding they are the same. Why can't we do something like this?

public static NDList toNDList(Encoding[] encodings, NDManager manager, boolean withTokenType, boolean int32) {
    NDList list = new NDList();
    for (int i = 0; i < encodings.length; i++) {
        NDList encoding = encodings[i].toNDList(manager, withTokenType, int32);
        list.addAll(encoding);
    }
    return list;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an optimization for performance (was trying to compete with TEI). It reduced number of memory copies to create a batched NDArray. This way it only invoke JNI 3 times, while the loop one will copy N times and then use stack to batchify them.

I actually not sure if this is necessary.

NDList list = new NDList();
if (!int32) {
long[][] ids = new long[encodings.length][];
long[][] attentionMask = new long[encodings.length][];
long[][] typeIds = new long[encodings.length][];
for (int i = 0; i < encodings.length; i++) {
ids[i] = encodings[i].getIds();
attentionMask[i] = encodings[i].getAttentionMask();
if (withTokenType) {
typeIds[i] = encodings[i].getTypeIds();
}
}
list.add(manager.create(ids));
NDArray inputAttentionMask = manager.create(attentionMask);
list.add(inputAttentionMask);
if (withTokenType) {
list.add(manager.create(typeIds));
}
return list;
}

int[][] ids = new int[encodings.length][];
int[][] attentionMask = new int[encodings.length][];
int[][] typeIds = new int[encodings.length][];
for (int i = 0; i < encodings.length; i++) {
ids[i] = Arrays.stream(encodings[i].getIds()).mapToInt(l -> (int) l).toArray();
attentionMask[i] =
Arrays.stream(encodings[i].getAttentionMask()).mapToInt(l -> (int) l).toArray();
if (withTokenType) {
typeIds[i] =
Arrays.stream(encodings[i].getTypeIds()).mapToInt(l -> (int) l).toArray();
}
}
list.add(manager.create(ids));
NDArray inputAttentionMask = manager.create(attentionMask);
list.add(inputAttentionMask);
if (withTokenType) {
list.add(manager.create(typeIds));
}
return list;
}

/**
* Returns the {@link NDList} representation of the encoding.
*
* @param manager the {@link NDManager} to create the NDList
* @param withTokenType true to include the token type id
* @param int32 true to use int32 datatype
* @return the {@link NDList}
*/
public NDList toNDList(NDManager manager, boolean withTokenType) {
public NDList toNDList(NDManager manager, boolean withTokenType, boolean int32) {
// Converting encoding to int32 NDList because candle can't convert int64 to fp16 in cuda
NDList list = new NDList(withTokenType ? 3 : 2);
int[] intIds = Arrays.stream(ids).mapToInt(i -> (int) i).toArray();
int[] intAttentionMask = Arrays.stream(attentionMask).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intIds));
list.add(manager.create(intAttentionMask));
if (withTokenType) {
int[] intTypeIds = Arrays.stream(typeIds).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intTypeIds));
if (int32) {
int[] intIds = Arrays.stream(ids).mapToInt(i -> (int) i).toArray();
int[] intAttentionMask = Arrays.stream(attentionMask).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intIds));
list.add(manager.create(intAttentionMask));
if (withTokenType) {
int[] intTypeIds = Arrays.stream(typeIds).mapToInt(i -> (int) i).toArray();
list.add(manager.create(intTypeIds));
}
} else {
list.add(manager.create(ids));
list.add(manager.create(attentionMask));
if (withTokenType) {
list.add(manager.create(typeIds));
}
}
return list;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ public class CrossEncoderTranslator implements Translator<StringPair, float[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean int32;
private boolean sigmoid;
private Batchifier batchifier;

CrossEncoderTranslator(
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
boolean int32,
boolean sigmoid,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.int32 = int32;
this.sigmoid = sigmoid;
this.batchifier = batchifier;
}
Expand All @@ -60,7 +63,7 @@ public Batchifier getBatchifier() {
public NDList processInput(TranslatorContext ctx, StringPair input) {
Encoding encoding = tokenizer.encode(input.getKey(), input.getValue());
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes, int32);
}

/** {@inheritDoc} */
Expand All @@ -71,7 +74,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List<StringPair> inputs)
Encoding[] encodings = tokenizer.batchEncode(list);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes, int32);
}
return batchifier.batchify(batch);
}
Expand Down Expand Up @@ -145,6 +148,7 @@ public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean int32;
private boolean sigmoid = true;
private Batchifier batchifier = Batchifier.STACK;

Expand All @@ -163,6 +167,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
return this;
}

/**
* Sets if use int32 datatype for the {@link Translator}.
*
* @param int32 true to include token types
* @return this builder
*/
public Builder optInt32(boolean int32) {
this.int32 = int32;
return this;
}

/**
* Sets if apply sigmoid for the {@link Translator}.
*
Expand Down Expand Up @@ -192,6 +207,7 @@ public Builder optBatchifier(Batchifier batchifier) {
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
optInt32(ArgumentsUtil.booleanValue(arguments, "int32"));
optSigmoid(ArgumentsUtil.booleanValue(arguments, "sigmoid", true));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
Expand All @@ -204,7 +220,8 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public CrossEncoderTranslator build() throws IOException {
return new CrossEncoderTranslator(tokenizer, includeTokenTypes, sigmoid, batchifier);
return new CrossEncoderTranslator(
tokenizer, includeTokenTypes, int32, sigmoid, batchifier);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,21 @@ public class FillMaskTranslator implements Translator<String, Classifications> {
private long maskTokenId;
private int topK;
private boolean includeTokenTypes;
private boolean int32;
private Batchifier batchifier;

FillMaskTranslator(
HuggingFaceTokenizer tokenizer,
String maskToken,
int topK,
boolean includeTokenTypes,
boolean int32,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.maskToken = maskToken;
this.topK = topK;
this.includeTokenTypes = includeTokenTypes;
this.int32 = int32;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
Expand All @@ -68,7 +71,7 @@ public NDList processInput(TranslatorContext ctx, String input) throws Translate
long[] indices = encoding.getIds();
int maskIndex = getMaskIndex(indices);
ctx.setAttachment("maskIndex", maskIndex);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes, int32);
}

/** {@inheritDoc} */
Expand All @@ -83,7 +86,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List<String> inputs)
for (int i = 0; i < batch.length; ++i) {
long[] indices = encodings[i].getIds();
maskIndices[i] = getMaskIndex(indices);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes, int32);
}
return batchifier.batchify(batch);
}
Expand Down Expand Up @@ -167,6 +170,7 @@ public static final class Builder {
private String maskedToken = "[MASK]";
private int topK = 5;
private boolean includeTokenTypes;
private boolean int32;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
Expand Down Expand Up @@ -206,6 +210,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
return this;
}

/**
* Sets if use int32 datatype for the {@link Translator}.
*
* @param int32 true to include token types
* @return this builder
*/
public Builder optInt32(boolean int32) {
this.int32 = int32;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -224,6 +239,7 @@ public Builder optBatchifier(Batchifier batchifier) {
*/
public void configure(Map<String, ?> arguments) {
optMaskToken(ArgumentsUtil.stringValue(arguments, "maskToken", "[MASK]"));
optInt32(ArgumentsUtil.booleanValue(arguments, "int32"));
optTopK(ArgumentsUtil.intValue(arguments, "topK", 5));
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
Expand All @@ -238,7 +254,7 @@ public void configure(Map<String, ?> arguments) {
*/
public FillMaskTranslator build() throws IOException {
return new FillMaskTranslator(
tokenizer, maskedToken, topK, includeTokenTypes, batchifier);
tokenizer, maskedToken, topK, includeTokenTypes, int32, batchifier);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,19 @@ public class QuestionAnsweringTranslator implements Translator<QAInput, String>

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean int32;
private Batchifier batchifier;
private boolean detail;

QuestionAnsweringTranslator(
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
boolean int32,
Batchifier batchifier,
boolean detail) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.int32 = int32;
this.batchifier = batchifier;
this.detail = detail;
}
Expand All @@ -62,7 +65,7 @@ public Batchifier getBatchifier() {
public NDList processInput(TranslatorContext ctx, QAInput input) {
Encoding encoding = tokenizer.encode(input.getQuestion(), input.getParagraph());
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes, int32);
}

/** {@inheritDoc} */
Expand All @@ -77,7 +80,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List<QAInput> inputs) {
ctx.setAttachment("encodings", encodings);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes, int32);
}
return batchifier.batchify(batch);
}
Expand Down Expand Up @@ -190,6 +193,7 @@ public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private boolean int32;
private Batchifier batchifier = Batchifier.STACK;
private boolean detail;

Expand All @@ -208,6 +212,17 @@ public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
return this;
}

/**
* Sets if use int32 datatype for the {@link Translator}.
*
* @param int32 true to include token types
* @return this builder
*/
public Builder optInt32(boolean int32) {
this.int32 = int32;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand Down Expand Up @@ -237,6 +252,7 @@ public Builder optDetail(boolean detail) {
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
optInt32(ArgumentsUtil.booleanValue(arguments, "int32"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optDetail(ArgumentsUtil.booleanValue(arguments, "detail"));
optBatchifier(Batchifier.fromString(batchifierStr));
Expand All @@ -250,7 +266,7 @@ public void configure(Map<String, ?> arguments) {
*/
public QuestionAnsweringTranslator build() throws IOException {
return new QuestionAnsweringTranslator(
tokenizer, includeTokenTypes, batchifier, detail);
tokenizer, includeTokenTypes, int32, batchifier, detail);
}
}
}
Loading
Loading