From 898d8ea8a21f5850288bc4a860399678131a2d30 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 7 Aug 2024 22:27:59 -0700 Subject: [PATCH] bugfix: Improve numerical stability of sampling kernels (#429) 1. use `sum_of_probs_gt_pivot` rather than `sum_of_probs_leq_pivot` 2. make sure pivot will not decrease during iterations. --- include/flashinfer/sampling.cuh | 88 ++++++++++++++------------------- 1 file changed, 37 insertions(+), 51 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index ffd5c6492..c7a8104f9 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -291,13 +291,13 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, vec_t probs_vec; DType aggregate; - DType q = DType(0); + DType q = DType(1); DType pivot = DType(0); IdType sampled_id; for (uint32_t round = 0; round < max_top_k_rounds; ++round) { temp_storage.data.sampled_id = d - 1; __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q); + DType u = uniform_samples[round * batch_size + bx] * q; aggregate = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); @@ -314,42 +314,38 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, } __syncthreads(); sampled_id = temp_storage.data.sampled_id; - pivot = probs[bx * d + sampled_id]; + pivot = max(pivot, probs[bx * d + sampled_id]); - Pair aggregate_leq_pivot{DType(0), 0}; + Pair aggregate_gt_pivot{DType(0), 0}; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - Pair probs_leq_pivot[VEC_SIZE]; + Pair probs_gt_pivot[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_leq_pivot[j] = { - (probs_vec[j] <= pivot) ? probs_vec[j] : DType(0), - (probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0), + (probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; } - aggregate_leq_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_pair) - .Sum(probs_leq_pivot); + aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_gt_pivot); if (tx == 0) { - temp_storage.data.block_aggregate.pair = aggregate_leq_pivot; + temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; } __syncthreads(); - if (temp_storage.data.block_aggregate.pair.count + k > d) { - break; - } } q = temp_storage.data.block_aggregate.pair.value; - if (temp_storage.data.block_aggregate.pair.count + k > d) { + if (temp_storage.data.block_aggregate.pair.count < k) { break; } } __syncthreads(); if (tx == 0) { - if (temp_storage.data.block_aggregate.pair.count + k <= d) { + if (temp_storage.data.block_aggregate.pair.count >= k) { // failed to sample within MAX_TOP_P_ROUNDS if (success != nullptr) { success[bx] = false; @@ -363,8 +359,6 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, } } -constexpr float eps = 1e-5; - template @@ -387,13 +381,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, vec_t probs_vec; DType aggregate; - DType q = DType(0); + DType q = DType(1); DType pivot = DType(0); IdType sampled_id; for (uint32_t round = 0; round < max_top_p_rounds; ++round) { temp_storage.data.sampled_id = d - 1; __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q); + DType u = uniform_samples[round * batch_size + bx] * q; aggregate = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); @@ -410,39 +404,36 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, } __syncthreads(); sampled_id = temp_storage.data.sampled_id; - pivot = probs[row_idx * d + sampled_id]; + pivot = max(pivot, probs[row_idx * d + sampled_id]); - DType aggregate_leq_pivot = DType(0); + DType aggregate_gt_pivot = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - DType probs_leq_pivot[VEC_SIZE]; + DType probs_gt_pivot[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_leq_pivot[j] = (probs_vec[j] <= pivot) ? probs_vec[j] : DType(0); + probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); } - aggregate_leq_pivot += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_leq_pivot); + aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot); if (tx == 0) { - temp_storage.data.block_aggregate.value = aggregate_leq_pivot; + temp_storage.data.block_aggregate.value = aggregate_gt_pivot; } __syncthreads(); - if (float(temp_storage.data.block_aggregate.value) + top_p > 1 + eps) { - break; - } } q = temp_storage.data.block_aggregate.value; - if (float(q) + top_p > 1 + eps) { + if (float(q) < top_p) { break; } } __syncthreads(); if (tx == 0) { - if (float(q) + top_p <= 1 + eps) { + if (float(q) >= top_p) { // failed to sample within MAX_TOP_P_ROUNDS if (success != nullptr) { success[bx] = false; @@ -475,13 +466,13 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp vec_t probs_vec; DType aggregate; - DType q = DType(0); + DType q = DType(1); DType pivot = DType(0); IdType sampled_id; for (uint32_t round = 0; round < max_rounds; ++round) { temp_storage.data.sampled_id = d - 1; __syncthreads(); - DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q); + DType u = uniform_samples[round * batch_size + bx] * q; aggregate = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); @@ -498,43 +489,38 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp } __syncthreads(); sampled_id = temp_storage.data.sampled_id; - pivot = probs[bx * d + sampled_id]; + pivot = max(pivot, probs[bx * d + sampled_id]); - Pair aggregate_leq_pivot{DType(0), 0}; + Pair aggregate_gt_pivot{DType(0), 0}; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - Pair probs_leq_pivot[VEC_SIZE]; + Pair probs_gt_pivot[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_leq_pivot[j] = { - (probs_vec[j] <= pivot) ? probs_vec[j] : DType(0), - (probs_vec[j] <= pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0), + (probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; } - aggregate_leq_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_pair) - .Sum(probs_leq_pivot); + aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_gt_pivot); if (tx == 0) { - temp_storage.data.block_aggregate.pair = aggregate_leq_pivot; + temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; } __syncthreads(); - if (temp_storage.data.block_aggregate.pair.count + k > d && - float(temp_storage.data.block_aggregate.pair.value) + p > 1 + eps) { - break; - } } q = temp_storage.data.block_aggregate.pair.value; - if (temp_storage.data.block_aggregate.pair.count + k > d && float(q) + p > 1 + eps) { + if (temp_storage.data.block_aggregate.pair.count < k && float(q) < p) { break; } } __syncthreads(); if (tx == 0) { - if (temp_storage.data.block_aggregate.pair.count + k <= d || float(q) + p <= 1 + eps) { + if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) { // failed to sample within MAX_TOP_P_ROUNDS if (success != nullptr) { success[bx] = false;