Skip to content

Commit

Permalink
[docker] separate vllm and lmi-dist modes into separate virtual envir… (
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Jan 31, 2025
1 parent 923a931 commit dc08153
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from djl_python.chat_completions.vllm_chat_properties import ChatProperties
from djl_python.properties_manager.properties import Properties
from djl_python.rolling_batch.rolling_batch_vllm_utils import maybe_serialize_tool_calls
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
parse_chat_messages,
resolve_chat_template_content_format)


def is_chat_completions_request(inputs: Dict) -> bool:
Expand Down Expand Up @@ -70,9 +70,16 @@ def parse_chat_completions_request_vllm(
tool_dicts = None if chat_params.tools is None else [
tool.model_dump() for tool in chat_params.tools
]
# TODO - figure out what we need to pass for given format
content_format = resolve_chat_template_content_format(
chat_template=None,
given_format="auto",
tokenizer=tokenizer,
)

conversation, mm_data = parse_chat_messages(
chat_params.messages, rolling_batch.get_model_config(), tokenizer)
chat_params.messages, rolling_batch.get_model_config(), tokenizer,
content_format)

prompt_data: Union[str, List[int]]
if is_mistral_tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,8 @@ def get_model_config(self):
return self.engine.preprocessor.model_config if not self.is_t5_model else None

def use_vllm_chat_completions(self):
return True

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
# an interface method and retrieve it from there after v12
return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None
# vllm chat parsing requires 0.7.0 currently, lmi-dist is on 0.6.3.post1
return False

def get_huggingface_model_config(self):
# TODO: this is a hack right now to get the model config from the engine. We should expose this as
Expand Down
32 changes: 32 additions & 0 deletions serving/docker/lmi-container-requirements-common.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
peft==0.13.2
protobuf==3.20.3
transformers==4.45.2
hf-transfer
zstandard
datasets==3.0.1
mpi4py
sentencepiece
tiktoken
blobfile
einops
accelerate==1.0.1
bitsandbytes==0.44.1
auto-gptq==0.7.1
pandas
pyarrow
jinja2
retrying
opencv-contrib-python-headless
safetensors
scipy
onnx
sentence_transformers
onnxruntime-gpu==1.20.0
autoawq==0.2.5
tokenizers==0.20.3
pydantic==2.9.2
optimum==1.23.2
torch==2.5.1
torchvision==0.20.1
# sequence scheduler wheel for hf accelerate rolling batch
https://publish.djl.ai/seq_scheduler/seq_scheduler-0.1.0-py3-none-any.whl
25 changes: 10 additions & 15 deletions serving/docker/lmi.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ COPY config.properties /opt/djl/conf/config.properties
COPY partition /opt/djl/partition
COPY scripts/telemetry.sh /opt/djl/bin

COPY distribution[s]/ ./
RUN mv *.deb djl-serving_all.deb || true

