Skip to content

Commit

Permalink
[serving][post 7/24] Fixes tensor_parallel_degree detection on CPU (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Jul 25, 2024
1 parent a9c32a1 commit b752c31
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ static String[] getPythonStartCmd(
String recommendedEntryPoint = pyEnv.getRecommendedEntryPoint();

if (PyEnv.isMultiNode()) {

int worldSize = tensorParallelDegree * pipelineParallelDegree;

if (tensorParallelDegree * pipelineParallelDegree % clusterSize != 0) {
throw new IllegalArgumentException(
"Error: Cannot use cluster size: "
Expand Down Expand Up @@ -211,9 +209,7 @@ static String[] getPythonStartCmd(
args[46] = "--recommended-entry-point";
args[47] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
return args;
}

if (pyEnv.isMpiMode()) {
} else if (pyEnv.isMpiMode()) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices);
String[] args = new String[42];
Expand Down
11 changes: 9 additions & 2 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,16 @@ private static void setTensorParallelDegree(Properties lmiProperties) {
}
String tpDegree = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max");
if ("max".equals(tpDegree)) {
tpDegree = String.valueOf(CudaUtils.getGpuCount());
int numGpus = CudaUtils.getGpuCount();
if (numGpus > 0) {
tpDegree = String.valueOf(numGpus);
} else {
tpDegree = null;
}
}
if (tpDegree != null) {
lmiProperties.setProperty("option.tensor_parallel_degree", tpDegree);
}
lmiProperties.setProperty("option.tensor_parallel_degree", tpDegree);
}

private static void setPipelineParallelDegree(Properties lmiProperties) {
Expand Down
2 changes: 1 addition & 1 deletion wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ static void configureLmiModel(ModelInfo<?, ?> modelInfo) throws ModelException {

LmiConfigRecommender.configure(modelInfo, prop, modelConfig);
logger.info(
"Detected mpi_mode: {}, rolling_batch: {}, tensor_parallel_degree {}, for"
"Detected mpi_mode: {}, rolling_batch: {}, tensor_parallel_degree: {}, for"
+ " modelType: {}",
prop.getProperty("option.mpi_mode"),
prop.getProperty("option.rolling_batch"),
Expand Down
4 changes: 4 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import ai.djl.serving.wlm.util.EventManager;
import ai.djl.serving.wlm.util.WlmConfigManager;
import ai.djl.serving.wlm.util.WlmOutOfMemoryException;
import ai.djl.translate.NoopServingTranslatorFactory;
import ai.djl.translate.TranslateException;
import ai.djl.util.NeuronUtils;
import ai.djl.util.Utils;
Expand Down Expand Up @@ -262,6 +263,9 @@ public void load(Device device) throws ModelException, IOException {
// override model_id
builder.optOption("model_id", downloadDir.toAbsolutePath().toString());
}
if (translator == null && translatorFactory == null && "Python".equals(engineName)) {
builder.optTranslatorFactory(new NoopServingTranslatorFactory());
}
ZooModel<I, O> m = builder.build().loadModel();
m.setProperty("metric_dimension", id);

Expand Down

0 comments on commit b752c31

Please sign in to comment.