diff --git a/examples/intel_extension_for_pytorch/README.md b/examples/intel_extension_for_pytorch/README.md new file mode 100644 index 0000000000..6ef003f623 --- /dev/null +++ b/examples/intel_extension_for_pytorch/README.md @@ -0,0 +1,126 @@ +# TorchServe with Intel® Extension for PyTorch* + +TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware. +Here we show how to use TorchServe with IPEX. + +## Contents of this Document +* [Install Intel Extension for PyTorch](#install-intel-extension-for-pytorch) +* [Serving model with Intel Extension for PyTorch](#serving-model-with-intel-extension-for-pytorch) +* [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex) +* [Benchmarking with Launcher](#benchmarking-with-launcher) + + +## Install Intel Extension for PyTorch +Refer to the documentation [here](https://github.com/intel/intel-extension-for-pytorch#installation). + +## Serving model with Intel Extension for PyTorch +After installation, all it needs to be done to use TorchServe with IPEX is to enable it in `config.properties`. +``` +ipex_enable=true +``` +Once IPEX is enabled, deploying IPEX exported model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). Torchserve with IPEX can deploy any model and do inference. + +## Creating and Exporting INT8 model for IPEX +Intel Extension for PyTorch supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX. + +### 1. Creating a serialized file +First create `.pt` serialized file using IPEX INT8 inference. Here we show two examples with BERT and ResNet50. + +#### BERT + +``` +import intel_extension_for_pytorch as ipex +from transformers import AutoModelForSequenceClassification, AutoConfig +import transformers +from datasets import load_dataset +import torch + +# load the model +config = AutoConfig.from_pretrained( + "bert-base-uncased", return_dict=False, torchscript=True, num_labels=2) +model = AutoModelForSequenceClassification.from_pretrained( + "bert-base-uncased", config=config) +model = model.eval() + +max_length = 384 +dummy_tensor = torch.ones((1, max_length), dtype=torch.long) +jit_inputs = (dummy_tensor, dummy_tensor, dummy_tensor) +conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine) + + +# calibration +with torch.no_grad(): + for i in range(100): + with ipex.quantization.calibrate(conf): + model(dummy_tensor, dummy_tensor, dummy_tensor) + +# optionally save the configuraiton for later use +conf.save(‘model_conf.json’, default_recipe=True) + +# conversion +model = ipex.quantization.convert(model, conf, jit_inputs) + +# save to .pt +torch.jit.save(model, 'bert_int8_jit.pt') +``` + +#### ResNet50 + +``` +import intel_extension_for_pytorch as ipex +import torchvision.models as models +import torch +import torch.fx.experimental.optimization as optimization +from copy import deepcopy + + +model = models.resnet50(pretrained=True) +model = model.eval() + +dummy_tensor = torch.randn(1, 3, 224, 224).contiguous(memory_format=torch.channels_last) +jit_inputs = (dummy_tensor) +conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_symmetric) + +with torch.no_grad(): + for i in range(100): + with ipex.quantization.calibrate(conf): + model(dummy_tensor) + +model = ipex.quantization.convert(model, conf, jit_inputs) +torch.jit.save(model, 'rn50_int8_jit.pt') +``` +### 2. Creating a Model Archive +Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal. Use the following command to package the model. +``` +torch-model-archiver --model-name rn50_ipex_int8 --version 1.0 --serialized-file rn50_int8_jit.pt --handler image_classifier +``` +### 3. Start Torchserve to serve the model +Make sure to set `ipex_enable = True` in `config.properties`. Use the following command to start Torchserve with IPEX. +``` +torchserve --start --ncs --model-store model_store --ts-config config.properties +``` + +### 4. Registering and Deploying model +Registering and deploying the model follows the same steps shown [here](https://pytorch.org/serve/use_cases.html). + +## Benchmarking with Launcher +`intel_extension_for_pytorch.cpu.launch` launcher can be used with Torchserve official [benchmark](https://github.com/pytorch/serve/tree/master/benchmarks) to launch server and benchmark requests with optimal configuration on Intel hardware. + +In this section, we provde an example of using launcher to benchmark on a single instance (worker), single socket, and using all physical cores on that socket. This is to avoid thread oversupscription while using all resources. + +### 1. Launcher configuration +All it needs to be done to use Torchserve with launcher is to set its configuration at `config.properties` in the benchmark directory. Note that the number of instance, `-- ninstance` is 1 by default. `--ncore_per_instance` can be set as appropriately by checking the number of cores per socket using `lscpu`. + +For a full list of tunable configuration of launcher, refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) + +``` +ipex_enable = True +cpu_launcher_enable=true +cpu_launcher_args=--ncore_per_instance 28 --socket_id 0 +``` + +### 2. Benchmarking with Launcher +The rest of the steps for benchmarking follows the same steps shown [here](https://github.com/pytorch/serve/tree/master/benchmarks). + +CPU usage is shown as below. +![sample_launcher](https://user-images.githubusercontent.com/93151422/143912711-cacbd38b-4be9-430a-810b-e5d3a9be9732.gif) diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index a5e2a6ce00..7c29879e59 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -70,6 +70,8 @@ public final class ConfigManager { // IPEX config option that can be set at config.properties private static final String TS_IPEX_ENABLE = "ipex_enable"; + private static final String TS_CPU_LAUNCHER_ENABLE = "cpu_launcher_enable"; + private static final String TS_CPU_LAUNCHER_ARGS = "cpu_launcher_args"; private static final String TS_ASYNC_LOGGING = "async_logging"; private static final String TS_CORS_ALLOWED_ORIGIN = "cors_allowed_origin"; @@ -339,6 +341,14 @@ public boolean isMetricApiEnable() { return Boolean.parseBoolean(getProperty(TS_ENABLE_METRICS_API, "true")); } + public boolean isCPULauncherEnabled() { + return Boolean.parseBoolean(getProperty(TS_CPU_LAUNCHER_ENABLE, "false")); + } + + public String getCPULauncherArgs() { + return getProperty(TS_CPU_LAUNCHER_ARGS, null); + } + public int getNettyThreads() { return getIntProperty(TS_NUMBER_OF_NETTY_THREADS, 0); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 272fb14716..22596757be 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -51,6 +51,20 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup ArrayList argl = new ArrayList(); argl.add(EnvironmentUtils.getPythonRunTime(model)); + if (configManager.isCPULauncherEnabled()) { + argl.add("-m"); + argl.add("intel_extension_for_pytorch.cpu.launch"); + argl.add("----ninstance"); + argl.add("1"); + String largs = configManager.getCPULauncherArgs(); + if (largs != null && largs.length() > 1) { + String[] argarray = largs.split(" "); + for (int i = 0; i < argarray.length; i++) { + argl.add(argarray[i]); + } + } + } + argl.add(new File(workingDir, "ts/model_service_worker.py").getAbsolutePath()); argl.add("--sock-type"); argl.add(connector.getSocketType());