RUN apt-get update && apt-get install -yq libaio-dev libopenmpi-dev g++ unzip cuda-compat-12-4 \
&& scripts/install_openssh.sh \
&& scripts/install_python.sh ${python_version} \
Expand All @@ -84,24 +81,22 @@ RUN apt-get update && apt-get install -yq libaio-dev libopenmpi-dev g++ unzip cu
&& apt-get clean -y \
&& rm -rf /var/lib/apt/lists/*

COPY requirements-lmi.txt ./requirements.txt
RUN pip3 install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124 && pip3 cache purge
RUN pip3 install -r requirements.txt \
&& pip3 install ${djl_converter_wheel} --no-deps \
&& git clone https://github.com/neuralmagic/AutoFP8.git \
&& cd AutoFP8 \
&& git reset --hard 4b2092c \
&& pip3 install . \
&& cd .. \
&& rm -rf AutoFP8 \
&& pip3 cache purge

RUN scripts/patch_oss_dlc.sh python \
&& scripts/security_patch.sh lmi \
&& useradd -m -d /home/djl djl \
&& chown -R djl:djl /opt/djl \
&& apt-get clean -y && rm -rf /var/lib/apt/lists/*

COPY lmi-container-requirements-common.txt ./requirements-common.txt
COPY requirements-lmi.txt ./requirements-lmi.txt
COPY requirements-vllm.txt ./requirements-vllm.txt
RUN pip3 install -r requirements-common.txt \
&& scripts/create_virtual_env.sh /opt/djl/vllm_venv requirements-vllm.txt \
&& scripts/create_virtual_env.sh /opt/djl/lmi_dist_venv requirements-lmi.txt

COPY distribution[s]/ ./
RUN mv *.deb djl-serving_all.deb || true

RUN scripts/install_djl_serving.sh $djl_version $djl_serving_version ${djl_torch_version} \
&& djl-serving -i ai.djl.onnxruntime:onnxruntime-engine:$djl_version \
&& djl-serving -i com.microsoft.onnxruntime:onnxruntime_gpu:$djl_onnx_version
Expand Down
33 changes: 1 addition & 32 deletions serving/docker/requirements-lmi.txt
Original file line number Diff line number Diff line change
@@ -1,35 +1,4 @@
peft==0.13.2
protobuf==3.20.3
transformers==4.45.2
hf-transfer
zstandard
datasets==3.0.1
mpi4py
sentencepiece
tiktoken
blobfile
einops
accelerate==1.0.1
bitsandbytes==0.44.1
auto-gptq==0.7.1
pandas
pyarrow
jinja2
retrying
opencv-contrib-python-headless
safetensors
scipy
onnx
sentence_transformers
onnxruntime-gpu==1.20.0
autoawq==0.2.5
tokenizers==0.20.3
pydantic==2.9.2
optimum==1.23.2
torch==2.5.1
torchvision==0.20.1
# sequence scheduler wheel for hf accelerate rolling batch
https://publish.djl.ai/seq_scheduler/seq_scheduler-0.1.0-py3-none-any.whl
-r requirements-common.txt
# flash infer kernels for vllm/lmi-dist
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu124torch2.4-cp311-cp311-linux_x86_64.whl
# vllm wheel built with pt2.5.1
Expand Down
2 changes: 2 additions & 0 deletions serving/docker/requirements-vllm.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-r requirements-common.txt
vllm==0.7.0
18 changes: 18 additions & 0 deletions serving/docker/scripts/create_virtual_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env bash
# used in the dockerfiles to create virtualenvs per engine
# currently only intended for use in lmi.Dockerfile, need to refactor this to work for trtllm/neuron if needed
venv_directory=$1
requirements_file=$2

# This was copied over from the previous pip install defined in the lmi.Dockerfile, so it's specific to that Dockerfile
python -m venv --system-site-packages $venv_directory
venv_pip="${venv_directory}/bin/pip"
$venv_pip install -r $requirements_file
$venv_pip install https://publish.djl.ai/djl_converter/djl_converter-0.31.0-py3-none-any.whl --no-deps
git clone https://github.com/neuralmagic/AutoFP8.git
cd AutoFP8
git reset --hard 4b2092c
$venv_pip install .
cd ..
rm -rf AutoFP8
$venv_pip cache purge
16 changes: 16 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ static void configure(Properties lmiProperties, LmiUtils.HuggingFaceModelConfig
setRollingBatchSize(lmiProperties);
setIsPeftModel(lmiProperties, modelConfig);
setPropertiesForLora(lmiProperties);
setPythonExecutable(lmiProperties);
}

private static void setRollingBatch(
Expand Down Expand Up @@ -206,4 +207,19 @@ private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig mod
OPTIMIZED_TASK_ARCHITECTURES);
return false;
}

private static void setPythonExecutable(Properties lmiProperties) {
if (lmiProperties.containsKey("option.pythonExecutable")) {
return;
}
String rollingBatch = lmiProperties.getProperty("option.rolling_batch");
if ("vllm".equals(rollingBatch)) {
lmiProperties.setProperty("option.pythonExecutable", "/opt/djl/vllm_venv/bin/python");
return;
}
if ("lmi-dist".equals(rollingBatch)) {
lmiProperties.setProperty(
"option.pythonExecutable", "/opt/djl/lmi_dist_venv/bin/python");
}
}
}

0 comments on commit dc08153

Please sign in to comment.