From 5dd996b91216e2544b81e149247e203a8e64e83b Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 24 May 2021 09:13:37 -0700 Subject: [PATCH] [serving] allow load model with specified engine Change-Id: I0c4113a42dc30e31b334f6bb7ef1e35215a88b6d --- .../main/java/ai/djl/serving/ModelServer.java | 1 + .../serving/http/InferenceRequestHandler.java | 2 ++ .../http/ManagementRequestHandler.java | 23 +++++++++++-------- .../java/ai/djl/serving/wlm/ModelManager.java | 3 +++ 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/serving/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/serving/src/main/java/ai/djl/serving/ModelServer.java index 60352bb0a76..84ea4444618 100644 --- a/serving/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -303,6 +303,7 @@ private void initModelStore() throws IOException { modelManager.registerModel( ModelInfo.inferModelNameFromUrl(url), url, + null, configManager.getBatchSize(), configManager.getMaxBatchDelay(), configManager.getMaxIdleTime()); diff --git a/serving/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java b/serving/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java index 2269e65e580..6845ce9a79c 100644 --- a/serving/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java +++ b/serving/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java @@ -146,6 +146,7 @@ private void predict( throw new ModelNotFoundException("Permission denied: " + modelUrl); } } + String engineName = input.getProperty("engine_name", null); logger.info("Loading model {} from: {}", modelName, modelUrl); @@ -153,6 +154,7 @@ private void predict( .registerModel( modelName, modelUrl, + engineName, ConfigManager.getInstance().getBatchSize(), ConfigManager.getInstance().getMaxBatchDelay(), ConfigManager.getInstance().getMaxIdleTime()) diff --git a/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java index ad338102a8d..725ac3f9b6d 100644 --- a/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java +++ b/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java @@ -35,23 +35,25 @@ */ public class ManagementRequestHandler extends HttpRequestHandler { - /** HTTP Paramater "synchronous". */ + /** HTTP Parameter "synchronous". */ private static final String SYNCHRONOUS_PARAMETER = "synchronous"; - /** HTTP Paramater "initial_workers". */ + /** HTTP Parameter "initial_workers". */ private static final String INITIAL_WORKERS_PARAMETER = "initial_workers"; - /** HTTP Paramater "url". */ + /** HTTP Parameter "url". */ private static final String URL_PARAMETER = "url"; - /** HTTP Paramater "batch_size". */ + /** HTTP Parameter "batch_size". */ private static final String BATCH_SIZE_PARAMETER = "batch_size"; - /** HTTP Paramater "model_name". */ + /** HTTP Parameter "model_name". */ private static final String MODEL_NAME_PARAMETER = "model_name"; - /** HTTP Paramater "max_batch_delay". */ + /** HTTP Parameter "model_name". */ + private static final String ENGINE_NAME_PARAMETER = "engine_name"; + /** HTTP Parameter "max_batch_delay". */ private static final String MAX_BATCH_DELAY_PARAMETER = "max_batch_delay"; - /** HTTP Paramater "max_idle_time". */ + /** HTTP Parameter "max_idle_time". */ private static final String MAX_IDLE_TIME__PARAMETER = "max_idle_time"; - /** HTTP Paramater "max_worker". */ + /** HTTP Parameter "max_worker". */ private static final String MAX_WORKER_PARAMETER = "max_worker"; - /** HTTP Paramater "min_worker". */ + /** HTTP Parameter "min_worker". */ private static final String MIN_WORKER_PARAMETER = "min_worker"; private static final Pattern PATTERN = Pattern.compile("^/models([/?].*)?"); @@ -147,6 +149,7 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec if (modelName == null || modelName.isEmpty()) { modelName = ModelInfo.inferModelNameFromUrl(modelUrl); } + String engineName = NettyUtils.getParameter(decoder, ENGINE_NAME_PARAMETER, null); int batchSize = NettyUtils.getIntParameter(decoder, BATCH_SIZE_PARAMETER, 1); int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100); int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME__PARAMETER, 60); @@ -159,7 +162,7 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec final ModelManager modelManager = ModelManager.getInstance(); CompletableFuture future = modelManager.registerModel( - modelName, modelUrl, batchSize, maxBatchDelay, maxIdleTime); + modelName, modelUrl, engineName, batchSize, maxBatchDelay, maxIdleTime); CompletableFuture f = future.thenAccept( modelInfo -> diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java b/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java index 9b2e5f2fa49..d58a4d7f05f 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java @@ -75,6 +75,7 @@ public static ModelManager getInstance() { * * @param modelName the name of the model for HTTP endpoint * @param modelUrl the model url + * @param engineName the engine to load the model * @param batchSize the batch size * @param maxBatchDelay the maximum delay for batching * @param maxIdleTime the maximum idle time of the worker threads before scaling down. @@ -83,6 +84,7 @@ public static ModelManager getInstance() { public CompletableFuture registerModel( final String modelName, final String modelUrl, + final String engineName, final int batchSize, final int maxBatchDelay, final int maxIdleTime) { @@ -93,6 +95,7 @@ public CompletableFuture registerModel( Criteria.builder() .setTypes(Input.class, Output.class) .optModelUrls(modelUrl) + .optEngine(engineName) .build(); ZooModel model = ModelZoo.loadModel(criteria); ModelInfo modelInfo =