Skip to content

Commit

Permalink
[api] Adds ZeroShotClassification support
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Dec 30, 2024
1 parent 9b40f85 commit 94fbc93
Show file tree
Hide file tree
Showing 9 changed files with 744 additions and 0 deletions.
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ public static Application of(String path) {
case "nlp/token_classification":
case "token_classification":
return NLP.TOKEN_CLASSIFICATION;
case "nlp/zero_shot_classification":
case "zero_shot_classification":
return NLP.ZERO_SHOT_CLASSIFICATION;
case "nlp/word_embedding":
case "word_embedding":
return NLP.WORD_EMBEDDING;
Expand Down Expand Up @@ -280,6 +283,12 @@ public interface NLP {
*/
Application SENTIMENT_ANALYSIS = new Application("nlp/sentiment_analysis");

/**
* An application that classifies text into arbitrary label, a specific case of {@link
* #TEXT_CLASSIFICATION}.
*/
Application ZERO_SHOT_CLASSIFICATION = new Application("nlp/zero_shot_classification");

/**
* A natural language understanding application that assigns a label to some tokens in a
* text.
Expand Down
3 changes: 3 additions & 0 deletions api/src/main/java/ai/djl/inference/Predictor.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ public O predict(I input) throws TranslateException {
protected NDList predictInternal(TranslatorContext ctx, NDList ndList)
throws TranslateException {
logger.trace("Predictor input data: {}", ndList);
if (ndList.isEmpty()) {
return new NDList();
}
return block.forward(parameterStore, ndList, false);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.translator;

import ai.djl.modality.Input;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;

import com.google.gson.JsonParseException;
import com.google.gson.annotations.SerializedName;

/** A class that represents a {@code ZeroShotClassificationInput} object. */
public class ZeroShotClassificationInput {

private String text;

@SerializedName("candidate_labels")
private String[] candidates;

@SerializedName("multi_label")
private boolean multiLabel;

@SerializedName("hypothesis_template")
private String hypothesisTemplate;

/**
* Constructs a new {@code ZeroShotClassificationInput} instance.
*
* @param text the text to classify
* @param candidates the candidate labels
*/
public ZeroShotClassificationInput(String text, String[] candidates) {
this(text, candidates, false);
}

/**
* Constructs a new {@code ZeroShotClassificationInput} instance.
*
* @param text the text to classify
* @param candidates the candidate labels
* @param multiLabel true to classify multiple labels
*/
public ZeroShotClassificationInput(String text, String[] candidates, boolean multiLabel) {
this(text, candidates, multiLabel, null);
}

/**
* Constructs a new {@code ZeroShotClassificationInput} instance.
*
* @param text the text to classify
* @param candidates the candidate labels
* @param multiLabel true to classify multiple labels
* @param hypothesisTemplate the custom template
*/
public ZeroShotClassificationInput(
String text, String[] candidates, boolean multiLabel, String hypothesisTemplate) {
this.text = text;
this.candidates = candidates;
this.multiLabel = multiLabel;
this.hypothesisTemplate = hypothesisTemplate;
}

/**
* Returns the {@code ZeroShotClassificationInput} from the {@link Input}.
*
* @param input the input object
* @return the {@code ZeroShotClassificationInput} from the {@link Input}
* @throws TranslateException if the input is invalid
*/
public static ZeroShotClassificationInput parseInput(Input input) throws TranslateException {
String text = input.getData().getAsString();
try {
return JsonUtils.GSON.fromJson(text, ZeroShotClassificationInput.class);
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
}

/**
* Returns the text.
*
* @return the text to be classified
*/
public String getText() {
return text;
}

/**
* Returns the candidate labels.
*
* @return the candidate labels
*/
public String[] getCandidates() {
return candidates;
}

/**
* Returns {@code true} if to classify multiple labels.
*
* @return {@code true} if to classify multiple labels
*/
public boolean isMultiLabel() {
return multiLabel;
}

/**
* Returns the custom template.
*
* @return the custom template
*/
public String getHypothesisTemplate() {
return hypothesisTemplate == null ? "This example is {}." : hypothesisTemplate;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.translator;

/** A class that represents a {@code ZeroShotClassificationOutput} object. */
public class ZeroShotClassificationOutput {

private String sequence;
private String[] labels;
private double[] scores;

/**
* Constructs a new {@code ZeroShotClassificationOutput} instance.
*
* @param sequence the input text
* @param labels the labels
* @param scores the scores of the labels
*/
public ZeroShotClassificationOutput(String sequence, String[] labels, double[] scores) {
this.sequence = sequence;
this.labels = labels;
this.scores = scores;
}

/**
* Returns the input text.
*
* @return the input text
*/
public String getSequence() {
return sequence;
}

/**
* Returns the labels in sorted order.
*
* @return the labels in sorted order
*/
public String[] getLabels() {
return labels;
}

/**
* Returns the scores of the labels.
*
* @return the scores of the labels
*/
public double[] getScores() {
return scores;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.translator;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;

/**
* A {@link Translator} that can handle generic zero-shot-classification {@link Input} and {@link
* Output}.
*/
public class ZeroShotClassificationServingTranslator
implements NoBatchifyTranslator<Input, Output> {

private Translator<ZeroShotClassificationInput, ZeroShotClassificationOutput> translator;

/**
* Constructs a {@code TokenClassificationServingTranslator} instance.
*
* @param translator a {@code Translator} processes token classification input
*/
public ZeroShotClassificationServingTranslator(
Translator<ZeroShotClassificationInput, ZeroShotClassificationOutput> translator) {
this.translator = translator;
}

/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws Exception {
translator.prepare(ctx);
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
if (input.getContent().isEmpty()) {
throw new TranslateException("Input data is empty.");
}

ZeroShotClassificationInput prompt = ZeroShotClassificationInput.parseInput(input);
NDList ret = translator.processInput(ctx, prompt);
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
NDList[] batch = {ret};
return batchifier.batchify(batch);
}
return ret;
}

/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
return output;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package ai.djl.examples.inference.nlp;

import ai.djl.ModelException;
import ai.djl.huggingface.translator.ZeroShotClassificationTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.translator.ZeroShotClassificationInput;
import ai.djl.modality.nlp.translator.ZeroShotClassificationOutput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;

public class ZeroShotClassification {

private static final Logger logger = LoggerFactory.getLogger(ZeroShotClassification.class);

private ZeroShotClassification() {}

public static void main(String[] args) throws ModelException, IOException, TranslateException {
ZeroShotClassificationOutput ret = predict(false);
logger.info("{}", JsonUtils.GSON_PRETTY.toJson(ret));

ret = predict(true);
logger.info("{}", JsonUtils.GSON_PRETTY.toJson(ret));
}

public static ZeroShotClassificationOutput predict(boolean multiLabels)
throws ModelException, IOException, TranslateException {
Path path =
Paths.get(
"/Users/frankliu/source/junkyard/ptest/huggingface/zero-shot-classification/models/model.pt");

Criteria<ZeroShotClassificationInput, ZeroShotClassificationOutput> criteria =
Criteria.builder()
.setTypes(
ZeroShotClassificationInput.class,
ZeroShotClassificationOutput.class)
.optModelPath(path)
.optEngine("PyTorch")
.optTranslatorFactory(new ZeroShotClassificationTranslatorFactory())
.build();
String prompt = "one day I will see the world";
String[] candidates = {"travel", "cooking", "dancing", "exploration"};

try (ZooModel<ZeroShotClassificationInput, ZeroShotClassificationOutput> model =
criteria.loadModel();
Predictor<ZeroShotClassificationInput, ZeroShotClassificationOutput> predictor =
model.newPredictor()) {
ZeroShotClassificationInput input =
new ZeroShotClassificationInput(prompt, candidates, multiLabels);
return predictor.predict(input);
}
}
}
Loading

0 comments on commit 94fbc93

Please sign in to comment.