From 7d63ed77241eafcd1215a683b3e201a3dbdc75c5 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 28 Jun 2024 10:19:32 -0700 Subject: [PATCH] Support multi node for lmi-dist --- engines/python/setup/djl_python_engine.py | 3 +- .../java/ai/djl/python/engine/Connection.java | 87 ++++++++++++++++--- .../java/ai/djl/python/engine/PyEngine.java | 9 +- .../djl/python/engine/PyEngineProvider.java | 6 +- .../main/java/ai/djl/python/engine/PyEnv.java | 26 +++++- .../java/ai/djl/python/engine/PyModel.java | 6 +- .../java/ai/djl/python/engine/PyProcess.java | 58 +++++++++++-- 7 files changed, 169 insertions(+), 26 deletions(-) diff --git a/engines/python/setup/djl_python_engine.py b/engines/python/setup/djl_python_engine.py index 404070ac30..db8370236f 100644 --- a/engines/python/setup/djl_python_engine.py +++ b/engines/python/setup/djl_python_engine.py @@ -58,7 +58,7 @@ def __init__(self, args, service): self.clean_up() elif self.sock_type == "tcp": - self.sock_name = "127.0.0.1" + self.sock_name = socket.gethostname() if self.port is None: raise ValueError("Missing port argument.") else: @@ -97,6 +97,7 @@ def run_server(self): self.sock.bind(self.sock_name) else: self.sock.bind((self.sock_name, int(self.port))) + logging.info(f"Socket bind on address: {self.sock_name}:{self.port}") self.sock.listen(128) logging.info("Python engine started.") diff --git a/engines/python/src/main/java/ai/djl/python/engine/Connection.java b/engines/python/src/main/java/ai/djl/python/engine/Connection.java index e95879ae7b..4eee68f3bf 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/Connection.java +++ b/engines/python/src/main/java/ai/djl/python/engine/Connection.java @@ -50,6 +50,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Path; @@ -66,18 +67,21 @@ class Connection { private static final Logger logger = LoggerFactory.getLogger(Connection.class); private static final String MASTER_ADDR = "127.0.0.1"; + private int clusterSize; private int port; private SocketAddress socketAddress; private Channel channel; private RequestHandler requestHandler; - Connection(PyEnv pyEnv, int basePort, int rank) { + Connection(PyEnv pyEnv, int basePort, int rank, int clusterSize, String hostname) { requestHandler = new RequestHandler(); + this.clusterSize = clusterSize; port = 19000 + basePort; - socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank); + socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank, clusterSize, hostname); } - static Process startPython(PyEnv pyEnv, Model model, int workerId, int port) + static Process startPython( + PyEnv pyEnv, Model model, int workerId, int port, int clusterSize, String[] hosts) throws IOException { Path tmp = Paths.get(System.getProperty("java.io.tmpdir")); try (Stream stream = Files.list(tmp)) { @@ -100,7 +104,7 @@ static Process startPython(PyEnv pyEnv, Model model, int workerId, int port) }); } File modelPath = model.getModelPath().toFile(); - String[] args = getPythonStartCmd(pyEnv, model, workerId, port); + String[] args = getPythonStartCmd(pyEnv, model, workerId, port, clusterSize, hosts); String[] envp = pyEnv.getEnvironmentVars(model); logger.debug("cmd: {}", (Object) args); @@ -120,10 +124,59 @@ CompletableFuture send(Input input) throws InterruptedException { return f; } - static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int port) { + static String[] getPythonStartCmd( + PyEnv pyEnv, Model model, int workerId, int port, int clusterSize, String[] hosts) { Device device = model.getNDManager().getDevice(); int deviceId = device.getDeviceId(); int tensorParallelDegree = pyEnv.getTensorParallelDegree(); + // int pipelineParallelDegree = pyEnv.getPipelineParallelDegree(); + + if (clusterSize > 1) { + String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree); + logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices); + String[] args = new String[36]; + args[0] = "mpirun"; + args[1] = "-np"; + // TODO: When we support multi nodes, change it to the product of tensor parallel value + // and + // pipeline parallel value. + args[2] = String.valueOf(clusterSize * tensorParallelDegree); + args[3] = "--host"; + args[4] = String.join(",", hosts); + args[5] = "--allow-run-as-root"; + args[6] = "--bind-to"; + args[7] = "none"; + args[8] = "--mca"; + args[9] = "btl_vader_single_copy_mechanism"; + args[10] = "none"; + args[11] = "--tag-output"; + args[12] = "-x"; + args[13] = "FI_PROVIDER=efa"; + args[14] = "-x"; + args[15] = "RDMAV_FORK_SAFE=1"; + args[16] = "-x"; + args[17] = "FI_EFA_USE_DEVICE_RDMA=1"; + args[18] = "-x"; + args[19] = "LD_LIBRARY_PATH"; + args[20] = "-x"; + args[21] = "PYTHONPATH"; + args[22] = "-x"; + args[23] = "CUDA_VISIBLE_DEVICES=" + cudaDevices; + args[24] = pyEnv.getPythonExecutable(); + args[25] = PyEnv.getEngineCacheDir() + "/djl_python_engine.py"; + args[26] = "--model-dir"; + args[27] = model.getModelPath().toAbsolutePath().toString(); + args[28] = "--entry-point"; + args[29] = pyEnv.getEntryPoint(); + args[30] = "--sock-type"; + args[31] = "tcp"; + args[32] = "--port"; + args[33] = String.valueOf(port); + args[34] = "--tensor-parallel-degree"; + args[35] = String.valueOf(tensorParallelDegree); + return args; + } + if (pyEnv.isMpiMode()) { String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree); logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices); @@ -242,13 +295,15 @@ private static String getNeuronThreads(int tensorParallelDegree) { return String.valueOf(1); } - void connect() throws InterruptedException { + void connect() throws InterruptedException, UnknownHostException { + logger.info("Connecting to socket: {}", socketAddress); EventLoopGroup group = PyEnv.getEventLoopGroup(); Bootstrap clientBootstrap = new Bootstrap(); + clientBootstrap .group(group) - .channel(getClientChannel()) + .channel(getClientChannel(this.clusterSize)) .remoteAddress(socketAddress) .handler( new ChannelInitializer<>() { @@ -289,7 +344,11 @@ private static String getSocketPath(int port) { return System.getProperty("java.io.tmpdir") + "/djl_sock." + port; } - private SocketAddress getSocketAddress(boolean mpiMode, int rank) { + private SocketAddress getSocketAddress( + boolean mpiMode, int rank, int clusterSize, String hostname) { + if (clusterSize > 1) { + return new InetSocketAddress(hostname, port + rank); + } if (mpiMode) { return new DomainSocketAddress(getSocketPath(port) + '.' + rank); } @@ -300,17 +359,23 @@ private SocketAddress getSocketAddress(boolean mpiMode, int rank) { return new InetSocketAddress("127.0.0.1", port); } - static EventLoopGroup newEventLoopGroup() { + static EventLoopGroup newEventLoopGroup(int clusterSize) { + if (clusterSize > 1) { + return new NioEventLoopGroup(new DaemonThreadFactory()); + } if (Epoll.isAvailable()) { return new EpollEventLoopGroup(new DaemonThreadFactory()); } else if (KQueue.isAvailable()) { return new KQueueEventLoopGroup(new DaemonThreadFactory()); } - return new NioEventLoopGroup(new DaemonThreadFactory()); } - private static Class getClientChannel() { + private static Class getClientChannel(int clusterSize) { + + if (clusterSize > 1) { + return NioSocketChannel.class; + } if (Epoll.isAvailable()) { return EpollDomainSocketChannel.class; } else if (KQueue.isAvailable()) { diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyEngine.java b/engines/python/src/main/java/ai/djl/python/engine/PyEngine.java index 65a437649e..25f29bda10 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyEngine.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyEngine.java @@ -26,12 +26,15 @@ public final class PyEngine extends Engine { private String engineName; private boolean mpiMode; + + private int clusterSize; private Engine alternativeEngine; private boolean initialized; - PyEngine(String engineName, boolean mpiMode) { + PyEngine(String engineName, boolean mpiMode, int clusterSize) { this.engineName = engineName; this.mpiMode = mpiMode; + this.clusterSize = clusterSize; } /** {@inheritDoc} */ @@ -98,4 +101,8 @@ public NDManager newBaseManager(Device device) { boolean isMpiMode() { return mpiMode; } + + int getClusterSize() { + return clusterSize; + } } diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyEngineProvider.java b/engines/python/src/main/java/ai/djl/python/engine/PyEngineProvider.java index 04d552f8ff..c06d5e40e7 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyEngineProvider.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyEngineProvider.java @@ -14,6 +14,7 @@ import ai.djl.engine.Engine; import ai.djl.engine.EngineProvider; +import ai.djl.util.Utils; /** {@code PyEngineProvider} is the Python implementation of {@link EngineProvider}. */ public class PyEngineProvider implements EngineProvider { @@ -43,8 +44,9 @@ public Engine getEngine() { synchronized (this) { if (!initialized) { initialized = true; - PyEnv.init(); - engine = new PyEngine(getEngineName(), mpiMode); + int clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE", "-1")); + PyEnv.init(clusterSize); + engine = new PyEngine(getEngineName(), mpiMode, clusterSize); } } } diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java index 57bc1ab63f..8ec6a57b79 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java @@ -53,6 +53,7 @@ public class PyEnv { private int predictTimeout; private int modelLoadingTimeout; private int tensorParallelDegree; + private int pipelineParallelDegree; private Map envs; private Map initParameters; private boolean initialized; @@ -78,12 +79,12 @@ public PyEnv(boolean mpiMode) { initParameters = new ConcurrentHashMap<>(); } - static synchronized void init() { + static synchronized void init(int clusterSize) { if (eventLoopGroup != null) { return; } - eventLoopGroup = Connection.newEventLoopGroup(); + eventLoopGroup = Connection.newEventLoopGroup(clusterSize); Path tmp = null; try { @@ -320,6 +321,24 @@ public int getTensorParallelDegree() { return tensorParallelDegree; } + /** + * Returns the pipeline parallel degree. + * + * @return the pipeline parallel degree + */ + public int getPipelineParallelDegree() { + if (pipelineParallelDegree == 0) { + String value = Utils.getenv("PIPELINE_PARALLEL_DEGREE"); + if (value != null) { + pipelineParallelDegree = Integer.parseInt(value); + } else { + pipelineParallelDegree = 1; + } + } + + return pipelineParallelDegree; + } + static int getDefaultTensorParallelDegree() { int gpus = CudaUtils.getGpuCount(); if (gpus > 0) { @@ -347,6 +366,9 @@ int getMpiWorkers() { } gpuCount = visibleCount; } + + // return 1 + return gpuCount / getTensorParallelDegree(); } diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index 3abb162aa8..fd5caf38e1 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -48,6 +48,7 @@ public class PyModel extends BaseModel { private static final Logger logger = LoggerFactory.getLogger(PyModel.class); + private int clusterSize; private PyEnv pyEnv; private boolean parallelLoading; private LinkedBlockingDeque workerQueue; @@ -63,6 +64,7 @@ public class PyModel extends BaseModel { this.manager = manager; this.manager.setName("pythonModel"); boolean mpiMode = ((PyEngine) manager.getEngine()).isMpiMode(); + clusterSize = ((PyEngine) manager.getEngine()).getClusterSize(); pyEnv = new PyEnv(mpiMode); dataType = DataType.FLOAT32; workerQueue = new LinkedBlockingDeque<>(); @@ -256,7 +258,7 @@ public Predictor newPredictor(Translator translator, Device d } return new PyPredictor<>(this, workerQueue.poll(), timeout, translator, device); } - PyProcess worker = new PyProcess(this, pyEnv, -1); + PyProcess worker = new PyProcess(this, pyEnv, -1, clusterSize); worker.startPythonProcess(); return new PyPredictor<>(this, worker, timeout, translator, device); } @@ -305,7 +307,7 @@ private void createAllPyProcesses(int mpiWorkers, int tp) { int deviceId = manager.getDevice().getDeviceId(); for (int i = 0; i < mpiWorkers; ++i) { logger.debug("Pre-creating python worker: {} ", i); - PyProcess worker = new PyProcess(this, pyEnv, deviceId + i * tp); + PyProcess worker = new PyProcess(this, pyEnv, deviceId + i * tp, clusterSize); workerQueue.offer(worker); if (pool != null) { logger.debug("Submitting to pool: {}", i); diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java index a31cbd4b0e..eb6a772596 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java @@ -17,6 +17,7 @@ import ai.djl.metric.Metric; import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,6 +43,9 @@ class PyProcess { private PyEnv pyEnv; private Model model; private int workerId; + + private int clusterSize; + private String[] hosts; private Process process; private String pid; private List connections; @@ -53,21 +57,43 @@ class PyProcess { private static AtomicInteger counter = new AtomicInteger(0); - PyProcess(Model model, PyEnv pyEnv, int workerId) { + PyProcess(Model model, PyEnv pyEnv, int workerId, int clusterSize) { this.model = model; this.pyEnv = pyEnv; this.workerId = workerId; + this.clusterSize = clusterSize; int port = counter.getAndIncrement(); - if (pyEnv.isMpiMode()) { + if (clusterSize > 0) { // Multi node + hosts = getHosts(); + int tensorParallelDegree = pyEnv.getTensorParallelDegree(); + connections = new ArrayList<>(hosts.length * tensorParallelDegree); + + for (int i = 0; i < clusterSize; ++i) { + for (int j = 0; j < tensorParallelDegree; ++j) { + connections.add( + new Connection( + pyEnv, + port, + i * tensorParallelDegree + j, + clusterSize, + hosts[i])); + } + } + counter.set(port + tensorParallelDegree); + } else if (pyEnv.isMpiMode()) { // Single node int tensorParallelDegree = pyEnv.getTensorParallelDegree(); connections = new ArrayList<>(tensorParallelDegree); - for (int i = 0; i < tensorParallelDegree; ++i) { - connections.add(new Connection(pyEnv, port, i)); + + for (int j = 0; j < tensorParallelDegree; ++j) { + connections.add( + new Connection(pyEnv, port, tensorParallelDegree + j, clusterSize, null)); } counter.set(port + tensorParallelDegree); } else { - connections = Collections.singletonList(new Connection(pyEnv, port, -1)); + connections = + Collections.singletonList(new Connection(pyEnv, port, -1, 1, "127.0.0.1")); } + restartCount = new AtomicInteger(0); // TODO: avoid using this hack when TRT-LLM improve its behavior trtLlmMode = "trtllm".equals(model.getProperty("rolling_batch")); @@ -132,12 +158,16 @@ Output predict(Input inputs, int timeout, boolean initialLoad) { } synchronized void startPythonProcess() { + // Do not start python process in worker nodes + if (Utils.getenv("WORKER_INDEX") != null) { + return; + } try { int id = restartCount.get(); int port = connections.get(0).getPort(); logger.info("Start process: {} - retry: {}", port, id); pyEnv.installDependency(model.getModelPath()); - process = Connection.startPython(pyEnv, model, workerId, port); + process = Connection.startPython(pyEnv, model, workerId, port, clusterSize, hosts); pid = process.toString().split(", ")[0].replace("Process[pid=", ""); String modelName = model.getName(); @@ -159,8 +189,8 @@ synchronized void startPythonProcess() { throw new IllegalThreadStateException( "Python stream closed unexpectedly, exit code: " + exitCode); } - for (Connection conn : connections) { + logger.warn("Trying to connect to processes"); conn.connect(); } @@ -187,6 +217,20 @@ synchronized void startPythonProcess() { } } + public static String[] getHosts() { + String leaderAddress = Utils.getenv("LWS_LEADER_ADDRESS"); + int clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE")); + String lwsName = Utils.getenv("LWS_NAME"); + String namespace = Utils.getenv("NAMESPACE"); + String groupIndex = Utils.getenv("GROUP_INDEX"); + String[] hosts = new String[clusterSize]; + hosts[0] = leaderAddress; + for (int i = 1; i < clusterSize; i++) { + hosts[i] = String.format("%s-%s-%d.%s.%s", lwsName, groupIndex, i, lwsName, namespace); + } + return hosts; + } + synchronized void stopPythonProcess(boolean error) { restartCount.getAndIncrement(); logger.info("Stop process: {}:{}, failure={}", workerId, pid, error);