Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Update GGML dependency #226

Merged
merged 9 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions binaries/generate-ggml-bindings/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
use std::path::PathBuf;

fn main() {
const HEADER_PATH: &str = "crates/ggml/sys/ggml/include/ggml/ggml.h";

let bindings = bindgen::Builder::default()
.header(HEADER_PATH)
.header("crates/ggml/sys/bindings.h")
// Suppress some warnings
.raw_line("#![allow(non_upper_case_globals)]")
.raw_line("#![allow(non_camel_case_types)]")
.raw_line("#![allow(non_snake_case)]")
.raw_line("#![allow(unused)]")
// Do not generate code for ggml's includes (stdlib)
.allowlist_file(HEADER_PATH)
// Only generate code if it's from GGML
.allowlist_file("crates/ggml/.*")
.generate()
.expect("Unable to generate bindings");

Expand Down
10 changes: 5 additions & 5 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,13 @@ pub enum FileType {
/// Float 32-bit.
F32,
}
impl From<FileType> for llm::FileType {
impl From<FileType> for llm::FileTypeFormat {
fn from(t: FileType) -> Self {
match t {
FileType::Q4_0 => llm::FileType::MostlyQ4_0,
FileType::Q4_1 => llm::FileType::MostlyQ4_1,
FileType::F16 => llm::FileType::MostlyF16,
FileType::F32 => llm::FileType::F32,
FileType::Q4_0 => llm::FileTypeFormat::MostlyQ4_0,
FileType::Q4_1 => llm::FileTypeFormat::MostlyQ4_1,
FileType::F16 => llm::FileTypeFormat::MostlyF16,
FileType::F32 => llm::FileTypeFormat::F32,
}
}
}
Expand Down
45 changes: 40 additions & 5 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,26 +143,47 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// In-place, scales `a` by the 1D tensor `b`.
/// Scales `a` by the 1D tensor `b`.
pub fn op_scale(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_scale(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// In-place, sets the elements above the diagonal to -INF.
/// In-place, scales `a` by the 1D tensor `b`.
pub fn op_scale_inplace(&self, a: &Tensor, b: &Tensor) -> Tensor {
let tensor =
unsafe { sys::ggml_scale_inplace(self.ptr.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Sets the elements above the diagonal to -INF.
pub fn op_diag_mask_inf(&self, a: &Tensor, n_past: usize) -> Tensor {
let tensor = unsafe {
sys::ggml_diag_mask_inf(self.ptr.as_ptr(), a.ptr.as_ptr(), usize_to_i32(n_past))
};
self.new_tensor_raw(tensor)
}

/// In-place, applies the [Softmax function](https://en.wikipedia.org/wiki/Softmax_function) to `a`.
/// In-place, sets the elements above the diagonal to -INF.
pub fn op_diag_mask_inf_inplace(&self, a: &Tensor, n_past: usize) -> Tensor {
let tensor = unsafe {
sys::ggml_diag_mask_inf_inplace(self.ptr.as_ptr(), a.ptr.as_ptr(), usize_to_i32(n_past))
};
self.new_tensor_raw(tensor)
}

/// Applies the [Softmax function](https://en.wikipedia.org/wiki/Softmax_function) to `a`.
pub fn op_soft_max(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_soft_max(self.ptr.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// In-place, applies the [Softmax function](https://en.wikipedia.org/wiki/Softmax_function) to `a`.
pub fn op_soft_max_inplace(&self, a: &Tensor) -> Tensor {
let tensor = unsafe { sys::ggml_soft_max_inplace(self.ptr.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// Creates a new tensor with result of mapping `fun` with `a`.
///
/// `cnt` is the number of `f32` elements to be mapped.
Expand Down Expand Up @@ -332,7 +353,7 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// In-place; applies ROtary Positional Encoding.
/// Applies ROtary Positional Encoding.
pub fn op_rope(&self, a: &Tensor, npast: usize, ndims: usize, mode: i32) -> Tensor {
let tensor = unsafe {
sys::ggml_rope(
Expand All @@ -346,6 +367,20 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// In-place; applies ROtary Positional Encoding.
pub fn op_rope_inplace(&self, a: &Tensor, npast: usize, ndims: usize, mode: i32) -> Tensor {
let tensor = unsafe {
sys::ggml_rope_inplace(
self.ptr.as_ptr(),
a.ptr.as_ptr(),
usize_to_i32(npast),
usize_to_i32(ndims),
mode,
)
};
self.new_tensor_raw(tensor)
}

/// Computes the specified graph. Must be run in order to evaluate the graph.
pub fn graph_compute(&self, graph: &mut ComputationGraph) {
unsafe {
Expand Down Expand Up @@ -380,7 +415,7 @@ impl Context {
}
}

/// TODO: something something
/// Attention with LInear BIases (Ref: <https://arxiv.org/pdf/2108.12409.pdf>)
pub fn op_alibi(&self, a: &Tensor, n_past: usize, n_head: usize) -> Tensor {
let tensor = unsafe {
sys::ggml_alibi(
Expand Down
42 changes: 18 additions & 24 deletions crates/ggml/src/format/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ pub enum LoadError<E: Error> {
#[error("invalid file magic number: {0}")]
/// The file magic number is invalid.
InvalidMagic(u32),
#[error("invalid ggml format: format={0:?} version={1}")]
#[error("invalid ggml format: format={0:?}")]
/// An unsupported format version was found.
InvalidFormatVersion(ContainerType, u32),
InvalidFormatVersion(ContainerType),
#[error("non-specific I/O error")]
/// A non-specific IO error.
Io(#[from] std::io::Error),
Expand Down Expand Up @@ -142,28 +142,20 @@ pub fn load<E: Error, R: BufRead + Seek>(
handler: &mut impl LoadHandler<E>,
) -> Result<(), LoadError<E>> {
// Verify magic
let container_type: ContainerType = match read_u32(reader)? {
crate::FILE_MAGIC_GGMF => ContainerType::Ggmf,
crate::FILE_MAGIC_GGJT => ContainerType::Ggjt,
crate::FILE_MAGIC_UNVERSIONED => ContainerType::Ggml,
crate::FILE_MAGIC_GGLA => ContainerType::Ggla,
magic => return Err(LoadError::InvalidMagic(magic)),
};
handler
.container_type(container_type)
.map_err(LoadError::ImplementationError)?;
let container_type = ContainerType::read(reader)?;

// Load format version
match container_type {
ContainerType::Ggmf | ContainerType::Ggjt | ContainerType::Ggla => {
let _version: u32 = match read_u32(reader)? {
crate::FORMAT_VERSION => crate::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion(container_type, version)),
};
}
ContainerType::Ggml => {}
ContainerType::Ggml
| ContainerType::Ggmf(1)
| ContainerType::Ggjt(1 | 2)
| ContainerType::Ggla(1) => {}
_ => return Err(LoadError::InvalidFormatVersion(container_type)),
}

handler
.container_type(container_type)
.map_err(LoadError::ImplementationError)?;

// Load hyper params
let hparams = handler
.read_hyperparameters(reader)
Expand All @@ -175,8 +167,8 @@ pub fn load<E: Error, R: BufRead + Seek>(
let len = read_u32(reader)?.try_into()?;
let token = read_bytes_with_len(reader, len)?;
let token_score = match container_type {
ContainerType::Ggmf | ContainerType::Ggjt => read_f32(reader)?,
ContainerType::Ggml | ContainerType::Ggla => {
ContainerType::Ggmf(_version) | ContainerType::Ggjt(_version) => read_f32(reader)?,
ContainerType::Ggml | ContainerType::Ggla(_) => {
// Legacy model, set empty score
0.
}
Expand All @@ -188,8 +180,10 @@ pub fn load<E: Error, R: BufRead + Seek>(

// Load tensor data
match container_type {
ContainerType::Ggmf | ContainerType::Ggml => load_weights(reader, handler, false),
ContainerType::Ggjt | ContainerType::Ggla => load_weights(reader, handler, true),
ContainerType::Ggmf(_) | ContainerType::Ggml => load_weights(reader, handler, false),
ContainerType::Ggjt(_version) | ContainerType::Ggla(_version) => {
load_weights(reader, handler, true)
}
}
}

Expand Down
7 changes: 3 additions & 4 deletions crates/ggml/src/format/saver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
io::{Seek, Write},
};

use crate::{util, ElementType};
use crate::{util, ContainerType, ElementType};

#[derive(Debug, thiserror::Error)]
/// Errors that can occur while writing a model.
Expand Down Expand Up @@ -57,16 +57,15 @@ pub struct TensorSaveInfo {

/// Saves a model to the given writer.
///
/// Only GGJT is supported.
/// Only GGJT version 2 is supported.
pub fn save<E: Error, W: Write + Seek>(
writer: &mut W,
handler: &mut dyn SaveHandler<E>,
vocabulary: &[(Vec<u8>, f32)],
tensor_names: &[String],
) -> Result<(), SaveError<E>> {
// Write header and hyperparameters
util::write_u32(writer, crate::FILE_MAGIC_GGJT)?;
util::write_u32(writer, crate::FORMAT_VERSION)?;
ContainerType::Ggjt(2).write(writer)?;
handler
.write_hyperparameters(writer)
.map_err(SaveError::ImplementationError)?;
Expand Down
95 changes: 95 additions & 0 deletions crates/ggml/src/legacy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//! This module exposes legacy functionality of GGML that has been extracted
//! to help bridge versions.

/// Quantization version 0.
pub mod qnt0 {
use crate::ElementType;

macro_rules! generate_dequantization_function {
($rust_name:ident, $c_name:ident, $doc:literal) => {
#[doc=$doc]
pub fn $rust_name(row: &[u8], out: &mut [f32], row_size: usize) {
assert_eq!(row_size, out.len());
unsafe {
ggml_sys::$c_name(
row.as_ptr() as *const _,
out.as_mut_ptr(),
row_size.try_into().unwrap(),
)
}
}
};
}

generate_dequantization_function!(
dequantize_row_q4_0,
qnt0_ggml_dequantize_row_q4_0,
"Dequantizes a QNT0 q4_0 row to f32."
);

generate_dequantization_function!(
dequantize_row_q4_1,
qnt0_ggml_dequantize_row_q4_1,
"Dequantizes a QNT0 q4_1 row to f32."
);

generate_dequantization_function!(
dequantize_row_q4_2,
qnt0_ggml_dequantize_row_q4_2,
"Dequantizes a QNT0 q4_2 row to f32."
);

generate_dequantization_function!(
dequantize_row_q5_0,
qnt0_ggml_dequantize_row_q5_0,
"Dequantizes a QNT0 q5_0 row to f32."
);

generate_dequantization_function!(
dequantize_row_q5_1,
qnt0_ggml_dequantize_row_q5_1,
"Dequantizes a QNT0 q5_1 row to f32."
);

generate_dequantization_function!(
dequantize_row_q8_0,
qnt0_ggml_dequantize_row_q8_0,
"Dequantizes a QNT0 q8_0 row to f32."
);

/// Dequantizes a QNT0 row to f32.
pub fn dequantize_row(
element_type: ElementType,
row: &[u8],
out: &mut [f32],
row_size: usize,
) -> bool {
match element_type {
crate::Type::Q4_0 => {
dequantize_row_q4_0(row, out, row_size);
true
}
crate::Type::Q4_1 => {
dequantize_row_q4_1(row, out, row_size);
true
}
crate::Type::LegacyQ4_2 => {
dequantize_row_q4_2(row, out, row_size);
true
}
crate::Type::Q5_0 => {
dequantize_row_q5_0(row, out, row_size);
true
}
crate::Type::Q5_1 => {
dequantize_row_q5_1(row, out, row_size);
true
}
crate::Type::Q8_0 => {
dequantize_row_q8_0(row, out, row_size);
true
}
_ => false,
}
}
}
Loading