diff --git a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java index 89de4c254d0..f2bdf29a76a 100644 --- a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java +++ b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java @@ -195,11 +195,19 @@ protected Model createModel( throws IOException { Model model = Model.newInstance(name, device, engine); if (block == null) { - String className = (String) arguments.get("blockFactory"); - BlockFactory factory = - ClassLoaderUtils.findImplementation(modelPath, BlockFactory.class, className); - if (factory != null) { - block = factory.newBlock(model, modelPath, arguments); + Object bf = arguments.get("blockFactory"); + if (bf instanceof BlockFactory) { + block = ((BlockFactory) bf).newBlock(model, modelPath, arguments); + } else { + String className = (String) bf; + BlockFactory factory = + ClassLoaderUtils.findImplementation( + modelPath, BlockFactory.class, className); + if (factory != null) { + block = factory.newBlock(model, modelPath, arguments); + } else if (className != null) { + throw new IllegalArgumentException("Failed to load BlockFactory: " + className); + } } } if (block != null) { diff --git a/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java b/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java index b9f254eea5f..d8b6c9dcc2e 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainWithOptimizers.java @@ -15,7 +15,6 @@ import ai.djl.Model; import ai.djl.ModelException; import ai.djl.basicdataset.cv.classification.Cifar10; -import ai.djl.basicmodelzoo.BasicModelZoo; import ai.djl.basicmodelzoo.cv.classification.ResNetV1; import ai.djl.examples.training.util.Arguments; import ai.djl.metric.Metrics; @@ -25,10 +24,6 @@ import ai.djl.modality.cv.transform.ToTensor; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; -import ai.djl.nn.Blocks; -import ai.djl.nn.SequentialBlock; -import ai.djl.nn.SymbolBlock; -import ai.djl.nn.core.Linear; import ai.djl.repository.zoo.Criteria; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; @@ -53,7 +48,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Map; /** This example features sample usage of a variety of optimizers to train Cifar10. */ public final class TrainWithOptimizers { @@ -97,50 +91,15 @@ public static TrainingResult runExample(String[] args) } private static Model getModel(Arguments arguments) throws IOException, ModelException { - boolean isSymbolic = arguments.isSymbolic(); boolean preTrained = arguments.isPreTrained(); - Map options = arguments.getCriteria(); Criteria.Builder builder = Criteria.builder() .setTypes(Image.class, Classifications.class) .optEngine(arguments.getEngine()) - .optProgress(new ProgressBar()) - .optArtifactId("resnet"); - if (isSymbolic) { - // currently only MxEngine support removeLastBlock - builder.optGroupId("ai.djl.mxnet"); - if (options == null) { - builder.optFilter("layers", "50"); - builder.optFilter("flavor", "v1"); - } else { - builder.optFilters(options); - } - - Model model = builder.build().loadModel(); - SequentialBlock newBlock = new SequentialBlock(); - SymbolBlock block = (SymbolBlock) model.getBlock(); - block.removeLastBlock(); - newBlock.add(block); - // the original model don't include the flatten - // so apply the flatten here - newBlock.add(Blocks.batchFlattenBlock()); - newBlock.add(Linear.builder().setUnits(10).build()); - model.setBlock(newBlock); - if (!preTrained) { - model.getBlock().clear(); - } - return model; - } + .optProgress(new ProgressBar()); // imperative resnet50 if (preTrained) { - builder.optGroupId(BasicModelZoo.GROUP_ID); - if (options == null) { - builder.optFilter("layers", "50"); - builder.optFilter("flavor", "v1"); - builder.optFilter("dataset", "cifar10"); - } else { - builder.optFilters(options); - } + builder.optModelUrls("djl://ai.djl.zoo/resnet/0.0.2/resnetv1"); // load pre-trained imperative ResNet50 from DJL model zoo return builder.build().loadModel(); } else { diff --git a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java index e6e7e33d996..00317f82b99 100644 --- a/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java +++ b/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java @@ -12,11 +12,9 @@ */ package ai.djl.examples.training.transferlearning; -import ai.djl.Application; import ai.djl.Model; import ai.djl.ModelException; import ai.djl.basicdataset.cv.classification.Cifar10; -import ai.djl.basicmodelzoo.BasicModelZoo; import ai.djl.basicmodelzoo.cv.classification.ResNetV1; import ai.djl.examples.training.util.Arguments; import ai.djl.inference.Predictor; @@ -30,10 +28,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; -import ai.djl.nn.Blocks; -import ai.djl.nn.SequentialBlock; -import ai.djl.nn.SymbolBlock; -import ai.djl.nn.core.Linear; +import ai.djl.nn.BlockFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.DefaultTrainingConfig; @@ -55,7 +50,6 @@ import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.Map; /** * An example of training an image classification (ResNet for Cifar10) model. @@ -112,8 +106,7 @@ public static TrainingResult runExample(String[] args) Path modelPath = Paths.get("build/model"); model.save(modelPath, "resnetv1"); - Classifications classifications = - testSaveParameters(model.getBlock(), modelPath, arguments); + Classifications classifications = testSaveParameters(modelPath, arguments); logger.info("Predict result: {}", classifications.topK(3)); return result; } @@ -121,50 +114,15 @@ public static TrainingResult runExample(String[] args) } private static Model getModel(Arguments arguments) throws IOException, ModelException { - boolean isSymbolic = arguments.isSymbolic(); boolean preTrained = arguments.isPreTrained(); - Map options = arguments.getCriteria(); Criteria.Builder builder = Criteria.builder() - .optApplication(Application.CV.IMAGE_CLASSIFICATION) .setTypes(Image.class, Classifications.class) .optEngine(arguments.getEngine()) - .optProgress(new ProgressBar()) - .optArtifactId("resnet"); - if (isSymbolic) { - // load the model - builder.optGroupId("ai.djl.mxnet"); - if (options == null) { - builder.optFilter("layers", "50"); - builder.optFilter("flavor", "v1"); - } else { - builder.optFilters(options); - } - Model model = builder.build().loadModel(); - SequentialBlock newBlock = new SequentialBlock(); - SymbolBlock block = (SymbolBlock) model.getBlock(); - block.removeLastBlock(); - newBlock.add(block); - // the original model don't include the flatten - // so apply the flatten here - newBlock.add(Blocks.batchFlattenBlock()); - newBlock.add(Linear.builder().setUnits(10).build()); - model.setBlock(newBlock); - if (!preTrained) { - model.getBlock().clear(); - } - return model; - } + .optProgress(new ProgressBar()); // imperative resnet50 if (preTrained) { - builder.optGroupId(BasicModelZoo.GROUP_ID); - if (options == null) { - builder.optFilter("layers", "50"); - builder.optFilter("flavor", "v1"); - builder.optFilter("dataset", "cifar10"); - } else { - builder.optFilters(options); - } + builder.optModelUrls("djl://ai.djl.zoo/resnet/0.0.2/resnetv1"); // load pre-trained imperative ResNet50 from DJL model zoo return builder.build().loadModel(); } else { @@ -181,7 +139,7 @@ private static Model getModel(Arguments arguments) throws IOException, ModelExce } } - private static Classifications testSaveParameters(Block block, Path path, Arguments arguments) + static Classifications testSaveParameters(Path path, Arguments arguments) throws IOException, ModelException, TranslateException { String synsetUrl = "https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset_cifar10.txt"; @@ -192,6 +150,13 @@ private static Classifications testSaveParameters(Block block, Path path, Argume .optSynsetUrl(synsetUrl) .optApplySoftmax(true) .build(); + BlockFactory resnetFactory = + (model, modelPath, arguments1) -> + ResNetV1.builder() + .setImageShape(new Shape(3, 32, 32)) + .setNumLayers(50) + .setOutSize(10) + .build(); Image img = ImageFactory.getInstance().fromUrl("src/test/resources/airplane1.png"); @@ -201,7 +166,7 @@ private static Classifications testSaveParameters(Block block, Path path, Argume .optModelPath(path) .optEngine(arguments.getEngine()) .optTranslator(translator) - .optBlock(block) + .optArgument("blockFactory", resnetFactory) .optModelName("resnetv1") .build(); diff --git a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java index 5a0048226c4..673698df123 100644 --- a/examples/src/main/java/ai/djl/examples/training/util/Arguments.java +++ b/examples/src/main/java/ai/djl/examples/training/util/Arguments.java @@ -33,7 +33,6 @@ public class Arguments { protected int epoch; protected int batchSize; protected int maxGpus; - protected boolean isSymbolic; protected boolean preTrained; protected String outputDir; protected long limit; @@ -60,7 +59,6 @@ protected void setCmd(CommandLine cmd) { } else { batchSize = maxGpus > 0 ? 32 * maxGpus : 32; } - isSymbolic = cmd.hasOption("symbolic-model"); preTrained = cmd.hasOption("pre-trained"); if (cmd.hasOption("output-dir")) { @@ -126,12 +124,6 @@ public Options getOptions() { .argName("MAXGPUS") .desc("Max number of GPUs to use for training") .build()); - options.addOption( - Option.builder("s") - .longOpt("symbolic-model") - .argName("SYMBOLIC") - .desc("Use symbolic model, use imperative model if false") - .build()); options.addOption( Option.builder("p") .longOpt("pre-trained") @@ -190,10 +182,6 @@ public Device[] getMaxGpus() { return Engine.getEngine(engine).getDevices(maxGpus); } - public boolean isSymbolic() { - return isSymbolic; - } - public boolean isPreTrained() { return preTrained; } diff --git a/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java b/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java index f33be26de90..13cb4b3f474 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainResNetTest.java @@ -26,40 +26,20 @@ import java.io.IOException; public class TrainResNetTest { + private static final int SEED = 1234; @Test public void testTrainResNet() throws ModelException, IOException, TranslateException { - TestRequirements.linux(); - - // Limit max 4 gpu for cifar10 training to make it converge faster. - // and only train 10 batch for unit test. - // only MXNet support symbolic model - String[] args = {"-e", "2", "-g", "4", "-m", "10", "-s", "-p", "--engine", "MXNet"}; - - TrainingResult result = TrainResnetWithCifar10.runExample(args); - Assert.assertNotNull(result); - } - - @Test - public void testTrainResNetSymbolicNightly() - throws ModelException, IOException, TranslateException { - TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.gpu("MXNet"); // Limit max 4 gpu for cifar10 training to make it converge faster. // and only train 10 batch for unit test. // only MXNet support symbolic model - String[] args = {"-e", "10", "-g", "4", "-s", "-p", "--engine", "MXNet"}; - - Engine.getEngine("MXNet").setRandomSeed(SEED); + String[] args = {"-e", "2", "-g", "4", "-m", "10", "-p"}; TrainingResult result = TrainResnetWithCifar10.runExample(args); - Assert.assertNotNull(result); - Assert.assertTrue(result.getTrainEvaluation("Accuracy") >= 0.8f); - Assert.assertTrue(result.getValidateEvaluation("Accuracy") >= 0.68f); - Assert.assertTrue(result.getValidateLoss() < 1.1); + Assert.assertNotNull(result); } @Test @@ -67,13 +47,13 @@ public void testTrainResNetImperativeNightly() throws ModelException, IOException, TranslateException { TestRequirements.linux(); TestRequirements.nightly(); - TestRequirements.gpu("MXNet"); + TestRequirements.gpu("PyTorch"); // Limit max 4 gpu for cifar10 training to make it converge faster. // and only train 10 batch for unit test. - String[] args = {"-e", "10", "-g", "4", "--engine", "MXNet"}; + String[] args = {"-e", "10", "-g", "4"}; - Engine.getEngine("MXNet").setRandomSeed(SEED); + Engine.getEngine("PyTorch").setRandomSeed(SEED); TrainingResult result = TrainResnetWithCifar10.runExample(args); Assert.assertNotNull(result);