Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add config option for ipex (https://github.com/min-jean-cho/serve/blob/ipex_enable/examples/IPEX/README.md) #1319

Merged
merged 3 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ public final class ConfigManager {
private static final String TS_NETTY_CLIENT_THREADS = "netty_client_threads";
private static final String TS_JOB_QUEUE_SIZE = "job_queue_size";
private static final String TS_NUMBER_OF_GPU = "number_of_gpu";

// IPEX config option that can be set at config.properties
private static final String TS_IPEX_ENABLE = "ipex_enable";

private static final String TS_ASYNC_LOGGING = "async_logging";
private static final String TS_CORS_ALLOWED_ORIGIN = "cors_allowed_origin";
private static final String TS_CORS_ALLOWED_METHODS = "cors_allowed_methods";
Expand Down Expand Up @@ -708,7 +712,7 @@ public HashMap<String, String> getBackendConfiguration() {
HashMap<String, String> config = new HashMap<>();
// Append properties used by backend worker here
config.put("TS_DECODE_INPUT_REQUEST", prop.getProperty(TS_DECODE_INPUT_REQUEST, "true"));

config.put("TS_IPEX_ENABLE", prop.getProperty(TS_IPEX_ENABLE, "false"));
return config;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Scanner;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -48,13 +49,13 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup
throw new WorkerInitializationException("Failed get TS home directory", e);
}

String[] args = new String[6];
args[0] = EnvironmentUtils.getPythonRunTime(model);
args[1] = new File(workingDir, "ts/model_service_worker.py").getAbsolutePath();
args[2] = "--sock-type";
args[3] = connector.getSocketType();
args[4] = connector.isUds() ? "--sock-name" : "--port";
args[5] = connector.getSocketPath();
ArrayList<String> argl = new ArrayList<String>();
argl.add(EnvironmentUtils.getPythonRunTime(model));
argl.add(new File(workingDir, "ts/model_service_worker.py").getAbsolutePath());
argl.add("--sock-type");
argl.add(connector.getSocketType());
argl.add(connector.isUds() ? "--sock-name" : "--port");
argl.add(connector.getSocketPath());

String[] envp =
EnvironmentUtils.getEnvString(
Expand All @@ -65,6 +66,9 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup
try {
latch = new CountDownLatch(1);

String[] args = argl.toArray(new String[argl.size()]);
logger.debug("Worker cmdline: {}", argl.toString());

synchronized (this) {
process = Runtime.getRuntime().exec(args, envp, modelPath);

Expand Down
3 changes: 1 addition & 2 deletions ts/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def ts_parser():
action='store_true')
parser.add_argument('--plugins-path', '--ppath',
dest='plugins_path',
help='plugin jars to be included in torchserve class path',
)
help='plugin jars to be included in torchserve class path')

return parser

Expand Down
24 changes: 19 additions & 5 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,25 @@
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""

import abc
import logging
import os
import importlib.util
import time
import torch

from ..utils.util import list_classes_from_module, load_label_mapping


logger = logging.getLogger(__name__)

ipex_enabled = False
if os.environ.get("TS_IPEX_ENABLE", "false") == "true":
try:
import intel_extension_for_pytorch as ipex
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
ipex_enabled = True
except ImportError as error:
logger.warning("IPEX was not installed. Please install IPEX if wanted.")

class BaseHandler(abc.ABC):
"""
Expand All @@ -33,7 +41,7 @@ def __init__(self):

def initialize(self, context):
"""Initialize function loads the model.pt file and initialized the model object.
First try to load torchscript else load eager mode state_dict based model.
First try to load torchscript else load eager mode state_dict based model.

Args:
context (context): It is a JSON Object containing information
Expand All @@ -44,7 +52,8 @@ def initialize(self, context):

"""
properties = context.system_properties
self.map_location = "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
self.map_location = "cuda" if torch.cuda.is_available(
) and properties.get("gpu_id") is not None else "cpu"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
Expand All @@ -63,7 +72,8 @@ def initialize(self, context):

if model_file:
logger.debug("Loading eager model")
self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
self.model = self._load_pickled_model(
model_dir, model_file, model_pt_path)
self.model.to(self.device)
else:
logger.debug("Loading torchscript model")
Expand All @@ -73,6 +83,9 @@ def initialize(self, context):
self.model = self._load_torchscript_model(model_pt_path)

self.model.eval()
if ipex_enabled:
self.model = self.model.to(memory_format=torch.channels_last)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this apply only to vision models?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but benign to non-CNN models.

self.model = ipex.optimize(self.model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@min-jean-cho does it choose optimization level 01 by default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ipex will choose level 'O1' by default.


logger.debug('Model file %s loaded successfully', model_pt_path)

Expand Down Expand Up @@ -203,7 +216,8 @@ def handle(self, data, context):
output = self.explain_handle(data_preprocess, data)

stop_time = time.time()
metrics.add_time('HandlerTime', round((stop_time - start_time) * 1000, 2), None, 'ms')
metrics.add_time('HandlerTime', round(
(stop_time - start_time) * 1000, 2), None, 'ms')
return output

def explain_handle(self, data_preprocess, raw_data):
Expand Down
4 changes: 2 additions & 2 deletions ts/torch_handler/unit_tests/test_mnist_kf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@


image_processing = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

@pytest.fixture()
Expand Down
2 changes: 1 addition & 1 deletion ts/torch_handler/vision_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def initialize(self, context):
self.ig = IntegratedGradients(self.model)
self.initialized = True
properties = context.system_properties
if not properties.get("limit_max_image_pixels") :
if not properties.get("limit_max_image_pixels"):
Image.MAX_IMAGE_PIXELS = None

def preprocess(self, data):
Expand Down