From e1c3e6eba50625b0a98545191364ebaa7e7c5931 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 30 May 2024 09:43:20 -0400 Subject: [PATCH] Disable cublaslt if using f16 kernels (#359) --- mistralrs-core/src/layers.rs | 132 ++++++++++++++++++++--------------- 1 file changed, 74 insertions(+), 58 deletions(-) diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index de45a35ebf..a81dc2f028 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -275,6 +275,29 @@ impl MatMul { } } +/// Computes softmax(QK^T*sqrt(d_k))V +fn naive_sdpa( + q: &Tensor, + k: &Tensor, + v: &Tensor, + head_dim: usize, + mask: Option<&Tensor>, +) -> Result { + let att = MatMul.matmul_affine_div( + &q.contiguous()?, + &k.t()?.contiguous()?, + (head_dim as f64).sqrt(), + )?; + + let att = match mask { + Some(m) => att.broadcast_add(m)?, + None => att, + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + MatMul.matmul(&att, &v.contiguous()?) +} + pub struct ScaledDotProductAttention; impl ScaledDotProductAttention { @@ -307,66 +330,59 @@ impl ScaledDotProductAttention { } if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) { - #[cfg(feature = "cuda")] - { - // cuBLASLt batch matmul implementation requires inputs to be dims3 - let k = k.flatten(0, 1)?; - let q = q.flatten(0, 1)?; - let v = v.flatten(0, 1)?; - let attention_bias = mask.map(|mask| mask.flatten(0, 1)).transpose()?; - - // If attention_bias is set, we fuse the add by giving it as the output matrix - // and setting beta to 1.0 - let beta = match attention_bias.is_some() { - true => Some(1.0), - false => None, - }; - - // Batch matrix multiplication - // Fuse softmax scale and attention_bias add - let attention_scores = cublaslt.batch_matmul( - &k, - &q, - attention_bias.as_ref(), - Some((1.0 / (head_dim as f64).sqrt()) as f32), - beta, - None, - None, - )?; - let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; - - let context_layer = cublaslt.batch_matmul( - &v.t()?.contiguous()?, - &attention_probs, - // We save one allocation - Some(&q), - None, - None, - None, - None, - )?; - - // Reshape to dims4 - context_layer.reshape((b_sz, n_attn_heads, seq_len, head_dim)) - } - #[cfg(not(feature = "cuda"))] - { - candle_core::bail!("`cuda` feature is not enabled") + if !get_use_matmul_via_f16() { + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let k = k.flatten(0, 1)?; + let q = q.flatten(0, 1)?; + let v = v.flatten(0, 1)?; + let attention_bias = mask.map(|mask| mask.flatten(0, 1)).transpose()?; + + // If attention_bias is set, we fuse the add by giving it as the output matrix + // and setting beta to 1.0 + let beta = match attention_bias.is_some() { + true => Some(1.0), + false => None, + }; + + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let attention_scores = cublaslt.batch_matmul( + &k, + &q, + attention_bias.as_ref(), + Some((1.0 / (head_dim as f64).sqrt()) as f32), + beta, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &v.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&q), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape((b_sz, n_attn_heads, seq_len, head_dim)) + } + #[cfg(not(feature = "cuda"))] + { + candle_core::bail!("`cuda` feature is not enabled") + } + } else { + // Use the f16 kernels here if quantized (ISQ or GGML), and a large enough prompt + naive_sdpa(q, k, v, head_dim, mask) } } else { - let att = MatMul.matmul_affine_div( - &q.contiguous()?, - &k.t()?.contiguous()?, - (head_dim as f64).sqrt(), - )?; - - let att = match mask { - Some(m) => att.broadcast_add(m)?, - None => att, - }; - let att = candle_nn::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - MatMul.matmul(&att, &v.contiguous()?) + naive_sdpa(q, k, v, head_dim, mask) } } }