diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 52c08e169..e11069f50 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -21,6 +21,7 @@ #include #include +#include "math.cuh" #include "utils.cuh" #include "vec_dtypes.cuh" @@ -207,9 +208,9 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, (probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; } - aggregate_leq_pivot += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_pair) - .Sum(probs_leq_pivot); + aggregate_leq_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_leq_pivot); if (tx == 0) { temp_storage.data.block_aggregate.pair = aggregate_leq_pivot; } @@ -421,6 +422,340 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b return cudaSuccess; } +template +struct RenormTempStorage { + union { + typename BlockReduce::TempStorage reduce; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_pair; + } block_prim; + struct { + T max_val; + union { + T value; + Pair pair; + } block_aggregate; + } data; +}; + +template +__global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float p, float eps, + uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem[]; + auto& temp_storage = + reinterpret_cast&>(smem); + temp_storage.data.max_val = DType(0); + vec_t probs_vec; + DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 + + DType threadlocal_max_val = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_greater_than_pivot[j] = probs_vec[j]; + } + threadlocal_max_val = + max(threadlocal_max_val, + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(probs_greater_than_pivot, cub::Max())); + __syncthreads(); + } + if (tx == 0) { + temp_storage.data.max_val = threadlocal_max_val; + } + __syncthreads(); + threadlocal_max_val = temp_storage.data.max_val; + + float low = 0, high = threadlocal_max_val; + DType sum_low(1); + // f(x) = probs[probs > x], f(x) is non-increasing + // loop invariant: f(low) >= p, f(high) < p + while (high - low > eps) { + DType threadlocal_sum(0); + float mid = (low + high) / 2; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : DType(0); + } + threadlocal_sum += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_greater_than_pivot); + __syncthreads(); + } + if (tx == 0) { + temp_storage.data.block_aggregate.value = threadlocal_sum; + } + __syncthreads(); + threadlocal_sum = temp_storage.data.block_aggregate.value; + if (threadlocal_sum >= p) { + low = mid; + sum_low = float(threadlocal_sum); + } else { + high = mid; + } + } + + DType normalizer = math::ptx_rcp(max(sum_low, eps)); + + // normalize + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0); + } + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + } +} + +template +__global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32_t k, float eps, + uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem[]; + auto& temp_storage = + reinterpret_cast&>(smem); + temp_storage.data.max_val = DType(0); + vec_t probs_vec; + DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 + + DType threadlocal_max_val = DType(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_greater_than_pivot[j] = probs_vec[j]; + } + threadlocal_max_val = + max(threadlocal_max_val, + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(probs_greater_than_pivot, cub::Max())); + __syncthreads(); + } + if (tx == 0) { + temp_storage.data.max_val = threadlocal_max_val; + } + __syncthreads(); + threadlocal_max_val = temp_storage.data.max_val; + + float low = 0, high = threadlocal_max_val; + DType sum_low(1); + // f(x) = len(nonzero(probs > x)), f(x) is non-increasing + // loop invariant: f(low) >= k, f(high) < k + while (high - low > eps) { + Pair threadlocal_sum{DType(0), 0}; + Pair probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0 + float mid = (low + high) / 2; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_greater_than_pivot_pair[j] = { + (probs_vec[j] > mid) ? probs_vec[j] : DType(0), + (probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + threadlocal_sum += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_greater_than_pivot_pair); + __syncthreads(); + } + if (tx == 0) { + temp_storage.data.block_aggregate.pair = threadlocal_sum; + } + __syncthreads(); + threadlocal_sum = temp_storage.data.block_aggregate.pair; + if (threadlocal_sum.count >= k) { + low = mid; + sum_low = float(threadlocal_sum.value); + } else { + high = mid; + } + } + + float normalizer = math::ptx_rcp(max(sum_low, eps)); + + // normalize + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0); + } + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + } +} + +template +cudaError_t TopPRenormProb(DType* probs, IdType* renormed_prob, float p, float eps, + uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { + const uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &p, &eps, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopPRenormProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t TopKRenormProb(DType* probs, IdType* renormed_prob, uint32_t k, float eps, + uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { + const uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &k, &eps, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKRenormProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +template +__global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids, + DType* uniform_samples, DType* target_probs, + IdType* output_token_ids, uint32_t num_speculative_tokens, + uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + + extern __shared__ __align__(alignof( + SamplingTempStorage)) uint8_t smem[]; + auto& temp_storage = reinterpret_cast< + SamplingTempStorage&>(smem); + + bool rejected = false; + uint32_t pos = 0; + for (pos = 0; pos < num_speculative_tokens; ++pos) { + IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos]; + float q = target_probs[(row_idx * (num_speculative_tokens + 1) + pos) * d + draft_id], + p = draft_probs[(row_idx * num_speculative_tokens + pos) * d + draft_id]; + DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + pos]; + if (u * p < q) { + // accept the draft models output + output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = draft_id; + } else { + break; + } + } + + // sample from relu(target_probs - draft_probs) + DType sum_relu_q_minus_p(0); + vec_t q_vec, p_vec; + DType relu_q_minus_p[VEC_SIZE]; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + q_vec.fill(DType(0)); + p_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + q_vec.load(target_probs + row_idx * (num_speculative_tokens + 1) * d + + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + if (pos != num_speculative_tokens) { + // there is no draft_probs for the bonus token + p_vec.load(draft_probs + row_idx * num_speculative_tokens * d + + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0)); + } + sum_relu_q_minus_p += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(relu_q_minus_p); + } + if (tx == 0) { + temp_storage.data.block_aggregate.value = sum_relu_q_minus_p; + } + // init the first rejected token to (d - 1) + temp_storage.data.sampled_id = d - 1; + __syncthreads(); + sum_relu_q_minus_p = temp_storage.data.block_aggregate.value; + DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + pos] * sum_relu_q_minus_p; + + DType aggregate_relu_q_minus_p(0); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + q_vec.fill(DType(0)); + p_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + q_vec.load(target_probs + row_idx * (num_speculative_tokens + 1) * d + + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + if (pos != num_speculative_tokens) { + // there is no draft_probs for the bonus token + p_vec.load(draft_probs + row_idx * num_speculative_tokens * d + + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + } + + vec_t relu_q_minus_p_vec; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0)); + } + + DeviceSamplingFromProb( + i, d, DType(0), u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage); + if (aggregate_relu_q_minus_p > u) { + break; + } + } + __syncthreads(); + // set the first rejected token + output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = temp_storage.data.sampled_id; + // move to the next token + pos++; + + // pad remaining tokens with -1 + for (; pos < num_speculative_tokens + 1; ++pos) { + output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = -1; + } +} + template cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, IdType* row_indices, T* top_p_arr, @@ -446,6 +781,36 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o return cudaSuccess; } +template +cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids, + DType* uniform_samples, DType* target_probs, + IdType* output_token_ids, uint32_t batch_size, + uint32_t num_speculative_tokens, uint32_t d, + cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = + sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&draft_probs, + &draft_token_ids, + &uniform_samples, + &target_probs, + &output_token_ids, + &num_speculative_tokens, + &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = + ChainSpeculativeSampling; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + } // namespace sampling } // namespace flashinfer diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 54d5f55df..47a0df176 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -38,8 +38,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( cudaError_t status = handler_->BeginForward(static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim); + workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); } @@ -166,8 +166,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( cudaError_t status = handler_->BeginForward(static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim); + workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); } diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 67daa87dd..05118bd1b 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -34,6 +34,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Top-k sampling from probabilities"); m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top-p sampling from probabilities"); + m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask"); + m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask"); + m.def("chain_speculative_sampling", &chain_speculative_sampling, + "Speculative sampling from sequence of probabilities"); m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); py::class_(m, "BatchDecodeWithPagedKVCachePyTorchWrapper") diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index acfa39bb9..d826d71fd 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -62,6 +62,13 @@ std::vector top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, unsigned int top_k); +torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps); + +torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps); + +torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, + torch::Tensor uniform_samples, torch::Tensor target_probs); + torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps); class BatchDecodeWithPagedKVCachePyTorchWrapper { @@ -83,8 +90,7 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, unsigned int max_workspace_size_in_bytes) : kv_layout_(flashinfer::QKVLayout(layout)), - handler_( - std::make_shared(max_workspace_size_in_bytes)) {} + handler_(std::make_shared(max_workspace_size_in_bytes)) {} protected: std::shared_ptr handler_; diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index f980fdcd3..db0dcf668 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -96,3 +96,85 @@ std::vector top_k_sampling_from_probs(torch::Tensor probs, return {samples, success}; } + +torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) { + CHECK_INPUT(probs); + CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) + unsigned int batch_size = probs.size(0); + unsigned int vocab_size = probs.size(1); + probs = probs.to(torch::kFloat32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + auto renorm_probs = + torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(probs.device())); + + cudaError_t status = sampling::TopPRenormProb( + static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), top_p, + eps, batch_size, vocab_size, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status))); + return renorm_probs; +} + +torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps) { + CHECK_INPUT(probs); + CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) + unsigned int batch_size = probs.size(0); + unsigned int vocab_size = probs.size(1); + probs = probs.to(torch::kFloat32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + auto renorm_probs = + torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(probs.device())); + + cudaError_t status = sampling::TopKRenormProb( + static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), top_k, + eps, batch_size, vocab_size, torch_current_stream); + + TORCH_CHECK(status == cudaSuccess, + "TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status))); + return renorm_probs; +} + +torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, + torch::Tensor uniform_samples, + torch::Tensor target_probs) { + CHECK_INPUT(draft_probs); + CHECK_INPUT(draft_token_ids); + CHECK_INPUT(uniform_samples); + CHECK_INPUT(target_probs); + CHECK_DIM(3, draft_probs); // draft_probs: (batch_size, num_speculate_tokens, vocab_size) + CHECK_DIM(2, draft_token_ids); // draft_token_ids: (batch_size, num_speculate_tokens) + CHECK_DIM(2, uniform_samples); // uniform_samples: (batch_size, num_speculate_tokens + 1) + CHECK_DIM(3, target_probs); // target_probs: (batch_size, num_speculate_tokens + 1, vocab_size) + unsigned int batch_size = draft_probs.size(0); + unsigned int num_speculate_tokens = draft_probs.size(1); + unsigned int vocab_size = draft_probs.size(2); + CHECK_EQ(batch_size, draft_token_ids.size(0)); + CHECK_EQ(batch_size, uniform_samples.size(0)); + CHECK_EQ(batch_size, target_probs.size(0)); + CHECK_EQ(num_speculate_tokens + 1, uniform_samples.size(1)); + CHECK_EQ(num_speculate_tokens + 1, target_probs.size(1)); + CHECK_EQ(vocab_size, target_probs.size(2)); + + draft_probs = draft_probs.to(torch::kFloat32); + draft_token_ids = draft_token_ids.to(torch::kInt32); + uniform_samples = uniform_samples.to(torch::kFloat32); + target_probs = target_probs.to(torch::kFloat32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + auto output_token_ids = + torch::empty({batch_size, num_speculate_tokens + 1}, + torch::dtype(torch::kInt32).device(draft_token_ids.device())); + + cudaError_t status = sampling::ChainSpeculativeSampling( + static_cast(draft_probs.data_ptr()), static_cast(draft_token_ids.data_ptr()), + static_cast(uniform_samples.data_ptr()), static_cast(target_probs.data_ptr()), + static_cast(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size, + torch_current_stream); + + TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " + + std::string(cudaGetErrorString(status))); + + return output_token_ids; +} diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index f7353d48a..64ce3ce27 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -40,6 +40,9 @@ sampling_from_probs, top_p_sampling_from_probs, top_k_sampling_from_probs, + top_p_renorm_prob, + top_k_renorm_prob, + chain_speculative_sampling, ) from .norm import rmsnorm diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index cedbefaac..1ba04f1a9 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -632,9 +632,9 @@ def forward_return_lse( class CUDAGraphBatchDecodeWithPagedKVCacheWrapper: - r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first - proposed in `vLLM `_) for batch of requests. - + r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first + proposed in `vLLM `_) for batch of requests. + Note that this wrapper may not be as efficient as :class:`BatchDecodeWithPagedKVCacheWrapper` because we won't dispatch to different kernels for different batch sizes/sequence lengths/etc to accomodate the CUDAGraph requirement. @@ -673,7 +673,7 @@ def __init__( during the lifecycle of this wrapper. indices_buffer : torch.Tensor The user reserved buffer on GPU to store the page indices of the paged kv cache, - should be large enough to store the maximum number of page indices + should be large enough to store the maximum number of page indices (``max_num_pages``) during the lifecycle of this wrapper. last_page_len_buffer : torch.Tensor The user reserved buffer on GPU to store the number of entries in the last page, diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index c1b8eda41..6d9d16e40 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -30,7 +30,7 @@ def sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor): - r"""Category sampling from probabilities. + r"""Fused GPU kernel for category sampling from probabilities. Parameters ---------- @@ -75,8 +75,11 @@ def sampling_from_probs(probs: torch.Tensor, uniform_samples: torch.Tensor): def top_p_sampling_from_probs( probs: torch.Tensor, uniform_samples: torch.Tensor, top_p: float ): - r"""Top-p sampling (nucleus sampling) from probabilities, this operator implements - GPU-based rejection sampling without explicit sorting. + r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, + this operator implements GPU-based rejection sampling without explicit sorting. + + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, + which is more efficient than the naive implementation that launches a series of kernels. Parameters ---------- @@ -134,8 +137,11 @@ def top_p_sampling_from_probs( def top_k_sampling_from_probs( probs: torch.Tensor, uniform_samples: torch.Tensor, top_k: int ): - r"""Top-k sampling from probabilities, this operator implements GPU-based rejection sampling - without explicit sorting. + r"""Fused GPU kernel for top-k sampling from probabilities, + this operator implements GPU-based rejection sampling without explicit sorting. + + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, + which is more efficient than the naive implementation that launches a series of kernels. Parameters ---------- @@ -188,3 +194,96 @@ def top_k_sampling_from_probs( implementation usually use much fewer rounds for rejection sampling because of early stopping. """ return _kernels.top_k_sampling_from_probs(probs, uniform_samples, top_k) + + +def top_p_renorm_prob(probs: torch.Tensor, top_p: float, eps: float = 1e-5): + r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding. + + Parameters + ---------- + probs: torch.Tensor + Probabilities, shape ``(batch_size, num_classes)``. + top_p: float + The threshold for re-normalizing probabilities, should be in ``(0, 1)``. + We mask out the probabilities less than `threshold` where the cumulative sum + of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities. + eps: float + The epsilon value for numerical stability. + + Returns + ------- + renorm_probs: torch.Tensor + Renormalized probabilities, shape ``(batch_size, num_classes)``. + + This combination of ``top_p_renorm_prob`` and ``sampling_from_probs`` should be equivalent to + ``top_p_sampling_from_probs``. + """ + return _kernels.top_p_renorm_prob(probs, top_p, eps) + + +def top_k_renorm_prob(probs: torch.Tensor, top_k: int, eps: float = 1e-5): + r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding. + + Parameters + ---------- + probs: torch.Tensor + Probabilities, shape ``(batch_size, num_classes)``. + top_k: int + The threshold for re-normalizing probabilities, should be in ``(0, num_classes)``. + We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities. + eps: float + The epsilon value for numerical stability. + + Returns + ------- + renorm_probs: torch.Tensor + Renormalized probabilities, shape ``(batch_size, num_classes)``. + + Note + ---- + This combination of ``top_k_renorm_prob`` and ``sampling_from_probs`` should be equivalent to + ``top_k_sampling_from_probs``. + """ + return _kernels.top_k_renorm_prob(probs, top_k, eps) + + +def chain_speculative_sampling( + draft_probs, + draft_token_ids, + uniform_samples, + target_probs, +): + r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in + paper `Accelerating Large Language Model Decoding with Speculative Sampling `_), + where the draft model generates a sequence(chain) of tokens for each request. + + Parameters + ---------- + draft_probs: torch.Tensor + The probability over vocabulary generated by draft model. + Shape: ``(batch_size, num_speculate_tokens, vocab_size)`` + draft_token_ids: torch.Tensor + The draft model's generated token indices. + Shape: ``(batch_size, num_specutate_tokens)`` + uniform_samples: torch.Tensor + The uniform samples used as needle for sampling, shape ``(batch_size, num_speculate_tokens + 1)``. + Expected to be uniformly distributed in ``[0, 1)``. + target_probs: torch.Tensor + The probability over vocabulary generated by target model. + Compared to input :attr:`draft_probs`, the target model's probability has an additional + slot at the end because the target model will generate one more token than the draft model. + Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)`` + + Returns + ------- + output_token_ids: torch.Tensor + The output token indices verified by the target model, rejected samples are + padded with ``-1``. + Compared to input :attr:`draft_token_ids`, the output tensor has an additional + token index at the end for the final token, if all previous tokens are accepted, + another "bonus" token will be sampled from the target model's probability. + Shape: (batch_size, num_specutate_tokens + 1) + """ + return _kernels.chain_speculative_sampling( + draft_probs, draft_token_ids, uniform_samples, target_probs + ) diff --git a/python/setup.py b/python/setup.py index 990ab57b7..3eec45df7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -64,14 +64,14 @@ def get_instantiation_cu() -> List[str]: (root / prefix).mkdir(parents=True, exist_ok=True) group_sizes = os.environ.get("FLASHINFER_GROUP_SIZES", "1,4,6,8").split(",") - page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1,16,32").split(",") - head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") + page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1").split(",") + head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "128,256").split(",") kv_layouts = os.environ.get("FLASHINFER_KV_LAYOUTS", "0,1").split(",") pos_encoding_modes = os.environ.get("FLASHINFER_POS_ENCODING_MODES", "0,1,2").split( "," ) allow_fp16_qk_reduction_options = os.environ.get( - "FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0,1" + "FLASHINFER_ALLOW_FP16_QK_REDUCTION_OPTIONS", "0" ).split(",") causal_options = os.environ.get("FLASHINFER_CAUSAL_OPTIONS", "0,1").split(",") # dispatch.inc diff --git a/python/tests/test_sampling.py b/python/tests/test_sampling.py index 025f4588c..961f26819 100644 --- a/python/tests/test_sampling.py +++ b/python/tests/test_sampling.py @@ -95,7 +95,107 @@ def test_top_k_sampling(batch_size, vocab_size, k): ] +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) +def test_top_p_renorm_prob(batch_size, vocab_size, p): + eps = 1e-6 + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = flashinfer.sampling.top_p_renorm_prob(normalized_prob, p, eps=eps) + numpy.testing.assert_allclose( + renorm_prob_ground_truth.cpu().numpy(), + renorm_prob.cpu().numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_renorm_prob(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = flashinfer.sampling.top_k_renorm_prob(normalized_prob, k) + numpy.testing.assert_allclose( + renorm_prob_ground_truth.cpu().numpy(), + renorm_prob.cpu().numpy(), + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("num_speculate_tokens", [1, 3, 5, 7]) +def test_chain_speculative_sampling( + batch_size, + vocab_size, + num_speculate_tokens, +): + pre_norm_draft_prob = torch.rand(batch_size, num_speculate_tokens, vocab_size).to(0) + normalized_draft_prob = pre_norm_draft_prob / pre_norm_draft_prob.sum( + dim=-1, keepdim=True + ) + draft_token_ids = torch.randint(vocab_size, (batch_size, num_speculate_tokens)).to( + 0 + ) + uniform_samples = torch.empty(batch_size, num_speculate_tokens + 1).to(0) + pre_norm_target_prob = torch.rand( + batch_size, num_speculate_tokens + 1, vocab_size + ).to(0) + normalized_target_prob = pre_norm_target_prob / pre_norm_target_prob.sum( + dim=-1, keepdim=True + ) + + # NOTE(Zihao): this is a very simple test that only checks whether output is valid or not. + for trials in range(10): + uniform_samples.uniform_() + output_token_ids = flashinfer.sampling.chain_speculative_sampling( + normalized_draft_prob, + draft_token_ids, + uniform_samples, + normalized_target_prob, + ) + assert torch.all(output_token_ids[output_token_ids >= 0] < vocab_size) + assert output_token_ids.shape == (batch_size, num_speculate_tokens + 1) + matches = output_token_ids[..., :-1] != draft_token_ids + for row in range(batch_size): + mismatch_idx = torch.nonzero(matches[row], as_tuple=True)[0] + if len(mismatch_idx) > 0: + # mismatch_idx should be contiguous + assert torch.all(mismatch_idx[1:] == mismatch_idx[:-1] + 1) + # from the second mismatched token on, the output tokens should be -1 + assert torch.all(output_token_ids[row, mismatch_idx[0] + 1 :] == -1) + + if __name__ == "__main__": test_sampling(1, 111) test_top_p_sampling(3, 111, 0.9) test_top_k_sampling(3, 111, 10) + test_top_p_renorm_prob(3, 111, 0.9) + test_top_k_renorm_prob(3, 111, 10) + test_chain_speculative_sampling(3, 111, 3)