Skip to content

Commit

Permalink
test changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Aug 18, 2024
1 parent 7341071 commit cdc74b7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
3 changes: 2 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def prev_rank(self):
def graph_capture(
self, graph_capture_context: Optional[GraphCaptureContext] = None):
if graph_capture_context is None:
stream = torch.cuda.Stream()
stream = torch.cuda.Stream(self.device)
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
Expand Down Expand Up @@ -905,6 +905,7 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
# world_size: int = 1
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)

Expand Down
13 changes: 7 additions & 6 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def __init__(

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with CudaMemoryProfiler() as m:
with CudaMemoryProfiler(self.device) as m:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
Expand Down Expand Up @@ -1206,12 +1206,12 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:

# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
input_tokens = torch.zeros(max_batch_size, dtype=torch.long, device=self.device)
input_positions = torch.zeros(max_batch_size, dtype=torch.long, device=self.device)
slot_mapping = torch.empty(max_batch_size, dtype=torch.long, device=self.device)
slot_mapping.fill_(_PAD_SLOT_ID)
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
seq_lens = torch.ones(max_batch_size, dtype=torch.int32, device=self.device)
block_tables = torch.from_numpy(self.graph_block_tables).to(self.device)
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
Expand Down Expand Up @@ -1669,6 +1669,7 @@ def capture(
torch.cuda.synchronize()

# Capture the graph.
# breakpoint()
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = self.model(
Expand Down
6 changes: 3 additions & 3 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ def init_device(self) -> None:
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
# torch.cuda.set_device(self.device)

_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
self.init_gpu_memory = torch.cuda.mem_get_info(self.device)[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
Expand Down Expand Up @@ -193,7 +193,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info(self.device)
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
Expand Down

0 comments on commit cdc74b7

Please sign in to comment.