From f602919c1fd83d984316abe419bdcec291193d52 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 18 Jul 2024 19:24:12 -0700 Subject: [PATCH] [serving] Download model while initialize multi-node cluster --- .../main/java/ai/djl/serving/ModelServer.java | 226 ++------------- .../serving/http/ClusterRequestHandler.java | 20 ++ .../java/ai/djl/serving/util/ModelStore.java | 274 ++++++++++++++++++ .../java/ai/djl/serving/ModelServerTest.java | 3 +- .../java/ai/djl/serving/wlm/LmiUtils.java | 2 +- .../java/ai/djl/serving/wlm/ModelInfo.java | 19 +- 6 files changed, 332 insertions(+), 212 deletions(-) create mode 100644 serving/src/main/java/ai/djl/serving/util/ModelStore.java diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index d9bc956a9..b527ca871 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -13,6 +13,7 @@ package ai.djl.serving; import ai.djl.Device; +import ai.djl.ModelException; import ai.djl.engine.Engine; import ai.djl.engine.EngineException; import ai.djl.metric.Dimension; @@ -20,7 +21,6 @@ import ai.djl.metric.Unit; import ai.djl.modality.Input; import ai.djl.modality.Output; -import ai.djl.repository.FilenameUtils; import ai.djl.serving.http.ServerStartupException; import ai.djl.serving.models.ModelManager; import ai.djl.serving.plugins.DependencyManager; @@ -28,13 +28,11 @@ import ai.djl.serving.util.ClusterConfig; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.Connector; +import ai.djl.serving.util.ModelStore; import ai.djl.serving.util.ServerGroups; -import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.WorkerPoolConfig; import ai.djl.serving.workflow.BadWorkflowException; import ai.djl.serving.workflow.Workflow; -import ai.djl.serving.workflow.WorkflowDefinition; -import ai.djl.util.RandomUtils; -import ai.djl.util.Utils; import ai.djl.util.cuda.CudaUtils; import io.netty.bootstrap.ServerBootstrap; @@ -55,37 +53,21 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.BufferedWriter; import java.io.IOException; import java.lang.management.MemoryUsage; -import java.net.MalformedURLException; -import java.net.URI; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; import java.security.GeneralSecurityException; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; import java.util.List; -import java.util.Objects; -import java.util.Properties; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import java.util.stream.Stream; /** The main entry point for model server. */ public class ModelServer { private static final Logger logger = LoggerFactory.getLogger(ModelServer.class); private static final Logger SERVER_METRIC = LoggerFactory.getLogger("server_metric"); - private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[?([^?]+?)]?=)?(.+)"); private ServerGroups serverGroups; private List futures = new ArrayList<>(2); @@ -204,11 +186,16 @@ public List start() pluginManager.loadPlugins(true); - initMultiNode(); - try { - initModelStore(); - } catch (BadWorkflowException | CompletionException e) { + ModelStore modelStore = ModelStore.getInstance(); + modelStore.initialize(); + + List workflows = modelStore.getWorkflows(); + + initMultiNode(workflows); + + loadModels(workflows); + } catch (BadWorkflowException | ModelException | CompletionException e) { throw new ServerStartupException( "Failed to initialize startup models and workflows", e); } @@ -276,11 +263,12 @@ public void stop() { serverGroups.reset(); } - private void initMultiNode() + private void initMultiNode(List workflows) throws GeneralSecurityException, IOException, InterruptedException, - ServerStartupException { + ServerStartupException, + ModelException { ClusterConfig cc = ClusterConfig.getInstance(); int clusterSize = cc.getClusterSize(); if (clusterSize > 1) { @@ -293,7 +281,12 @@ private void initMultiNode() ChannelFuture future = initializeServer(multiNodeConnector, serverGroup, workerGroup); - // start download model here + // download the models + for (Workflow workflow : workflows) { + for (WorkerPoolConfig model : workflow.getWpcs()) { + model.initialize(); + } + } cc.countDown(); logger.info("Waiting for all worker nodes ready ..."); @@ -369,116 +362,9 @@ private ChannelFuture initializeServer( return f; } - private void initModelStore() throws IOException, BadWorkflowException { - Set startupModels = ModelManager.getInstance().getStartupWorkflows(); - - String loadModels = configManager.getLoadModels(); - Path modelStore = configManager.getModelStore(); - if (loadModels == null || loadModels.isEmpty()) { - loadModels = "ALL"; - } - + private void loadModels(List workflows) { ModelManager modelManager = ModelManager.getInstance(); - Set urls = new HashSet<>(); - if ("NONE".equalsIgnoreCase(loadModels)) { - // to disable load all models from model store - return; - } else if ("ALL".equalsIgnoreCase(loadModels)) { - if (modelStore == null) { - logger.warn("Model store is not configured."); - return; - } - - if (Files.isDirectory(modelStore)) { - // contains only directory or archive files - boolean isMultiModelsDirectory = - Files.list(modelStore) - .filter(p -> !p.getFileName().toString().startsWith(".")) - .allMatch( - p -> - Files.isDirectory(p) - || FilenameUtils.isArchiveFile( - p.toString())); - - if (isMultiModelsDirectory) { - // Check folders to see if they can be models as well - try (Stream stream = Files.list(modelStore)) { - urls.addAll( - stream.map(this::mapModelUrl) - .filter(Objects::nonNull) - .collect(Collectors.toList())); - } - } else { - // Check if root model store folder contains a model - String url = mapModelUrl(modelStore); - if (url != null) { - urls.add(url); - } - } - } else { - logger.warn("Model store path is not found: {}", modelStore); - } - } else { - String[] modelsUrls = loadModels.split("[, ]+"); - urls.addAll(Arrays.asList(modelsUrls)); - } - - String huggingFaceModelId = Utils.getEnvOrSystemProperty("HF_MODEL_ID"); - if (huggingFaceModelId != null) { - urls.add(createHuggingFaceModel(huggingFaceModelId)); - } - - for (String url : urls) { - logger.info("Initializing model: {}", url); - Matcher matcher = MODEL_STORE_PATTERN.matcher(url); - if (!matcher.matches()) { - throw new AssertionError("Invalid model store url: " + url); - } - String endpoint = matcher.group(2); - String modelUrl = matcher.group(3); - String version = null; - String engineName = null; - String deviceMapping = null; - String modelName = null; - if (endpoint != null) { - String[] tokens = endpoint.split(":", -1); - modelName = tokens[0]; - if (tokens.length > 1) { - version = tokens[1].isEmpty() ? null : tokens[1]; - } - if (tokens.length > 2) { - engineName = tokens[2].isEmpty() ? null : tokens[2]; - } - if (tokens.length > 3) { - deviceMapping = tokens[3]; - } - } - - Workflow workflow; - URI uri = WorkflowDefinition.toWorkflowUri(modelUrl); - if (uri != null) { - workflow = WorkflowDefinition.parse(modelName, uri).toWorkflow(); - } else { - if (modelName == null) { - modelName = ModelInfo.inferModelNameFromUrl(modelUrl); - } - ModelInfo modelInfo = - new ModelInfo<>( - modelName, - modelUrl, - version, - engineName, - deviceMapping, - Input.class, - Output.class, - -1, - -1, - -1, - -1, - -1, - -1); - workflow = new Workflow(modelInfo); - } + for (Workflow workflow : workflows) { CompletableFuture f = modelManager.registerWorkflow(workflow); f.exceptionally( t -> { @@ -499,33 +385,6 @@ private void initModelStore() throws IOException, BadWorkflowException { if (configManager.waitModelLoading()) { f.join(); } - startupModels.add(modelName); - } - } - - String mapModelUrl(Path path) { - try { - if (!Files.exists(path) - || Files.isHidden(path) - || (!Files.isDirectory(path) - && !FilenameUtils.isArchiveFile(path.toString()))) { - return null; - } - - if (Files.list(path).findFirst().isEmpty()) { - return null; - } - - path = Utils.getNestedModelDir(path); - String url = path.toUri().toURL().toString(); - String modelName = ModelInfo.inferModelNameFromUrl(url); - logger.info("Found model {}={}", modelName, url); - return modelName + '=' + url; - } catch (MalformedURLException e) { - throw new AssertionError("Invalid path: " + path, e); - } catch (IOException e) { - logger.warn("Failed to access file: {}", path, e); - return null; } } @@ -535,43 +394,4 @@ private static void printHelp(String msg, Options options) { formatter.setWidth(120); formatter.printHelp(msg, options); } - - private String createHuggingFaceModel(String modelId) throws IOException { - if (modelId.startsWith("djl://") || modelId.startsWith("s3://")) { - return modelId; - } - Path path = Paths.get(modelId); - if (Files.exists(path)) { - // modelId point to a local file - return mapModelUrl(path); - } - - // TODO: Download the full model from HF - String hash = Utils.hash(modelId); - String downloadDir = Utils.getenv("SERVING_DOWNLOAD_DIR", null); - Path parent = downloadDir == null ? Utils.getCacheDir() : Paths.get(downloadDir); - Path huggingFaceModelDir = parent.resolve(hash); - String modelName = modelId.replaceAll("(\\W|^_)", "_"); - if (Files.exists(huggingFaceModelDir)) { - logger.warn("HuggingFace Model {} already exists, use random model name", modelId); - return modelName + '_' + RandomUtils.nextInt() + '=' + huggingFaceModelDir; - } - String huggingFaceModelRevision = Utils.getEnvOrSystemProperty("HF_REVISION"); - Properties huggingFaceProperties = new Properties(); - huggingFaceProperties.put("option.model_id", modelId); - if (huggingFaceModelRevision != null) { - huggingFaceProperties.put("option.revision", huggingFaceModelRevision); - } - String task = Utils.getEnvOrSystemProperty("HF_TASK"); - if (task != null) { - huggingFaceProperties.put("option.task", task); - } - Files.createDirectories(huggingFaceModelDir); - Path propertiesFile = huggingFaceModelDir.resolve("serving.properties"); - try (BufferedWriter writer = Files.newBufferedWriter(propertiesFile)) { - huggingFaceProperties.store(writer, null); - } - logger.debug("Created serving.properties for model at path {}", propertiesFile); - return modelName + '=' + huggingFaceModelDir; - } } diff --git a/serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java index 2f7fb0ff7..063d60db7 100644 --- a/serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java @@ -13,8 +13,14 @@ package ai.djl.serving.http; import ai.djl.ModelException; +import ai.djl.modality.Input; +import ai.djl.modality.Output; import ai.djl.serving.util.ClusterConfig; +import ai.djl.serving.util.ModelStore; import ai.djl.serving.util.NettyUtils; +import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.WorkerPoolConfig; +import ai.djl.serving.workflow.Workflow; import ai.djl.util.Utils; import io.netty.channel.ChannelHandlerContext; @@ -30,6 +36,8 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** A class handling inbound HTTP requests for the cluster management API. */ public class ClusterRequestHandler extends HttpRequestHandler { @@ -65,6 +73,18 @@ protected void handleRequest( } NettyUtils.sendFile(ctx, file, false); return; + case "models": + ModelStore modelStore = ModelStore.getInstance(); + List workflows = modelStore.getWorkflows(); + Map map = new ConcurrentHashMap<>(); + for (Workflow workflow : workflows) { + for (WorkerPoolConfig wpc : workflow.getWpcs()) { + ModelInfo model = (ModelInfo) wpc; + map.put(model.getId(), model.getModelUrl()); + } + } + NettyUtils.sendJsonResponse(ctx, map); + return; case "status": List messages = decoder.parameters().get("message"); if (messages.size() != 1) { diff --git a/serving/src/main/java/ai/djl/serving/util/ModelStore.java b/serving/src/main/java/ai/djl/serving/util/ModelStore.java new file mode 100644 index 000000000..8f2304372 --- /dev/null +++ b/serving/src/main/java/ai/djl/serving/util/ModelStore.java @@ -0,0 +1,274 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.util; + +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.repository.FilenameUtils; +import ai.djl.serving.ModelServer; +import ai.djl.serving.models.ModelManager; +import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.workflow.BadWorkflowException; +import ai.djl.serving.workflow.Workflow; +import ai.djl.serving.workflow.WorkflowDefinition; +import ai.djl.util.RandomUtils; +import ai.djl.util.Utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** A class represent model server's model store. */ +public final class ModelStore { + + private static final Logger logger = LoggerFactory.getLogger(ModelServer.class); + private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[?([^?]+?)]?=)?(.+)"); + + private static final ModelStore INSTANCE = new ModelStore(); + + private List workflows; + + private ModelStore() { + workflows = new ArrayList<>(); + } + + /** + * Returns the {@code ModelStore} singleton instance. + * + * @return the {@code ModelStore} singleton instance + */ + public static ModelStore getInstance() { + return INSTANCE; + } + + /** + * Initializes the model store. + * + * @throws IOException if failed read model from file system + * @throws BadWorkflowException if failed parse workflow definition + */ + public void initialize() throws IOException, BadWorkflowException { + workflows.clear(); + Set startupModels = ModelManager.getInstance().getStartupWorkflows(); + ConfigManager configManager = ConfigManager.getInstance(); + + String loadModels = configManager.getLoadModels(); + Path modelStore = configManager.getModelStore(); + if (loadModels == null || loadModels.isEmpty()) { + loadModels = "ALL"; + } + + Set urls = new HashSet<>(); + if ("NONE".equalsIgnoreCase(loadModels)) { + // to disable load all models from model store + return; + } else if ("ALL".equalsIgnoreCase(loadModels)) { + if (modelStore == null) { + logger.warn("Model store is not configured."); + return; + } + + if (Files.isDirectory(modelStore)) { + // contains only directory or archive files + boolean isMultiModelsDirectory; + try (Stream stream = Files.list(modelStore)) { + isMultiModelsDirectory = + stream.filter(p -> !p.getFileName().toString().startsWith(".")) + .allMatch( + p -> + Files.isDirectory(p) + || FilenameUtils.isArchiveFile( + p.toString())); + } + + if (isMultiModelsDirectory) { + // Check folders to see if they can be models as well + try (Stream stream = Files.list(modelStore)) { + urls.addAll( + stream.map(ModelStore::mapModelUrl) + .filter(Objects::nonNull) + .collect(Collectors.toList())); + } + } else { + // Check if root model store folder contains a model + String url = mapModelUrl(modelStore); + if (url != null) { + urls.add(url); + } + } + } else { + logger.warn("Model store path is not found: {}", modelStore); + } + } else { + String[] modelsUrls = loadModels.split("[, ]+"); + urls.addAll(Arrays.asList(modelsUrls)); + } + + String huggingFaceModelId = Utils.getEnvOrSystemProperty("HF_MODEL_ID"); + if (huggingFaceModelId != null) { + urls.add(createHuggingFaceModel(huggingFaceModelId)); + } + + for (String url : urls) { + logger.info("Initializing model: {}", url); + Matcher matcher = MODEL_STORE_PATTERN.matcher(url); + if (!matcher.matches()) { + throw new AssertionError("Invalid model store url: " + url); + } + String endpoint = matcher.group(2); + String modelUrl = matcher.group(3); + String version = null; + String engineName = null; + String deviceMapping = null; + String modelName = null; + if (endpoint != null) { + String[] tokens = endpoint.split(":", -1); + modelName = tokens[0]; + if (tokens.length > 1) { + version = tokens[1].isEmpty() ? null : tokens[1]; + } + if (tokens.length > 2) { + engineName = tokens[2].isEmpty() ? null : tokens[2]; + } + if (tokens.length > 3) { + deviceMapping = tokens[3]; + } + } + + URI uri = WorkflowDefinition.toWorkflowUri(modelUrl); + if (uri != null) { + workflows.add(WorkflowDefinition.parse(modelName, uri).toWorkflow()); + } else { + if (modelName == null) { + modelName = ModelInfo.inferModelNameFromUrl(modelUrl); + } + ModelInfo modelInfo = + new ModelInfo<>( + modelName, + modelUrl, + version, + engineName, + deviceMapping, + Input.class, + Output.class, + -1, + -1, + -1, + -1, + -1, + -1); + workflows.add(new Workflow(modelInfo)); + } + startupModels.add(modelName); + } + } + + /** + * Returns a list of workflows to be loaded on startup. + * + * @return a list of workflows to be loaded on startup + */ + public List getWorkflows() { + return workflows; + } + + /** + * Maps model directory to model url. + * + * @param path the model directory + * @return the mapped model url + */ + public static String mapModelUrl(Path path) { + try { + if (!Files.exists(path) + || Files.isHidden(path) + || (!Files.isDirectory(path) + && !FilenameUtils.isArchiveFile(path.toString()))) { + return null; + } + try (Stream stream = Files.list(path)) { + if (stream.findFirst().isEmpty()) { + return null; + } + } + + path = Utils.getNestedModelDir(path); + String url = path.toUri().toURL().toString(); + String modelName = ModelInfo.inferModelNameFromUrl(url); + logger.info("Found model {}={}", modelName, url); + return modelName + '=' + url; + } catch (MalformedURLException e) { + throw new AssertionError("Invalid path: " + path, e); + } catch (IOException e) { + logger.warn("Failed to access file: {}", path, e); + return null; + } + } + + private String createHuggingFaceModel(String modelId) throws IOException { + if (modelId.startsWith("djl://") || modelId.startsWith("s3://")) { + return modelId; + } + Path path = Paths.get(modelId); + if (Files.exists(path)) { + // modelId point to a local file + return mapModelUrl(path); + } + + // TODO: Download the full model from HF + String hash = Utils.hash(modelId); + String downloadDir = Utils.getenv("SERVING_DOWNLOAD_DIR", null); + Path parent = downloadDir == null ? Utils.getCacheDir() : Paths.get(downloadDir); + Path huggingFaceModelDir = parent.resolve(hash); + String modelName = modelId.replaceAll("(\\W|^_)", "_"); + if (Files.exists(huggingFaceModelDir)) { + logger.warn("HuggingFace Model {} already exists, use random model name", modelId); + return modelName + '_' + RandomUtils.nextInt() + '=' + huggingFaceModelDir; + } + String huggingFaceModelRevision = Utils.getEnvOrSystemProperty("HF_REVISION"); + Properties huggingFaceProperties = new Properties(); + huggingFaceProperties.put("option.model_id", modelId); + if (huggingFaceModelRevision != null) { + huggingFaceProperties.put("option.revision", huggingFaceModelRevision); + } + String task = Utils.getEnvOrSystemProperty("HF_TASK"); + if (task != null) { + huggingFaceProperties.put("option.task", task); + } + Files.createDirectories(huggingFaceModelDir); + Path propertiesFile = huggingFaceModelDir.resolve("serving.properties"); + try (BufferedWriter writer = Files.newBufferedWriter(propertiesFile)) { + huggingFaceProperties.store(writer, null); + } + logger.debug("Created serving.properties for model at path {}", propertiesFile); + return modelName + '=' + huggingFaceModelDir; + } +} diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 6f0aecaf8..06fe40c86 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -36,6 +36,7 @@ import ai.djl.serving.models.ModelManager; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.Connector; +import ai.djl.serving.util.ModelStore; import ai.djl.serving.wlm.util.EventManager; import ai.djl.serving.wlm.util.ModelServerListenerAdapter; import ai.djl.util.JsonUtils; @@ -214,7 +215,7 @@ public void test() try { EventManager.getInstance().addListener(new Listener()); Path notModel = Paths.get("build/non-model"); - String url = server.mapModelUrl(notModel); // not a model dir + String url = ModelStore.mapModelUrl(notModel); // not a model dir assertNull(url); assertTrue(server.isRunning()); diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index 92971e765..18df917e2 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -165,7 +165,7 @@ static void convertOnnxModel(ModelInfo info) throws IOException { modelId = repo.toString(); } String optimization = info.prop.getProperty("option.optimization"); - info.modelUrl = convertOnnx(modelId, optimization).toUri().toURL().toString(); + info.resolvedModelUrl = convertOnnx(modelId, optimization).toUri().toURL().toString(); } private static Path convertOnnx(String modelId, String optimization) throws IOException { diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index 5f5af488c..0e280b7a1 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -88,6 +88,7 @@ public final class ModelInfo extends WorkerPoolConfig { private boolean dynamicAdapters; transient Path modelDir; + transient String resolvedModelUrl; private transient String artifactName; transient Path downloadDir; @@ -231,7 +232,7 @@ public void load(Device device) throws ModelException, IOException { builder = Criteria.builder() .setTypes(inputClass, outputClass) - .optModelUrls(modelUrl) + .optModelUrls(resolvedModelUrl) .optModelName(modelName) .optEngine(engineName) .optFilters(filters) @@ -502,6 +503,7 @@ public void initialize() throws IOException, ModelException { if (initialize) { return; } + this.resolvedModelUrl = modelUrl; if (adapters == null) { adapters = new ConcurrentHashMap<>(); } @@ -694,6 +696,9 @@ private String inferEngine() throws ModelException { } String prefix = prop.getProperty("option.modelName", artifactName); + if (prefix == null) { + prefix = modelDir.toFile().getName(); + } if (Files.isRegularFile(modelDir.resolve("metadata.yaml"))) { eng = SageMakerUtils.inferSageMakerEngine(this); if (eng != null) { @@ -762,15 +767,15 @@ private boolean isPythonModel(String prefix) { } private void downloadModel() throws ModelException, IOException { - if (modelUrl.startsWith("s3://")) { - modelDir = downloadS3ToDownloadDir(modelUrl); - modelUrl = modelDir.toUri().toURL().toString(); + if (resolvedModelUrl.startsWith("s3://")) { + modelDir = downloadS3ToDownloadDir(resolvedModelUrl); + resolvedModelUrl = modelDir.toUri().toURL().toString(); return; } - Repository repository = Repository.newInstance("modelStore", modelUrl); + Repository repository = Repository.newInstance("modelStore", resolvedModelUrl); List mrls = repository.getResources(); if (mrls.isEmpty()) { - throw new ModelNotFoundException("Invalid model url: " + modelUrl); + throw new ModelNotFoundException("Invalid model url: " + resolvedModelUrl); } Artifact artifact = mrls.get(0).getDefaultArtifact(); @@ -1115,7 +1120,7 @@ void downloadS3() throws ModelException, IOException { this.downloadDir = downloadS3ToDownloadDir(modelId.trim()); } else if (modelId.startsWith("djl://")) { logger.info("{}: djl model zoo url found: {}", uid, modelId); - modelUrl = modelId.trim(); + resolvedModelUrl = modelId.trim(); // download real model from model zoo downloadModel(); }