Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mypy checking on V1 code #11105

Merged
merged 31 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
536d641
v1: fix overloaded getitem/setitem in ConstantList
markmc Dec 11, 2024
d3c7c66
v1: fix already defined error in Request
markmc Dec 11, 2024
655707d
v1: add type hints to EngineCoreProc queues
markmc Dec 11, 2024
9885552
v1: ignore not defined errors with zmq.Socket/Context
markmc Dec 11, 2024
08fa203
v1: make Processor.process_inputs(arrival_time) optional
markmc Dec 11, 2024
0b64357
v1: add profile_async() to EngineCoreClient/AsyncMPClient
markmc Dec 11, 2024
37489ed
v1: add type hint to AsyncLLM.output_handler
markmc Dec 11, 2024
db1edfb
v1: fix type error with AsyncLLM.dead_error
markmc Dec 11, 2024
4325936
v1: make FreeKVCacheBlockQueue free_list_head/tail optional
markmc Dec 11, 2024
a2273c6
v1: fix reversed type hint in KVCacheManager
markmc Dec 11, 2024
e30d72a
v1: add new_blocks assertion in scheduler
markmc Dec 11, 2024
bf7bda4
v1: fix stop_token type hint in add_tokens()
markmc Dec 11, 2024
c3d8ba7
v1: copy flash atten forward pass assertion from v0
markmc Dec 11, 2024
9038742
v1: don't reuse req_data with different types
markmc Dec 11, 2024
854292b
v1: add req_id is not None assertion in GPUModelRunner
markmc Dec 11, 2024
632738b
v1: fix req_input_ids type
markmc Dec 11, 2024
df6c8af
v1: add some of req_id is not None assertions
markmc Dec 11, 2024
9dd22bc
v1: fix various mypy issues with MultiprocExecutor
markmc Dec 11, 2024
a7ea4a2
v1: tell mypy to ignore erros with msgspec.Struct kwargs
markmc Dec 11, 2024
c96d77f
[V1] Run mypy on
WoosukKwon Dec 6, 2024
b2d3729
v1: add type hints to LRUDictCache
markmc Dec 12, 2024
8e85f42
v1: remove return type from cache_hit_ratio()
markmc Dec 12, 2024
89a75fc
v1: mm_hashes type hints
markmc Dec 12, 2024
7839436
v1: add another MultiprocExecutor.workers assertion
markmc Dec 12, 2024
1c857e2
v1: make mm_hashes List[str] not List[Optional[str]]
markmc Dec 12, 2024
583385d
v1: fix block_hash mypy errors
markmc Dec 13, 2024
b8c03e9
v1: ignore duplicate keyword argument mypy errors
markmc Dec 13, 2024
4276cab
v1: Make MultiprocExecutor.workers non-optional
markmc Dec 13, 2024
fdf070d
v1: Make UniprocExecutor.worker non-optional
markmc Dec 14, 2024
e953f47
v1: fix MultiprocExecutor.collective_rpc() return type
markmc Dec 14, 2024
d140a50
v1: MPClient.proc-handle can be None
markmc Dec 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/v1
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

assert output is not None, "Output tensor must be provided."

if attn_metadata is None:
# Profiling run.
return output
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional

from vllm.logger import init_logger
from vllm.utils import cdiv
Expand Down Expand Up @@ -263,12 +263,13 @@ def free(self, request: Request) -> None:
"""
# Default to [] in case a request is freed (aborted) before alloc.
blocks = self.req_to_blocks.pop(request.request_id, [])
ordered_blocks: Iterable[KVCacheBlock] = blocks
if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are
# freed first.
blocks = reversed(blocks)
ordered_blocks = reversed(blocks)

for block in blocks:
for block in ordered_blocks:
block.decr_ref()
if block.ref_cnt == 0:
self.free_block_queue.append(block)
Expand Down Expand Up @@ -396,8 +397,7 @@ def _cache_full_blocks(
f"{request.request_id}({request})")

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
tuple(block_tokens))
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)

# Update and added the full block to the cache.
blk.block_hash = block_hash
Expand Down
17 changes: 9 additions & 8 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""KV-Cache Utilities."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple

Expand All @@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
collision happens when the hash value is the same.
"""
hash_value: int
token_ids: Tuple[int]
token_ids: Tuple[int, ...]


@dataclass
Expand Down Expand Up @@ -79,8 +80,8 @@ def __init__(self, blocks: List[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)

# Initialize the doubly linked list of free blocks.
self.free_list_head = blocks[0]
self.free_list_tail = blocks[-1]
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
Expand Down Expand Up @@ -159,7 +160,7 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]:


def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Tuple[int]) -> BlockHashType:
curr_block_token_ids: Sequence[int]) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
Expand All @@ -171,19 +172,19 @@ def hash_block_tokens(parent_block_hash: Optional[int],
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A tuple of token ids in the current
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.

Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
curr_block_token_ids)
tuple(curr_block_token_ids))


