Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bamba VLLM Draft #2

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft

Bamba VLLM Draft #2

wants to merge 9 commits into from

Conversation

fabianlim
Copy link
Owner

@fabianlim fabianlim commented Nov 28, 2024

NOTES

  • refactor to make it TP-able
  • add tests and make sure it the non-chunked pre-fill tests pass
  • ensure chunked pre-fill the chunked prefill tests pass
  • investigate the cuda invalid access for long input sequences
  • fix precision problem for gated norm
  • fix mamba kernels for long sequences

Tests

Currently the tests, except the ones for chunked pre-fill, are passing

================================================================================================== short test summary info ==================================================================================================
FAILED tests/models/decoder_only/language/test_bamba.py::test_chunked_prefill_with_parallel_sampling[10-float-/workspace/bamba-ckpt-fp16] - ValueError: too many values to unpack (expected 2)
FAILED tests/models/decoder_only/language/test_bamba.py::test_chunked_prefill[1-32-float-/workspace/bamba-ckpt-fp16] - AssertionError: Test0:
FAILED tests/models/decoder_only/language/test_bamba.py::test_chunked_prefill[4-32-float-/workspace/bamba-ckpt-fp16] - AssertionError: Test0:
FAILED tests/models/decoder_only/language/test_bamba.py::test_chunked_prefill[16-32-float-/workspace/bamba-ckpt-fp16] - AssertionError: Test0:
==================================================================================== 4 failed, 9 passed, 1 warning in 496.02s (0:08:16) =====================================================================================
(mamba-vllm) 1000960000@flim-mamba-master-0:~/data/vllm$ pytest tests/models/decoder_only/language/test_bamba.py::test_chunked_prefill
==================================================================================================== test session starts ====================================================================================================
platform linux -- Python 3.10.12, pytest-8.3.3, 

@fabianlim fabianlim marked this pull request as draft November 28, 2024 14:33
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim force-pushed the pr-draft branch 2 times, most recently from 4d67c31 to 98ba4fa Compare November 30, 2024 05:35
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
if initial_states is not None:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel chunked prefill support requires update of the kernels. This is because initial_states seems to be only implemented to handle batch > 1 when cu_seqlens == None, since the latter case is only supported when we flatten the input x such that batch == 1.

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim
Copy link
Owner Author

For the chunked prefill effort, i made some mind maps of the _mamba_chunk_scan_combined_fwd function

  • since this is inference we only care about fwd
  • we only need to target _state_passing_fwd as that is the function that accepts initial_states
  • in _state_passing_fwd, the initial_states will override state in the computation states = scale * states + new_states, so we need to detect when a sequence has finished. we can do this using seq_idx_new and seq_idx. A new sequence is starting when their indices differ
   for c in range(nchunks):
        new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
        scale = tl.exp(dA_cs)
        if HAS_SEQ_IDX:
            seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
            scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
            seq_idx = seq_idx_new
        states = scale * states + new_states
        if c < nchunks - 1:
            tl.store(out_ptrs, states, mask=offs_m < dim)
        else:
            tl.store(final_states_ptrs, states, mask=offs_m < dim)
        states_ptrs += stride_states_chunk
        dA_cs_ptr += stride_dA_cs_chunk
        out_ptrs += stride_out_chunk

       # if we detect that seq_idx and seq_idx_new has changed, then we need to load the new init states
       # modification
      initstates_ptrs = initstates_ptr + offs_m * stride_batch
      states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
image image

@cyang49
Copy link

cyang49 commented Dec 4, 2024

To get stable latency measurements in nsight systems we rely on enabling cudagraph. However in our test with long context (64k) cudagraph usage is not seen in the profile. After digging in the code I found that the usage is controlled by engine argument max_seq_len_to_capture and default to 8192. Overriding this config will enable the usage correctly.

image

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants