Skip to content

Commit

Permalink
refactor: rewrite Stack::push_slice to allow arbitrary lengths (#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes authored Oct 25, 2023
1 parent 885c0cc commit 0d78d1e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 53 deletions.
10 changes: 5 additions & 5 deletions crates/interpreter/src/instructions/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ pub fn push0<H: Host, SPEC: Spec>(interpreter: &mut Interpreter<'_>, _host: &mut

pub fn push<const N: usize, H: Host>(interpreter: &mut Interpreter<'_>, _host: &mut H) {
gas!(interpreter, gas::VERYLOW);
let start = interpreter.instruction_pointer;
// Safety: In Analysis we appended needed bytes for bytecode so that we are safe to just add without
// checking if it is out of bound. This makes both of our unsafes block safe to do.
// SAFETY: In analysis we append trailing bytes to the bytecode so that this is safe to do
// without bounds checking.
let ip = interpreter.instruction_pointer;
if let Err(result) = interpreter
.stack
.push_slice::<N>(unsafe { core::slice::from_raw_parts(start, N) })
.push_slice(unsafe { core::slice::from_raw_parts(ip, N) })
{
interpreter.instruction_result = result;
return;
}
interpreter.instruction_pointer = unsafe { start.add(N) };
interpreter.instruction_pointer = unsafe { ip.add(N) };
}

pub fn dup<const N: usize, H: Host>(interpreter: &mut Interpreter<'_>, _host: &mut H) {
Expand Down
85 changes: 37 additions & 48 deletions crates/interpreter/src/interpreter/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,63 +222,52 @@ impl Stack {
Ok(())
}

/// Push a slice of bytes of `N` length onto the stack.
///
/// If it will exceed the stack limit, returns `StackOverflow` error and leaves the stack
/// unchanged.
/// Pushes an arbitrary length slice of bytes onto the stack, padding the last word with zeros
/// if necessary.
#[inline]
pub fn push_slice<const N: usize>(&mut self, slice: &[u8]) -> Result<(), InstructionResult> {
let new_len = self.data.len() + 1;
pub fn push_slice(&mut self, slice: &[u8]) -> Result<(), InstructionResult> {
if slice.is_empty() {
return Ok(());
}

let n_words = (slice.len() + 31) / 32;
let new_len = self.data.len() + n_words;
if new_len > STACK_LIMIT {
return Err(InstructionResult::StackOverflow);
}

let slot;
// Safety: check above ensures us that we are okay in increment len.
// SAFETY: length checked above.
unsafe {
self.data.set_len(new_len);
slot = self.data.get_unchecked_mut(new_len - 1);
}
let dst = self.data.as_mut_ptr().add(self.data.len()).cast::<u64>();
let mut i = 0;

unsafe {
*slot.as_limbs_mut() = [0u64; 4];
let mut dangling = [0u8; 8];
if N < 8 {
dangling[8 - N..].copy_from_slice(slice);
slot.as_limbs_mut()[0] = u64::from_be_bytes(dangling);
} else if N < 16 {
slot.as_limbs_mut()[0] =
u64::from_be_bytes(slice[N - 8..N].try_into().expect("Infallible"));
if N != 8 {
dangling[8 * 2 - N..].copy_from_slice(&slice[..N - 8]);
slot.as_limbs_mut()[1] = u64::from_be_bytes(dangling);
}
} else if N < 24 {
slot.as_limbs_mut()[0] =
u64::from_be_bytes(slice[N - 8..N].try_into().expect("Infallible"));
slot.as_limbs_mut()[1] =
u64::from_be_bytes(slice[N - 16..N - 8].try_into().expect("Infallible"));
if N != 16 {
dangling[8 * 3 - N..].copy_from_slice(&slice[..N - 16]);
slot.as_limbs_mut()[2] = u64::from_be_bytes(dangling);
}
} else {
// M<32
slot.as_limbs_mut()[0] =
u64::from_be_bytes(slice[N - 8..N].try_into().expect("Infallible"));
slot.as_limbs_mut()[1] =
u64::from_be_bytes(slice[N - 16..N - 8].try_into().expect("Infallible"));
slot.as_limbs_mut()[2] =
u64::from_be_bytes(slice[N - 24..N - 16].try_into().expect("Infallible"));
if N == 32 {
slot.as_limbs_mut()[3] =
u64::from_be_bytes(slice[..N - 24].try_into().expect("Infallible"));
} else if N != 24 {
dangling[8 * 4 - N..].copy_from_slice(&slice[..N - 24]);
slot.as_limbs_mut()[3] = u64::from_be_bytes(dangling);
}
// write full words
let limbs = slice.rchunks_exact(8);
let rem = limbs.remainder();
for limb in limbs {
*dst.add(i) = u64::from_be_bytes(limb.try_into().unwrap());
i += 1;
}

// write remainder by padding with zeros
if !rem.is_empty() {
let mut tmp = [0u8; 8];
tmp[8 - rem.len()..].copy_from_slice(rem);
*dst.add(i) = u64::from_be_bytes(tmp);
i += 1;
}

debug_assert_eq!((i + 3) / 4, n_words, "wrote beyond end of stack");

// zero out upper bytes of last word
let m = i % 4; // 32 / 8
if m != 0 {
dst.add(i).write_bytes(0, 4 - m);
}

self.data.set_len(new_len);
}

Ok(())
}

Expand Down

0 comments on commit 0d78d1e

Please sign in to comment.