Skip to content

Commit

Permalink
WIP Add test that exports to IREE a small sharded Llama model
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Sep 26, 2024
1 parent 2818131 commit 5fc031d
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 93 deletions.
8 changes: 5 additions & 3 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,11 @@ def compute_batch_mask(
Tensor of [bs, sl, 1, d] that will be later passed to apply_batch_mask.
"""
self.trace_tensor("rope.start_positions", start_positions)
positions_seq = torch.arange(0, batch_seq_len, device=self.device).unsqueeze(
0
) + start_positions.unsqueeze(1)
positions_seq = ops.elementwise(
torch.add,
torch.arange(0, batch_seq_len, device=self.device).unsqueeze(0),
start_positions.unsqueeze(1),
)
# Broadcast lookup to [b, ...].
self.trace_tensor("rope.positions_seq", positions_seq)
freqs_cis = self.rotary_embed_table[positions_seq]
Expand Down
9 changes: 8 additions & 1 deletion sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import itertools
from numbers import Number
import math
import numpy as np

from ..types import (
AnyTensor,
Expand Down Expand Up @@ -365,6 +364,14 @@ def elementwise_binary_replicated_lhs_unsharded_rhs(
return elementwise(operator, x, y_replicated, *args, **kwargs)


@elementwise.override(Tensor, ReplicatedTensor)
def elementwise_binary_replicated_lhs_unsharded_rhs(
operator, x: Tensor, y: ReplicatedTensor, *args, **kwargs
):
x_replicated = reshard_like(x, like=y)
return elementwise(operator, x_replicated, y, *args, **kwargs)


# Embedding Lookup
@embedding_lookup.override(ReplicatedTensor, ReplicatedTensor)
def embedding_lookup_default(
Expand Down
259 changes: 170 additions & 89 deletions sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,121 +5,118 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import unittest
from typing import Any, Dict, List, Tuple
from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
import sharktank.ops as ops
from sharktank.types import Dataset
from sharktank.models.llama.testing import make_random_llama_theta
from sharktank.models.llama.sharding import shard_theta
from sharktank.layers.configs import LlamaHParams
from sharktank.utils.math import round_up_to_multiple_of
from sharktank.utils import iterables_equal
import tempfile
import torch
from copy import deepcopy
from shark_turbine.aot import FxProgramsBuilder


class AttentionBlockTest(unittest.TestCase):
def testToyModelCompareToUnsharded(self):
"""Run a sharded variant of a toy model size and compare it against the
unsharded variant."""
class ShardedLlamaTest(unittest.TestCase):
def setUp(self):
torch.random.manual_seed(123456)
dtype = torch.float32
torch.set_default_dtype(dtype)
batch_size = 3
attention_head_count_kv = 4
attention_head_count = attention_head_count_kv * 5
vocabulary_size = 19
rope_dimension_count = 7 * 2
attn_head_dim = rope_dimension_count
block_seq_stride = 13
cache_page_count = 11
config = LlamaModelConfig(
self.dtype = torch.float32
torch.set_default_dtype(self.dtype)
self.batch_size = 3
self.attention_head_count_kv = 4
self.attention_head_count = self.attention_head_count_kv * 5
self.vocabulary_size = 19
self.rope_dimension_count = 7 * 2
self.attn_head_dim = self.rope_dimension_count
self.block_seq_stride = 13
self.cache_page_count = 11
self.config = LlamaModelConfig(
hp=LlamaHParams(
context_length=block_seq_stride * 2,
embedding_length=attention_head_count * attn_head_dim,
context_length=self.block_seq_stride * 2,
embedding_length=self.attention_head_count * self.attn_head_dim,
block_count=3,
feed_forward_length=23,
rope_dimension_count=rope_dimension_count,
rope_dimension_count=self.rope_dimension_count,
rope_freq_base=500000.0,
attention_head_count=attention_head_count,
attn_head_dim=attn_head_dim,
attention_head_count=self.attention_head_count,
attn_head_dim=self.attn_head_dim,
attention_layer_norm_rms_epsilon=0.01,
attention_head_count_kv=attention_head_count_kv,
attention_head_count_kv=self.attention_head_count_kv,
expert_count=0,
expert_used_count=0,
),
block_seq_stride=block_seq_stride,
activation_dtype=dtype,
attention_dtype=dtype,
block_seq_stride=self.block_seq_stride,
activation_dtype=self.dtype,
attention_dtype=self.dtype,
)
theta = make_random_llama_theta(
config=config,
vocab_size=vocabulary_size,
self.theta = make_random_llama_theta(
config=self.config,
vocab_size=self.vocabulary_size,
)

model = PagedLlamaModelV1(theta, config)
seq_lens = torch.randint(high=config.hp.context_length + 1, size=[batch_size])
seq_lens[batch_size - 1] = config.hp.context_length
cache_state = model.cache.paged.allocate(page_count=cache_page_count)
cache_state = [torch.rand_like(cache_state[0])]
cache_state_snapshot = deepcopy(cache_state)
def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
seq_lens = torch.randint(
high=self.config.hp.context_length + 1, size=[self.batch_size]
)
seq_lens[self.batch_size - 1] = self.config.hp.context_length
batch_seq_len = round_up_to_multiple_of(
int(torch.max(seq_lens)), model.cache.pad_sequence_stride
)
token_ids = torch.randint(
low=0,
high=vocabulary_size,
size=[batch_size, batch_seq_len],
high=self.vocabulary_size,
size=[self.batch_size, batch_seq_len],
dtype=torch.int32,
)
attention_mask = model.attention_mask(model.input_mask(seq_lens, batch_seq_len))
seq_block_ids = torch.arange(
batch_size * batch_seq_len // config.block_seq_stride
).view(batch_size, -1)
self.batch_size * batch_seq_len // self.config.block_seq_stride
).view(self.batch_size, -1)
cache_state = model.cache.paged.allocate(page_count=self.cache_page_count)
cache_state = [torch.rand_like(cache_state[0])]
return {
"tokens": token_ids,
"attention_mask": attention_mask,
"seq_block_ids": seq_block_ids,
"cache_state": cache_state,
}

# Verify prefill step.
sharded_config = deepcopy(config)
sharded_config.tensor_parallelism_size = 2
sharded_theta = shard_theta(theta, sharded_config)
sharded_model = PagedLlamaModelV1(sharded_theta, sharded_config)
def make_equal_unsharded_and_sharded_prefill_args(
self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
prefill_args = self.make_prefill_args(model)
sharded_cache_state = sharded_model.cache.paged.allocate(
page_count=cache_page_count
page_count=self.cache_page_count
)
assert iterables_equal(
prefill_args["cache_state"][0].shape, sharded_cache_state[0].shape
)
sharded_prefill_args = deepcopy(prefill_args)
sharded_cache_state = sharded_model.cache.paged.shard_state(
deepcopy(cache_state)
sharded_prefill_args["cache_state"]
)
sharded_prefill_args["cache_state"] = sharded_cache_state
return prefill_args, sharded_prefill_args

expected_prefill_result = model.prefill(
token_ids,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)
sharded_prefill_result = sharded_model.prefill(
token_ids,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=sharded_cache_state,
)
# The errors are quite high, but for float64 both errors drop to < 1e-12.
# The numerics are probably correct.
torch.testing.assert_close(
sharded_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2
def make_decode_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
seq_lens = torch.randint(
high=self.config.hp.context_length + 1, size=[self.batch_size]
)
expected_cache_state = cache_state[0]
actual_cache_state = ops.unshard(
sharded_model.cache.paged.unflatten_page_table(sharded_cache_state)
).flatten(start_dim=1)
torch.testing.assert_close(
actual_cache_state, expected_cache_state, atol=1e-4, rtol=1e-1
seq_lens[self.batch_size - 1] = self.config.hp.context_length
batch_seq_len = round_up_to_multiple_of(
int(torch.max(seq_lens)), model.cache.pad_sequence_stride
)

# Verify decode step.
decode_token_ids = torch.randint(
low=0,
high=vocabulary_size,
size=[batch_size, 1],
high=self.vocabulary_size,
size=[self.batch_size, 1],
dtype=torch.int32,
)
decode_seq_lens = torch.randint(
high=config.hp.context_length - 2, size=[batch_size]
high=self.config.hp.context_length - 2, size=[self.batch_size]
)
start_positions = decode_seq_lens + 1
decode_batch_seq_len = round_up_to_multiple_of(
Expand All @@ -128,32 +125,116 @@ def testToyModelCompareToUnsharded(self):
decode_attention_mask = model.decode_attention_mask(
model.input_mask(decode_seq_lens, decode_batch_seq_len)
)
decode_cache_state = deepcopy(cache_state_snapshot)
decode_sharded_cache_state = sharded_model.cache.paged.shard_state(
deepcopy(decode_cache_state)
)
expected_decode_result = model.decode(
decode_token_ids,
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=decode_cache_state,
seq_block_ids = torch.arange(
self.batch_size * batch_seq_len // self.config.block_seq_stride
).view(self.batch_size, -1)
cache_state = model.cache.paged.allocate(page_count=self.cache_page_count)
cache_state = [torch.rand_like(cache_state[0])]
return {
"tokens": decode_token_ids,
"attention_mask": decode_attention_mask,
"start_positions": start_positions,
"seq_block_ids": seq_block_ids,
"cache_state": cache_state,
}

def make_equal_unsharded_and_sharded_decode_args(
self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1
):
decode_args = self.make_decode_args(model)
sharded_decode_args = deepcopy(decode_args)
sharded_decode_args["cache_state"] = sharded_model.cache.paged.shard_state(
sharded_decode_args["cache_state"]
)
return decode_args, sharded_decode_args

def testCompareToySizedModelToUnsharded(self):
"""Run a sharded variant of a toy model size and compare it against the
unsharded variant."""
model = PagedLlamaModelV1(self.theta, self.config)
sharded_config = deepcopy(self.config)
sharded_config.tensor_parallelism_size = 2
sharded_theta = shard_theta(self.theta, sharded_config)
sharded_model = PagedLlamaModelV1(sharded_theta, sharded_config)

# Verify prefill step.
(
prefill_args,
sharded_prefill_args,
) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model)

expected_prefill_result = model.prefill(**prefill_args)
sharded_prefill_result = sharded_model.prefill(**sharded_prefill_args)
# The errors are quite high, but for float64 both errors drop to < 1e-12.
# The numerics are probably correct.
torch.testing.assert_close(
sharded_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2
)
sharded_decode_result = sharded_model.decode(
decode_token_ids,
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=decode_sharded_cache_state,
expected_cache_state = prefill_args["cache_state"][0]
actual_cache_state = ops.unshard(
sharded_model.cache.paged.unflatten_page_table(
sharded_prefill_args["cache_state"]
)
).flatten(start_dim=1)
torch.testing.assert_close(
actual_cache_state, expected_cache_state, atol=1e-4, rtol=1e-1
)

# Verify decode step.
(
decode_args,
sharded_decode_args,
) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model)
expected_decode_result = model.decode(**decode_args)
sharded_decode_result = sharded_model.decode(**sharded_decode_args)
torch.testing.assert_close(sharded_decode_result, expected_decode_result)
expected_decode_cache_state = decode_cache_state[0]
expected_decode_cache_state = decode_args["cache_state"][0]
actual_decode_cache_state = ops.unshard(
sharded_model.cache.paged.unflatten_page_table(decode_sharded_cache_state)
sharded_model.cache.paged.unflatten_page_table(
sharded_decode_args["cache_state"]
)
).flatten(start_dim=1)
# TODO: investigate why the Windows machine CI is producing a larger numerical
# error.
# The Ubuntu CI runs fine with default tolerances.
torch.testing.assert_close(
actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4
)

def testExportToySizedModelToIree(self):
with tempfile.TemporaryDirectory() as temp_dir:
sharded_config = deepcopy(self.config)
sharded_config.tensor_parallelism_size = 2
sharded_theta = shard_theta(self.theta, sharded_config)
sharded_theta.rename_tensors_to_paths()
sharded_dataset = Dataset({}, sharded_theta)
parameters_path = f"{temp_dir}/parameters.irpa"
sharded_dataset.save(f"{temp_dir}/parameters.irpa")
sharded_dataset = Dataset.load(parameters_path)

model = PagedLlamaModelV1(self.theta, self.config)
sharded_model = PagedLlamaModelV1(
sharded_dataset.root_theta, sharded_config
)
sharded_fxb = FxProgramsBuilder(sharded_model)

(
_,
sharded_prefill_args,
) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model)

@sharded_fxb.export_program(
name="sharded_llama_prefill", args=tuple(), kwargs=sharded_prefill_args
)
def _(model, *args, **kwargs) -> torch.Tensor:
return model.prefill(*args, **kwargs)

_, sharded_decode_args = self.make_equal_unsharded_and_sharded_decode_args(
model, sharded_model
)

@sharded_fxb.export_program(
name="sharded_llama_decode", args=tuple(), kwargs=sharded_decode_args
)
def _(model, *args, **kwargs) -> torch.Tensor:
return model.decode(*args, **kwargs)

0 comments on commit 5fc031d

Please sign in to comment.