Skip to content

Commit

Permalink
[lmi][lmi-dist] Add pipeline-parallel support to lmi-dist, and use_pa…
Browse files Browse the repository at this point in the history
…ssive_workers by default (#2445)

Co-authored-by: Siddharth Venkatesan <siddhave@amazon.com>
  • Loading branch information
davidthomas426 and siddvenk authored Oct 17, 2024
1 parent e95e407 commit bfcd7be
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class LmiDistRbProperties(Properties):
enable_prefix_caching: Optional[bool] = False
disable_sliding_window: Optional[bool] = False
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
use_passive_workers: Optional[bool] = True

@model_validator(mode='after')
def validate_mpi(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class VllmRbProperties(Properties):
load_format: Optional[str] = "auto"
quantize: Optional[VllmQuantizeMethods] = None
tensor_parallel_degree: Optional[int] = None
pipeline_parallel_degree: int = 1
max_rolling_batch_prefill_tokens: Optional[int] = None
# Adjustable prefix model length for certain 32k or longer model
max_model_len: Optional[int] = None
Expand Down Expand Up @@ -111,3 +112,11 @@ def validate_speculative_model(self):
"Speculative decoding requires usage of the V2 block manager. Enable it with option.use_v2_block_manager=true."
)
return self

@model_validator(mode='after')
def validate_pipeline_parallel(self):
if self.pipeline_parallel_degree != 1:
raise ValueError(
"Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation"
)
return self
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
enable_prefix_caching=self.lmi_dist_config.enable_prefix_caching,
disable_sliding_window=self.lmi_dist_config.disable_sliding_window,
limit_mm_per_prompt=self.lmi_dist_config.limit_mm_per_prompt,
use_passive_workers=self.lmi_dist_config.use_passive_workers,
**engine_kwargs)

