Skip to content

Commit

Permalink
Adding Q4_{0/1} support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Mar 17, 2023
1 parent 454924b commit bd27e75
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
9 changes: 8 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ fn prepare(tensor_dict: HashMap<String, &PyDict>) -> PyResult<HashMap<String, Te
dtype = match value {
"bool" => Some(Dtype::BOOL),
"int8" => Some(Dtype::I8),
"uint8" => Some(Dtype::U8),
"q4_0" => Some(Dtype::Q4_0),
"q4_1" => Some(Dtype::Q4_1),
"int16" => Some(Dtype::I16),
"uint8" => Some(Dtype::U8),
"uint16" => Some(Dtype::U16),
"int32" => Some(Dtype::I32),
"uint32" => Some(Dtype::U32),
Expand Down Expand Up @@ -874,6 +876,11 @@ fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
Dtype::U8 => module.getattr(intern!(py, "uint8"))?.into(),
Dtype::I8 => module.getattr(intern!(py, "int8"))?.into(),
Dtype::BOOL => module.getattr(intern!(py, "bool"))?.into(),
Dtype::Q4_1 | Dtype::Q4_0 => {
return Err(SafetensorError::new_err(format!(
"Dtype not supported by framework: {dtype:?}"
)))
}
dtype => {
return Err(SafetensorError::new_err(format!(
"Dtype not understood: {dtype:?}"
Expand Down
43 changes: 39 additions & 4 deletions safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pub enum SafeTensorError {
TensorNotFound(String),
/// Invalid information between shape, dtype and the proposed offsets in the file
TensorInvalidInfo,
/// Invalid information between shape, dtype and the proposed offsets in the file
/// The total number of bytes for the buffer is not an integer
QuantizationMisaligned,
/// The offsets declared for tensor with name `String` in the header are invalid
InvalidOffset(String),
/// IoError
Expand Down Expand Up @@ -461,7 +464,11 @@ impl Metadata {
}
start = e;
let nelements: usize = info.shape.iter().product();
let nbytes = nelements * info.dtype.size();
let nbits = nelements * info.dtype.nbits();
if !nbits % 8 == 0 {
return Err(SafeTensorError::QuantizationMisaligned);
}
let nbytes = nbits / 8;
if (e - s) != nbytes {
return Err(SafeTensorError::TensorInvalidInfo);
}
Expand Down Expand Up @@ -570,6 +577,10 @@ pub struct TensorInfo {
pub enum Dtype {
/// Boolan type
BOOL,
/// Unsigned int4
Q4_0,
/// Signed int4
Q4_1,
/// Unsigned byte
U8,
/// Signed byte
Expand Down Expand Up @@ -601,20 +612,44 @@ impl Dtype {
pub fn size(&self) -> usize {
match self {
Dtype::BOOL => 1,
Dtype::Q4_0 => 1,
Dtype::Q4_1 => 1,
Dtype::U8 => 1,
Dtype::I8 => 1,
Dtype::I16 => 2,
Dtype::U16 => 2,
Dtype::F16 => 2,
Dtype::BF16 => 2,
Dtype::I32 => 4,
Dtype::U32 => 4,
Dtype::F32 => 4,
Dtype::I64 => 8,
Dtype::U64 => 8,
Dtype::F16 => 2,
Dtype::BF16 => 2,
Dtype::F32 => 4,
Dtype::F64 => 8,
}
}

/// Gives out the size (in bits) of 1 element of this dtype.
/// This is important for sub-byte types like q4_0 and q4_1
pub fn nbits(&self) -> usize {
match self {
Dtype::Q4_0 => 4,
Dtype::Q4_1 => 4,
Dtype::BOOL => 8,
Dtype::U8 => 8,
Dtype::I8 => 8,
Dtype::I16 => 16,
Dtype::U16 => 16,
Dtype::F16 => 16,
Dtype::BF16 => 16,
Dtype::I32 => 32,
Dtype::U32 => 32,
Dtype::F32 => 32,
Dtype::I64 => 64,
Dtype::U64 => 64,
Dtype::F64 => 64,
}
}
}

#[cfg(test)]
Expand Down

0 comments on commit bd27e75

Please sign in to comment.