From 1c31746e2f903e791bc2a41a0bc23614958e46cd Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 5 Nov 2024 17:58:32 +0800 Subject: [PATCH] add unitest for xfuser attn layer (#335) --- tests/core/test_ring_flash_attn.py | 27 ++- tests/core/test_xfuser_attn.py | 181 ++++++++++++++++++ xfuser/core/fast_attention/fast_attn_state.py | 4 + .../long_ctx_attention/hybrid/attn_layer.py | 12 +- .../core/long_ctx_attention/hybrid/utils.py | 4 - .../core/long_ctx_attention/ring/__init__.py | 2 +- .../ring/ring_flash_attn.py | 1 + 7 files changed, 217 insertions(+), 14 deletions(-) create mode 100644 tests/core/test_xfuser_attn.py delete mode 100644 xfuser/core/long_ctx_attention/hybrid/utils.py diff --git a/tests/core/test_ring_flash_attn.py b/tests/core/test_ring_flash_attn.py index 4e49cbe..12bf63d 100644 --- a/tests/core/test_ring_flash_attn.py +++ b/tests/core/test_ring_flash_attn.py @@ -2,9 +2,22 @@ import torch import torch.distributed as dist from xfuser.core.long_ctx_attention.ring.ring_flash_attn import xdit_ring_flash_attn_func +from xfuser.core.long_ctx_attention import xFuserLongContextAttention from flash_attn import flash_attn_func import os +from xfuser.model_executor.layers.attention_processor import ( + xFuserAttnProcessor2_0, +) +from diffusers.models.attention_processor import ( + Attention, +) +from xfuser.core.distributed import ( + init_distributed_environment, + initialize_model_parallel, +) + + def init_dist(backend='nccl'): local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) @@ -13,7 +26,9 @@ def init_dist(backend='nccl'): print(f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}") torch.cuda.set_device(local_rank) - dist.init_process_group(backend=backend) + # dist.init_process_group(backend=backend) + init_distributed_environment(rank=rank, world_size=world_size) + return rank, world_size class TestRingFlashAttn(unittest.TestCase): @@ -60,7 +75,7 @@ def _create_test_tensors(self): local_v = v.chunk(self.world_size, dim=1)[self.rank] return q, k, v, local_q, local_k, local_v - def test_distributed(self): + def test_xdit_ring_flash_attn_func(self): """Test ring flash attention in distributed mode""" q, k, v, local_q, local_k, local_v = self._create_test_tensors() @@ -87,8 +102,7 @@ def test_distributed(self): torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3) self.assertEqual(ref_output.shape, output.shape) - - def test_joint_strategy_rear(self): + def test_xdit_ring_flash_attn_func_joint_strategy_rear(self): """Test ring flash attention with joint strategy""" q, k, v, local_q, local_k, local_v = self._create_test_tensors() joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = self._create_test_tensors() @@ -119,8 +133,7 @@ def test_joint_strategy_rear(self): torch.testing.assert_close(ref_output, output_rear, rtol=1e-3, atol=1e-3) - - def test_joint_strategy_front(self): + def test_xdit_ring_flash_attn_func_joint_strategy_front(self): """Test ring flash attention with joint strategy""" q, k, v, local_q, local_k, local_v = self._create_test_tensors() joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = self._create_test_tensors() @@ -153,4 +166,4 @@ def test_joint_strategy_front(self): # torchrun --nproc_per_node=2 -m unittest tests/core/test_ring_flash_attn.py if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() \ No newline at end of file diff --git a/tests/core/test_xfuser_attn.py b/tests/core/test_xfuser_attn.py new file mode 100644 index 0000000..be3a353 --- /dev/null +++ b/tests/core/test_xfuser_attn.py @@ -0,0 +1,181 @@ +import unittest +import torch +import torch.distributed as dist +from xfuser.core.long_ctx_attention.ring.ring_flash_attn import xdit_ring_flash_attn_func +from xfuser.core.long_ctx_attention import xFuserLongContextAttention, xFuserJointLongContextAttention, xFuserFluxLongContextAttention +from flash_attn import flash_attn_func +import os + +from xfuser.model_executor.layers.attention_processor import ( + xFuserAttnProcessor2_0, +) +from diffusers.models.attention_processor import ( + Attention, +) +from xfuser.core.distributed import ( + init_distributed_environment, + initialize_model_parallel, +) + + +def init_dist(backend='nccl'): + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + print(f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}") + + torch.cuda.set_device(local_rank) + # dist.init_process_group(backend=backend) + init_distributed_environment(rank=rank, world_size=world_size) + + # construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2) + if world_size > 1: + ring_degree = world_size // 2 + ulysses_degree = 2 + else: + ring_degree = 1 + ulysses_degree = 1 + + initialize_model_parallel( + sequence_parallel_degree=world_size , ring_degree=ring_degree, ulysses_degree=ulysses_degree + ) + + return rank, world_size, ring_degree, ulysses_degree + +class TestRingFlashAttn(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.batch_size = 1 + cls.num_heads = 4 + cls.head_dim = 32 + cls.seq_len = 128 + cls.dtype = torch.float16 + + + cls.rank, cls.world_size, cls.ring_degree, cls.ulysses_degree = init_dist() + cls.device = torch.device(f'cuda:{cls.rank}') + + def setUp(self): + torch.manual_seed(42 + self.rank) + + @classmethod + def tearDownClass(cls): + dist.destroy_process_group() + + def _create_test_tensors(self): + """Helper to create test input tensors""" + shape = (self.batch_size, self.seq_len, self.num_heads, self.head_dim) + + # Prepare inputs + q = torch.randn( + shape, device=self.device, dtype=self.dtype, requires_grad=False + ) + k = torch.randn( + shape, device=self.device, dtype=self.dtype, requires_grad=False + ) + v = torch.randn( + shape, device=self.device, dtype=self.dtype, requires_grad=False + ) + + dist.broadcast(q, src=0) + dist.broadcast(k, src=0) + dist.broadcast(v, src=0) + + local_q = q.chunk(self.world_size, dim=1)[self.rank] + local_k = k.chunk(self.world_size, dim=1)[self.rank] + local_v = v.chunk(self.world_size, dim=1)[self.rank] + return q, k, v, local_q, local_k, local_v + + def test_xfuser_attn_layer_joint_strategy_rear(self): + """Test xFuserLongContextAttention layer in distributed mode""" + # Create test tensors + q, k, v, local_q, local_k, local_v = self._create_test_tensors() + joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = self._create_test_tensors() + joint_strategy = "rear" + + attn = None + + # Create attention layer + attn_layer = xFuserJointLongContextAttention( + scatter_idx=2, + gather_idx=1, + ring_impl_type="basic", + use_kv_cache=False, + ).to(device=self.device, dtype=self.dtype) + + assert attn_layer.ring_pg.size() == self.ring_degree + assert attn_layer.ulysses_pg.size() == self.ulysses_degree + + ref_output = flash_attn_func( + torch.cat([q, joint_q], dim=1), + torch.cat([k, joint_k], dim=1), + torch.cat([v, joint_v], dim=1), + dropout_p=0.0, + window_size=(-1, -1), + ) + + # Split ref_output into base and joint parts + base_out = ref_output[:, :self.seq_len, ::] # First half for base attention + joint_out = ref_output[:, self.seq_len:, ::] # Second half for joint attention + + # Get local shard for base output + base_out_shard = base_out.chunk(self.world_size, dim=1)[self.rank] + # Duplicate joint output as specified + ref_output = torch.cat([base_out_shard, joint_out], dim=1) + + # Run distributed implementation + output = attn_layer( + attn=None, + query=local_q, + key=local_k, + value=local_v, + dropout_p=0.0, + window_size=(-1, -1), + joint_tensor_query=joint_q, + joint_tensor_key=joint_k, + joint_tensor_value=joint_v, + joint_strategy=joint_strategy, + ) + # assert torch.max(torch.abs(output - ref_output)) < 1e-3 + torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3) + + def test_xfuser_attn_layer(self): + """Test xFuserLongContextAttention layer in distributed mode""" + # Create test tensors + q, k, v, local_q, local_k, local_v = self._create_test_tensors() + attn = None + + # Create attention layer + attn_layer = xFuserLongContextAttention( + scatter_idx=2, + gather_idx=1, + ring_impl_type="basic", + use_kv_cache=False, + ).to(device=self.device, dtype=self.dtype) + + assert attn_layer.ring_pg.size() == self.ring_degree + assert attn_layer.ulysses_pg.size() == self.ulysses_degree + + ref_output = flash_attn_func( + q, k, v, + dropout_p=0.0, + window_size=(-1, -1), + ) + ref_output = ref_output.chunk(self.world_size, dim=1)[self.rank] + + # Run distributed implementation + output = attn_layer( + attn=None, + query=local_q, + key=local_k, + value=local_v, + dropout_p=0.0, + window_size=(-1, -1), + ) + assert torch.max(torch.abs(output - ref_output)) < 1e-3 + torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3) + +# torchrun --nproc_per_node=4 -m unittest tests/core/test_xfuser_attn.py +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/xfuser/core/fast_attention/fast_attn_state.py b/xfuser/core/fast_attention/fast_attn_state.py index 91823c4..cb0427e 100644 --- a/xfuser/core/fast_attention/fast_attn_state.py +++ b/xfuser/core/fast_attention/fast_attn_state.py @@ -61,16 +61,20 @@ def get_fast_attn_state() -> FastAttnState: def get_fast_attn_enable() -> bool: """Return whether fast attention is enabled.""" + if get_fast_attn_state() is None: + return False return get_fast_attn_state().enable def get_fast_attn_step() -> int: """Return the fast attention step.""" + assert get_fast_attn_state() is not None, "FastAttn state is not initialized" return get_fast_attn_state().n_step def get_fast_attn_calib() -> int: """Return the fast attention calibration.""" + assert get_fast_attn_state() is not None, "FastAttn state is not initialized" return get_fast_attn_state().n_calib diff --git a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py index 0f5b068..24acb18 100644 --- a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py +++ b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py @@ -7,7 +7,6 @@ from yunchang.comm.all_to_all import SeqAllToAll4D from xfuser.logger import init_logger -from xfuser.core.long_ctx_attention.ring import xdit_ring_flash_attn_func logger = init_logger(__name__) @@ -46,6 +45,8 @@ def __init__( raise RuntimeError( f"ring_impl_type: {ring_impl_type} do not support SP kv cache." ) + + from xfuser.core.long_ctx_attention.ring import xdit_ring_flash_attn_func self.ring_attn_fn = xdit_ring_flash_attn_func @torch.compiler.disable @@ -71,14 +72,21 @@ def forward( """forward Arguments: + attn (Attention): the attention module query (Tensor): query input to the layer key (Tensor): key input to the layer value (Tensor): value input to the layer - args: other args + args: other args, + joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy + joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy + joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy, + *args: the args same as flash_attn_interface + joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear" Returns: * output (Tensor): context output """ + assert causal == False, "causal attention is not applied in DiTs" # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) # scatter 2, gather 1 if self.use_pack_qkv: diff --git a/xfuser/core/long_ctx_attention/hybrid/utils.py b/xfuser/core/long_ctx_attention/hybrid/utils.py deleted file mode 100644 index fceb2f3..0000000 --- a/xfuser/core/long_ctx_attention/hybrid/utils.py +++ /dev/null @@ -1,4 +0,0 @@ -from yunchang.ring import ( - zigzag_ring_flash_attn_func, - stripe_flash_attn_func, -) diff --git a/xfuser/core/long_ctx_attention/ring/__init__.py b/xfuser/core/long_ctx_attention/ring/__init__.py index 6b37c5b..8ed9877 100644 --- a/xfuser/core/long_ctx_attention/ring/__init__.py +++ b/xfuser/core/long_ctx_attention/ring/__init__.py @@ -1,5 +1,5 @@ from .ring_flash_attn import xdit_ring_flash_attn_func __all__ = [ - "ring_flash_attn_func", + "xdit_ring_flash_attn_func", ] diff --git a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py index dd209a5..bdae65d 100644 --- a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py +++ b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py @@ -1,5 +1,6 @@ import torch from flash_attn.flash_attn_interface import _flash_attn_forward +from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.cache_manager.cache_manager import get_cache_manager from yunchang.ring.utils import RingComm, update_out_and_lse from yunchang.ring.ring_flash_attn import RingFlashAttnFunc