Skip to content

Commit

Permalink
Disable cublaslt if using f16 kernels (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored May 30, 2024
1 parent 9f2937c commit e1c3e6e
Showing 1 changed file with 74 additions and 58 deletions.
132 changes: 74 additions & 58 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,29 @@ impl MatMul {
}
}

/// Computes softmax(QK^T*sqrt(d_k))V
fn naive_sdpa(
q: &Tensor,
k: &Tensor,
v: &Tensor,
head_dim: usize,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let att = MatMul.matmul_affine_div(
&q.contiguous()?,
&k.t()?.contiguous()?,
(head_dim as f64).sqrt(),
)?;

let att = match mask {
Some(m) => att.broadcast_add(m)?,
None => att,
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
MatMul.matmul(&att, &v.contiguous()?)
}

pub struct ScaledDotProductAttention;

impl ScaledDotProductAttention {
Expand Down Expand Up @@ -307,66 +330,59 @@ impl ScaledDotProductAttention {
}

if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) {
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;
let attention_bias = mask.map(|mask| mask.flatten(0, 1)).transpose()?;

// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};

// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let attention_scores = cublaslt.batch_matmul(
&k,
&q,
attention_bias.as_ref(),
Some((1.0 / (head_dim as f64).sqrt()) as f32),
beta,
None,
None,
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;

let context_layer = cublaslt.batch_matmul(
&v.t()?.contiguous()?,
&attention_probs,
// We save one allocation
Some(&q),
None,
None,
None,
None,
)?;

// Reshape to dims4
context_layer.reshape((b_sz, n_attn_heads, seq_len, head_dim))
}
#[cfg(not(feature = "cuda"))]
{
candle_core::bail!("`cuda` feature is not enabled")
if !get_use_matmul_via_f16() {
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;
let attention_bias = mask.map(|mask| mask.flatten(0, 1)).transpose()?;

// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};

// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let attention_scores = cublaslt.batch_matmul(
&k,
&q,
attention_bias.as_ref(),
Some((1.0 / (head_dim as f64).sqrt()) as f32),
beta,
None,
None,
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;

let context_layer = cublaslt.batch_matmul(
&v.t()?.contiguous()?,
&attention_probs,
// We save one allocation
Some(&q),
None,
None,
None,
None,
)?;

// Reshape to dims4
context_layer.reshape((b_sz, n_attn_heads, seq_len, head_dim))
}
#[cfg(not(feature = "cuda"))]
{
candle_core::bail!("`cuda` feature is not enabled")
}
} else {
// Use the f16 kernels here if quantized (ISQ or GGML), and a large enough prompt
naive_sdpa(q, k, v, head_dim, mask)
}
} else {
let att = MatMul.matmul_affine_div(
&q.contiguous()?,
&k.t()?.contiguous()?,
(head_dim as f64).sqrt(),
)?;

let att = match mask {
Some(m) => att.broadcast_add(m)?,
None => att,
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
MatMul.matmul(&att, &v.contiguous()?)
naive_sdpa(q, k, v, head_dim, mask)
}
}
}
Expand Down

0 comments on commit e1c3e6e

Please sign in to comment.