Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tentative] Adding 192 head dim (step_size = 12) #454

Closed
wants to merge 3 commits into from

Conversation

Narsil
Copy link

@Narsil Narsil commented Aug 19, 2024

Not sure if this PR actually works correctly (I'm going to check it).

Deepseek models use head_dim=192, and this cannot be compiled because of this static assert.
This modification works by jumping like step_size=4 + jumping by 8 every iteration.

Let me know if this is interesting in here.

@zhyncs
Copy link
Member

zhyncs commented Aug 19, 2024

@Narsil May you write unit test for it? And you can ref https://github.com/zhyncs/dl/blob/master/flashinfer_build.sh to compile from source.

@zhyncs zhyncs requested a review from yzh119 August 19, 2024 08:38
@Narsil
Copy link
Author

Narsil commented Aug 19, 2024

Which tests would you like me to add, batch_prefill_kernels ? others ?

Wdym ref to build from source ? I am building already.

@zhyncs
Copy link
Member

zhyncs commented Aug 19, 2024

@Narsil
Copy link
Author

Narsil commented Aug 19, 2024

Are the tests ran anywhere ?

@zhyncs
Copy link
Member

zhyncs commented Aug 19, 2024

Are the tests ran anywhere ?

There is currently no CI configured, you can use pytest in the local development environment to run.

@yzh119
Copy link
Collaborator

yzh119 commented Aug 19, 2024

Hi @Narsil , thanks for your contribution!
You can add 192 to

head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",")

if compilation successes, you can run unittests such as https://github.com/flashinfer-ai/flashinfer/blob/0d618712faff20a84bbd513d02ac01e16be19306/python/tests/test_batch_prefill_kernels.py and see how does it work.

@zhyncs
Copy link
Member

zhyncs commented Sep 19, 2024

Hi @Narsil Any update?

@yzh119 yzh119 mentioned this pull request Sep 25, 2024
yzh119 added a commit that referenced this pull request Oct 7, 2024
This PR implements the JIT compilation (#170 ) of flashinfer, after this
PR, flashinfer will compile kernels just-in-time for different input
data types and shapes, and cached the kernels at the disk, instead of
pre-compile a set of kernels in the wheel.

# Motivation
The pip wheel size is exploding as we add support to more data types,
more head dimensions, more attention variants and more kernel
implementation. Pre-compile everything is not sustainable, and impedes
development speed.

This PR refactors the codebase to use torch's [JIT Compiling
Extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions)
feature instead of pre-compile kernels in the wheel.

## Attention Variants
We learned from [FlexAttention](https://pytorch.org/blog/flexattention/)
and describes every attention variant as a template class, each instance
of the struct can carry some closure variable defined in local memory or
shared memory, below are two examples (logits soft cap and alibi
attention, the programming interface is tentative and will be updated as
we improve the programmability of the JIT template):

```cuda
template <typename ParamsT>
struct LogitsSoftCap {
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;

  uint32_t qo_len, kv_len;
  uint32_t window_left;

  __device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
                                               uint32_t qo_idx, uint32_t kv_idx,
                                               uint32_t qo_head_idx, uint32_t kv_head_idx) {
    return params.logits_soft_cap * math::log2e * float(math::tanh(logits));
  }

  __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
                                             uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
                                             uint32_t kv_head_idx) {
    return true;
  }
};

template <typename ParamsT>
struct ALIBIAttention {
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;
  using IdType = typename ParamsT::IdType;

  uint32_t qo_len, kv_len;
  uint32_t window_left;

  __device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.sm_scale * math::log2e;
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
                                               uint32_t qo_idx, uint32_t kv_idx,
                                               uint32_t qo_head_idx, uint32_t kv_head_idx) {
    return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx));
  }

  __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
                                             uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
                                             uint32_t kv_head_idx) {
    return true;
  }
};
```
User can customize their own `ParamsT` class and variants class to
define their own attention variants, we hope such refactor will make the
codebase more concise and extensive.

# Roadmap

After this PR, we will add support for:
1. PyPI wheels #153 
2. fp8 tensor cores attention: #502
3. different head dimensions: #142 #454 #455
4. flashattention3 #369 
5. multi-head latency attention #237 
6. Generate ParamsT and Attention variants description from python dsl

The development of this features have been blocked by the limitation of
wheel size (binary size >= 2GB will trigger some linking issues), I hope
this PR will make development easier in the future.
@zhyncs
Copy link
Member

zhyncs commented Oct 9, 2024

cc @yzh119

@yzh119
Copy link
Collaborator

yzh119 commented Oct 9, 2024

Hi @zhyncs, I'll create a PR to support any head_dim that is divisible by 16, and 192 will be supported there.

While I appreciate the effort of this PR, I think the implementation is not correct because the step_size=8 actually aligns with the granularity of CUDA's cp.async instruction. But there is no such instruction that aligns with step_size=12.

@zhyncs
Copy link
Member

zhyncs commented Oct 9, 2024

Ok that's good. I'll close this PR for now. Thanks all! @yzh119 @Narsil

@zhyncs zhyncs closed this Oct 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants