diff --git a/crates/interpreter/src/lib.rs b/crates/interpreter/src/lib.rs index 5196f999..9676741e 100644 --- a/crates/interpreter/src/lib.rs +++ b/crates/interpreter/src/lib.rs @@ -157,3 +157,4 @@ pub use syntax::{ FuncType, GlobalType, ImportDesc, Limits, Mut, RefType, ResultType, TableType, ValType, }; pub use valid::prepare; +pub use valid::merge; diff --git a/crates/interpreter/src/side_table.rs b/crates/interpreter/src/side_table.rs index 78fee146..217cab8a 100644 --- a/crates/interpreter/src/side_table.rs +++ b/crates/interpreter/src/side_table.rs @@ -15,44 +15,49 @@ use alloc::vec::Vec; use core::ops::Range; +use bytemuck::{Pod, Zeroable}; + use crate::error::*; use crate::module::Parser; +#[derive(Default)] +pub struct BranchTableView<'m> { + pub metadata: Metadata<'m>, + pub branch_idx: usize, +} + +// TODO(dev/fast-interp): Change [u16] to [u8] to not rely on alignment. pub struct SideTableView<'m> { pub func_idx: usize, pub indices: &'m [u16], // including 0 and the length of metadata_array pub metadata: &'m [u16], - pub branch_table_view: Metadata<'m>, + pub branch_table_view: BranchTableView<'m>, } impl<'m> SideTableView<'m> { - pub fn new(parser: &mut crate::valid::Parser<'m>) -> Result { + pub fn new(binary: &'m [u8]) -> Result { + let num_funcs = + bytemuck::pod_read_unaligned::(bytemuck::cast_slice(&binary[0 .. 2])) as usize; + let indices_end = 2 + (num_funcs + 1) * 2; Ok(SideTableView { func_idx: 0, - indices: parse_side_table_field(parser)?, - metadata: parse_side_table_field(parser)?, + indices: bytemuck::cast_slice::<_, u16>(binary.get(2 .. indices_end).unwrap()), + metadata: bytemuck::cast_slice::<_, u16>(binary.get(indices_end ..).unwrap()), branch_table_view: Default::default(), }) } - pub fn metadata(&self, func_idx: usize) -> Metadata<'m> { + pub fn branch_table_view(&mut self, func_idx: usize) -> BranchTableView<'m> { + BranchTableView { metadata: self.metadata(func_idx), ..Default::default() } + } + + fn metadata(&self, func_idx: usize) -> Metadata<'m> { Metadata( &self.metadata[self.indices[func_idx] as usize .. self.indices[func_idx + 1] as usize], ) } } -fn parse_u16(data: &[u8]) -> u16 { - bytemuck::pod_read_unaligned::(bytemuck::cast_slice(&data[0 .. 2])) -} - -fn parse_side_table_field<'m>(parser: &mut crate::valid::Parser<'m>) -> Result<&'m [u16], Error> { - let len = parse_u16(parser.save()) as usize; - let parser = parser.split_at(len)?; - let bytes = parser.save().get(0 .. len * 2).unwrap(); - Ok(bytemuck::cast_slice::<_, u16>(bytes)) -} - #[derive(Default, Copy, Clone)] pub struct Metadata<'m>(&'m [u16]); @@ -86,10 +91,16 @@ pub struct MetadataEntry { pub branch_table: Vec, } -#[derive(Copy, Clone, Debug, bytemuck::AnyBitPattern)] +#[derive(Copy, Clone, Debug, Pod, Zeroable)] #[repr(transparent)] pub struct BranchTableEntry([u16; 3]); +impl BranchTableEntry { + pub fn as_bytes(&self) -> &[u8; size_of::()] { + bytemuck::cast_ref(self) + } +} + pub struct BranchTableEntryView { /// The amount to adjust the instruction pointer by if the branch is taken. pub delta_ip: i32, diff --git a/crates/interpreter/src/valid.rs b/crates/interpreter/src/valid.rs index dab87ede..6d172d41 100644 --- a/crates/interpreter/src/valid.rs +++ b/crates/interpreter/src/valid.rs @@ -26,6 +26,23 @@ use crate::toctou::*; use crate::util::*; use crate::*; +pub fn merge(binary: &[u8], side_table: Vec) -> Result, Error> { + let mut wasm = vec![]; + wasm.extend_from_slice(&binary[0 .. 8]); + wasm.push(0); + wasm.extend_from_slice(&side_table.len().to_le_bytes()); + for entry in side_table { + wasm.extend_from_slice(&entry.type_idx.to_le_bytes()); + wasm.extend_from_slice(&entry.parser_range.start.to_le_bytes()); + wasm.extend_from_slice(&entry.parser_range.end.to_le_bytes()); + for branch in entry.branch_table { + wasm.extend_from_slice(branch.as_bytes()); + } + } + wasm.extend_from_slice(&binary[8 ..]); + Ok(wasm) +} + /// Checks whether a WASM module in binary format is valid, and returns the side table. pub fn prepare(binary: &[u8]) -> Result, Error> { validate::(binary) @@ -138,7 +155,7 @@ impl ValidMode for Verify { /// Contains at most one _target_ branch. Source branches are eagerly patched to /// their target branch using the branch table. type Branches<'m> = Option>; - type BranchTable<'m> = Metadata<'m>; + type BranchTable<'m> = BranchTableView<'m>; type SideTable<'m> = SideTableView<'m>; type Result = (); @@ -146,16 +163,16 @@ impl ValidMode for Verify { check(parser.parse_section_id()? == SectionId::Custom)?; let mut section = parser.split_section()?; check(section.parse_name()? == "wasefire-sidetable")?; - SideTableView::new(parser) + SideTableView::new(parser.save()) } fn next_branch_table<'a, 'm>( side_table: &'a mut Self::SideTable<'m>, type_idx: usize, parser_range: Range, ) -> Result<&'a mut Self::BranchTable<'m>, Error> { - side_table.branch_table_view = side_table.metadata(side_table.func_idx); + side_table.branch_table_view = side_table.branch_table_view(side_table.func_idx); side_table.func_idx += 1; - check(side_table.branch_table_view.type_idx() == type_idx)?; - check(side_table.branch_table_view.parser_range() == parser_range)?; + check(side_table.branch_table_view.metadata.type_idx() == type_idx)?; + check(side_table.branch_table_view.metadata.parser_range() == parser_range)?; Ok(&mut side_table.branch_table_view) } @@ -170,7 +187,7 @@ impl<'m> BranchesApi<'m> for Option> { } } -impl<'m> BranchTableApi<'m> for Metadata<'m> { +impl<'m> BranchTableApi<'m> for BranchTableView<'m> { fn stitch_branch( &mut self, source: SideTableBranch<'m>, target: SideTableBranch<'m>, ) -> CheckResult { @@ -178,7 +195,8 @@ impl<'m> BranchTableApi<'m> for Metadata<'m> { } fn patch_branch(&self, mut source: SideTableBranch<'m>) -> Result, Error> { - let entry = self.branch_table()[source.branch_table].view(); + source.branch_table = self.branch_idx; + let entry = self.metadata.branch_table()[source.branch_table].view(); offset_front(source.parser, entry.delta_ip as isize); source.branch_table += entry.delta_stp as usize; source.result = entry.val_cnt as usize; @@ -186,10 +204,12 @@ impl<'m> BranchTableApi<'m> for Metadata<'m> { Ok(source) } - fn allocate_branch(&mut self) {} + fn allocate_branch(&mut self) { + self.branch_idx += 1; + } fn next_index(&self) -> usize { - 0 + self.branch_idx } } diff --git a/crates/interpreter/tests/spec.rs b/crates/interpreter/tests/spec.rs index de3b9bc2..b5ab57bc 100644 --- a/crates/interpreter/tests/spec.rs +++ b/crates/interpreter/tests/spec.rs @@ -164,6 +164,9 @@ impl<'m> Env<'m> { } fn maybe_instantiate(&mut self, name: &str, wasm: &[u8]) -> Result { + let side_table = prepare(wasm)?; + let merged = merge(wasm, side_table)?; + let wasm = merged.as_slice(); let module = self.alloc(wasm.len()); module.copy_from_slice(wasm); let module = Module::new(module)?;