Skip to content

Commit

Permalink
Code cleanup part 2 (#59)
Browse files Browse the repository at this point in the history
This PR cleans and simplifies the code.

### Changes:

- simplified warmup by using a function call to remove duplicated lines
- moving mask and position_ids from `SENDNNCasualLM` to
`SENDNNModelRunner`
- fixing error in pyproject.toml 
- already merged PR #52 and main into this branch for easier merge.

The code has been in client/server mode for the `llama 194m` and
`granite 3b` on `AIU` and `CPU`.
  • Loading branch information
yannicks1 authored and GitHub Enterprise committed Nov 1, 2024
1 parent 2a0fb3f commit 9728380
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 140 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ requires = [
"packaging",
"setuptools>=61",
"setuptools-scm>=8.0",
"torch == 2.4.0", platform_machine!='s390x'",
"torch == 2.4.0",
"wheel",
"jinja2",
]
Expand Down
36 changes: 1 addition & 35 deletions vllm/model_executor/model_loader/sendnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,47 +65,13 @@ def __init__(
logits_as_input=True)
self.sampler = Sampler()
self.past_key_value_states = None
# key: request_id, value: position_ids of sequence
self.position_ids = dict()
# key: request_id, value: attention mask of sequence
self.mask = dict()
self.dtype = torch.float16 if envs.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn_decoder' else torch.float32
# number of added padding sequences to fill batch to warmed up batch size
self.num_padded_sequences = 0

self.dtype = torch.float16 if envs.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn_decoder' else torch.float32

# Lazy initialized
self.model: nn.Module

def update_mask(self, request_id) -> None:
"""Updating/extending the attention masks of a sequence in a SequenceGroup. Will be called in decoding phase"""

assert self.mask[request_id] is not None
masks = self.mask[request_id]

# expand batch dimension (batch size 1) during inference to use the same function for inference and warmup
is_decoding = False
if len(masks.shape) == 2:
masks = masks.unsqueeze(0)
is_decoding = True

masks_new = []
for mask in masks:
# get the last row of the 3d mask
mask_new = mask[-1:, :]

# extend the mask one slot
mask_new = torch.cat((mask_new, torch.zeros(1, 1, dtype=mask_new.dtype, device=mask_new.device),),dim=1,)
masks_new.append(mask_new)

masks_new_stacked = torch.stack(masks_new, dim=0)

# collaps batch dimension again for decoding phase (scheduler handles batch dimensions there)
if is_decoding:
masks_new_stacked = masks_new_stacked.squeeze(0)

self.mask[request_id] = masks_new_stacked


def forward(
self,
Expand Down
77 changes: 55 additions & 22 deletions vllm/worker/sendnn_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def __init__(
self._num_decode_tokens = [20]
self._batch_sizes = [1]
self._padded_batch_size = self._batch_sizes[0] # will be set accordingly in prefill phase

# key: request_id, value: position_ids of sequence
self._position_ids = dict()
# key: request_id, value: attention mask of sequence
self._mask = dict()

# Lazy initialization.
self.model: nn.Module # initialize after load_model.

Expand Down Expand Up @@ -132,18 +136,18 @@ def _prepare_prompt(
prompt_token_padded = prompt_token_padded_tensor.tolist()[0]

# set padded position ids for request_id
self.model.position_ids[request_id] = padding_kwargs['position_ids'][0].tolist() # there is only one dummy batch dimension
self._position_ids[request_id] = padding_kwargs['position_ids'][0].tolist() # there is only one dummy batch dimension
# set padding attention mask for request_id
self.model.mask[request_id] = padding_kwargs['mask'][0] # there is only one dummy batch dimension
self._mask[request_id] = padding_kwargs['mask'][0] # there is only one dummy batch dimension

input_tokens.append(prompt_token_padded)

seq_len = len(prompt_token_padded)
seq_lens.append(seq_len)

input_positions.append(self.model.position_ids[request_id])
input_positions.append(self._position_ids[request_id])

input_masks.append(self.model.mask[request_id])
input_masks.append(self._mask[request_id])

assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
Expand Down Expand Up @@ -185,20 +189,20 @@ def _prepare_prompt(
input_tokens_pad = input_tokens_pad_tensor.tolist()[0]

# set padded position ids for request_id ='padding_request_id'
self.model.position_ids['padding_request_id'] = padding_kwargs_pad['position_ids'][0].tolist() # there is only one dummy batch dimension
self._position_ids['padding_request_id'] = padding_kwargs_pad['position_ids'][0].tolist() # there is only one dummy batch dimension

# set padding attention mask for request_id = 'padding_request_id'
self.model.mask['padding_request_id'] = padding_kwargs_pad['mask'][0] # there is only one dummy batch dimension
self._mask['padding_request_id'] = padding_kwargs_pad['mask'][0] # there is only one dummy batch dimension

# append needed batch dimensions
for i in range(num_batch_pads):
# token ids
input_tokens.append(input_tokens_pad)
seq_lens.append(max_seq_len)
# position ids
input_positions.append(self.model.position_ids['padding_request_id'])
input_positions.append(self._position_ids['padding_request_id'])
# masks
input_masks.append(self.model.mask['padding_request_id'])
input_masks.append(self._mask['padding_request_id'])
# block ids: no usage on AIU yet
input_block_ids.append(0)
# increase padded batches counter
Expand Down Expand Up @@ -246,12 +250,12 @@ def _prepare_decode(

seq_len = seq_data.get_len()

position_id = self.model.position_ids[request_id][-1] + 1
self.model.position_ids[request_id] = self.model.position_ids[request_id] + [position_id] # append new position to sequence
position_id = self._position_ids[request_id][-1] + 1
self._position_ids[request_id] = self._position_ids[request_id] + [position_id] # append new position to sequence
input_positions.append([position_id])

self.model.update_mask(request_id)
input_masks.append(self.model.mask[request_id])
self._update_mask(request_id)
input_masks.append(self._mask[request_id])

context_lens.append(seq_len)

Expand All @@ -264,8 +268,8 @@ def _prepare_decode(
# TODO ysc: add condition when reaching eos token.
if seq_data.get_output_len() == seq_group_metadata.sampling_params.max_tokens - 1:
# delete attention mask and position ids for corresponding request_id
del self.model.mask[request_id]
del self.model.position_ids[request_id]
del self._mask[request_id]
del self._position_ids[request_id]

actual_batch_size = len(seq_group_metadata_list)
# getting batch size we padded to in prefill stage
Expand All @@ -278,10 +282,10 @@ def _prepare_decode(

# token_ids and position_ids
token_id_pad = [0]
position_id_pad = [self.model.position_ids['padding_request_id'][-1] + 1]
position_id_pad = [self._position_ids['padding_request_id'][-1] + 1]
# update position ids and mask
self.model.position_ids['padding_request_id'] = self.model.position_ids['padding_request_id'] + position_id_pad
self.model.update_mask('padding_request_id')
self._position_ids['padding_request_id'] = self._position_ids['padding_request_id'] + position_id_pad
self._update_mask('padding_request_id')

# append needed batch dimensions
for i in range(num_batch_pads):
Expand All @@ -290,17 +294,17 @@ def _prepare_decode(
# position ids
input_positions.append(position_id_pad)
# masks
input_masks.append(self.model.mask['padding_request_id'])
input_masks.append(self._mask['padding_request_id'])
# why is this here, it has no effect?
context_lens.append(0) # padding sequence has context length 0
# block ids: no usage on AIU yet
input_block_ids.append(0)

# delete attention masks and position ids of batch padding in last decoding step to free memory
if len(self.model.mask) == 1 and len(self.model.position_ids) == 1:
if len(self._mask) == 1 and len(self._position_ids) == 1:
# if batch padding was applied and there is only one remaining entry -> end of decoding -> delete padding entry
del self.model.mask['padding_request_id']
del self.model.position_ids['padding_request_id']
del self._mask['padding_request_id']
del self._position_ids['padding_request_id']

input_tokens = make_tensor_with_pad(input_tokens,
pad=0,
Expand All @@ -323,6 +327,35 @@ def _prepare_decode(

return input_tokens, input_positions, input_masks, input_block_ids

def _update_mask(self, request_id) -> None:
"""Updating/extending the attention masks of a sequence in a SequenceGroup. Will be called in decoding phase"""

assert self._mask[request_id] is not None
masks = self._mask[request_id]

# expand batch dimension (batch size 1) during inference to use the same function for inference and warmup
is_decoding = False
if len(masks.shape) == 2:
masks = masks.unsqueeze(0)
is_decoding = True

masks_new = []
for mask in masks:
# get the last row of the 3d mask
mask_new = mask[-1:, :]

# extend the mask one slot
mask_new = torch.cat((mask_new, torch.zeros(1, 1, dtype=mask_new.dtype, device=mask_new.device),),dim=1,)
masks_new.append(mask_new)

masks_new_stacked = torch.stack(masks_new, dim=0)

# collaps batch dimension again for decoding phase (scheduler handles batch dimensions there)
if is_decoding:
masks_new_stacked = masks_new_stacked.squeeze(0)

self._mask[request_id] = masks_new_stacked

def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down
Loading

0 comments on commit 9728380

Please sign in to comment.