diff --git a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java index d135d5c8bba..11e1274b603 100644 --- a/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java +++ b/api/src/main/java/ai/djl/nn/transformer/IdEmbedding.java @@ -63,15 +63,20 @@ public Shape[] getOutputShapes(Shape[] inputShapes) { protected NDList forwardInternal( ParameterStore ps, NDList inputs, boolean training, PairList params) { NDArray input = inputs.singletonOrThrow(); - // on info to the right shapes, see: http://beta.mxnet.io/r/api/mx.symbol.gather_nd.html - NDArray ids = input.flatten().reshape(1, input.getShape().size()); - // create the embedding Table - NDArray embeddingTable = ps.getValue(embedding, ids.getDevice(), training); - // We do not perform a sparse lookup, instead we just project into the table - NDArray result = embeddingTable.gatherNd(ids); - // we want the original shape of the input + the last dimension of the embedding - Shape targetShape = input.getShape().addAll(new Shape(embeddingTable.getShape().get(1))); - return new NDList(result.reshape(targetShape)); + try (NDManager scope = NDManager.subManagerOf(input)) { + // on info to the right shapes, see: http://beta.mxnet.io/r/api/mx.symbol.gather_nd.html + NDArray ids = input.flatten().reshape(1, input.getShape().size()); + // create the embedding Table + NDArray embeddingTable = ps.getValue(embedding, ids.getDevice(), training); + scope.tempAttachAll(embeddingTable); + // We do not perform a sparse lookup, instead we just project into the table + NDArray result = embeddingTable.gatherNd(ids); + result.attach(inputs.getManager()); + // we want the original shape of the input + the last dimension of the embedding + Shape targetShape = + input.getShape().addAll(new Shape(embeddingTable.getShape().get(1))); + return new NDList(result.reshape(targetShape)); + } } /**