diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java
new file mode 100644
index 00000000000..eb7d3e699e0
--- /dev/null
+++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java
@@ -0,0 +1,64 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.util.passthrough;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDArrayAdapter;
+import java.nio.ByteBuffer;
+
+/**
+ * An {@link NDArray} that stores an arbitrary Java object.
+ *
+ *
This class is mainly for use in extensions and hybrid engines. Despite it's name, it will
+ * often not contain actual {@link NDArray}s but just any object necessary to conform to the DJL
+ * predictor API.
+ */
+public class PassthroughNDArray extends NDArrayAdapter {
+
+ private Object object;
+
+ /**
+ * Constructs a {@link PassthroughNDArray} storing an object.
+ *
+ * @param object the object to store
+ */
+ public PassthroughNDArray(Object object) {
+ super(null, null, null, null, null);
+ this.object = object;
+ }
+
+ /**
+ * Returns the object stored.
+ *
+ * @return the object stored
+ */
+ public Object getObject() {
+ return object;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public ByteBuffer toByteBuffer() {
+ throw new UnsupportedOperationException("Operation not supported for FastText");
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void intern(NDArray replaced) {
+ throw new UnsupportedOperationException("Operation not supported for FastText");
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public void detach() {}
+}
diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java
new file mode 100644
index 00000000000..493b2fa0e9a
--- /dev/null
+++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java
@@ -0,0 +1,209 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.util.passthrough;
+
+import ai.djl.Device;
+import ai.djl.engine.Engine;
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.ndarray.NDResource;
+import ai.djl.ndarray.types.DataType;
+import ai.djl.ndarray.types.Shape;
+import ai.djl.util.PairList;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.charset.Charset;
+import java.nio.file.Path;
+
+/** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */
+public final class PassthroughNDManager implements NDManager {
+
+ private static final String UNSUPPORTED = "Not supported by PassthroughNDManager";
+ public static final PassthroughNDManager INSTANCE = new PassthroughNDManager();
+
+ private PassthroughNDManager() {}
+
+ @Override
+ public Device defaultDevice() {
+ return Device.cpu();
+ }
+
+ @Override
+ public ByteBuffer allocateDirect(int capacity) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray from(NDArray array) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray create(String[] data, Charset charset, Shape shape) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray create(Shape shape, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray createCoo(Buffer data, long[][] indices, Shape shape) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDList load(Path path) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public void setName(String name) {}
+
+ @Override
+ public String getName() {
+ return "PassthroughNDManager";
+ }
+
+ @Override
+ public NDArray zeros(Shape shape, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray ones(Shape shape, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray full(Shape shape, float value, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray arange(float start, float stop, float step, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray eye(int rows, int cols, int k, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray linspace(float start, float stop, int num, boolean endpoint) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray randomInteger(long low, long high, Shape shape, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray randomMultinomial(int n, NDArray pValues) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public boolean isOpen() {
+ return true;
+ }
+
+ @Override
+ public NDManager getParentManager() {
+ return this;
+ }
+
+ @Override
+ public NDManager newSubManager() {
+ return this;
+ }
+
+ @Override
+ public NDManager newSubManager(Device device) {
+ return this;
+ }
+
+ @Override
+ public Device getDevice() {
+ return Device.cpu();
+ }
+
+ @Override
+ public void attachInternal(String resourceId, AutoCloseable resource) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public void tempAttachInternal(
+ NDManager originalManager, String resourceId, NDResource resource) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public void detachInternal(String resourceId) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public void invoke(
+ String operation, NDArray[] src, NDArray[] dest, PairList params) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public NDList invoke(String operation, NDList src, PairList params) {
+ throw new UnsupportedOperationException(UNSUPPORTED);
+ }
+
+ @Override
+ public Engine getEngine() {
+ return null;
+ }
+
+ @Override
+ public void close() {}
+}
diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java
new file mode 100644
index 00000000000..55f6692a2b9
--- /dev/null
+++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.util.passthrough;
+
+import ai.djl.ndarray.NDList;
+import ai.djl.translate.NoBatchifyTranslator;
+import ai.djl.translate.TranslatorContext;
+
+/**
+ * A translator that stores and removes data from a {@link PassthroughNDArray}.
+ *
+ * @param translator input type
+ * @param translator output type
+ */
+public class PassthroughTranslator implements NoBatchifyTranslator {
+
+ @Override
+ public NDList processInput(TranslatorContext ctx, I input) throws Exception {
+ return new NDList(new PassthroughNDArray(input));
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public O processOutput(TranslatorContext ctx, NDList list) {
+ PassthroughNDArray wrapper = (PassthroughNDArray) list.singletonOrThrow();
+ return (O) wrapper.getObject();
+ }
+}
diff --git a/api/src/main/java/ai/djl/util/passthrough/package-info.java b/api/src/main/java/ai/djl/util/passthrough/package-info.java
new file mode 100644
index 00000000000..62a0fd37ce9
--- /dev/null
+++ b/api/src/main/java/ai/djl/util/passthrough/package-info.java
@@ -0,0 +1,15 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+/** Contains passthrough DJL classes for use in extensions and hybrid engines. */
+package ai.djl.util.passthrough;
diff --git a/extensions/fasttext/README.md b/extensions/fasttext/README.md
index 64f582f8b07..f39e64b5730 100644
--- a/extensions/fasttext/README.md
+++ b/extensions/fasttext/README.md
@@ -5,8 +5,10 @@
This module contains the NLP support with fastText implementation.
fastText module's implementation in DJL is not considered as an Engine, it doesn't support Trainer and Predictor.
-The training and inference functionality is directly provided through [FtModel](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/FtModel.html)
-class. You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java).
+Training is only supported by using [TrainFastText](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/TrainFastText.html).
+This produces a special block which can perform inference on its own or by using a model and predictor.
+Pre-trained FastText models can also be loaded by using the standard DJL criteria.
+You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java).
Current implementation has the following limitations:
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java
new file mode 100644
index 00000000000..c7ffec04a95
--- /dev/null
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.fasttext;
+
+import ai.djl.fasttext.jni.FtWrapper;
+import ai.djl.nn.AbstractSymbolBlock;
+import java.nio.file.Path;
+
+/**
+ * A parent class containing shared behavior for {@link ai.djl.nn.SymbolBlock}s based on fasttext
+ * models.
+ */
+public abstract class FtAbstractBlock extends AbstractSymbolBlock implements AutoCloseable {
+
+ protected FtWrapper fta;
+
+ protected Path modelFile;
+
+ /**
+ * Constructs a {@link FtAbstractBlock}.
+ *
+ * @param fta the {@link FtWrapper} containing the "fasttext model"
+ */
+ public FtAbstractBlock(FtWrapper fta) {
+ this.fta = fta;
+ }
+
+ /**
+ * Returns the fasttext model file for the block.
+ *
+ * @return the fasttext model file for the block
+ */
+ public Path getModelFile() {
+ return modelFile;
+ }
+
+ /**
+ * Embeds a word using fasttext.
+ *
+ * @param word the word to embed
+ * @return the embedding
+ * @see ai.djl.modality.nlp.embedding.WordEmbedding
+ */
+ public float[] embedWord(String word) {
+ return fta.getWordVector(word);
+ }
+
+ @Override
+ public void close() {
+ fta.unloadModel();
+ fta.close();
+ }
+}
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java
index fed7a38b4d8..d7b4451b739 100644
--- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java
@@ -15,19 +15,19 @@
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
-import ai.djl.basicdataset.RawDataset;
import ai.djl.fasttext.jni.FtWrapper;
+import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification;
+import ai.djl.fasttext.zoo.nlp.word_embedding.FtWordEmbeddingBlock;
import ai.djl.inference.Predictor;
-import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
-import ai.djl.training.TrainingResult;
import ai.djl.translate.Translator;
import ai.djl.util.PairList;
+import ai.djl.util.passthrough.PassthroughNDManager;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
@@ -41,11 +41,12 @@
/**
* {@code FtModel} is the fastText implementation of {@link Model}.
*
- * FtModel contains all the methods in Model to load and process a model.
+ *
FtModel contains all the methods in Model to load and process a model. However, it only
+ * supports training by using {@link TrainFastText}.
*/
public class FtModel implements Model {
- FtWrapper fta;
+ FtAbstractBlock block;
private Path modelDir;
private String modelName;
@@ -58,7 +59,6 @@ public class FtModel implements Model {
*/
public FtModel(String name) {
this.modelName = name;
- fta = FtWrapper.newInstance();
properties = new ConcurrentHashMap<>();
}
@@ -80,6 +80,7 @@ public void load(Path modelPath, String prefix, Map options)
}
String modelFilePath = modelFile.toString();
+ FtWrapper fta = FtWrapper.newInstance();
if (!fta.checkModel(modelFilePath)) {
throw new MalformedModelException("Malformed FastText model file:" + modelFilePath);
}
@@ -90,7 +91,21 @@ public void load(Path modelPath, String prefix, Map options)
properties.put(entry.getKey(), entry.getValue().toString());
}
}
- properties.put("model-type", fta.getModelType());
+ String modelType = fta.getModelType();
+ properties.put("model-type", modelType);
+
+ if ("sup".equals(modelType)) {
+ String labelPrefix =
+ properties.getOrDefault(
+ "label-prefix", FtTextClassification.DEFAULT_LABEL_PREFIX);
+ block = new FtTextClassification(fta, labelPrefix);
+ modelDir = block.getModelFile();
+ } else if ("cbow".equals(modelType) || "sg".equals(modelType)) {
+ block = new FtWordEmbeddingBlock(fta);
+ modelDir = block.getModelFile();
+ } else {
+ throw new MalformedModelException("Unexpected FastText model type: " + modelType);
+ }
}
/** {@inheritDoc} */
@@ -130,49 +145,6 @@ private Path findModelFile(String prefix) {
return modelFile;
}
- /**
- * Returns top K number of classifications of the input text.
- *
- * @param text the input text to be classified
- * @param topK the value of K
- * @return classifications of the input text
- */
- public Classifications classify(String text, int topK) {
- String labelPrefix = properties.getOrDefault("label-prefix", "__label__");
- return fta.predictProba(text, topK, labelPrefix);
- }
-
- /**
- * Train the fastText model.
- *
- * @param config the training configuration to use
- * @param dataset the training dataset
- * @return the result of the training
- * @throws IOException when IO operation fails in loading a resource
- */
- public TrainingResult fit(FtTrainingConfig config, RawDataset dataset)
- throws IOException {
- Path outputDir = config.getOutputDir();
- if (Files.notExists(outputDir)) {
- Files.createDirectory(outputDir);
- }
- String fitModelName = config.getModelName();
- Path modelFile = outputDir.resolve(fitModelName).toAbsolutePath();
-
- String[] args = config.toCommand(dataset.getData().toString());
-
- fta.runCmd(args);
- setModelFile(modelFile);
-
- TrainingResult result = new TrainingResult();
- int epoch = config.getEpoch();
- if (epoch <= 0) {
- epoch = 5;
- }
- result.setEpoch(epoch);
- return result;
- }
-
/** {@inheritDoc} */
@Override
public void save(Path modelDir, String newModelName) {}
@@ -185,14 +157,17 @@ public Path getModelPath() {
/** {@inheritDoc} */
@Override
- public Block getBlock() {
- throw new UnsupportedOperationException("Fasttext doesn't support Block.");
+ public FtAbstractBlock getBlock() {
+ return block;
}
/** {@inheritDoc} */
@Override
public void setBlock(Block block) {
- throw new UnsupportedOperationException("Fasttext doesn't support setting the Block.");
+ if (!(block instanceof FtAbstractBlock)) {
+ throw new IllegalArgumentException("Expected a FtAbstractBlock Block");
+ }
+ this.block = (FtAbstractBlock) block;
}
/** {@inheritDoc} */
@@ -205,7 +180,7 @@ public String getName() {
@Override
public Trainer newTrainer(TrainingConfig trainingConfig) {
throw new UnsupportedOperationException(
- "FastText only supports training using FtModel.fit");
+ "FastText only supports training using the FtAbstractBlocks");
}
/** {@inheritDoc} */
@@ -263,7 +238,7 @@ public InputStream getArtifactAsStream(String name) {
/** {@inheritDoc} */
@Override
public NDManager getNDManager() {
- return null;
+ return PassthroughNDManager.INSTANCE;
}
/** {@inheritDoc} */
@@ -278,15 +253,10 @@ public String getProperty(String key) {
return properties.get(key);
}
- void setModelFile(Path modelFile) {
- this.modelDir = modelFile;
- }
-
/** {@inheritDoc} */
@Override
public void close() {
- fta.unloadModel();
- fta.close();
+ block.close();
}
/** {@inheritDoc} */
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java
new file mode 100644
index 00000000000..6250063932a
--- /dev/null
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.fasttext;
+
+import ai.djl.basicdataset.RawDataset;
+import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification;
+import java.io.IOException;
+import java.nio.file.Path;
+
+/** A utility to aggregate options for training with fasttext. */
+public final class TrainFastText {
+
+ private TrainFastText() {}
+
+ /**
+ * Trains a fastText {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION} model.
+ *
+ * @param config the training configuration to use
+ * @param dataset the training dataset
+ * @return the result of the training
+ * @throws IOException when IO operation fails in loading a resource
+ * @see FtTextClassification#fit(FtTrainingConfig, RawDataset)
+ */
+ public static FtTextClassification textClassification(
+ FtTrainingConfig config, RawDataset dataset) throws IOException {
+ return FtTextClassification.fit(config, dataset);
+ }
+}
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java
index 42fecd3e219..d874d3ad1aa 100644
--- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java
@@ -12,6 +12,8 @@
*/
package ai.djl.fasttext.jni;
+import java.util.ArrayList;
+
/** A class containing utilities to interact with the SentencePiece Engine's JNI layer. */
@SuppressWarnings("MissingJavadocMethod")
final class FastTextLibrary {
@@ -32,8 +34,13 @@ private FastTextLibrary() {}
native String getModelType(long handle);
+ @SuppressWarnings("PMD.LooseCoupling")
native int predictProba(
- long handle, String text, int topK, String[] classes, float[] probabilities);
+ long handle,
+ String text,
+ int topK,
+ ArrayList classes,
+ ArrayList probabilities);
native float[] getWordVector(long handle, String word);
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java
index 7bbfcb1fe58..03d9499698d 100644
--- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java
@@ -59,25 +59,26 @@ public String getModelType() {
}
public Classifications predictProba(String text, int topK, String labelPrefix) {
- String[] labels = new String[topK];
- float[] probs = new float[topK];
+ int cap = topK != -1 ? topK : 10;
+ ArrayList labels = new ArrayList<>(cap);
+ ArrayList probs = new ArrayList<>(cap);
int size = FastTextLibrary.LIB.predictProba(getHandle(), text, topK, labels, probs);
List classes = new ArrayList<>(size);
List probabilities = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
- String label = labels[i];
+ String label = labels.get(i);
if (label.startsWith(labelPrefix)) {
label = label.substring(labelPrefix.length());
}
classes.add(label);
- probabilities.add((double) probs[i]);
+ probabilities.add((double) probs.get(i));
}
return new Classifications(classes, probabilities);
}
- public float[] getDataVector(String word) {
+ public float[] getWordVector(String word) {
return FastTextLibrary.LIB.getWordVector(getHandle(), word);
}
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java
new file mode 100644
index 00000000000..c19b1370bfb
--- /dev/null
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java
@@ -0,0 +1,144 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.fasttext.zoo.nlp.textclassification;
+
+import ai.djl.basicdataset.RawDataset;
+import ai.djl.fasttext.FtAbstractBlock;
+import ai.djl.fasttext.FtTrainingConfig;
+import ai.djl.fasttext.jni.FtWrapper;
+import ai.djl.fasttext.zoo.nlp.word_embedding.FtWordEmbeddingBlock;
+import ai.djl.modality.Classifications;
+import ai.djl.ndarray.NDList;
+import ai.djl.training.ParameterStore;
+import ai.djl.training.TrainingResult;
+import ai.djl.util.PairList;
+import ai.djl.util.passthrough.PassthroughNDArray;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+
+/** A {@link FtAbstractBlock} for {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION}. */
+public class FtTextClassification extends FtAbstractBlock {
+
+ public static final String DEFAULT_LABEL_PREFIX = "__label__";
+
+ private String labelPrefix;
+
+ private TrainingResult trainingResult;
+
+ /**
+ * Constructs a {@link FtTextClassification}.
+ *
+ * @param fta the {@link FtWrapper} containing the "fasttext model"
+ * @param labelPrefix the prefix to use for labels
+ */
+ public FtTextClassification(FtWrapper fta, String labelPrefix) {
+ super(fta);
+ this.labelPrefix = labelPrefix;
+ }
+
+ /**
+ * Trains the fastText model.
+ *
+ * @param config the training configuration to use
+ * @param dataset the training dataset
+ * @return the result of the training
+ * @throws IOException when IO operation fails in loading a resource
+ */
+ public static FtTextClassification fit(FtTrainingConfig config, RawDataset dataset)
+ throws IOException {
+ Path outputDir = config.getOutputDir();
+ if (Files.notExists(outputDir)) {
+ Files.createDirectory(outputDir);
+ }
+ String fitModelName = config.getModelName();
+ FtWrapper fta = FtWrapper.newInstance();
+ Path modelFile = outputDir.resolve(fitModelName).toAbsolutePath();
+
+ String[] args = config.toCommand(dataset.getData().toString());
+
+ fta.runCmd(args);
+
+ TrainingResult result = new TrainingResult();
+ int epoch = config.getEpoch();
+ if (epoch <= 0) {
+ epoch = 5;
+ }
+ result.setEpoch(epoch);
+
+ FtTextClassification block = new FtTextClassification(fta, config.getLabelPrefix());
+ block.modelFile = modelFile;
+ block.trainingResult = result;
+ return block;
+ }
+
+ /**
+ * Returns the fasttext label prefix.
+ *
+ * @return the fasttext label prefix
+ */
+ public String getLabelPrefix() {
+ return labelPrefix;
+ }
+
+ /**
+ * Returns the results of training, or null if not trained.
+ *
+ * @return the results of training, or null if not trained
+ */
+ public TrainingResult getTrainingResult() {
+ return trainingResult;
+ }
+
+ @Override
+ protected NDList forwardInternal(
+ ParameterStore parameterStore,
+ NDList inputs,
+ boolean training,
+ PairList params) {
+ PassthroughNDArray inputWrapper = (PassthroughNDArray) inputs.singletonOrThrow();
+ String input = (String) inputWrapper.getObject();
+ Classifications result = fta.predictProba(input, -1, labelPrefix);
+ return new NDList(new PassthroughNDArray(result));
+ }
+
+ /**
+ * Converts the block into the equivalent {@link FtWordEmbeddingBlock}.
+ *
+ * @return the equivalent {@link FtWordEmbeddingBlock}
+ */
+ public FtWordEmbeddingBlock toWordEmbedding() {
+ return new FtWordEmbeddingBlock(fta);
+ }
+
+ /**
+ * Returns the classifications of the input text.
+ *
+ * @param text the input text to be classified
+ * @return classifications of the input text
+ */
+ public Classifications classify(String text) {
+ return classify(text, -1);
+ }
+
+ /**
+ * Returns top K classifications of the input text.
+ *
+ * @param text the input text to be classified
+ * @param topK the value of K
+ * @return classifications of the input text
+ */
+ public Classifications classify(String text, int topK) {
+ return fta.predictProba(text, topK, labelPrefix);
+ }
+}
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java
index 06c863301bd..4219420c585 100644
--- a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java
@@ -24,6 +24,7 @@
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.util.Progress;
+import ai.djl.util.passthrough.PassthroughTranslator;
import java.io.IOException;
import java.nio.file.Path;
@@ -66,6 +67,6 @@ public ZooModel loadModel(Criteria criteria)
Model model = new FtModel(modelName);
Path modelPath = mrl.getRepository().getResourceDirectory(artifact);
model.load(modelPath, modelName, criteria.getOptions());
- return new ZooModel<>(model, null);
+ return new ZooModel<>(model, new PassthroughTranslator<>());
}
}
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java
similarity index 64%
rename from extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java
rename to extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java
index 18e765092f2..230079d79cd 100644
--- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java
@@ -10,27 +10,50 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
-package ai.djl.fasttext;
+package ai.djl.fasttext.zoo.nlp.word_embedding;
+import ai.djl.Model;
+import ai.djl.fasttext.FtAbstractBlock;
+import ai.djl.fasttext.FtModel;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.WordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
+import ai.djl.repository.zoo.ZooModel;
/** An implementation of {@link WordEmbedding} for FastText word embeddings. */
public class FtWord2VecWordEmbedding implements WordEmbedding {
- private FtModel model;
+ private FtAbstractBlock embedding;
private Vocabulary vocabulary;
/**
* Constructs a {@link FtWord2VecWordEmbedding}.
*
- * @param model a loaded FastText model
+ * @param model a loaded FastText wordEmbedding model or a ZooModel containing one
* @param vocabulary the {@link Vocabulary} to get indices from
*/
- public FtWord2VecWordEmbedding(FtModel model, Vocabulary vocabulary) {
- this.model = model;
+ public FtWord2VecWordEmbedding(Model model, Vocabulary vocabulary) {
+ if (model instanceof ZooModel) {
+ model = ((ZooModel, ?>) model).getWrappedModel();
+ }
+
+ if (!(model instanceof FtModel)) {
+ throw new IllegalArgumentException("The FtWord2VecWordEmbedding requires an FtModel");
+ }
+
+ this.embedding = (FtAbstractBlock) model.getBlock();
+ this.vocabulary = vocabulary;
+ }
+
+ /**
+ * Constructs a {@link FtWord2VecWordEmbedding}.
+ *
+ * @param embedding the word embedding
+ * @param vocabulary the {@link Vocabulary} to get indices from
+ */
+ public FtWord2VecWordEmbedding(FtAbstractBlock embedding, Vocabulary vocabulary) {
+ this.embedding = embedding;
this.vocabulary = vocabulary;
}
@@ -56,7 +79,7 @@ public NDArray embedWord(NDArray index) {
@Override
public NDArray embedWord(NDManager manager, long index) {
String word = vocabulary.getToken(index);
- float[] buf = model.fta.getDataVector(word);
+ float[] buf = embedding.embedWord(word);
return manager.create(buf);
}
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java
new file mode 100644
index 00000000000..8f18558858d
--- /dev/null
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.fasttext.zoo.nlp.word_embedding;
+
+import ai.djl.fasttext.FtAbstractBlock;
+import ai.djl.fasttext.jni.FtWrapper;
+import ai.djl.ndarray.NDList;
+import ai.djl.training.ParameterStore;
+import ai.djl.util.PairList;
+import ai.djl.util.passthrough.PassthroughNDArray;
+
+/** A {@link FtAbstractBlock} for {@link ai.djl.Application.NLP#WORD_EMBEDDING}. */
+public class FtWordEmbeddingBlock extends FtAbstractBlock {
+
+ /**
+ * Constructs a {@link FtWordEmbeddingBlock}.
+ *
+ * @param fta the {@link FtWrapper} for the "fasttext model".
+ */
+ public FtWordEmbeddingBlock(FtWrapper fta) {
+ super(fta);
+ }
+
+ @Override
+ protected NDList forwardInternal(
+ ParameterStore parameterStore,
+ NDList inputs,
+ boolean training,
+ PairList params) {
+ PassthroughNDArray inputWrapper = (PassthroughNDArray) inputs.singletonOrThrow();
+ String input = (String) inputWrapper.getObject();
+ float[] result = embedWord(input);
+ return new NDList(new PassthroughNDArray(result));
+ }
+}
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java
new file mode 100644
index 00000000000..4bbd761e559
--- /dev/null
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+/**
+ * Contains classes for the {@link ai.djl.Application.NLP#WORD_EMBEDDING} models in the {@link
+ * ai.djl.fasttext.zoo.FtModelZoo}.
+ */
+package ai.djl.fasttext.zoo.nlp.word_embedding;
diff --git a/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc b/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc
index c4cb4354cee..8a14c1078b2 100644
--- a/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc
+++ b/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc
@@ -97,9 +97,9 @@ JNIEXPORT jstring JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getModelType(
if (modelName == model_name::cbow) {
return env->NewStringUTF("cbow");
} else if (modelName == model_name::sg) {
- return env->NewStringUTF("cbow");
+ return env->NewStringUTF("sg");
} else if (modelName == model_name::sup) {
- return env->NewStringUTF("cbow");
+ return env->NewStringUTF("sup");
} else {
jclass jexception = env->FindClass("ai/djl/engine/EngineException");
env->ThrowNew(jexception, "Unrecognized model type");
@@ -108,7 +108,7 @@ JNIEXPORT jstring JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getModelType(
}
JNIEXPORT jint JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_predictProba(
- JNIEnv* env, jobject jthis, jlong jhandle, jstring jtext, jint top_k, jobjectArray jclasses, jfloatArray jprob) {
+ JNIEnv* env, jobject jthis, jlong jhandle, jstring jtext, jint top_k, jobject jclasses, jobject jprob) {
auto* fasttext_ptr = reinterpret_cast(jhandle);
std::string text = djl::utils::jni::GetStringFromJString(env, jtext);
std::istringstream in(text);
@@ -116,13 +116,15 @@ JNIEXPORT jint JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_predictProba(
fasttext_ptr->predictLine(in, predictions, top_k, 0.0);
int size = predictions.size();
- std::vector prob;
+ jclass java_lang_Float = static_cast(env->NewGlobalRef(env->FindClass("java/lang/Float")));
+ jmethodID java_lang_Float_ = env->GetMethodID(java_lang_Float, "", "(F)V");
+ jclass java_util_ArrayList = static_cast(env->NewGlobalRef(env->FindClass("java/util/ArrayList")));
+ jmethodID java_util_ArrayList_add = env->GetMethodID(java_util_ArrayList, "add", "(Ljava/lang/Object;)Z");
for (int i = 0; i < size; ++i) {
std::pair pair = predictions[i];
- env->SetObjectArrayElement(jclasses, i, env->NewStringUTF(pair.second.c_str()));
- prob.push_back(pair.first);
+ env->CallBooleanMethod(jclasses, java_util_ArrayList_add, env->NewStringUTF(pair.second.c_str()));
+ env->CallBooleanMethod(jprob, java_util_ArrayList_add, env->NewObject(java_lang_Float, java_lang_Float_, pair.first));
}
- env->SetFloatArrayRegion(jprob, 0, size, prob.data());
return size;
}
diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java
index 0ab6d1180d4..c0bb2fc3485 100644
--- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java
+++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java
@@ -12,19 +12,26 @@
*/
package ai.djl.fasttext;
+import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.basicdataset.nlp.CookingStackExchange;
+import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification;
+import ai.djl.fasttext.zoo.nlp.word_embedding.FtWord2VecWordEmbedding;
+import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
+import ai.djl.repository.Artifact;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
+import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.TestRequirements;
import ai.djl.training.TrainingResult;
+import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
@@ -33,6 +40,8 @@
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.Collections;
+import java.util.List;
+import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
@@ -43,30 +52,30 @@ public class CookingStackExchangeTest {
private static final Logger logger = LoggerFactory.getLogger(CookingStackExchangeTest.class);
@Test
- public void testTrainTextClassification() throws IOException {
+ public void testTrainTextClassification() throws IOException, TranslateException {
TestRequirements.notWindows(); // fastText is not supported on windows
- try (FtModel model = new FtModel("cooking")) {
- CookingStackExchange dataset = CookingStackExchange.builder().build();
-
- // setup training configuration
- FtTrainingConfig config =
- FtTrainingConfig.builder()
- .setOutputDir(Paths.get("build"))
- .setModelName("cooking")
- .optEpoch(5)
- .optLoss(FtTrainingConfig.FtLoss.HS)
- .build();
-
- TrainingResult result = model.fit(config, dataset);
- Assert.assertEquals(result.getEpoch(), 5);
- Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin")));
- }
+ CookingStackExchange dataset = CookingStackExchange.builder().build();
+
+ // setup training configuration
+ FtTrainingConfig config =
+ FtTrainingConfig.builder()
+ .setOutputDir(Paths.get("build"))
+ .setModelName("cooking")
+ .optEpoch(5)
+ .optLoss(FtTrainingConfig.FtLoss.HS)
+ .build();
+
+ FtTextClassification block = TrainFastText.textClassification(config, dataset);
+ TrainingResult result = block.getTrainingResult();
+ Assert.assertEquals(result.getEpoch(), 5);
+ Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin")));
}
@Test
public void testTextClassification()
- throws IOException, MalformedModelException, ModelNotFoundException {
+ throws IOException, MalformedModelException, ModelNotFoundException,
+ TranslateException {
TestRequirements.notWindows(); // fastText is not supported on windows
Criteria criteria =
@@ -75,11 +84,18 @@ public void testTextClassification()
.optArtifactId("ai.djl.fasttext:cooking_stackexchange")
.optOption("label-prefix", "__label")
.build();
+ Map> models = ModelZoo.listModels(criteria);
+ models.forEach(
+ (app, list) -> {
+ String appName = app.toString();
+ list.forEach(artifact -> logger.info("{} {}", appName, artifact));
+ });
try (ZooModel model = criteria.loadModel()) {
String input = "Which baking dish is best to bake a banana bread ?";
- FtModel ftModel = (FtModel) model.getWrappedModel();
- Classifications result = ftModel.classify(input, 8);
- Assert.assertEquals(result.item(0).getClassName(), "__bread");
+ try (Predictor predictor = model.newPredictor()) {
+ Classifications result = predictor.predict(input);
+ Assert.assertEquals(result.item(0).getClassName(), "__bread");
+ }
}
}
@@ -95,10 +111,9 @@ public void testWord2Vec() throws IOException, MalformedModelException, ModelNot
try (ZooModel model = criteria.loadModel();
NDManager manager = NDManager.newBaseManager()) {
- FtModel ftModel = (FtModel) model.getWrappedModel();
FtWord2VecWordEmbedding fasttextWord2VecWordEmbedding =
new FtWord2VecWordEmbedding(
- ftModel, new DefaultVocabulary(Collections.singletonList("bread")));
+ model, new DefaultVocabulary(Collections.singletonList("bread")));
long index = fasttextWord2VecWordEmbedding.preprocessWordToEmbed("bread");
NDArray embedding = fasttextWord2VecWordEmbedding.embedWord(manager, index);
Assert.assertEquals(embedding.getShape(), new Shape(100));
@@ -125,7 +140,7 @@ public void testBlazingText() throws IOException, ModelException {
model.load(modelFile);
String text =
"Convair was an american aircraft manufacturing company which later expanded into rockets and spacecraft .";
- Classifications result = model.classify(text, 5);
+ Classifications result = ((FtTextClassification) model.getBlock()).classify(text, 5);
logger.info("{}", result);
Assert.assertEquals(result.item(0).getClassName(), "Company");