diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 146c152aa1..a5e2a6ce00 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -67,6 +67,10 @@ public final class ConfigManager { private static final String TS_NETTY_CLIENT_THREADS = "netty_client_threads"; private static final String TS_JOB_QUEUE_SIZE = "job_queue_size"; private static final String TS_NUMBER_OF_GPU = "number_of_gpu"; + + // IPEX config option that can be set at config.properties + private static final String TS_IPEX_ENABLE = "ipex_enable"; + private static final String TS_ASYNC_LOGGING = "async_logging"; private static final String TS_CORS_ALLOWED_ORIGIN = "cors_allowed_origin"; private static final String TS_CORS_ALLOWED_METHODS = "cors_allowed_methods"; @@ -708,7 +712,7 @@ public HashMap getBackendConfiguration() { HashMap config = new HashMap<>(); // Append properties used by backend worker here config.put("TS_DECODE_INPUT_REQUEST", prop.getProperty(TS_DECODE_INPUT_REQUEST, "true")); - + config.put("TS_IPEX_ENABLE", prop.getProperty(TS_IPEX_ENABLE, "false")); return config; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index e1e9407070..272fb14716 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -4,6 +4,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Scanner; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -48,13 +49,13 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup throw new WorkerInitializationException("Failed get TS home directory", e); } - String[] args = new String[6]; - args[0] = EnvironmentUtils.getPythonRunTime(model); - args[1] = new File(workingDir, "ts/model_service_worker.py").getAbsolutePath(); - args[2] = "--sock-type"; - args[3] = connector.getSocketType(); - args[4] = connector.isUds() ? "--sock-name" : "--port"; - args[5] = connector.getSocketPath(); + ArrayList argl = new ArrayList(); + argl.add(EnvironmentUtils.getPythonRunTime(model)); + argl.add(new File(workingDir, "ts/model_service_worker.py").getAbsolutePath()); + argl.add("--sock-type"); + argl.add(connector.getSocketType()); + argl.add(connector.isUds() ? "--sock-name" : "--port"); + argl.add(connector.getSocketPath()); String[] envp = EnvironmentUtils.getEnvString( @@ -65,6 +66,9 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup try { latch = new CountDownLatch(1); + String[] args = argl.toArray(new String[argl.size()]); + logger.debug("Worker cmdline: {}", argl.toString()); + synchronized (this) { process = Runtime.getRuntime().exec(args, envp, modelPath); diff --git a/ts/arg_parser.py b/ts/arg_parser.py index 87931f24db..97a59c9d28 100644 --- a/ts/arg_parser.py +++ b/ts/arg_parser.py @@ -55,8 +55,7 @@ def ts_parser(): action='store_true') parser.add_argument('--plugins-path', '--ppath', dest='plugins_path', - help='plugin jars to be included in torchserve class path', - ) + help='plugin jars to be included in torchserve class path') return parser diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 199b4a8894..c3d3e59543 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -2,17 +2,25 @@ Base default handler to load torchscript or eager mode [state_dict] models Also, provides handle method per torch serve custom model specification """ + import abc import logging import os import importlib.util import time import torch - from ..utils.util import list_classes_from_module, load_label_mapping + logger = logging.getLogger(__name__) +ipex_enabled = False +if os.environ.get("TS_IPEX_ENABLE", "false") == "true": + try: + import intel_extension_for_pytorch as ipex + ipex_enabled = True + except ImportError as error: + logger.warning("IPEX was not installed. Please install IPEX if wanted.") class BaseHandler(abc.ABC): """ @@ -33,7 +41,7 @@ def __init__(self): def initialize(self, context): """Initialize function loads the model.pt file and initialized the model object. - First try to load torchscript else load eager mode state_dict based model. + First try to load torchscript else load eager mode state_dict based model. Args: context (context): It is a JSON Object containing information @@ -44,7 +52,8 @@ def initialize(self, context): """ properties = context.system_properties - self.map_location = "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" + self.map_location = "cuda" if torch.cuda.is_available( + ) and properties.get("gpu_id") is not None else "cpu" self.device = torch.device( self.map_location + ":" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None @@ -63,7 +72,8 @@ def initialize(self, context): if model_file: logger.debug("Loading eager model") - self.model = self._load_pickled_model(model_dir, model_file, model_pt_path) + self.model = self._load_pickled_model( + model_dir, model_file, model_pt_path) self.model.to(self.device) else: logger.debug("Loading torchscript model") @@ -73,6 +83,9 @@ def initialize(self, context): self.model = self._load_torchscript_model(model_pt_path) self.model.eval() + if ipex_enabled: + self.model = self.model.to(memory_format=torch.channels_last) + self.model = ipex.optimize(self.model) logger.debug('Model file %s loaded successfully', model_pt_path) @@ -203,7 +216,8 @@ def handle(self, data, context): output = self.explain_handle(data_preprocess, data) stop_time = time.time() - metrics.add_time('HandlerTime', round((stop_time - start_time) * 1000, 2), None, 'ms') + metrics.add_time('HandlerTime', round( + (stop_time - start_time) * 1000, 2), None, 'ms') return output def explain_handle(self, data_preprocess, raw_data): diff --git a/ts/torch_handler/unit_tests/test_mnist_kf.py b/ts/torch_handler/unit_tests/test_mnist_kf.py index 53bb36b04e..8f760d260e 100644 --- a/ts/torch_handler/unit_tests/test_mnist_kf.py +++ b/ts/torch_handler/unit_tests/test_mnist_kf.py @@ -18,8 +18,8 @@ image_processing = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) ]) @pytest.fixture() diff --git a/ts/torch_handler/vision_handler.py b/ts/torch_handler/vision_handler.py index f225ba765b..4a8dcffecd 100644 --- a/ts/torch_handler/vision_handler.py +++ b/ts/torch_handler/vision_handler.py @@ -21,7 +21,7 @@ def initialize(self, context): self.ig = IntegratedGradients(self.model) self.initialized = True properties = context.system_properties - if not properties.get("limit_max_image_pixels") : + if not properties.get("limit_max_image_pixels"): Image.MAX_IMAGE_PIXELS = None def preprocess(self, data):