Skip to content

Commit

Permalink
Support multi node for lmi-dist
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Jun 28, 2024
1 parent 7e18d6d commit 6db60b5
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 26 deletions.
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, args, service):

self.clean_up()
elif self.sock_type == "tcp":
self.sock_name = "127.0.0.1"
self.sock_name = socket.gethostname()
if self.port is None:
raise ValueError("Missing port argument.")
else:
Expand Down Expand Up @@ -97,6 +97,7 @@ def run_server(self):
self.sock.bind(self.sock_name)
else:
self.sock.bind((self.sock_name, int(self.port)))
logging.info(f"Socket bind on address: {self.sock_name}:{self.port}")

self.sock.listen(128)
logging.info("Python engine started.")
Expand Down
88 changes: 77 additions & 11 deletions engines/python/src/main/java/ai/djl/python/engine/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -66,18 +67,21 @@ class Connection {
private static final Logger logger = LoggerFactory.getLogger(Connection.class);
private static final String MASTER_ADDR = "127.0.0.1";

private int clusterSize;
private int port;
private SocketAddress socketAddress;
private Channel channel;
private RequestHandler requestHandler;

Connection(PyEnv pyEnv, int basePort, int rank) {
Connection(PyEnv pyEnv, int basePort, int rank, int clusterSize, String hostname) {
requestHandler = new RequestHandler();
this.clusterSize = clusterSize;
port = 19000 + basePort;
socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank);
socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank, clusterSize, hostname);
}

static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
static Process startPython(
PyEnv pyEnv, Model model, int workerId, int port, int clusterSize, String[] hosts)
throws IOException {
Path tmp = Paths.get(System.getProperty("java.io.tmpdir"));
try (Stream<Path> stream = Files.list(tmp)) {
Expand All @@ -100,7 +104,7 @@ static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
});
}
File modelPath = model.getModelPath().toFile();
String[] args = getPythonStartCmd(pyEnv, model, workerId, port);
String[] args = getPythonStartCmd(pyEnv, model, workerId, port, clusterSize, hosts);
String[] envp = pyEnv.getEnvironmentVars(model);
logger.debug("cmd: {}", (Object) args);

Expand All @@ -120,10 +124,60 @@ CompletableFuture<Output> send(Input input) throws InterruptedException {
return f;
}

static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int port) {
static String[] getPythonStartCmd(
PyEnv pyEnv, Model model, int workerId, int port, int clusterSize, String[] hosts) {
Device device = model.getNDManager().getDevice();
int deviceId = device.getDeviceId();
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
// int pipelineParallelDegree = pyEnv.getPipelineParallelDegree();
logger.info("Printing mpi boolean: {}", pyEnv.isMpiMode());

if (clusterSize > 1) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices);
String[] args = new String[36];
args[0] = "mpirun";
args[1] = "-np";
// TODO: When we support multi nodes, change it to the product of tensor parallel value
// and
// pipeline parallel value.
args[2] = String.valueOf(clusterSize * tensorParallelDegree);
args[3] = "--host";
args[4] = String.join(",", hosts);
args[5] = "--allow-run-as-root";
args[6] = "--bind-to";
args[7] = "none";
args[8] = "--mca";
args[9] = "btl_vader_single_copy_mechanism";
args[10] = "none";
args[11] = "--tag-output";
args[12] = "-x";
args[13] = "FI_PROVIDER=efa";
args[14] = "-x";
args[15] = "RDMAV_FORK_SAFE=1";
args[16] = "-x";
args[17] = "FI_EFA_USE_DEVICE_RDMA=1";
args[18] = "-x";
args[19] = "LD_LIBRARY_PATH";
args[20] = "-x";
args[21] = "PYTHONPATH";
args[22] = "-x";
args[23] = "CUDA_VISIBLE_DEVICES=" + cudaDevices;
args[24] = pyEnv.getPythonExecutable();
args[25] = PyEnv.getEngineCacheDir() + "/djl_python_engine.py";
args[26] = "--model-dir";
args[27] = model.getModelPath().toAbsolutePath().toString();
args[28] = "--entry-point";
args[29] = pyEnv.getEntryPoint();
args[30] = "--sock-type";
args[31] = "tcp";
args[32] = "--port";
args[33] = String.valueOf(port);
args[34] = "--tensor-parallel-degree";
args[35] = String.valueOf(tensorParallelDegree);
return args;
}

if (pyEnv.isMpiMode()) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices);
Expand Down Expand Up @@ -242,13 +296,15 @@ private static String getNeuronThreads(int tensorParallelDegree) {
return String.valueOf(1);
}

