diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9aea8a9c5..f5051eb86 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -109,7 +109,10 @@ jobs: cache-to: ${{ env.CACHE_TO }} push: ${{ github.event_name != 'pull_request' }} file: Dockerfile.ubi - + + - name: "List docker images" + run: docker images + - name: "Cleanup old cache images" uses: actions/delete-package-versions@v5 if: ${{ github.event_name == 'push' }} @@ -118,9 +121,6 @@ jobs: package-type: container delete-only-untagged-versions: true - - name: "List docker images" - run: docker images - - name: "Check disk usage" shell: bash run: | diff --git a/Dockerfile.ubi b/Dockerfile.ubi index 42156c4b5..22d0b8052 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -123,6 +123,11 @@ RUN microdnf install -y \ ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs" +# Workaround for https://github.com/openai/triton/issues/2507 and +# https://github.com/pytorch/pytorch/issues/107960 -- hopefully +# this won't be needed for future versions of this docker image +# or future versions of triton. +RUN ldconfig /usr/local/cuda-12.2/compat/ ## Development ################################################################# FROM cuda-devel AS dev @@ -181,8 +186,8 @@ RUN microdnf install -y \ && microdnf clean all ARG PYTHON_VERSION -# 0.3.3 is built for CUDA 12.1 and PyTorch 2.1.2 -ARG VLLM_WHEEL_VERSION=0.3.3 +# 0.4.0.post1 is built for CUDA 12.1 and PyTorch 2.1.2 +ARG VLLM_WHEEL_VERSION=0.4.0.post1 RUN curl -Lo vllm.whl https://github.com/vllm-project/vllm/releases/download/v${VLLM_WHEEL_VERSION}/vllm-${VLLM_WHEEL_VERSION}-cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}-manylinux1_x86_64.whl \ && unzip vllm.whl \ @@ -263,7 +268,7 @@ RUN umask 002 \ ## Release ##################################################################### # Note from the non-UBI Dockerfile: # We used base cuda image because pytorch installs its own cuda libraries. -# However cupy depends on cuda libraries so we had to switch to the runtime image +# However pynccl depends on cuda libraries so we had to switch to the runtime image # In the future it would be nice to get a container with pytorch and cuda without duplicating cuda FROM cuda-runtime AS vllm-openai @@ -280,7 +285,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # additional dependencies for the TGIS gRPC server grpcio-tools==1.62.1 \ # additional dependencies for openai api_server - accelerate==0.28.0 + accelerate==0.28.0 \ + # hf_transfer for faster HF hub downloads + hf_transfer==0.1.6 # Install flash attention (from pre-built wheel) RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ @@ -296,7 +303,8 @@ RUN microdnf install -y gcc \ ENV HF_HUB_OFFLINE=1 \ PORT=8000 \ GRPC_PORT=8033 \ - HOME=/home/vllm + HOME=/home/vllm \ + VLLM_USAGE_SOURCE=production-docker-image # setup non-root user for OpenShift RUN microdnf install -y shadow-utils \ diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index c2a263814..f1db62111 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -8,7 +8,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.tgis_utils.logits_processors import LengthPenaltyWarper +from vllm.tgis_utils.logits_processors import ExpDecayLengthPenaltyWarper from vllm.worker.model_runner import ModelRunner @@ -106,7 +106,7 @@ def test_exponential_decay_length_penalty(seed: int, device: str): logits_processor.scale = 1.0 eos_token_id = 100 - lenpen = LengthPenaltyWarper([2, 2.0], eos_token_id) + lenpen = ExpDecayLengthPenaltyWarper((2, 2.0), eos_token_id) seq_group_metadata_list = [] prompt_lens = [] diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index a7f6a5031..eb2150d03 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -35,7 +35,7 @@ from vllm.logger import init_logger from vllm.sequence import Logprob from vllm.tgis_utils import logs -from vllm.tgis_utils.logits_processors import (LengthPenaltyWarper, +from vllm.tgis_utils.logits_processors import (ExpDecayLengthPenaltyWarper, TypicalLogitsWarperWrapper) from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -319,6 +319,7 @@ async def _validate_and_convert_params( resp_options = params.response sampling = params.sampling stopping = params.stopping + decoding = params.decoding greedy = params.method == DecodingMethod.GREEDY max_new_tokens: Optional[int] = None @@ -338,9 +339,6 @@ async def _validate_and_convert_params( logprobs = with_default(logprobs, None) - # GAPS: - # - exp_decay_length_penalty - # NEW FUNCTION TO ADD (later) # - presence penalty, freq penalty # - min_p @@ -359,14 +357,15 @@ async def _validate_and_convert_params( if not greedy and 0.0 < sampling.typical_p < 1.0: logits_processors.append( TypicalLogitsWarperWrapper(mass=sampling.typical_p)) - if params.decoding.length_penalty is not None: - length_penalty = ( - params.decoding.length_penalty.start_index, - params.decoding.length_penalty.decay_factor, + + if decoding.HasField("length_penalty"): + length_penalty_tuple = ( + decoding.length_penalty.start_index, + decoding.length_penalty.decay_factor, ) logits_processors.append( - LengthPenaltyWarper(length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + ExpDecayLengthPenaltyWarper(length_penalty=length_penalty_tuple, + eos_token_id=self.tokenizer.eos_token_id)) time_limit_millis = stopping.time_limit_millis deadline = time.time( @@ -385,7 +384,7 @@ async def _validate_and_convert_params( top_p=with_default(sampling.top_p, 1.0), seed=sampling.seed if sampling.HasField("seed") else None, repetition_penalty=with_default( - params.decoding.repetition_penalty, 1.0), + decoding.repetition_penalty, 1.0), logits_processors=logits_processors, stop=with_default(stopping.stop_sequences, None), include_stop_str_in_output=stopping.include_stop_sequence diff --git a/vllm/entrypoints/grpc/validation.py b/vllm/entrypoints/grpc/validation.py index 308fef1d5..17920bdcd 100644 --- a/vllm/entrypoints/grpc/validation.py +++ b/vllm/entrypoints/grpc/validation.py @@ -71,13 +71,9 @@ def validate_params(params: Parameters, max_max_new_tokens: int): decoding = params.decoding # Decoding parameter checks - if decoding.HasField("length_penalty"): - args = [ - decoding.length_penalty.start_index, - decoding.length_penalty.decay_factor - ] - if None in args or not (1.0 <= args[1] <= 10.0): - TGISValidationError.LengthPenalty.error() + if decoding.HasField("length_penalty") and not ( + 1.0 <= decoding.length_penalty.decay_factor <= 10.0): + TGISValidationError.LengthPenalty.error() if not (0 <= decoding.repetition_penalty <= 2): # (a value of 0 means no penalty / unset) diff --git a/vllm/tgis_utils/logits_processors.py b/vllm/tgis_utils/logits_processors.py index 5ea20c442..2a8437e10 100644 --- a/vllm/tgis_utils/logits_processors.py +++ b/vllm/tgis_utils/logits_processors.py @@ -10,25 +10,26 @@ def __init__(self, mass: float): self.warper = TypicalLogitsWarper(mass=mass) def __call__(self, token_ids: List[int], - logits: torch.tensor) -> torch.tensor: + logits: torch.Tensor) -> torch.Tensor: # transformers warpers assume tensors of shape (batch_size, vocab_size) # and the typical warper doesn't use input_ids - return self.warper(input_ids=None, scores=logits.reshape((1, -1))) + return self.warper(input_ids=None, + scores=logits.reshape(1, -1)).flatten() -class LengthPenaltyWarper: +class ExpDecayLengthPenaltyWarper: def __init__(self, length_penalty: Tuple[int, float], eos_token_id: int): - self.length_penalty = length_penalty + self.start, self.penalty = length_penalty self.eos_token_id = eos_token_id def __call__(self, token_ids: List[int], - logits: torch.tensor) -> torch.tensor: - tokens_past = len(token_ids) - self.length_penalty[0] + logits: torch.Tensor) -> torch.Tensor: + tokens_past = len(token_ids) - self.start if tokens_past > 0: eos_logit = logits[self.eos_token_id] # To support negative logits we compute the penalty of the # absolute value and add to the original logit logits[self.eos_token_id] = eos_logit + torch.abs(eos_logit) * ( - pow(self.length_penalty[1], tokens_past) - 1) + pow(self.penalty, tokens_past) - 1) return logits