forked from flashinfer-ai/flashinfer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support fused silu mul (flashinfer-ai#427)
### Motivation as titled I implemented a simplified version based on FasterTransformers, and I am considering whether to use optimizations like [half2](https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/activation_kernels.cu), and whether to consider using CUTLASS's [LeftSiLUAndMul](https://github.com/NVIDIA/cutlass/blob/main/examples/45_dual_gemm/thread/left_silu_and_mul.h). Do you have any suggestions? Thanks. @yzh119 ### Modification - [x] fused silu mul - [x] test - [x] benchmark - [ ] error handling - [ ] coding style (PascalCase) --------- Co-authored-by: Zihao Ye <expye@outlook.com>
- Loading branch information
Showing
8 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Copyright (c) 2024 by FlashInfer team. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef FLASHINFER_ACTIVATION_CUH_ | ||
#define FLASHINFER_ACTIVATION_CUH_ | ||
|
||
#include "utils.cuh" | ||
#include "vec_dtypes.cuh" | ||
|
||
namespace flashinfer { | ||
|
||
namespace activation { | ||
|
||
// https://github.com/NVIDIA/FasterTransformer/blob/d21dc02bc5f70bc7dc0d18ba5801ae263565e68e/src/fastertransformer/kernels/activation_kernels.cu#L126-L133 | ||
__device__ __forceinline__ float silu_kernel(const float& val) { | ||
// NOTE(Zihao): use __fdividef might be faster, at the cost of precision | ||
return val / (1.0f + __expf(-val)); | ||
} | ||
|
||
template <typename T, float (*Activation)(const float&)> | ||
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { | ||
constexpr uint32_t vec_size = 16 / sizeof(T); | ||
const int64_t token_idx = blockIdx.x; | ||
const int64_t thread_idx = threadIdx.x; | ||
const int64_t stride = blockDim.x; | ||
const int64_t offset = token_idx * 2 * d; | ||
|
||
#pragma unroll 1 | ||
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { | ||
vec_t<float, vec_size> x_vec, y_vec, out_vec; | ||
x_vec.cast_load(input + offset + idx * vec_size); | ||
y_vec.cast_load(input + offset + d + idx * vec_size); | ||
#pragma unroll | ||
for (uint32_t i = 0; i < vec_size; ++i) { | ||
out_vec[i] = Activation(x_vec[i]) * y_vec[i]; | ||
} | ||
out_vec.cast_store(out + token_idx * d + idx * vec_size); | ||
} | ||
|
||
const int64_t remaining_offset = d - d % (stride * vec_size); | ||
// process the remaining elements | ||
#pragma unroll 1 | ||
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { | ||
float x = input[offset + remaining_offset + idx], | ||
y = input[offset + remaining_offset + d + idx]; | ||
out[token_idx * d + remaining_offset + idx] = Activation(x) * y; | ||
} | ||
} | ||
|
||
} // namespace activation | ||
} // namespace flashinfer | ||
|
||
#endif // FLASHINFER_ACTIVATION_CUH_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* Copyright (c) 2024 by FlashInfer team. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include <flashinfer/activation.cuh> | ||
|
||
#include "flashinfer_ops.h" | ||
#include "pytorch_extension_utils.h" | ||
|
||
using namespace flashinfer; | ||
|
||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { | ||
int d = input.size(-1) / 2; | ||
int64_t num_tokens = input.numel() / input.size(-1); | ||
dim3 grid(num_tokens); | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { | ||
uint32_t vec_size = 16 / sizeof(c_type); | ||
dim3 block(std::min(d / vec_size, 1024U)); | ||
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::silu_kernel> | ||
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), | ||
static_cast<c_type*>(input.data_ptr()), d); | ||
|
||
return true; | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import torch | ||
from typing import Optional | ||
|
||
# mypy: disable-error-code="attr-defined" | ||
try: | ||
from . import _kernels | ||
except ImportError as e: | ||
import logging | ||
import os | ||
|
||
if os.environ.get("BUILD_DOC", "0") == "1": | ||
_kernels = None | ||
logging.warning("Kernels are not loaded in documentation build mode.") | ||
else: | ||
raise e | ||
|
||
|
||
def _check_shape(input: torch.Tensor, output: torch.Tensor): | ||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" | ||
assert ( | ||
input.shape[:-1] == output.shape[:-1] | ||
), f"{input.shape[:-1]} != {output.shape[:-1]}" | ||
assert ( | ||
input.shape[-1] == 2 * output.shape[-1] | ||
), f"{input.shape[-1]} != {2 * output.shape[-1]}" | ||
|
||
|
||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: | ||
r"""Fused SiLU and Mul operation. | ||
Parameters | ||
---------- | ||
input: torch.Tensor | ||
Input tensor, shape (..., 2 * hidden_size). | ||
out: Optional[torch.Tensor] | ||
The the output tensor, if specified, the kernel will update this tensor inplace. | ||
Returns | ||
------- | ||
output: torch.Tensor | ||
Output tensor, shape (..., hidden_size). | ||
""" | ||
if input.shape[-1] * input.dtype.itemsize % 16 != 0: | ||
raise ValueError("The pointers must be multiple of 16 bytes.") | ||
if out is not None: | ||
_check_shape(input, out) | ||
else: | ||
out = torch.empty( | ||
input.shape[:-1] + (input.shape[-1] // 2,), | ||
device=input.device, | ||
dtype=input.dtype, | ||
) | ||
_kernels.silu_and_mul(out, input) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import numpy | ||
import pytest | ||
import torch | ||
|
||
import flashinfer | ||
|
||
|
||
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) | ||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) | ||
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) | ||
def test_fused_silu_mul(dim, batch_size, seq_len): | ||
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) | ||
y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim]) | ||
y = flashinfer.activation.silu_and_mul(x) | ||
numpy.testing.assert_allclose( | ||
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 | ||
) |