void connect() throws InterruptedException {
void connect() throws InterruptedException, UnknownHostException {
logger.info("Connecting to socket: {}", socketAddress);
EventLoopGroup group = PyEnv.getEventLoopGroup();

Bootstrap clientBootstrap = new Bootstrap();

clientBootstrap
.group(group)
.channel(getClientChannel())
.channel(getClientChannel(this.clusterSize))
.remoteAddress(socketAddress)
.handler(
new ChannelInitializer<>() {
Expand Down Expand Up @@ -289,7 +345,11 @@ private static String getSocketPath(int port) {
return System.getProperty("java.io.tmpdir") + "/djl_sock." + port;
}

private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
private SocketAddress getSocketAddress(
boolean mpiMode, int rank, int clusterSize, String hostname) {
if (clusterSize > 1) {
return new InetSocketAddress(hostname, port + rank);
}
if (mpiMode) {
return new DomainSocketAddress(getSocketPath(port) + '.' + rank);
}
Expand All @@ -300,17 +360,23 @@ private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
return new InetSocketAddress("127.0.0.1", port);
}

static EventLoopGroup newEventLoopGroup() {
static EventLoopGroup newEventLoopGroup(int clusterSize) {
if (clusterSize > 1) {
return new NioEventLoopGroup(new DaemonThreadFactory());
}
if (Epoll.isAvailable()) {
return new EpollEventLoopGroup(new DaemonThreadFactory());
} else if (KQueue.isAvailable()) {
return new KQueueEventLoopGroup(new DaemonThreadFactory());
}

return new NioEventLoopGroup(new DaemonThreadFactory());
}

private static Class<? extends Channel> getClientChannel() {
private static Class<? extends Channel> getClientChannel(int clusterSize) {

if (clusterSize > 1) {
return NioSocketChannel.class;
}
if (Epoll.isAvailable()) {
return EpollDomainSocketChannel.class;
} else if (KQueue.isAvailable()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ public final class PyEngine extends Engine {

private String engineName;
private boolean mpiMode;

private int clusterSize;
private Engine alternativeEngine;
private boolean initialized;

PyEngine(String engineName, boolean mpiMode) {
PyEngine(String engineName, boolean mpiMode, int clusterSize) {
this.engineName = engineName;
this.mpiMode = mpiMode;
this.clusterSize = clusterSize;
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -98,4 +101,8 @@ public NDManager newBaseManager(Device device) {
boolean isMpiMode() {
return mpiMode;
}

int getClusterSize() {
return clusterSize;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.engine.Engine;
import ai.djl.engine.EngineProvider;
import ai.djl.util.Utils;

/** {@code PyEngineProvider} is the Python implementation of {@link EngineProvider}. */
public class PyEngineProvider implements EngineProvider {
Expand Down Expand Up @@ -43,8 +44,9 @@ public Engine getEngine() {
synchronized (this) {
if (!initialized) {
initialized = true;
PyEnv.init();
engine = new PyEngine(getEngineName(), mpiMode);
int clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE", "-1"));
PyEnv.init(clusterSize);
engine = new PyEngine(getEngineName(), mpiMode, clusterSize);
}
}
}
Expand Down
26 changes: 24 additions & 2 deletions engines/python/src/main/java/ai/djl/python/engine/PyEnv.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class PyEnv {
private int predictTimeout;
private int modelLoadingTimeout;
private int tensorParallelDegree;
private int pipelineParallelDegree;
private Map<String, String> envs;
private Map<String, String> initParameters;
private boolean initialized;
Expand All @@ -78,12 +79,12 @@ public PyEnv(boolean mpiMode) {
initParameters = new ConcurrentHashMap<>();
}

static synchronized void init() {
static synchronized void init(int clusterSize) {
if (eventLoopGroup != null) {
return;
}

eventLoopGroup = Connection.newEventLoopGroup();
eventLoopGroup = Connection.newEventLoopGroup(clusterSize);

Path tmp = null;
try {
Expand Down Expand Up @@ -320,6 +321,24 @@ public int getTensorParallelDegree() {
return tensorParallelDegree;
}

/**
* Returns the pipeline parallel degree.
*
* @return the pipeline parallel degree
*/
public int getPipelineParallelDegree() {
if (pipelineParallelDegree == 0) {
String value = Utils.getenv("PIPELINE_PARALLEL_DEGREE");
if (value != null) {
pipelineParallelDegree = Integer.parseInt(value);
} else {
pipelineParallelDegree = 1;
}
}

return pipelineParallelDegree;
}

static int getDefaultTensorParallelDegree() {
int gpus = CudaUtils.getGpuCount();
if (gpus > 0) {
Expand Down Expand Up @@ -347,6 +366,9 @@ int getMpiWorkers() {
}
gpuCount = visibleCount;
}

// return 1

return gpuCount / getTensorParallelDegree();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class PyModel extends BaseModel {

private static final Logger logger = LoggerFactory.getLogger(PyModel.class);

private int clusterSize;
private PyEnv pyEnv;
private boolean parallelLoading;
private LinkedBlockingDeque<PyProcess> workerQueue;
Expand All @@ -63,6 +64,7 @@ public class PyModel extends BaseModel {
this.manager = manager;
this.manager.setName("pythonModel");
boolean mpiMode = ((PyEngine) manager.getEngine()).isMpiMode();
clusterSize = ((PyEngine) manager.getEngine()).getClusterSize();
pyEnv = new PyEnv(mpiMode);
dataType = DataType.FLOAT32;
workerQueue = new LinkedBlockingDeque<>();
Expand Down Expand Up @@ -256,7 +258,7 @@ public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator, Device d
}
return new PyPredictor<>(this, workerQueue.poll(), timeout, translator, device);
}
PyProcess worker = new PyProcess(this, pyEnv, -1);
PyProcess worker = new PyProcess(this, pyEnv, -1, clusterSize);
worker.startPythonProcess();
return new PyPredictor<>(this, worker, timeout, translator, device);
}
Expand Down Expand Up @@ -305,7 +307,7 @@ private void createAllPyProcesses(int mpiWorkers, int tp) {
int deviceId = manager.getDevice().getDeviceId();
for (int i = 0; i < mpiWorkers; ++i) {
logger.debug("Pre-creating python worker: {} ", i);
PyProcess worker = new PyProcess(this, pyEnv, deviceId + i * tp);
PyProcess worker = new PyProcess(this, pyEnv, deviceId + i * tp, clusterSize);
workerQueue.offer(worker);
if (pool != null) {
logger.debug("Submitting to pool: {}", i);
Expand Down
Loading

0 comments on commit 6db60b5

Please sign in to comment.