From c3ff4a5d54c5f6d3e8b694851d0825ebfb27a6e8 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 12 Apr 2022 10:16:18 -0700 Subject: [PATCH] Separate AbstractSymbolBlock from AbstractBlock (#1555) Right now, the AbstractSymbolBlock inherits from AbstractBlock. However, the AbstractBlock parameters system and getParameters implies that it is always possible to get the parameters for a block. While this should be true for a correctly written DJL block, it is not always true to symbol blocks. So, this change separates them to ensure that trying to get parameters when it is not possible returns the exception reflecting that the operation is unsupported. The shared functionality between AbstractSymbolBlock and AbstractBlock was moved to a common base class, AbstractBaseBlock. --- .../java/ai/djl/nn/AbstractBaseBlock.java | 417 ++++++++++++++++++ .../main/java/ai/djl/nn/AbstractBlock.java | 380 +--------------- .../java/ai/djl/nn/AbstractSymbolBlock.java | 8 +- .../ai/djl/dlr/engine/DlrSymbolBlock.java | 7 + .../ai/djl/ml/xgboost/XgbSymbolBlock.java | 7 + .../ai/djl/mxnet/engine/MxSymbolBlock.java | 12 +- .../onnxruntime/engine/OrtSymbolBlock.java | 7 + .../paddlepaddle/engine/PpSymbolBlock.java | 7 + .../ai/djl/pytorch/engine/PtSymbolBlock.java | 42 ++ .../djl/tensorflow/engine/TfSymbolBlock.java | 7 + .../djl/tensorrt/engine/TrtSymbolBlock.java | 7 + .../djl/tflite/engine/TfLiteSymbolBlock.java | 7 + .../java/ai/djl/fasttext/FtAbstractBlock.java | 7 + 13 files changed, 538 insertions(+), 377 deletions(-) create mode 100644 api/src/main/java/ai/djl/nn/AbstractBaseBlock.java diff --git a/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java b/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java new file mode 100644 index 00000000000..5c6a5503091 --- /dev/null +++ b/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java @@ -0,0 +1,417 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.nn; + +import ai.djl.MalformedModelException; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.training.ParameterStore; +import ai.djl.training.initializer.Initializer; +import ai.djl.util.Pair; +import ai.djl.util.PairList; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; + +/** + * This provides shared functionality for both the DJL-based {@link AbstractBlock}s and the imported + * {@link AbstractSymbolBlock}s. + */ +public abstract class AbstractBaseBlock implements Block { + + /** + * The model version of this block, used for checking if parameters are still valid during + * parameter loading. + */ + protected byte version; + + /** The shape of the input for this block, set by the initialization process. */ + protected Shape[] inputShapes; + + /** List of names for the input, named inputs should be manually set in sub class. */ + protected List inputNames = Collections.emptyList(); + + /** Constructs a new {@link AbstractBaseBlock} instance. */ + public AbstractBaseBlock() { + this((byte) 1); + } + + /** + * Builds an empty block with the given version for parameter serialization. + * + * @param version the version to use for parameter serialization. + */ + public AbstractBaseBlock(byte version) { + this.version = version; + } + + /** {@inheritDoc} */ + @Override + public final NDList forward( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDManager paramsManager = parameterStore.getManager(); + if (training && !isInitialized()) { + initialize(paramsManager, DataType.FLOAT32, inputs.getShapes()); + } + return forwardInternal(parameterStore, inputs, training, params); + } + + /** {@inheritDoc} */ + @Override + public NDList forward( + ParameterStore parameterStore, + NDList data, + NDList labels, + PairList params) { + NDManager paramsManager = parameterStore.getManager(); + if (!isInitialized()) { + initialize(paramsManager, DataType.FLOAT32, data.getShapes()); + } + return forwardInternal(parameterStore, data, labels, params); + } + + /** + * A helper for {@link Block#forward(ParameterStore, NDList, boolean, PairList)} after + * initialization. + * + * @param parameterStore the parameter store + * @param inputs the input NDList + * @param training true for a training forward pass + * @param params optional parameters + * @return the output of the forward pass + */ + protected abstract NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params); + + /** + * A helper for {@link Block#forward(ParameterStore, NDList, NDList, PairList)} after + * initialization. + * + * @param parameterStore the parameter store + * @param data the input data NDList + * @param labels the input labels NDList + * @param params optional parameters + * @return the output of the forward pass + * @see #forward(ParameterStore, NDList, boolean, PairList) + */ + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList data, + NDList labels, + PairList params) { + return forwardInternal(parameterStore, data, true, params); + } + + /** {@inheritDoc} */ + @Override + public PairList describeInput() { + if (!isInitialized()) { + throw new IllegalStateException( + "Parameter of this block are not initialised," + + "please call model.newTrainer and trainer.initialize"); + } + return new PairList<>(inputNames, Arrays.asList(inputShapes)); + } + + /** {@inheritDoc} */ + @Override + public void setInitializer(Initializer initializer, Parameter.Type params) { + Predicate predicate = parameter -> parameter.getType().equals(params); + setInitializer(initializer, predicate); + } + + /** {@inheritDoc} */ + @Override + public void setInitializer(Initializer initializer, String paramName) { + Parameter parameter = + getDirectParameters() + .values() + .stream() + .filter(p -> p.getName().equals(paramName)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + "Could not find parameter " + paramName)); + parameter.setInitializer(initializer); + } + + /** {@inheritDoc} */ + @Override + public void setInitializer(Initializer initializer, Predicate predicate) { + List params = getParameters().values(); + for (Parameter param : params) { + if (predicate.test(param)) { + param.setInitializer(initializer); + } + } + } + + /** {@inheritDoc} */ + @Override + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + beforeInitialize(inputShapes); + // if parameters are initialized, skip it + if (!isInitialized()) { + // setShape for all params + prepare(inputShapes); + } + for (Parameter parameter : getDirectParameters().values()) { + parameter.initialize(manager, dataType); + } + initializeChildBlocks(manager, dataType, inputShapes); + } + + /** + * Performs any action necessary before initialization. For example, keep the input information + * or verify the layout. + * + * @param inputShapes the expected shapes of the input + */ + protected void beforeInitialize(Shape... inputShapes) { + if (inputNames.isEmpty()) { + // automatically assign input names + inputNames = new ArrayList<>(); + for (int i = 0; i < inputShapes.length; ++i) { + inputNames.add("data" + i); + } + } + this.inputShapes = inputShapes; + } + + /** + * Initializes the Child blocks of this block. You need to override this method if your subclass + * has child blocks. Used to determine the correct input shapes for child blocks based on the + * requested input shape for this block. + * + * @param manager the manager to use for initialization + * @param dataType the requested data type + * @param inputShapes the expected input shapes for this block + */ + protected void initializeChildBlocks( + NDManager manager, DataType dataType, Shape... inputShapes) { + if (!getChildren().isEmpty()) { + throw new IllegalStateException( + getClass().getSimpleName() + + " has child blocks but initializeChildBlocks is not overwritten."); + } + } + + /** + * Sets the shape of {@link Parameter}s. + * + * @param inputShapes the shapes of inputs + */ + protected void prepare(Shape[] inputShapes) {} + + /** {@inheritDoc} */ + @Override + public ParameterList getParameters() { + // we accumulate a list of all parameters by starting with a list of the direct parameters + ParameterList allParams = getDirectParameters(); + // then we add the parameters of child blocks + for (Pair childPair : getChildren()) { + for (Pair paramPair : childPair.getValue().getParameters()) { + // we prepend the name of the child block to the parameter name + allParams.add(childPair.getKey() + "_" + paramPair.getKey(), paramPair.getValue()); + } + } + return allParams; + } + + /** {@inheritDoc} */ + @Override + public boolean isInitialized() { + if (inputShapes == null) { + return false; + } + for (Parameter param : getParameters().values()) { + if (!param.isInitialized()) { + return false; + } + } + return true; + } + + /** {@inheritDoc} */ + @Override + public void clear() { + getParameters().forEach(param -> param.getValue().close()); + } + + /** {@inheritDoc} */ + @Override + public void cast(DataType dataType) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public void saveParameters(DataOutputStream os) throws IOException { + os.write(version); + saveMetadata(os); + for (Parameter parameter : getDirectParameters().values()) { + parameter.save(os); + } + for (Block child : getChildren().values()) { + child.saveParameters(os); + } + } + + /** {@inheritDoc} */ + @Override + public void loadParameters(NDManager manager, DataInputStream is) + throws IOException, MalformedModelException { + byte loadVersion = is.readByte(); + loadMetadata(loadVersion, is); + for (Parameter parameter : getDirectParameters().values()) { + parameter.load(manager, is); + } + for (Block child : getChildren().values()) { + child.loadParameters(manager, is); + } + } + + /** + * Override this method to save additional data apart from parameter values. + * + *

This default implementation saves the currently set input shapes. + * + * @param os the non-null output stream the parameter values and metadata are written to + * @throws IOException saving failed + */ + protected void saveMetadata(DataOutputStream os) throws IOException { + saveInputShapes(os); + } + + /** + * Overwrite this to load additional metadata with the parameter values. + * + *

If you overwrite {@link AbstractBlock#saveMetadata(DataOutputStream)} or need to provide + * backward compatibility to older binary formats, you prabably need to overwrite this. This + * default implementation checks if the version number fits, if not it throws an {@link + * MalformedModelException}. After that it restores the input shapes. + * + * @param loadVersion the version used for loading this metadata. + * @param is the input stream we are loading from + * @throws IOException loading failed + * @throws MalformedModelException data can be loaded but has wrong format + */ + protected void loadMetadata(byte loadVersion, DataInputStream is) + throws IOException, MalformedModelException { + if (loadVersion != version) { + throw new MalformedModelException( + "Cannot load parameters for " + + this.getClass().getCanonicalName() + + ", expected version " + + version + + ", got " + + loadVersion + + "."); + } + readInputShapes(is); + } + + protected void saveInputShapes(DataOutputStream os) throws IOException { + os.writeInt(inputShapes.length); + for (Shape shape : inputShapes) { + os.write(shape.getEncoded()); + } + } + + protected void readInputShapes(DataInputStream is) throws IOException { + int len = is.readInt(); + Shape[] shapes = new Shape[len]; + for (int i = 0; i < len; ++i) { + shapes[i] = Shape.decode(is); + } + if (inputShapes == null) { + // load inputShapes from parameter file if Block has not been initialized + inputShapes = shapes; + } + } + + /** {@inheritDoc} */ + @Override + public String toString() { + // FIXME: This is a quick hack for display in jupyter notebook. + StringBuilder sb = new StringBuilder(200); + String className = getClass().getSimpleName(); + if (className.endsWith("Block")) { + className = className.substring(0, className.length() - 5); + } + sb.append(className).append('('); + if (isInitialized()) { + PairList inputShapeDescription = describeInput(); + appendShape(sb, inputShapeDescription.values().toArray(new Shape[0])); + sb.append(" -> "); + Shape[] outputShapes = + getOutputShapes(inputShapeDescription.values().toArray(new Shape[0])); + appendShape(sb, outputShapes); + } else { + sb.append("Uninitialized"); + } + sb.append(')'); + return sb.toString(); + } + + private void appendShape(StringBuilder sb, Shape[] shapes) { + boolean first = true; + for (Shape shape : shapes) { + if (first) { + first = false; + } else { + sb.append(", "); + } + long[] sh = shape.getShape(); + int length = sh.length; + if (length == 0) { + sb.append("()"); + } else { + int index = 0; + if (sh[0] == -1) { + --length; + index = 1; + } + + if (length == 0) { + sb.append("()"); + } else if (length == 1) { + sb.append(sh[index]); + } else { + sb.append('('); + for (int i = index; i < sh.length; ++i) { + if (i > index) { + sb.append(", "); + } + sb.append(sh[i]); + } + sb.append(')'); + } + } + } + } +} diff --git a/api/src/main/java/ai/djl/nn/AbstractBlock.java b/api/src/main/java/ai/djl/nn/AbstractBlock.java index 4bf3e7ae010..e3cc6cd6adc 100644 --- a/api/src/main/java/ai/djl/nn/AbstractBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractBlock.java @@ -12,25 +12,16 @@ */ package ai.djl.nn; -import ai.djl.MalformedModelException; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.training.ParameterStore; -import ai.djl.training.initializer.Initializer; import ai.djl.util.Pair; import ai.djl.util.PairList; -import java.io.DataInputStream; import java.io.DataOutputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Locale; -import java.util.function.Predicate; /** * {@code AbstractBlock} is an abstract implementation of {@link Block}. @@ -55,8 +46,8 @@ * implement the computation of your block *

  • IFF you need to save data apart from the parameter values of your block, you need to * override {@link AbstractBlock#saveMetadata(DataOutputStream)} and {@link - * AbstractBlock#loadMetadata(byte, DataInputStream)}. If you do not need to save or load any - * state other than parameters in your block, you can skip this. + * AbstractBlock#loadMetadata(byte, java.io.DataInputStream)}. If you do not need to save or + * load any state other than parameters in your block, you can skip this. * * *

    If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take @@ -68,19 +59,7 @@ // of this API know the children and parameters are always iterated over in insertion order. // LinkedHashMap provides this guarantee, Map does not. @SuppressWarnings("PMD.LooseCoupling") -public abstract class AbstractBlock implements Block { - - /** The shape of the input for this block, set by the initialization process. */ - protected Shape[] inputShapes; - - /** List of names for the input, named inputs should be manually set in sub class. */ - protected List inputNames = Collections.emptyList(); - - /** - * The model version of this block, used for checking if parameters are still valid during - * parameter loading. - */ - protected byte version; +public abstract class AbstractBlock extends AbstractBaseBlock { /** * All direct children of this Block. Keys are names of the blocks. @@ -99,9 +78,7 @@ public abstract class AbstractBlock implements Block { protected LinkedHashMap parameters = new LinkedHashMap<>(); /** Constructs a new {@code AbstractBlock} instance. */ - public AbstractBlock() { - this((byte) 1); - } + public AbstractBlock() {} /** * Builds an empty block with the given version for parameter serialization. @@ -109,70 +86,7 @@ public AbstractBlock() { * @param version the version to use for parameter serialization. */ public AbstractBlock(byte version) { - this.version = version; - } - - /** {@inheritDoc} */ - @Override - public final NDList forward( - ParameterStore parameterStore, - NDList inputs, - boolean training, - PairList params) { - NDManager paramsManager = parameterStore.getManager(); - if (training && !isInitialized()) { - initialize(paramsManager, DataType.FLOAT32, inputs.getShapes()); - } - return forwardInternal(parameterStore, inputs, training, params); - } - - /** {@inheritDoc} */ - @Override - public NDList forward( - ParameterStore parameterStore, - NDList data, - NDList labels, - PairList params) { - NDManager paramsManager = parameterStore.getManager(); - if (!isInitialized()) { - initialize(paramsManager, DataType.FLOAT32, data.getShapes()); - } - return forwardInternal(parameterStore, data, labels, params); - } - - /** - * A helper for {@link Block#forward(ParameterStore, NDList, boolean, PairList)} after - * initialization. - * - * @param parameterStore the parameter store - * @param inputs the input NDList - * @param training true for a training forward pass - * @param params optional parameters - * @return the output of the forward pass - */ - protected abstract NDList forwardInternal( - ParameterStore parameterStore, - NDList inputs, - boolean training, - PairList params); - - /** - * A helper for {@link Block#forward(ParameterStore, NDList, NDList, PairList)} after - * initialization. - * - * @param parameterStore the parameter store - * @param data the input data NDList - * @param labels the input labels NDList - * @param params optional parameters - * @return the output of the forward pass - * @see #forward(ParameterStore, NDList, boolean, PairList) - */ - protected NDList forwardInternal( - ParameterStore parameterStore, - NDList data, - NDList labels, - PairList params) { - return forwardInternal(parameterStore, data, true, params); + super(version); } /** @@ -215,293 +129,9 @@ public BlockList getChildren() { return defensiveCopy; } - /** {@inheritDoc} */ - @Override - public PairList describeInput() { - if (!isInitialized()) { - throw new IllegalStateException( - "Parameter of this block are not initialised," - + "please call model.newTrainer and trainer.initialize"); - } - return new PairList<>(inputNames, Arrays.asList(inputShapes)); - } - - /** {@inheritDoc} */ - @Override - public void setInitializer(Initializer initializer, Parameter.Type params) { - Predicate predicate = parameter -> parameter.getType().equals(params); - setInitializer(initializer, predicate); - } - - /** {@inheritDoc} */ - @Override - public void setInitializer(Initializer initializer, String paramName) { - Parameter parameter = parameters.get(paramName); - if (parameter == null) { - throw new IllegalArgumentException("Could not find parameter " + paramName); - } - parameter.setInitializer(initializer); - } - - /** {@inheritDoc} */ - @Override - public void setInitializer(Initializer initializer, Predicate predicate) { - List params = getParameters().values(); - for (Parameter param : params) { - if (predicate.test(param)) { - param.setInitializer(initializer); - } - } - } - - /** {@inheritDoc} */ - @Override - public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { - beforeInitialize(inputShapes); - // if parameters are initialized, skip it - if (!isInitialized()) { - // setShape for all params - prepare(inputShapes); - } - for (Parameter parameter : parameters.values()) { - parameter.initialize(manager, dataType); - } - initializeChildBlocks(manager, dataType, inputShapes); - } - - /** - * Performs any action necessary before initialization. For example, keep the input information - * or verify the layout. - * - * @param inputShapes the expected shapes of the input - */ - protected void beforeInitialize(Shape... inputShapes) { - if (inputNames.isEmpty()) { - // automatically assign input names - inputNames = new ArrayList<>(); - for (int i = 0; i < inputShapes.length; ++i) { - inputNames.add("data" + i); - } - } - this.inputShapes = inputShapes; - } - - /** - * Initializes the Child blocks of this block. You need to override this method if your subclass - * has child blocks. Used to determine the correct input shapes for child blocks based on the - * requested input shape for this block. - * - * @param manager the manager to use for initialization - * @param dataType the requested data type - * @param inputShapes the expected input shapes for this block - */ - protected void initializeChildBlocks( - NDManager manager, DataType dataType, Shape... inputShapes) { - if (!children.isEmpty()) { - throw new IllegalStateException( - getClass().getSimpleName() - + " has child blocks but initializeChildBlocks is not overwritten."); - } - } - - /** {@inheritDoc} */ - @Override - public ParameterList getParameters() { - // we accumulate a list of all parameters by starting with a list of the direct parameters - ParameterList allParams = getDirectParameters(); - // then we add the parameters of child blocks - for (Pair childPair : getChildren()) { - for (Pair paramPair : childPair.getValue().getParameters()) { - // we prepend the name of the child block to the parameter name - allParams.add(childPair.getKey() + "_" + paramPair.getKey(), paramPair.getValue()); - } - } - return allParams; - } - /** {@inheritDoc} */ @Override public ParameterList getDirectParameters() { return new ParameterList(parameters); } - - /** - * Sets the shape of {@link Parameter}s. - * - * @param inputShapes the shapes of inputs - */ - protected void prepare(Shape[] inputShapes) {} - - /** {@inheritDoc} */ - @Override - public boolean isInitialized() { - if (inputShapes == null) { - return false; - } - for (Parameter param : getParameters().values()) { - if (!param.isInitialized()) { - return false; - } - } - return true; - } - - /** {@inheritDoc} */ - @Override - public void clear() { - getParameters().forEach(param -> param.getValue().close()); - } - - /** {@inheritDoc} */ - @Override - public void cast(DataType dataType) { - throw new UnsupportedOperationException("Not implemented yet."); - } - - /** {@inheritDoc} */ - @Override - public void saveParameters(DataOutputStream os) throws IOException { - os.write(version); - saveMetadata(os); - for (Parameter parameter : parameters.values()) { - parameter.save(os); - } - for (Block child : children.values()) { - child.saveParameters(os); - } - } - - /** {@inheritDoc} */ - @Override - public void loadParameters(NDManager manager, DataInputStream is) - throws IOException, MalformedModelException { - byte loadVersion = is.readByte(); - loadMetadata(loadVersion, is); - for (Parameter parameter : parameters.values()) { - parameter.load(manager, is); - } - for (Block child : children.values()) { - child.loadParameters(manager, is); - } - } - - /** - * Override this method to save additional data apart from parameter values. - * - *

    This default implementation saves the currently set input shapes. - * - * @param os the non-null output stream the parameter values and metadata are written to - * @throws IOException saving failed - */ - protected void saveMetadata(DataOutputStream os) throws IOException { - saveInputShapes(os); - } - - /** - * Overwrite this to load additional metadata with the parameter values. - * - *

    If you overwrite {@link AbstractBlock#saveMetadata(DataOutputStream)} or need to provide - * backward compatibility to older binary formats, you prabably need to overwrite this. This - * default implementation checks if the version number fits, if not it throws an {@link - * MalformedModelException}. After that it restores the input shapes. - * - * @param loadVersion the version used for loading this metadata. - * @param is the input stream we are loading from - * @throws IOException loading failed - * @throws MalformedModelException data can be loaded but has wrong format - */ - protected void loadMetadata(byte loadVersion, DataInputStream is) - throws IOException, MalformedModelException { - if (loadVersion != version) { - throw new MalformedModelException( - "Cannot load parameters for " - + this.getClass().getCanonicalName() - + ", expected version " - + version - + ", got " - + loadVersion - + "."); - } - readInputShapes(is); - } - - protected void saveInputShapes(DataOutputStream os) throws IOException { - os.writeInt(inputShapes.length); - for (Shape shape : inputShapes) { - os.write(shape.getEncoded()); - } - } - - protected void readInputShapes(DataInputStream is) throws IOException { - int len = is.readInt(); - Shape[] shapes = new Shape[len]; - for (int i = 0; i < len; ++i) { - shapes[i] = Shape.decode(is); - } - if (inputShapes == null) { - // load inputShapes from parameter file if Block has not been initialized - inputShapes = shapes; - } - } - - /** {@inheritDoc} */ - @Override - public String toString() { - // FIXME: This is a quick hack for display in jupyter notebook. - StringBuilder sb = new StringBuilder(200); - String className = getClass().getSimpleName(); - if (className.endsWith("Block")) { - className = className.substring(0, className.length() - 5); - } - sb.append(className).append('('); - if (isInitialized()) { - PairList inputShapeDescription = describeInput(); - appendShape(sb, inputShapeDescription.values().toArray(new Shape[0])); - sb.append(" -> "); - Shape[] outputShapes = - getOutputShapes(inputShapeDescription.values().toArray(new Shape[0])); - appendShape(sb, outputShapes); - } else { - sb.append("Uninitialized"); - } - sb.append(')'); - return sb.toString(); - } - - private void appendShape(StringBuilder sb, Shape[] shapes) { - boolean first = true; - for (Shape shape : shapes) { - if (first) { - first = false; - } else { - sb.append(", "); - } - long[] sh = shape.getShape(); - int length = sh.length; - if (length == 0) { - sb.append("()"); - } else { - int index = 0; - if (sh[0] == -1) { - --length; - index = 1; - } - - if (length == 0) { - sb.append("()"); - } else if (length == 1) { - sb.append(sh[index]); - } else { - sb.append('('); - for (int i = index; i < sh.length; ++i) { - if (i > index) { - sb.append(", "); - } - sb.append(sh[i]); - } - sb.append(')'); - } - } - } - } } diff --git a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java index efcb4622276..9ffbd3f0083 100644 --- a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java @@ -15,7 +15,7 @@ import ai.djl.ndarray.types.Shape; /** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */ -public abstract class AbstractSymbolBlock extends AbstractBlock implements SymbolBlock { +public abstract class AbstractSymbolBlock extends AbstractBaseBlock implements SymbolBlock { /** Constructs a new {@code AbstractSymbolBlock} instance. */ public AbstractSymbolBlock() {} @@ -34,4 +34,10 @@ public AbstractSymbolBlock(byte version) { public Shape[] getOutputShapes(Shape[] inputShapes) { throw new UnsupportedOperationException("not implement!"); } + + /** {@inheritDoc} */ + @Override + public BlockList getChildren() { + return new BlockList(); + } } diff --git a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java index d46d87282a3..0e7dd93f59f 100644 --- a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java +++ b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java @@ -16,6 +16,7 @@ import ai.djl.dlr.jni.JniUtils; import ai.djl.ndarray.NDList; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -75,4 +76,10 @@ public void close() { JniUtils.deleteDlrModel(pointer); } } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java index 860526ef0d3..8f0d1b51c39 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -105,6 +106,12 @@ void setTreeLimit(int treeLimit) { this.treeLimit = treeLimit; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } + /** The mode of inference for OptionMask. */ public enum Mode { DEFAULT(0), diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java index 4d6c5be9be0..27959169c68 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java @@ -21,6 +21,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -31,6 +32,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -53,6 +55,7 @@ public class MxSymbolBlock extends AbstractSymbolBlock { private CachedOp op; private Symbol symbol; private List mxNetParams; // includes input data + private Map parameters; private Map paramShapes; private Shape[] outputShapes; private PairList inputDescriptions; @@ -94,9 +97,10 @@ public void setInputNames(List inputNames) { // now that we know which of the parameters are just input placeholders and which // are trainable, add them properly so they are correctly handled Set nameLookup = new HashSet<>(inputNames); + parameters = new LinkedHashMap<>(mxNetParams.size()); for (Parameter mxNetParameter : mxNetParams) { if (!nameLookup.contains(mxNetParameter.getName())) { - addParameter(mxNetParameter); + parameters.put(mxNetParameter.getName(), mxNetParameter); } } } @@ -156,6 +160,12 @@ public PairList describeInput() { return inputDescriptions; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + return new ParameterList(parameters); + } + /** {@inheritDoc} */ @Override public PairList describeOutput() { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index 3e2b9d76fc4..b9dad479b86 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -174,4 +175,10 @@ public void close() { } } } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java index 9fb7b5264af..cb75ce6d7cf 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.paddlepaddle.jni.JniUtils; import ai.djl.training.ParameterStore; @@ -73,6 +74,12 @@ private PpNDArray[] getInputs(PpNDManager sub, NDList inputs) { return inputArray; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } + /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(Shape[] inputShapes) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 0bde372bdf1..89463bef0ea 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -18,6 +18,8 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.pytorch.jni.IValue; import ai.djl.pytorch.jni.IValueUtils; @@ -27,6 +29,8 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,6 +53,7 @@ public class PtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable private PairList inputDescriptions; private PairList outputDescriptions; private boolean first; + private Map parameters; /** * Constructs a {@code PtSymbolBlock}. @@ -146,6 +151,43 @@ public PairList describeInput() { return inputDescriptions; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + if (parameters == null) { + NDList params = JniUtils.moduleGetParams(this, manager); + parameters = new LinkedHashMap<>(params.size()); + for (NDArray param : params) { + parameters.put( + param.getName(), + Parameter.builder() + .setName(param.getName()) + .setType(inferType(param.getName())) + .optArray(param) + .build()); + } + } + // Defensive copy + return new ParameterList(parameters); + } + + private static Parameter.Type inferType(String name) { + if (name.contains("bias")) { + return Parameter.Type.BIAS; + } else if (name.contains("gamma")) { + return Parameter.Type.GAMMA; + } else if (name.contains("beta")) { + return Parameter.Type.BETA; + } else if (name.contains("moving_mean") || name.contains("running_mean")) { + return Parameter.Type.RUNNING_MEAN; + } else if (name.contains("moving_var") || name.contains("running_var")) { + return Parameter.Type.RUNNING_VAR; + } else if (name.contains("weight")) { + return Parameter.Type.WEIGHT; + } + return Parameter.Type.OTHER; + } + /** {@inheritDoc} */ @Override public PairList describeOutput() { diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java index 5213d2fed3a..f0819808a9c 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.tensorflow.engine.javacpp.JavacppUtils; import ai.djl.training.ParameterStore; @@ -192,6 +193,12 @@ public final PairList describeInput() { return inputDescriptions; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } + /** {@inheritDoc} */ @Override public final PairList describeOutput() { diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtSymbolBlock.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtSymbolBlock.java index c032515d61d..232a7fb8dbb 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtSymbolBlock.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtSymbolBlock.java @@ -15,6 +15,7 @@ import ai.djl.ndarray.NDList; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.tensorrt.jni.JniUtils; import ai.djl.training.ParameterStore; @@ -66,4 +67,10 @@ TrtSession createSession(TrtNDManager manager) { long session = JniUtils.createSession(handle.get()); return new TrtSession(manager, handle.get(), session); } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } } diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteSymbolBlock.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteSymbolBlock.java index 942bdb414c7..360801d72f8 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteSymbolBlock.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteSymbolBlock.java @@ -16,6 +16,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -65,4 +66,10 @@ protected NDList forwardInternal( public void close() { interpreter.close(); } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java index c7ffec04a95..e450307a0b2 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java @@ -14,6 +14,7 @@ import ai.djl.fasttext.jni.FtWrapper; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import java.nio.file.Path; /** @@ -55,6 +56,12 @@ public float[] embedWord(String word) { return fta.getWordVector(word); } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } + @Override public void close() { fta.unloadModel();