Skip to content

Commit

Permalink
add unitest for xfuser attn layer (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Nov 5, 2024
1 parent 85c426d commit 1c31746
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 14 deletions.
27 changes: 20 additions & 7 deletions tests/core/test_ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@
import torch
import torch.distributed as dist
from xfuser.core.long_ctx_attention.ring.ring_flash_attn import xdit_ring_flash_attn_func
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from flash_attn import flash_attn_func
import os

from xfuser.model_executor.layers.attention_processor import (
xFuserAttnProcessor2_0,
)
from diffusers.models.attention_processor import (
Attention,
)
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)


def init_dist(backend='nccl'):
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
Expand All @@ -13,7 +26,9 @@ def init_dist(backend='nccl'):
print(f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}")

torch.cuda.set_device(local_rank)
dist.init_process_group(backend=backend)
# dist.init_process_group(backend=backend)
init_distributed_environment(rank=rank, world_size=world_size)

return rank, world_size

class TestRingFlashAttn(unittest.TestCase):
Expand Down Expand Up @@ -60,7 +75,7 @@ def _create_test_tensors(self):
local_v = v.chunk(self.world_size, dim=1)[self.rank]
return q, k, v, local_q, local_k, local_v

def test_distributed(self):
def test_xdit_ring_flash_attn_func(self):
"""Test ring flash attention in distributed mode"""
q, k, v, local_q, local_k, local_v = self._create_test_tensors()

Expand All @@ -87,8 +102,7 @@ def test_distributed(self):
torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3)
self.assertEqual(ref_output.shape, output.shape)


def test_joint_strategy_rear(self):
def test_xdit_ring_flash_attn_func_joint_strategy_rear(self):
"""Test ring flash attention with joint strategy"""
q, k, v, local_q, local_k, local_v = self._create_test_tensors()
joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = self._create_test_tensors()
Expand Down Expand Up @@ -119,8 +133,7 @@ def test_joint_strategy_rear(self):

torch.testing.assert_close(ref_output, output_rear, rtol=1e-3, atol=1e-3)


def test_joint_strategy_front(self):
def test_xdit_ring_flash_attn_func_joint_strategy_front(self):
"""Test ring flash attention with joint strategy"""
q, k, v, local_q, local_k, local_v = self._create_test_tensors()
joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = self._create_test_tensors()
Expand Down Expand Up @@ -153,4 +166,4 @@ def test_joint_strategy_front(self):

# torchrun --nproc_per_node=2 -m unittest tests/core/test_ring_flash_attn.py
if __name__ == '__main__':
unittest.main()
unittest.main()
181 changes: 181 additions & 0 deletions tests/core/test_xfuser_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import unittest
import torch
import torch.distributed as dist
from xfuser.core.long_ctx_attention.ring.ring_flash_attn import xdit_ring_flash_attn_func
from xfuser.core.long_ctx_attention import xFuserLongContextAttention, xFuserJointLongContextAttention, xFuserFluxLongContextAttention
from flash_attn import flash_attn_func
import os

from xfuser.model_executor.layers.attention_processor import (
xFuserAttnProcessor2_0,
)
from diffusers.models.attention_processor import (
Attention,
)
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)


def init_dist(backend='nccl'):
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

print(f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}")

torch.cuda.set_device(local_rank)
# dist.init_process_group(backend=backend)
init_distributed_environment(rank=rank, world_size=world_size)

# construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2)
if world_size > 1:
ring_degree = world_size // 2
ulysses_degree = 2
else:
ring_degree = 1
ulysses_degree = 1

initialize_model_parallel(
sequence_parallel_degree=world_size , ring_degree=ring_degree, ulysses_degree=ulysses_degree
)

return rank, world_size, ring_degree, ulysses_degree