def hash_request_tokens(block_size: int,
token_ids: List[int]) -> List[BlockHashType]:
token_ids: Sequence[int]) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.

Expand All @@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = tuple(token_ids[start:end])
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def schedule(self) -> "SchedulerOutput":
break
if not can_schedule:
break
assert new_blocks is not None

# Schedule the request.
scheduled_running_reqs.append(request)
Expand Down
23 changes: 14 additions & 9 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,19 @@ class EngineCoreRequest:
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[Optional[str]]]
mm_hashes: Optional[List[str]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]


class EngineCoreOutput(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

request_id: str
new_token_ids: List[int]
Expand All @@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
stop_reason: Union[int, str, None] = None


class EngineCoreOutputs(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
Expand All @@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'


EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
11 changes: 6 additions & 5 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
asyncio_mode=True,
)

self.output_handler = None
self.output_handler: Optional[asyncio.Task] = None

def __del__(self):
self.shutdown()
Expand Down Expand Up @@ -126,7 +126,8 @@ def shutdown(self):
handler.cancel()

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
Expand Down Expand Up @@ -361,10 +362,10 @@ async def check_health(self) -> None:
logger.debug("Called check_health.")

async def start_profile(self) -> None:
await self.engine_core.profile(True)
await self.engine_core.profile_async(True)

async def stop_profile(self) -> None:
await self.engine_core.profile(False)
await self.engine_core.profile_async(False)

@property
def is_running(self) -> bool:
Expand All @@ -380,7 +381,7 @@ def errored(self) -> bool:

@property
def dead_error(self) -> BaseException:
return Exception
return Exception() # TODO: implement


# Retain V0 name for backwards compatibility.
Expand Down
20 changes: 10 additions & 10 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from dataclasses import dataclass
from multiprocessing.process import BaseProcess
from typing import List, Tuple, Type, Union
from typing import List, Tuple, Type

import zmq
import zmq.asyncio
Expand All @@ -20,7 +20,7 @@
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -97,8 +97,10 @@ def add_request(self, request: EngineCoreRequest):
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = (
self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes))

req = Request.from_engine_core_request(request)

Expand Down Expand Up @@ -128,7 +130,7 @@ def step(self) -> List[EngineCoreOutput]:
def shutdown(self):
self.model_executor.shutdown()

def profile(self, is_start=True):
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)


Expand Down Expand Up @@ -161,8 +163,8 @@ def __init__(
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
daemon=True).start()
Expand Down Expand Up @@ -318,9 +320,7 @@ def _log_stats(self):

self._last_logging_time = now

def _handle_client_request(
self, request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""

if isinstance(request, EngineCoreRequest):
Expand Down
43 changes: 24 additions & 19 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import atexit
import os
from typing import List, Union
from typing import List, Optional

import msgspec
import zmq
Expand All @@ -10,8 +10,9 @@
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
EngineCoreProcHandle)
from vllm.v1.serial_utils import PickleEncoder

logger = init_logger(__name__)
Expand Down Expand Up @@ -59,7 +60,7 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
raise NotImplementedError

def abort_requests(self, request_ids: List[str]) -> None:
Expand All @@ -71,6 +72,9 @@ async def get_output_async(self) -> List[EngineCoreOutput]:
async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError

async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -105,7 +109,7 @@ def shutdown(self):
def __del__(self):
self.shutdown()

def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)


Expand Down Expand Up @@ -133,7 +137,10 @@ def __init__(
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)

# ZMQ setup.
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
if asyncio_mode:
self.ctx = zmq.asyncio.Context()
else:
self.ctx = zmq.Context() # type: ignore[attr-defined]

# Path for IPC.
ready_path = get_open_zmq_ipc_path()
Expand All @@ -149,11 +156,13 @@ def __init__(
self.input_socket.bind(input_path)

# Start EngineCore in background process.
self.proc_handle: Optional[EngineCoreProcHandle]
self.proc_handle = EngineCoreProc.make_engine_core_process(
*args,
input_path=input_path,
output_path=output_path,
ready_path=ready_path,
input_path=
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
output_path=output_path, # type: ignore[misc]
ready_path=ready_path, # type: ignore[misc]
**kwargs,
)
atexit.register(self.shutdown)
Expand Down Expand Up @@ -204,10 +213,8 @@ def get_output(self) -> List[EngineCoreOutput]:
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs

def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:

# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
Expand All @@ -219,7 +226,7 @@ def add_request(self, request: EngineCoreRequest) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids)

def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))

Expand All @@ -237,10 +244,8 @@ async def get_output_async(self) -> List[EngineCoreOutput]:

return engine_core_outputs

async def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:

msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
Expand All @@ -252,6 +257,6 @@ async def abort_requests_async(self, request_ids: List[str]) -> None:
if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)

async def profile(self, is_start=True) -> None:
async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
Loading
Loading