kwargs = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
return EngineArgs(
model=config.model_id_or_path,
tensor_parallel_size=config.tensor_parallel_degree,
pipeline_parallel_size=config.pipeline_parallel_degree,
dtype=DTYPE_MAPPER[config.dtype],
seed=0,
max_model_len=config.max_model_len,
Expand Down
39 changes: 23 additions & 16 deletions engines/python/src/main/java/ai/djl/python/engine/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,21 @@ static String[] getPythonStartCmd(
int clusterSize = PyEnv.getClusterSize();
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
int pipelineParallelDegree = pyEnv.getPipelineParallelDegree();
int worldSize = tensorParallelDegree * pipelineParallelDegree;
String entryPoint = pyEnv.getEntryPoint();
String recommendedEntryPoint = pyEnv.getRecommendedEntryPoint();
String pythonLogLevel = pyEnv.getPythonLogLevel();

if (PyEnv.isMultiNode()) {
int worldSize = tensorParallelDegree * pipelineParallelDegree;
if (tensorParallelDegree * pipelineParallelDegree % clusterSize != 0) {
if (worldSize % clusterSize != 0) {
throw new IllegalArgumentException(
"Error: Cannot use cluster size: "
+ clusterSize
+ "for world size (number of total GPUs): "
+ worldSize);
}

int localSize = (tensorParallelDegree * pipelineParallelDegree) / clusterSize;
int localSize = worldSize / clusterSize;

String cudaDevices = getVisibleDevices(workerId, localSize);
logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices);
Expand Down Expand Up @@ -195,6 +195,7 @@ static String[] getPythonStartCmd(
args[31] = model.getModelPath().toAbsolutePath().toString();
args[32] = "--entry-point";
args[33] = entryPoint == null ? "" : entryPoint;
// TODO: Use mix of Unix and TCP sockets for local/remote processes
args[34] = "--sock-type";
args[35] = "tcp";
args[36] = "--sock-name";
Expand All @@ -213,12 +214,12 @@ static String[] getPythonStartCmd(
args[49] = pythonLogLevel;
return args;
} else if (pyEnv.isMpiMode()) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
String cudaDevices = getVisibleDevices(workerId, worldSize);
logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices);
String[] args = new String[44];
String[] args = new String[46];
args[0] = "mpirun";
args[1] = "-np";
args[2] = String.valueOf(tensorParallelDegree);
args[2] = String.valueOf(worldSize);
args[3] = "--allow-run-as-root";
args[4] = "--bind-to";
args[5] = "none";
Expand Down Expand Up @@ -256,21 +257,27 @@ static String[] getPythonStartCmd(
args[37] = getSocketPath(port);
args[38] = "--tensor-parallel-degree";
args[39] = String.valueOf(tensorParallelDegree);
args[40] = "--recommended-entry-point";
args[41] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
args[42] = "--log-level";
args[43] = pythonLogLevel;
args[40] = "--pipeline-parallel-degree";
args[41] = String.valueOf(pipelineParallelDegree);
args[42] = "--recommended-entry-point";
args[43] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
args[44] = "--log-level";
args[45] = pythonLogLevel;
return args;
}

// TP settings
if (tensorParallelDegree > 0 && device.isGpu()) {
String cudaDevices = getVisibleDevices(deviceId, tensorParallelDegree);
// TP/PP settings
if (worldSize > 0 && device.isGpu()) {
String cudaDevices = getVisibleDevices(deviceId, worldSize);
deviceId = 0; // re-map logic device to 0
pyEnv.addEnv("CUDA_VISIBLE_DEVICES", cudaDevices);
logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices);
}
if ("nc".equals(device.getDeviceType())) {
if (pipelineParallelDegree > 1) {
throw new IllegalArgumentException(
"Error: Neuron does not currently support pipeline parallel degree > 1");
}
String visibleCores = getNeuronVisibleCores(deviceId, tensorParallelDegree);
// TODO: re-map logic device once neuron fixed bug
pyEnv.addEnv("NEURON_RT_VISIBLE_CORES", visibleCores);
Expand Down Expand Up @@ -303,19 +310,19 @@ static String[] getPythonStartCmd(
return args;
}

private static String getVisibleDevices(int deviceId, int tensorParallelDegree) {
private static String getVisibleDevices(int deviceId, int localDevicesPerWorker) {
StringBuilder sb = new StringBuilder(20);
// CUDA_VISIBLE_DEVICES=0,2,3,7 TP2
// -> 0,2 and 3,7
if (Utils.getenv("CUDA_VISIBLE_DEVICES") != null) {
String[] devices = Utils.getenv("CUDA_VISIBLE_DEVICES").split(",");
sb.append(devices[deviceId]);
for (int i = 1; i < tensorParallelDegree; ++i) {
for (int i = 1; i < localDevicesPerWorker; ++i) {
sb.append(',').append(devices[deviceId + i]);
}
} else {
sb.append(deviceId);
for (int i = 1; i < tensorParallelDegree; ++i) {
for (int i = 1; i < localDevicesPerWorker; ++i) {
sb.append(',').append(deviceId + i);
}
}
Expand Down
7 changes: 5 additions & 2 deletions engines/python/src/main/java/ai/djl/python/engine/PyEnv.java
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ public int getTensorParallelDegree() {
if (tensorParallelDegree == 0) {
String value = Utils.getenv("TENSOR_PARALLEL_DEGREE");
if ("max".equals(value)) {
tensorParallelDegree = getDefaultTensorParallelDegree();
tensorParallelDegree =
getDefaultTensorParallelDegree() / getPipelineParallelDegree();
} else if (value != null) {
tensorParallelDegree = Integer.parseInt(value);
}
Expand All @@ -350,7 +351,7 @@ public int getTensorParallelDegree() {
static int getDefaultTensorParallelDegree() {
int gpus = CudaUtils.getGpuCount();
if (gpus > 0) {
return gpus;
return gpus * clusterSize;
}
return NeuronUtils.getNeuronCores();
}
Expand All @@ -375,6 +376,8 @@ public int getPipelineParallelDegree() {
if (value != null) {
pipelineParallelDegree = Integer.parseInt(value);
} else {
// TODO: Use clusterSize as default value of pipelineParallelDegree, but only when
// supported
pipelineParallelDegree = 1;
}
}
Expand Down
37 changes: 27 additions & 10 deletions engines/python/src/main/java/ai/djl/python/engine/PyModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Translator;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -79,6 +78,9 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
String entryPoint = null;
String recommendedEntryPoint = null;
if (options != null) {
// If tp_degree set to "max", we defer and set it at the end to ensure we take pp degree
// into account.
boolean setTensorParallelDegreeToMax = false;
logger.debug("options in serving.properties for model: {}", modelName);
for (Map.Entry<String, ?> entry : options.entrySet()) {
String key = entry.getKey();
Expand Down Expand Up @@ -125,7 +127,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
break;
case "tensor_parallel_degree":
if ("max".equals(value)) {
pyEnv.setTensorParallelDegree(PyEnv.getDefaultTensorParallelDegree());
setTensorParallelDegreeToMax = true;
} else {
pyEnv.setTensorParallelDegree(Integer.parseInt(value));
}
Expand All @@ -150,6 +152,12 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
break;
}
}

if (setTensorParallelDegreeToMax) {
int tpDegree =
PyEnv.getDefaultTensorParallelDegree() / pyEnv.getPipelineParallelDegree();
pyEnv.setTensorParallelDegree(tpDegree);
}
}

// MMS and TorchServe Bcc
Expand Down Expand Up @@ -206,15 +214,24 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
}

if (pyEnv.isMpiMode()) {
int partitions = pyEnv.getTensorParallelDegree();
int tpDegree = pyEnv.getTensorParallelDegree();
int ppDegree = pyEnv.getPipelineParallelDegree();
int partitions = tpDegree * ppDegree;
if (partitions == 0) {
partitions = CudaUtils.getGpuCount();
pyEnv.setTensorParallelDegree(partitions);
setProperty("tensor_parallel_degree", String.valueOf(partitions));
partitions = PyEnv.getDefaultTensorParallelDegree();
tpDegree = partitions / ppDegree;
pyEnv.setTensorParallelDegree(tpDegree);
setProperty("tensor_parallel_degree", String.valueOf(tpDegree));
setProperty("pipeline_parallel_degree", String.valueOf(ppDegree));
logger.info(
"No tensor parallel degree specified. Defaulting to all available GPUs.");
"No tensor parallel degree specified. Defaulting to use all available"
+ " GPUs.");
}
logger.info("Loading model in MPI mode with TP: {}.", partitions);
logger.info(
"Loading model in MPI mode with world size {} (TP {}, PP {}).",
partitions,
tpDegree,
ppDegree);

int mpiWorkers = pyEnv.getMpiWorkers();
if (mpiWorkers <= 0) {
Expand Down Expand Up @@ -306,7 +323,7 @@ private Path findModelFile(String prefix) {
return modelFile;
}

private void createAllPyProcesses(int mpiWorkers, int tp) {
private void createAllPyProcesses(int mpiWorkers, int worldSize) {
long begin = System.currentTimeMillis();
ExecutorService pool = null;
List<Future<?>> futures = new ArrayList<>();
Expand All @@ -317,7 +334,7 @@ private void createAllPyProcesses(int mpiWorkers, int tp) {
int deviceId = manager.getDevice().getDeviceId();
for (int i = 0; i < mpiWorkers; ++i) {
logger.debug("Pre-creating python worker: {} ", i);
PyProcess worker = new PyProcess(this, pyEnv, deviceId + i * tp);
PyProcess worker = new PyProcess(this, pyEnv, deviceId + i * worldSize);
workerQueue.offer(worker);
if (pool != null) {
logger.debug("Submitting to pool: {}", i);
Expand Down
33 changes: 15 additions & 18 deletions engines/python/src/main/java/ai/djl/python/engine/PyProcess.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PyProcess {
private volatile boolean modelLoaded; // NOPMD
private AtomicInteger restartCount;
private CompletableFuture<Void> restartFuture;
private boolean trtLlmMode;
private boolean passiveWorkersMode;

private static AtomicInteger counter = new AtomicInteger(0);

Expand All @@ -65,36 +65,33 @@ class PyProcess {
if (pyEnv.isMpiMode()) {
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
int pipelineParallelDegree = pyEnv.getPipelineParallelDegree();
int worldSize = tensorParallelDegree * pipelineParallelDegree;
int clusterSize = PyEnv.getClusterSize();
connections = new ArrayList<>(tensorParallelDegree * pipelineParallelDegree);
connections = new ArrayList<>(worldSize);

if (clusterSize > 1) {
hosts = getHosts(clusterSize);
for (int i = 0; i < tensorParallelDegree * pipelineParallelDegree; ++i) {
connections.add(
new Connection(
pyEnv,
port,
i,
hosts[
i
/ (tensorParallelDegree
* pipelineParallelDegree
/ clusterSize)]));
for (int i = 0; i < worldSize; ++i) {
int connectionsPerHost = worldSize / clusterSize;
connections.add(new Connection(pyEnv, port, i, hosts[i / connectionsPerHost]));
}
} else {
for (int i = 0; i < tensorParallelDegree * pipelineParallelDegree; ++i) {
for (int i = 0; i < worldSize; ++i) {
connections.add(new Connection(pyEnv, port, i, "127.0.0.1"));
}
}
counter.set(port + tensorParallelDegree);
counter.set(port + worldSize);
} else {
connections = Collections.singletonList(new Connection(pyEnv, port, -1, "127.0.0.1"));
}

restartCount = new AtomicInteger(0);
// TODO: avoid using this hack when TRT-LLM improve its behavior
trtLlmMode = "trtllm".equals(model.getProperty("rolling_batch"));
// Note: Now, by default, we use passive worker behavior in MPI mode.
// We can get the old behavior by setting OPTION_USE_PASSIVE_WORKERS=false.
passiveWorkersMode =
"trtllm".equals(model.getProperty("rolling_batch"))
|| Boolean.parseBoolean(model.getProperty("use_passive_workers", "true"));
}

Output predict(Input inputs, int timeout, boolean initialLoad) {
Expand All @@ -107,7 +104,7 @@ Output predict(Input inputs, int timeout, boolean initialLoad) {
}

List<CompletableFuture<Output>> futures = new ArrayList<>(connections.size());
if (initialLoad || !trtLlmMode) {
if (initialLoad || !passiveWorkersMode) {
for (Connection conn : connections) {
futures.add(conn.send(inputs));
}
Expand All @@ -116,7 +113,7 @@ Output predict(Input inputs, int timeout, boolean initialLoad) {
}

Output output = null;
if (trtLlmMode) {
if (passiveWorkersMode) {
output = futures.get(0).get(timeout, TimeUnit.SECONDS);
} else {
for (CompletableFuture<Output> future : futures) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ static void configure(Properties lmiProperties, LmiUtils.HuggingFaceModelConfig
setRollingBatch(lmiProperties, modelConfig, features);
setMpiMode(lmiProperties);
setHeuristicDefaults(lmiProperties, modelConfig);
setTensorParallelDegree(lmiProperties);
setPipelineParallelDegree(lmiProperties);
setTensorParallelDegree(lmiProperties);
setRollingBatchSize(lmiProperties);
setIsPeftModel(lmiProperties, modelConfig);
}
Expand Down Expand Up @@ -93,10 +93,12 @@ private static void setTensorParallelDegree(Properties lmiProperties) {
return;
}
String tpDegree = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max");
int ppDegree =
Integer.parseInt(lmiProperties.getProperty("option.pipeline_parallel_degree"));
if ("max".equals(tpDegree)) {
int numGpus = CudaUtils.getGpuCount();
if (numGpus > 0) {
tpDegree = String.valueOf(numGpus);
tpDegree = String.valueOf(numGpus / ppDegree);
} else if (NeuronUtils.hasNeuron()) {
int numAccelerators = NeuronUtils.getNeuronCores();
if (numAccelerators > 0) {
Expand Down
Loading

0 comments on commit bfcd7be

Please sign in to comment.