Skip to content

Commit

Permalink
[api] Fixes IdEmbedding memory leak (#3257)
Browse files Browse the repository at this point in the history
Fixes: #3251
  • Loading branch information
frankfliu authored Jun 17, 2024
1 parent 34caf33 commit bb7dfd1
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,20 @@ public Shape[] getOutputShapes(Shape[] inputShapes) {
protected NDList forwardInternal(
ParameterStore ps, NDList inputs, boolean training, PairList<String, Object> params) {
NDArray input = inputs.singletonOrThrow();
// on info to the right shapes, see: http://beta.mxnet.io/r/api/mx.symbol.gather_nd.html
NDArray ids = input.flatten().reshape(1, input.getShape().size());
// create the embedding Table
NDArray embeddingTable = ps.getValue(embedding, ids.getDevice(), training);
// We do not perform a sparse lookup, instead we just project into the table
NDArray result = embeddingTable.gatherNd(ids);
// we want the original shape of the input + the last dimension of the embedding
Shape targetShape = input.getShape().addAll(new Shape(embeddingTable.getShape().get(1)));
return new NDList(result.reshape(targetShape));
try (NDManager scope = NDManager.subManagerOf(input)) {
// on info to the right shapes, see: http://beta.mxnet.io/r/api/mx.symbol.gather_nd.html
NDArray ids = input.flatten().reshape(1, input.getShape().size());
// create the embedding Table
NDArray embeddingTable = ps.getValue(embedding, ids.getDevice(), training);
scope.tempAttachAll(embeddingTable);
// We do not perform a sparse lookup, instead we just project into the table
NDArray result = embeddingTable.gatherNd(ids);
result.attach(inputs.getManager());
// we want the original shape of the input + the last dimension of the embedding
Shape targetShape =
input.getShape().addAll(new Shape(embeddingTable.getShape().get(1)));
return new NDList(result.reshape(targetShape));
}
}

/**
Expand Down

0 comments on commit bb7dfd1

Please sign in to comment.