class TestRingFlashAttn(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.batch_size = 1
cls.num_heads = 4
cls.head_dim = 32
cls.seq_len = 128
cls.dtype = torch.float16


cls.rank, cls.world_size, cls.ring_degree, cls.ulysses_degree = init_dist()
cls.device = torch.device(f'cuda:{cls.rank}')

def setUp(self):
torch.manual_seed(42 + self.rank)

@classmethod
def tearDownClass(cls):
dist.destroy_process_group()

def _create_test_tensors(self):
"""Helper to create test input tensors"""
shape = (self.batch_size, self.seq_len, self.num_heads, self.head_dim)

# Prepare inputs
q = torch.randn(
shape, device=self.device, dtype=self.dtype, requires_grad=False
)
k = torch.randn(
shape, device=self.device, dtype=self.dtype, requires_grad=False
)
v = torch.randn(
shape, device=self.device, dtype=self.dtype, requires_grad=False
)

dist.broadcast(q, src=0)
dist.broadcast(k, src=0)
dist.broadcast(v, src=0)

local_q = q.chunk(self.world_size, dim=1)[self.rank]
local_k = k.chunk(self.world_size, dim=1)[self.rank]
local_v = v.chunk(self.world_size, dim=1)[self.rank]
return q, k, v, local_q, local_k, local_v

def test_xfuser_attn_layer_joint_strategy_rear(self):
"""Test xFuserLongContextAttention layer in distributed mode"""
# Create test tensors
q, k, v, local_q, local_k, local_v = self._create_test_tensors()
joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = self._create_test_tensors()
joint_strategy = "rear"

attn = None

# Create attention layer
attn_layer = xFuserJointLongContextAttention(
scatter_idx=2,
gather_idx=1,
ring_impl_type="basic",
use_kv_cache=False,
).to(device=self.device, dtype=self.dtype)

assert attn_layer.ring_pg.size() == self.ring_degree
assert attn_layer.ulysses_pg.size() == self.ulysses_degree

ref_output = flash_attn_func(
torch.cat([q, joint_q], dim=1),
torch.cat([k, joint_k], dim=1),
torch.cat([v, joint_v], dim=1),
dropout_p=0.0,
window_size=(-1, -1),
)

# Split ref_output into base and joint parts
base_out = ref_output[:, :self.seq_len, ::] # First half for base attention
joint_out = ref_output[:, self.seq_len:, ::] # Second half for joint attention

# Get local shard for base output
base_out_shard = base_out.chunk(self.world_size, dim=1)[self.rank]
# Duplicate joint output as specified
ref_output = torch.cat([base_out_shard, joint_out], dim=1)

# Run distributed implementation
output = attn_layer(
attn=None,
query=local_q,
key=local_k,
value=local_v,
dropout_p=0.0,
window_size=(-1, -1),
joint_tensor_query=joint_q,
joint_tensor_key=joint_k,
joint_tensor_value=joint_v,
joint_strategy=joint_strategy,
)
# assert torch.max(torch.abs(output - ref_output)) < 1e-3
torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3)

def test_xfuser_attn_layer(self):
"""Test xFuserLongContextAttention layer in distributed mode"""
# Create test tensors
q, k, v, local_q, local_k, local_v = self._create_test_tensors()
attn = None

# Create attention layer
attn_layer = xFuserLongContextAttention(
scatter_idx=2,
gather_idx=1,
ring_impl_type="basic",
use_kv_cache=False,
).to(device=self.device, dtype=self.dtype)

assert attn_layer.ring_pg.size() == self.ring_degree
assert attn_layer.ulysses_pg.size() == self.ulysses_degree

ref_output = flash_attn_func(
q, k, v,
dropout_p=0.0,
window_size=(-1, -1),
)
ref_output = ref_output.chunk(self.world_size, dim=1)[self.rank]

# Run distributed implementation
output = attn_layer(
attn=None,
query=local_q,
key=local_k,
value=local_v,
dropout_p=0.0,
window_size=(-1, -1),
)
assert torch.max(torch.abs(output - ref_output)) < 1e-3
torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3)

# torchrun --nproc_per_node=4 -m unittest tests/core/test_xfuser_attn.py
if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions xfuser/core/fast_attention/fast_attn_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,20 @@ def get_fast_attn_state() -> FastAttnState:

def get_fast_attn_enable() -> bool:
"""Return whether fast attention is enabled."""
if get_fast_attn_state() is None:
return False
return get_fast_attn_state().enable


def get_fast_attn_step() -> int:
"""Return the fast attention step."""
assert get_fast_attn_state() is not None, "FastAttn state is not initialized"
return get_fast_attn_state().n_step


def get_fast_attn_calib() -> int:
"""Return the fast attention calibration."""
assert get_fast_attn_state() is not None, "FastAttn state is not initialized"
return get_fast_attn_state().n_calib


Expand Down
12 changes: 10 additions & 2 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from yunchang.comm.all_to_all import SeqAllToAll4D

from xfuser.logger import init_logger
from xfuser.core.long_ctx_attention.ring import xdit_ring_flash_attn_func


logger = init_logger(__name__)
Expand Down Expand Up @@ -46,6 +45,8 @@ def __init__(
raise RuntimeError(
f"ring_impl_type: {ring_impl_type} do not support SP kv cache."
)

from xfuser.core.long_ctx_attention.ring import xdit_ring_flash_attn_func
self.ring_attn_fn = xdit_ring_flash_attn_func

@torch.compiler.disable
Expand All @@ -71,14 +72,21 @@ def forward(
"""forward
Arguments:
attn (Attention): the attention module
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
args: other args,
joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy
joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy
joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy,
*args: the args same as flash_attn_interface
joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear"
Returns:
* output (Tensor): context output
"""
assert causal == False, "causal attention is not applied in DiTs"
# 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
if self.use_pack_qkv:
Expand Down
4 changes: 0 additions & 4 deletions xfuser/core/long_ctx_attention/hybrid/utils.py

This file was deleted.

2 changes: 1 addition & 1 deletion xfuser/core/long_ctx_attention/ring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .ring_flash_attn import xdit_ring_flash_attn_func

__all__ = [
"ring_flash_attn_func",
"xdit_ring_flash_attn_func",
]
1 change: 1 addition & 0 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from flash_attn.flash_attn_interface import _flash_attn_forward
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc
Expand Down

0 comments on commit 1c31746

Please sign in to comment.