-
Notifications
You must be signed in to change notification settings - Fork 685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tokenizers] Add int32 option to encoding #3571
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
package ai.djl.huggingface.tokenizers; | ||
|
||
import ai.djl.huggingface.tokenizers.jni.CharSpan; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
import ai.djl.ndarray.NDManager; | ||
|
||
|
@@ -55,23 +56,85 @@ protected Encoding( | |
this.overflowing = overflowing; | ||
} | ||
|
||
/** | ||
* Returns the {@link NDList} representation of the encodings. | ||
* | ||
* @param encodings the {@code Encoding} batch | ||
* @param manager the {@link NDManager} to create the NDList | ||
* @param withTokenType true to include the token type id | ||
* @param int32 true to use int32 datatype | ||
* @return the {@link NDList} | ||
*/ | ||
public static NDList toNDList( | ||
Encoding[] encodings, NDManager manager, boolean withTokenType, boolean int32) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we leverage the instance method toNdList here? It seems like we should be able to loop over all the encondings and call encoding.toNdList(...)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are different, the static one is batched ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but per encoding they are the same. Why can't we do something like this?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an optimization for performance (was trying to compete with TEI). It reduced number of memory copies to create a batched NDArray. This way it only invoke JNI 3 times, while the loop one will copy N times and then use stack to batchify them. I actually not sure if this is necessary. |
||
NDList list = new NDList(); | ||
if (!int32) { | ||
long[][] ids = new long[encodings.length][]; | ||
long[][] attentionMask = new long[encodings.length][]; | ||
long[][] typeIds = new long[encodings.length][]; | ||
for (int i = 0; i < encodings.length; i++) { | ||
ids[i] = encodings[i].getIds(); | ||
attentionMask[i] = encodings[i].getAttentionMask(); | ||
if (withTokenType) { | ||
typeIds[i] = encodings[i].getTypeIds(); | ||
} | ||
} | ||
list.add(manager.create(ids)); | ||
NDArray inputAttentionMask = manager.create(attentionMask); | ||
list.add(inputAttentionMask); | ||
if (withTokenType) { | ||
list.add(manager.create(typeIds)); | ||
} | ||
return list; | ||
} | ||
|
||
int[][] ids = new int[encodings.length][]; | ||
int[][] attentionMask = new int[encodings.length][]; | ||
int[][] typeIds = new int[encodings.length][]; | ||
for (int i = 0; i < encodings.length; i++) { | ||
ids[i] = Arrays.stream(encodings[i].getIds()).mapToInt(l -> (int) l).toArray(); | ||
attentionMask[i] = | ||
Arrays.stream(encodings[i].getAttentionMask()).mapToInt(l -> (int) l).toArray(); | ||
if (withTokenType) { | ||
typeIds[i] = | ||
Arrays.stream(encodings[i].getTypeIds()).mapToInt(l -> (int) l).toArray(); | ||
} | ||
siddvenk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
list.add(manager.create(ids)); | ||
NDArray inputAttentionMask = manager.create(attentionMask); | ||
list.add(inputAttentionMask); | ||
if (withTokenType) { | ||
list.add(manager.create(typeIds)); | ||
} | ||
return list; | ||
} | ||
|
||
/** | ||
* Returns the {@link NDList} representation of the encoding. | ||
* | ||
* @param manager the {@link NDManager} to create the NDList | ||
* @param withTokenType true to include the token type id | ||
* @param int32 true to use int32 datatype | ||
* @return the {@link NDList} | ||
*/ | ||
public NDList toNDList(NDManager manager, boolean withTokenType) { | ||
public NDList toNDList(NDManager manager, boolean withTokenType, boolean int32) { | ||
// Converting encoding to int32 NDList because candle can't convert int64 to fp16 in cuda | ||
NDList list = new NDList(withTokenType ? 3 : 2); | ||
int[] intIds = Arrays.stream(ids).mapToInt(i -> (int) i).toArray(); | ||
int[] intAttentionMask = Arrays.stream(attentionMask).mapToInt(i -> (int) i).toArray(); | ||
list.add(manager.create(intIds)); | ||
list.add(manager.create(intAttentionMask)); | ||
if (withTokenType) { | ||
int[] intTypeIds = Arrays.stream(typeIds).mapToInt(i -> (int) i).toArray(); | ||
list.add(manager.create(intTypeIds)); | ||
if (int32) { | ||
int[] intIds = Arrays.stream(ids).mapToInt(i -> (int) i).toArray(); | ||
int[] intAttentionMask = Arrays.stream(attentionMask).mapToInt(i -> (int) i).toArray(); | ||
list.add(manager.create(intIds)); | ||
list.add(manager.create(intAttentionMask)); | ||
if (withTokenType) { | ||
int[] intTypeIds = Arrays.stream(typeIds).mapToInt(i -> (int) i).toArray(); | ||
list.add(manager.create(intTypeIds)); | ||
} | ||
} else { | ||
list.add(manager.create(ids)); | ||
list.add(manager.create(attentionMask)); | ||
if (withTokenType) { | ||
list.add(manager.create(typeIds)); | ||
} | ||
} | ||
return list; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for my understanding, why do we have a boolean for int32? Should these methods accept a dtype argument for more flexibility? Is there an issue with Rust/Candle that prevents us from doing so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the token ids are always integer, it actually should always be int64, but some of the rust implementation currently only accept int32, that's a workaround for candle.