diff --git a/programs/bpf_loader/src/syscalls/mod.rs b/programs/bpf_loader/src/syscalls/mod.rs index 64d1d85e5ee964..4193b9fcfc97a8 100644 --- a/programs/bpf_loader/src/syscalls/mod.rs +++ b/programs/bpf_loader/src/syscalls/mod.rs @@ -43,7 +43,7 @@ use { remaining_compute_units_syscall_enabled, stop_sibling_instruction_search_at_parent, stop_truncating_strings_in_syscalls, switch_to_new_elf_parser, }, - hash::{Hasher, HASH_BYTES}, + hash::{Hash, Hasher}, instruction::{ AccountMeta, InstructionError, ProcessedSiblingInstruction, TRANSACTION_LEVEL_STACK_HEIGHT, @@ -132,6 +132,103 @@ pub enum SyscallError { type Error = Box; +pub trait HasherImpl { + const NAME: &'static str; + type Output: AsRef<[u8]>; + + fn create_hasher() -> Self; + fn hash(&mut self, val: &[u8]); + fn result(self) -> Self::Output; + fn get_base_cost(compute_budget: &ComputeBudget) -> u64; + fn get_byte_cost(compute_budget: &ComputeBudget) -> u64; + fn get_max_slices(compute_budget: &ComputeBudget) -> u64; +} + +pub struct Sha256Hasher(Hasher); +pub struct Blake3Hasher(blake3::Hasher); +pub struct Keccak256Hasher(keccak::Hasher); + +impl HasherImpl for Sha256Hasher { + const NAME: &'static str = "Sha256"; + type Output = Hash; + + fn create_hasher() -> Self { + Sha256Hasher(Hasher::default()) + } + + fn hash(&mut self, val: &[u8]) { + self.0.hash(val); + } + + fn result(self) -> Self::Output { + self.0.result() + } + + fn get_base_cost(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_base_cost + } + fn get_byte_cost(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_byte_cost + } + fn get_max_slices(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_max_slices + } +} + +impl HasherImpl for Blake3Hasher { + const NAME: &'static str = "Blake3"; + type Output = blake3::Hash; + + fn create_hasher() -> Self { + Blake3Hasher(blake3::Hasher::default()) + } + + fn hash(&mut self, val: &[u8]) { + self.0.hash(val); + } + + fn result(self) -> Self::Output { + self.0.result() + } + + fn get_base_cost(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_base_cost + } + fn get_byte_cost(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_byte_cost + } + fn get_max_slices(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_max_slices + } +} + +impl HasherImpl for Keccak256Hasher { + const NAME: &'static str = "Keccak256"; + type Output = keccak::Hash; + + fn create_hasher() -> Self { + Keccak256Hasher(keccak::Hasher::default()) + } + + fn hash(&mut self, val: &[u8]) { + self.0.hash(val); + } + + fn result(self) -> Self::Output { + self.0.result() + } + + fn get_base_cost(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_base_cost + } + fn get_byte_cost(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_byte_cost + } + fn get_max_slices(compute_budget: &ComputeBudget) -> u64 { + compute_budget.sha256_max_slices + } +} + fn consume_compute_meter(invoke_context: &InvokeContext, amount: u64) -> Result<(), Error> { invoke_context.consume_checked(amount)?; Ok(()) @@ -220,10 +317,10 @@ pub fn create_program_runtime_environment_v1<'a>( )?; // Sha256 - result.register_function_hashed(*b"sol_sha256", SyscallSha256::call)?; + result.register_function_hashed(*b"sol_sha256", SyscallHash::call::)?; // Keccak256 - result.register_function_hashed(*b"sol_keccak256", SyscallKeccak256::call)?; + result.register_function_hashed(*b"sol_keccak256", SyscallHash::call::)?; // Secp256k1 Recover result.register_function_hashed(*b"sol_secp256k1_recover", SyscallSecp256k1Recover::call)?; @@ -233,7 +330,7 @@ pub fn create_program_runtime_environment_v1<'a>( result, blake3_syscall_enabled, *b"sol_blake3", - SyscallBlake3::call, + SyscallHash::call::, )?; // Elliptic Curve Operations @@ -519,6 +616,32 @@ macro_rules! declare_syscall { }; } +#[macro_export] +macro_rules! declare_syscallhash { + ($(#[$attr:meta])* $name:ident, $inner_call:item) => { + $(#[$attr])* + pub struct $name {} + impl $name { + $inner_call + pub fn call( + invoke_context: &mut InvokeContext, + arg_a: u64, + arg_b: u64, + arg_c: u64, + arg_d: u64, + arg_e: u64, + memory_mapping: &mut MemoryMapping, + result: &mut ProgramResult, + ) { + let converted_result: ProgramResult = Self::inner_call::( + invoke_context, arg_a, arg_b, arg_c, arg_d, arg_e, memory_mapping, + ).into(); + *result = converted_result; + } + } + }; +} + declare_syscall!( /// Abort syscall functions, called when the SBF program calls `abort()` /// LLVM will insert calls to `abort()` if it detects an untenable situation, @@ -750,136 +873,6 @@ declare_syscall!( } ); -declare_syscall!( - /// SHA256 - SyscallSha256, - fn inner_call( - invoke_context: &mut InvokeContext, - vals_addr: u64, - vals_len: u64, - result_addr: u64, - _arg4: u64, - _arg5: u64, - memory_mapping: &mut MemoryMapping, - ) -> Result { - let compute_budget = invoke_context.get_compute_budget(); - if compute_budget.sha256_max_slices < vals_len { - ic_msg!( - invoke_context, - "Sha256 hashing {} sequences in one syscall is over the limit {}", - vals_len, - compute_budget.sha256_max_slices, - ); - return Err(SyscallError::TooManySlices.into()); - } - - consume_compute_meter(invoke_context, compute_budget.sha256_base_cost)?; - - let hash_result = translate_slice_mut::( - memory_mapping, - result_addr, - HASH_BYTES as u64, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - let mut hasher = Hasher::default(); - if vals_len > 0 { - let vals = translate_slice::<&[u8]>( - memory_mapping, - vals_addr, - vals_len, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - for val in vals.iter() { - let bytes = translate_slice::( - memory_mapping, - val.as_ptr() as u64, - val.len() as u64, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - let cost = compute_budget.mem_op_base_cost.max( - compute_budget.sha256_byte_cost.saturating_mul( - (val.len() as u64) - .checked_div(2) - .expect("div by non-zero literal"), - ), - ); - consume_compute_meter(invoke_context, cost)?; - hasher.hash(bytes); - } - } - hash_result.copy_from_slice(&hasher.result().to_bytes()); - Ok(0) - } -); - -declare_syscall!( - // Keccak256 - SyscallKeccak256, - fn inner_call( - invoke_context: &mut InvokeContext, - vals_addr: u64, - vals_len: u64, - result_addr: u64, - _arg4: u64, - _arg5: u64, - memory_mapping: &mut MemoryMapping, - ) -> Result { - let compute_budget = invoke_context.get_compute_budget(); - if compute_budget.sha256_max_slices < vals_len { - ic_msg!( - invoke_context, - "Keccak256 hashing {} sequences in one syscall is over the limit {}", - vals_len, - compute_budget.sha256_max_slices, - ); - return Err(SyscallError::TooManySlices.into()); - } - - consume_compute_meter(invoke_context, compute_budget.sha256_base_cost)?; - - let hash_result = translate_slice_mut::( - memory_mapping, - result_addr, - keccak::HASH_BYTES as u64, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - let mut hasher = keccak::Hasher::default(); - if vals_len > 0 { - let vals = translate_slice::<&[u8]>( - memory_mapping, - vals_addr, - vals_len, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - for val in vals.iter() { - let bytes = translate_slice::( - memory_mapping, - val.as_ptr() as u64, - val.len() as u64, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - let cost = compute_budget.mem_op_base_cost.max( - compute_budget.sha256_byte_cost.saturating_mul( - (val.len() as u64) - .checked_div(2) - .expect("div by non-zero literal"), - ), - ); - consume_compute_meter(invoke_context, cost)?; - hasher.hash(bytes); - } - } - hash_result.copy_from_slice(&hasher.result().to_bytes()); - Ok(0) - } -); - declare_syscall!( /// secp256k1_recover SyscallSecp256k1Recover, @@ -1314,71 +1307,6 @@ declare_syscall!( } ); -declare_syscall!( - // Blake3 - SyscallBlake3, - fn inner_call( - invoke_context: &mut InvokeContext, - vals_addr: u64, - vals_len: u64, - result_addr: u64, - _arg4: u64, - _arg5: u64, - memory_mapping: &mut MemoryMapping, - ) -> Result { - let compute_budget = invoke_context.get_compute_budget(); - if compute_budget.sha256_max_slices < vals_len { - ic_msg!( - invoke_context, - "Blake3 hashing {} sequences in one syscall is over the limit {}", - vals_len, - compute_budget.sha256_max_slices, - ); - return Err(SyscallError::TooManySlices.into()); - } - - consume_compute_meter(invoke_context, compute_budget.sha256_base_cost)?; - - let hash_result = translate_slice_mut::( - memory_mapping, - result_addr, - blake3::HASH_BYTES as u64, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - let mut hasher = blake3::Hasher::default(); - if vals_len > 0 { - let vals = translate_slice::<&[u8]>( - memory_mapping, - vals_addr, - vals_len, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - for val in vals.iter() { - let bytes = translate_slice::( - memory_mapping, - val.as_ptr() as u64, - val.len() as u64, - invoke_context.get_check_aligned(), - invoke_context.get_check_size(), - )?; - let cost = compute_budget.mem_op_base_cost.max( - compute_budget.sha256_byte_cost.saturating_mul( - (val.len() as u64) - .checked_div(2) - .expect("div by non-zero literal"), - ), - ); - consume_compute_meter(invoke_context, cost)?; - hasher.hash(bytes); - } - } - hash_result.copy_from_slice(&hasher.result().to_bytes()); - Ok(0) - } -); - declare_syscall!( /// Set return data SyscallSetReturnData, @@ -2020,6 +1948,75 @@ declare_syscall!( } ); +declare_syscallhash!( + // Generic Hashing Syscall + SyscallHash, + fn inner_call( + invoke_context: &mut InvokeContext, + vals_addr: u64, + vals_len: u64, + result_addr: u64, + _arg4: u64, + _arg5: u64, + memory_mapping: &mut MemoryMapping, + ) -> Result { + let compute_budget = invoke_context.get_compute_budget(); + let hash_base_cost = H::get_base_cost(compute_budget); + let hash_byte_cost = H::get_byte_cost(compute_budget); + let hash_max_slices = H::get_max_slices(compute_budget); + if hash_max_slices < vals_len { + ic_msg!( + invoke_context, + "{} Hashing {} sequences in one syscall is over the limit {}", + H::NAME, + vals_len, + hash_max_slices, + ); + return Err(SyscallError::TooManySlices.into()); + } + + consume_compute_meter(invoke_context, hash_base_cost)?; + + let hash_result = translate_slice_mut::( + memory_mapping, + result_addr, + std::mem::size_of::() as u64, + invoke_context.get_check_aligned(), + invoke_context.get_check_size(), + )?; + let mut hasher = H::create_hasher(); + if vals_len > 0 { + let vals = translate_slice::<&[u8]>( + memory_mapping, + vals_addr, + vals_len, + invoke_context.get_check_aligned(), + invoke_context.get_check_size(), + )?; + for val in vals.iter() { + let bytes = translate_slice::( + memory_mapping, + val.as_ptr() as u64, + val.len() as u64, + invoke_context.get_check_aligned(), + invoke_context.get_check_size(), + )?; + let cost = compute_budget.mem_op_base_cost.max( + hash_byte_cost.saturating_mul( + (val.len() as u64) + .checked_div(2) + .expect("div by non-zero literal"), + ), + ); + consume_compute_meter(invoke_context, cost)?; + hasher.hash(bytes); + } + } + hash_result.copy_from_slice(hasher.result().as_ref()); + Ok(0) + } +); + #[cfg(test)] #[allow(clippy::arithmetic_side_effects)] #[allow(clippy::indexing_slicing)] @@ -2042,7 +2039,7 @@ mod tests { account::{create_account_shared_data_for_test, AccountSharedData}, bpf_loader, fee_calculator::FeeCalculator, - hash::hashv, + hash::{hashv, HASH_BYTES}, instruction::Instruction, program::check_type_assumptions, stable_layout::stable_instruction::StableInstruction, @@ -2705,7 +2702,7 @@ mod tests { ); let mut result = ProgramResult::Ok(0); - SyscallSha256::call( + SyscallHash::call::( &mut invoke_context, ro_va, ro_len, @@ -2720,7 +2717,7 @@ mod tests { let hash_local = hashv(&[bytes1.as_ref(), bytes2.as_ref()]).to_bytes(); assert_eq!(hash_result, hash_local); let mut result = ProgramResult::Ok(0); - SyscallSha256::call( + SyscallHash::call::( &mut invoke_context, ro_va - 1, // AccessViolation ro_len, @@ -2732,7 +2729,7 @@ mod tests { ); assert_access_violation!(result, ro_va - 1, 32); let mut result = ProgramResult::Ok(0); - SyscallSha256::call( + SyscallHash::call::( &mut invoke_context, ro_va, ro_len + 1, // AccessViolation @@ -2744,7 +2741,7 @@ mod tests { ); assert_access_violation!(result, ro_va, 48); let mut result = ProgramResult::Ok(0); - SyscallSha256::call( + SyscallHash::call::( &mut invoke_context, ro_va, ro_len, @@ -2756,7 +2753,7 @@ mod tests { ); assert_access_violation!(result, rw_va - 1, HASH_BYTES as u64); let mut result = ProgramResult::Ok(0); - SyscallSha256::call( + SyscallHash::call::( &mut invoke_context, ro_va, ro_len,