Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] Remove symbolic training for MXNet #3299

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,19 @@ protected Model createModel(
throws IOException {
Model model = Model.newInstance(name, device, engine);
if (block == null) {
String className = (String) arguments.get("blockFactory");
BlockFactory factory =
ClassLoaderUtils.findImplementation(modelPath, BlockFactory.class, className);
if (factory != null) {
block = factory.newBlock(model, modelPath, arguments);
Object bf = arguments.get("blockFactory");
if (bf instanceof BlockFactory) {
block = ((BlockFactory) bf).newBlock(model, modelPath, arguments);
} else {
String className = (String) bf;
BlockFactory factory =
ClassLoaderUtils.findImplementation(
modelPath, BlockFactory.class, className);
if (factory != null) {
block = factory.newBlock(model, modelPath, arguments);
} else if (className != null) {
throw new IllegalArgumentException("Failed to load BlockFactory: " + className);
}
}
}
if (block != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.cv.classification.Cifar10;
import ai.djl.basicmodelzoo.BasicModelZoo;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.examples.training.util.Arguments;
import ai.djl.metric.Metrics;
Expand All @@ -25,10 +24,6 @@
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
Expand All @@ -53,7 +48,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;

/** This example features sample usage of a variety of optimizers to train Cifar10. */
public final class TrainWithOptimizers {
Expand Down Expand Up @@ -97,50 +91,15 @@ public static TrainingResult runExample(String[] args)
}

private static Model getModel(Arguments arguments) throws IOException, ModelException {
boolean isSymbolic = arguments.isSymbolic();
boolean preTrained = arguments.isPreTrained();
Map<String, String> options = arguments.getCriteria();
Criteria.Builder<Image, Classifications> builder =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optEngine(arguments.getEngine())
.optProgress(new ProgressBar())
.optArtifactId("resnet");
if (isSymbolic) {
// currently only MxEngine support removeLastBlock
builder.optGroupId("ai.djl.mxnet");
if (options == null) {
builder.optFilter("layers", "50");
builder.optFilter("flavor", "v1");
} else {
builder.optFilters(options);
}

Model model = builder.build().loadModel();
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
newBlock.add(block);
// the original model don't include the flatten
// so apply the flatten here
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);
if (!preTrained) {
model.getBlock().clear();
}
return model;
}
.optProgress(new ProgressBar());
// imperative resnet50
if (preTrained) {
builder.optGroupId(BasicModelZoo.GROUP_ID);
if (options == null) {
builder.optFilter("layers", "50");
builder.optFilter("flavor", "v1");
builder.optFilter("dataset", "cifar10");
} else {
builder.optFilters(options);
}
builder.optModelUrls("djl://ai.djl.zoo/resnet/0.0.2/resnetv1");
// load pre-trained imperative ResNet50 from DJL model zoo
return builder.build().loadModel();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
*/
package ai.djl.examples.training.transferlearning;

import ai.djl.Application;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.cv.classification.Cifar10;
import ai.djl.basicmodelzoo.BasicModelZoo;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.examples.training.util.Arguments;
import ai.djl.inference.Predictor;
Expand All @@ -30,10 +28,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.BlockFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
Expand All @@ -55,7 +50,6 @@
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;

/**
* An example of training an image classification (ResNet for Cifar10) model.
Expand Down Expand Up @@ -112,59 +106,23 @@ public static TrainingResult runExample(String[] args)
Path modelPath = Paths.get("build/model");
model.save(modelPath, "resnetv1");

Classifications classifications =
testSaveParameters(model.getBlock(), modelPath, arguments);
Classifications classifications = testSaveParameters(modelPath, arguments);
logger.info("Predict result: {}", classifications.topK(3));
return result;
}
}
}

private static Model getModel(Arguments arguments) throws IOException, ModelException {
boolean isSymbolic = arguments.isSymbolic();
boolean preTrained = arguments.isPreTrained();
Map<String, String> options = arguments.getCriteria();
Criteria.Builder<Image, Classifications> builder =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
.optEngine(arguments.getEngine())
.optProgress(new ProgressBar())
.optArtifactId("resnet");
if (isSymbolic) {
// load the model
builder.optGroupId("ai.djl.mxnet");
if (options == null) {
builder.optFilter("layers", "50");
builder.optFilter("flavor", "v1");
} else {
builder.optFilters(options);
}
Model model = builder.build().loadModel();
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();
newBlock.add(block);
// the original model don't include the flatten
// so apply the flatten here
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);
if (!preTrained) {
model.getBlock().clear();
}
return model;
}
.optProgress(new ProgressBar());
// imperative resnet50
if (preTrained) {
builder.optGroupId(BasicModelZoo.GROUP_ID);
if (options == null) {
builder.optFilter("layers", "50");
builder.optFilter("flavor", "v1");
builder.optFilter("dataset", "cifar10");
} else {
builder.optFilters(options);
}
builder.optModelUrls("djl://ai.djl.zoo/resnet/0.0.2/resnetv1");
// load pre-trained imperative ResNet50 from DJL model zoo
return builder.build().loadModel();
} else {
Expand All @@ -181,7 +139,7 @@ private static Model getModel(Arguments arguments) throws IOException, ModelExce
}
}

private static Classifications testSaveParameters(Block block, Path path, Arguments arguments)
static Classifications testSaveParameters(Path path, Arguments arguments)
throws IOException, ModelException, TranslateException {
String synsetUrl =
"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset_cifar10.txt";
Expand All @@ -192,6 +150,13 @@ private static Classifications testSaveParameters(Block block, Path path, Argume
.optSynsetUrl(synsetUrl)
.optApplySoftmax(true)
.build();
BlockFactory resnetFactory =
(model, modelPath, arguments1) ->
ResNetV1.builder()
.setImageShape(new Shape(3, 32, 32))
.setNumLayers(50)
.setOutSize(10)
.build();

Image img = ImageFactory.getInstance().fromUrl("src/test/resources/airplane1.png");

Expand All @@ -201,7 +166,7 @@ private static Classifications testSaveParameters(Block block, Path path, Argume
.optModelPath(path)
.optEngine(arguments.getEngine())
.optTranslator(translator)
.optBlock(block)
.optArgument("blockFactory", resnetFactory)
.optModelName("resnetv1")
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ public class Arguments {
protected int epoch;
protected int batchSize;
protected int maxGpus;
protected boolean isSymbolic;
protected boolean preTrained;
protected String outputDir;
protected long limit;
Expand All @@ -60,7 +59,6 @@ protected void setCmd(CommandLine cmd) {
} else {
batchSize = maxGpus > 0 ? 32 * maxGpus : 32;
}
isSymbolic = cmd.hasOption("symbolic-model");
preTrained = cmd.hasOption("pre-trained");

if (cmd.hasOption("output-dir")) {
Expand Down Expand Up @@ -126,12 +124,6 @@ public Options getOptions() {
.argName("MAXGPUS")
.desc("Max number of GPUs to use for training")
.build());
options.addOption(
Option.builder("s")
.longOpt("symbolic-model")
.argName("SYMBOLIC")
.desc("Use symbolic model, use imperative model if false")
.build());
options.addOption(
Option.builder("p")
.longOpt("pre-trained")
Expand Down Expand Up @@ -190,10 +182,6 @@ public Device[] getMaxGpus() {
return Engine.getEngine(engine).getDevices(maxGpus);
}

public boolean isSymbolic() {
return isSymbolic;
}

public boolean isPreTrained() {
return preTrained;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,54 +26,34 @@
import java.io.IOException;

public class TrainResNetTest {

private static final int SEED = 1234;

@Test
public void testTrainResNet() throws ModelException, IOException, TranslateException {
TestRequirements.linux();

// Limit max 4 gpu for cifar10 training to make it converge faster.
// and only train 10 batch for unit test.
// only MXNet support symbolic model
String[] args = {"-e", "2", "-g", "4", "-m", "10", "-s", "-p", "--engine", "MXNet"};

TrainingResult result = TrainResnetWithCifar10.runExample(args);
Assert.assertNotNull(result);
}

@Test
public void testTrainResNetSymbolicNightly()
throws ModelException, IOException, TranslateException {
TestRequirements.linux();
TestRequirements.nightly();
TestRequirements.gpu("MXNet");

// Limit max 4 gpu for cifar10 training to make it converge faster.
// and only train 10 batch for unit test.
// only MXNet support symbolic model
String[] args = {"-e", "10", "-g", "4", "-s", "-p", "--engine", "MXNet"};

Engine.getEngine("MXNet").setRandomSeed(SEED);
String[] args = {"-e", "2", "-g", "4", "-m", "10", "-p"};
TrainingResult result = TrainResnetWithCifar10.runExample(args);
Assert.assertNotNull(result);

Assert.assertTrue(result.getTrainEvaluation("Accuracy") >= 0.8f);
Assert.assertTrue(result.getValidateEvaluation("Accuracy") >= 0.68f);
Assert.assertTrue(result.getValidateLoss() < 1.1);
Assert.assertNotNull(result);
}

@Test
public void testTrainResNetImperativeNightly()
throws ModelException, IOException, TranslateException {
TestRequirements.linux();
TestRequirements.nightly();
TestRequirements.gpu("MXNet");
TestRequirements.gpu("PyTorch");

// Limit max 4 gpu for cifar10 training to make it converge faster.
// and only train 10 batch for unit test.
String[] args = {"-e", "10", "-g", "4", "--engine", "MXNet"};
String[] args = {"-e", "10", "-g", "4"};

Engine.getEngine("MXNet").setRandomSeed(SEED);
Engine.getEngine("PyTorch").setRandomSeed(SEED);
TrainingResult result = TrainResnetWithCifar10.runExample(args);
Assert.assertNotNull(result);

Expand Down
Loading