Skip to content

Commit

Permalink
Merge pull request #109 from byeongkeunahn/main
Browse files Browse the repository at this point in the history
Enable Rust 2024 edition
  • Loading branch information
byeongkeunahn authored Sep 7, 2024
2 parents 74e674a + ce4ac82 commit c4eaee5
Show file tree
Hide file tree
Showing 18 changed files with 2,194 additions and 1,967 deletions.
4 changes: 3 additions & 1 deletion basm-macro/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
cargo-features = ["edition2024"]

[package]
name = "basm-macro"
version = "0.1.0"
edition = "2021"
edition = "2024"
autobins = false

[lib]
Expand Down
4 changes: 3 additions & 1 deletion basm-std/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
cargo-features = ["edition2024"]

[package]
name = "basm-std"
version = "0.1.0"
edition = "2021"
edition = "2024"
autobins = false

[lib]
Expand Down
170 changes: 90 additions & 80 deletions basm-std/src/math/ntt/nttcore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,14 @@ unsafe fn ntt2_single_block<const P: u64, const INV: bool, const TWIDDLE: bool>(
mut px: *mut u64,
ptf: *const u64,
) -> (*mut u64, *const u64) {
let w1 = if TWIDDLE { *ptf } else { 0 };
for _ in 0..s1 {
(*px, *px.add(s1)) = ntt2_kernel::<P, INV, TWIDDLE>(w1, *px, *px.add(s1));
px = px.add(1);
unsafe {
let w1 = if TWIDDLE { *ptf } else { 0 };
for _ in 0..s1 {
(*px, *px.add(s1)) = ntt2_kernel::<P, INV, TWIDDLE>(w1, *px, *px.add(s1));
px = px.add(1);
}
(px.add(s1), ptf.add(1))
}
(px.add(s1), ptf.add(1))
}
const fn ntt3_kernel<const P: u64, const INV: bool, const TWIDDLE: bool>(
w1: u64,
Expand Down Expand Up @@ -414,14 +416,16 @@ unsafe fn ntt3_single_block<const P: u64, const INV: bool, const TWIDDLE: bool>(
mut px: *mut u64,
ptf: *const u64,
) -> (*mut u64, *const u64) {
let w1 = if TWIDDLE { *ptf } else { 0 };
let w2 = Arith::<P>::mmulmod(w1, w1);
for _ in 0..s1 {
(*px, *px.add(s1), *px.add(2 * s1)) =
ntt3_kernel::<P, INV, TWIDDLE>(w1, w2, *px, *px.add(s1), *px.add(2 * s1));
px = px.add(1);
}
(px.add(2 * s1), ptf.add(1))
unsafe {
let w1 = if TWIDDLE { *ptf } else { 0 };
let w2 = Arith::<P>::mmulmod(w1, w1);
for _ in 0..s1 {
(*px, *px.add(s1), *px.add(2 * s1)) =
ntt3_kernel::<P, INV, TWIDDLE>(w1, w2, *px, *px.add(s1), *px.add(2 * s1));
px = px.add(1);
}
(px.add(2 * s1), ptf.add(1))
}
}
const fn ntt4_kernel<const P: u64, const INV: bool, const TWIDDLE: bool>(
w1: u64,
Expand Down Expand Up @@ -456,22 +460,24 @@ unsafe fn ntt4_single_block<const P: u64, const INV: bool, const TWIDDLE: bool>(
mut px: *mut u64,
ptf: *const u64,
) -> (*mut u64, *const u64) {
let w1 = if TWIDDLE { *ptf } else { 0 };
let w2 = Arith::<P>::mmulmod(w1, w1);
let w3 = Arith::<P>::mmulmod(w1, w2);
for _ in 0..s1 {
(*px, *px.add(s1), *px.add(2 * s1), *px.add(3 * s1)) = ntt4_kernel::<P, INV, TWIDDLE>(
w1,
w2,
w3,
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
);
px = px.add(1);
unsafe {
let w1 = if TWIDDLE { *ptf } else { 0 };
let w2 = Arith::<P>::mmulmod(w1, w1);
let w3 = Arith::<P>::mmulmod(w1, w2);
for _ in 0..s1 {
(*px, *px.add(s1), *px.add(2 * s1), *px.add(3 * s1)) = ntt4_kernel::<P, INV, TWIDDLE>(
w1,
w2,
w3,
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
);
px = px.add(1);
}
(px.add(3 * s1), ptf.add(1))
}
(px.add(3 * s1), ptf.add(1))
}
const fn ntt5_kernel<const P: u64, const INV: bool, const TWIDDLE: bool>(
w1: u64,
Expand Down Expand Up @@ -523,31 +529,33 @@ unsafe fn ntt5_single_block<const P: u64, const INV: bool, const TWIDDLE: bool>(
mut px: *mut u64,
ptf: *const u64,
) -> (*mut u64, *const u64) {
let w1 = if TWIDDLE { *ptf } else { 0 };
let w2 = Arith::<P>::mmulmod(w1, w1);
let w3 = Arith::<P>::mmulmod(w1, w2);
let w4 = Arith::<P>::mmulmod(w2, w2);
for _ in 0..s1 {
(
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
*px.add(4 * s1),
) = ntt5_kernel::<P, INV, TWIDDLE>(
w1,
w2,
w3,
w4,
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
*px.add(4 * s1),
);
px = px.add(1);
unsafe {
let w1 = if TWIDDLE { *ptf } else { 0 };
let w2 = Arith::<P>::mmulmod(w1, w1);
let w3 = Arith::<P>::mmulmod(w1, w2);
let w4 = Arith::<P>::mmulmod(w2, w2);
for _ in 0..s1 {
(
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
*px.add(4 * s1),
) = ntt5_kernel::<P, INV, TWIDDLE>(
w1,
w2,
w3,
w4,
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
*px.add(4 * s1),
);
px = px.add(1);
}
(px.add(4 * s1), ptf.add(1))
}
(px.add(4 * s1), ptf.add(1))
}
const fn ntt6_kernel<const P: u64, const INV: bool, const TWIDDLE: bool>(
w1: u64,
Expand Down Expand Up @@ -602,35 +610,37 @@ unsafe fn ntt6_single_block<const P: u64, const INV: bool, const TWIDDLE: bool>(
mut px: *mut u64,
ptf: *const u64,
) -> (*mut u64, *const u64) {
let w1 = if TWIDDLE { *ptf } else { 0 };
let w2 = Arith::<P>::mmulmod(w1, w1);
let w3 = Arith::<P>::mmulmod(w1, w2);
let w4 = Arith::<P>::mmulmod(w2, w2);
let w5 = Arith::<P>::mmulmod(w2, w3);
for _ in 0..s1 {
(
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
*px.add(4 * s1),
*px.add(5 * s1),
) = ntt6_kernel::<P, INV, TWIDDLE>(
w1,
w2,
w3,
w4,
w5,
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
*px.add(4 * s1),
*px.add(5 * s1),
);
px = px.add(1);
unsafe {
let w1 = if TWIDDLE { *ptf } else { 0 };
let w2 = Arith::<P>::mmulmod(w1, w1);
let w3 = Arith::<P>::mmulmod(w1, w2);
let w4 = Arith::<P>::mmulmod(w2, w2);
let w5 = Arith::<P>::mmulmod(w2, w3);
for _ in 0..s1 {
(
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
*px.add(4 * s1),
*px.add(5 * s1),
) = ntt6_kernel::<P, INV, TWIDDLE>(
w1,
w2,
w3,
w4,
w5,
*px,
*px.add(s1),
*px.add(2 * s1),
*px.add(3 * s1),
*px.add(4 * s1),
*px.add(5 * s1),
);
px = px.add(1);
}
(px.add(5 * s1), ptf.add(1))
}
(px.add(5 * s1), ptf.add(1))
}

fn ntt_dif_dit<const P: u64, const INV: bool>(plan: &NttPlan, x: &mut [u64], tf_list: &[u64]) {
Expand Down
18 changes: 10 additions & 8 deletions basm-std/src/platform/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,28 @@ pub unsafe fn install_malloc_impl(
ptr_dealloc: unsafe fn(*mut u8, usize, usize),
ptr_realloc: unsafe fn(*mut u8, usize, usize, usize) -> *mut u8,
) {
PTR_ALLOC = ptr_alloc;
PTR_ALLOC_ZEROED = ptr_alloc_zeroed;
PTR_DEALLOC = ptr_dealloc;
PTR_REALLOC = ptr_realloc;
unsafe {
PTR_ALLOC = ptr_alloc;
PTR_ALLOC_ZEROED = ptr_alloc_zeroed;
PTR_DEALLOC = ptr_dealloc;
PTR_REALLOC = ptr_realloc;
}
}

pub struct Allocator;

unsafe impl GlobalAlloc for Allocator {
#[inline(always)]
unsafe fn alloc(&self, layout: core::alloc::Layout) -> *mut u8 {
PTR_ALLOC(layout.size(), layout.align())
unsafe { PTR_ALLOC(layout.size(), layout.align()) }
}
#[inline(always)]
unsafe fn alloc_zeroed(&self, layout: core::alloc::Layout) -> *mut u8 {
PTR_ALLOC_ZEROED(layout.size(), layout.align())
unsafe { PTR_ALLOC_ZEROED(layout.size(), layout.align()) }
}
#[inline(always)]
unsafe fn dealloc(&self, ptr: *mut u8, layout: core::alloc::Layout) {
PTR_DEALLOC(ptr, layout.size(), layout.align())
unsafe { PTR_DEALLOC(ptr, layout.size(), layout.align()) }
}
#[inline(always)]
unsafe fn realloc(
Expand All @@ -40,6 +42,6 @@ unsafe impl GlobalAlloc for Allocator {
layout: core::alloc::Layout,
new_size: usize,
) -> *mut u8 {
PTR_REALLOC(ptr, layout.size(), layout.align(), new_size)
unsafe { PTR_REALLOC(ptr, layout.size(), layout.align(), new_size) }
}
}
Loading

0 comments on commit c4eaee5

Please sign in to comment.