Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Add bucket cache in MemoryManager #260

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions canbench_results.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ benches:
scopes: {}
btreemap_get_blob_512_1024_v2_mem_manager:
total:
instructions: 2624136090
instructions: 2517895062
heap_increase: 0
stable_memory_increase: 0
scopes: {}
Expand Down Expand Up @@ -139,7 +139,7 @@ benches:
scopes: {}
btreemap_get_u64_u64_v2_mem_manager:
total:
instructions: 421366446
instructions: 337472533
heap_increase: 0
stable_memory_increase: 0
scopes: {}
Expand Down Expand Up @@ -223,7 +223,7 @@ benches:
scopes: {}
btreemap_insert_blob_1024_512_v2_mem_manager:
total:
instructions: 5402984503
instructions: 5243334303
heap_increase: 0
stable_memory_increase: 256
scopes: {}
Expand Down Expand Up @@ -379,7 +379,7 @@ benches:
scopes: {}
btreemap_insert_u64_u64_mem_manager:
total:
instructions: 680292499
instructions: 553691634
heap_increase: 0
stable_memory_increase: 0
scopes: {}
Expand Down Expand Up @@ -631,7 +631,7 @@ benches:
scopes: {}
memory_manager_overhead:
total:
instructions: 1182002161
instructions: 1181967369
heap_increase: 0
stable_memory_increase: 8320
scopes: {}
Expand Down Expand Up @@ -661,7 +661,7 @@ benches:
scopes: {}
vec_get_blob_4_mem_manager:
total:
instructions: 9333723
instructions: 7238723
heap_increase: 0
stable_memory_increase: 0
scopes: {}
Expand All @@ -673,7 +673,7 @@ benches:
scopes: {}
vec_get_blob_64_mem_manager:
total:
instructions: 17664902
instructions: 15339702
heap_increase: 0
stable_memory_increase: 0
scopes: {}
Expand Down
168 changes: 160 additions & 8 deletions src/memory_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::{
types::{Address, Bytes},
write, write_struct, Memory, WASM_PAGE_SIZE,
};
use std::cell::RefCell;
use std::cell::{Cell, RefCell};
use std::rc::Rc;

const MAGIC: &[u8; 3] = b"MGR";
Expand Down Expand Up @@ -151,6 +151,7 @@ impl<M: Memory> MemoryManager<M> {
VirtualMemory {
id,
memory_manager: self.inner.clone(),
cache: BucketCache::new(),
}
}

Expand Down Expand Up @@ -193,6 +194,7 @@ impl Header {
pub struct VirtualMemory<M: Memory> {
id: MemoryId,
memory_manager: Rc<RefCell<MemoryManagerInner<M>>>,
cache: BucketCache,
}

impl<M: Memory> Memory for VirtualMemory<M> {
Expand All @@ -205,17 +207,21 @@ impl<M: Memory> Memory for VirtualMemory<M> {
}

fn read(&self, offset: u64, dst: &mut [u8]) {
self.memory_manager.borrow().read(self.id, offset, dst)
self.memory_manager
.borrow()
.read(self.id, offset, dst, &self.cache)
}

unsafe fn read_unsafe(&self, offset: u64, dst: *mut u8, count: usize) {
self.memory_manager
.borrow()
.read_unsafe(self.id, offset, dst, count)
.read_unsafe(self.id, offset, dst, count, &self.cache)
}

fn write(&self, offset: u64, src: &[u8]) {
self.memory_manager.borrow().write(self.id, offset, src)
self.memory_manager
.borrow()
.write(self.id, offset, src, &self.cache)
}
}

Expand Down Expand Up @@ -373,7 +379,15 @@ impl<M: Memory> MemoryManagerInner<M> {
old_size as i64
}

fn write(&self, id: MemoryId, offset: u64, src: &[u8]) {
fn write(&self, id: MemoryId, offset: u64, src: &[u8], bucket_cache: &BucketCache) {
if let Some(real_address) = bucket_cache.get(VirtualSegment {
address: offset.into(),
length: src.len().into(),
}) {
self.memory.write(real_address.get(), src);
return;
}

if (offset + src.len() as u64) > self.memory_size(id) * WASM_PAGE_SIZE {
panic!("{id:?}: write out of bounds");
}
Expand All @@ -385,6 +399,7 @@ impl<M: Memory> MemoryManagerInner<M> {
address: offset.into(),
length: src.len().into(),
},
bucket_cache,
|RealSegment { address, length }| {
self.memory.write(
address.get(),
Expand All @@ -397,17 +412,33 @@ impl<M: Memory> MemoryManagerInner<M> {
}

#[inline]
fn read(&self, id: MemoryId, offset: u64, dst: &mut [u8]) {
fn read(&self, id: MemoryId, offset: u64, dst: &mut [u8], bucket_cache: &BucketCache) {
// SAFETY: this is trivially safe because dst has dst.len() space.
unsafe { self.read_unsafe(id, offset, dst.as_mut_ptr(), dst.len()) }
unsafe { self.read_unsafe(id, offset, dst.as_mut_ptr(), dst.len(), bucket_cache) }
}

/// # Safety
///
/// Callers must guarantee that
/// * it is valid to write `count` number of bytes starting from `dst`,
/// * `dst..dst + count` does not overlap with `self`.
unsafe fn read_unsafe(&self, id: MemoryId, offset: u64, dst: *mut u8, count: usize) {
unsafe fn read_unsafe(
&self,
id: MemoryId,
offset: u64,
dst: *mut u8,
count: usize,
bucket_cache: &BucketCache,
) {
// First try to find the virtual segment in the cache.
if let Some(real_address) = bucket_cache.get(VirtualSegment {
address: offset.into(),
length: count.into(),
}) {
self.memory.read_unsafe(real_address.get(), dst, count);
return;
}

if (offset + count as u64) > self.memory_size(id) * WASM_PAGE_SIZE {
panic!("{id:?}: read out of bounds");
}
Expand All @@ -419,6 +450,7 @@ impl<M: Memory> MemoryManagerInner<M> {
address: offset.into(),
length: count.into(),
},
bucket_cache,
|RealSegment { address, length }| {
self.memory.read_unsafe(
address.get(),
Expand Down Expand Up @@ -465,6 +497,7 @@ impl<M: Memory> MemoryManagerInner<M> {
&self,
MemoryId(id): MemoryId,
virtual_segment: VirtualSegment,
bucket_cache: &BucketCache,
mut func: impl FnMut(RealSegment),
) {
// Get the buckets allocated to the given memory id.
Expand All @@ -482,8 +515,19 @@ impl<M: Memory> MemoryManagerInner<M> {
while length > 0 {
let bucket_address =
self.bucket_address(buckets.get(bucket_idx).expect("bucket idx out of bounds"));

let bucket_start = bucket_idx as u64 * bucket_size_in_bytes;
let segment_len = (bucket_size_in_bytes - start_offset_in_bucket).min(length);

// Cache this bucket.
bucket_cache.store(
VirtualSegment {
address: bucket_start.into(),
length: self.bucket_size_in_bytes(),
},
bucket_address,
);

func(RealSegment {
address: bucket_address + start_offset_in_bucket.into(),
length: segment_len.into(),
Expand Down Expand Up @@ -516,11 +560,18 @@ impl<M: Memory> MemoryManagerInner<M> {
}
}

#[derive(Copy, Clone)]
struct VirtualSegment {
address: Address,
length: Bytes,
}

impl VirtualSegment {
fn contains_segment(&self, other: &VirtualSegment) -> bool {
self.address <= other.address && other.address + other.length <= self.address + self.length
}
}

struct RealSegment {
address: Address,
length: Bytes,
Expand All @@ -547,6 +598,49 @@ fn bucket_allocations_address(id: BucketId) -> Address {
Address::from(0) + Header::size() + Bytes::from(id.0)
}

/// Cache which stores the last touched bucket and the corresponding real address.
///
/// If a segment from this bucket is accessed, we can return the real address faster.
#[derive(Clone)]
struct BucketCache {
bucket: Cell<VirtualSegment>,
/// The real address that corresponds to bucket.address
real_address: Cell<Address>,
}

impl BucketCache {
#[inline]
fn new() -> Self {
BucketCache {
bucket: Cell::new(VirtualSegment {
address: Address::from(0),
length: Bytes::new(0),
}),
real_address: Cell::new(Address::from(0)),
}
}
}

impl BucketCache {
/// Returns the real address corresponding to `virtual_segment.address` if `virtual_segment`
/// is fully contained within the cached bucket, otherwise `None`.
#[inline]
fn get(&self, virtual_segment: VirtualSegment) -> Option<Address> {
let cached_bucket = self.bucket.get();

cached_bucket
.contains_segment(&virtual_segment)
.then(|| self.real_address.get() + (virtual_segment.address - cached_bucket.address))
}

/// Stores the mapping of a bucket to a real address.
#[inline]
fn store(&self, bucket: VirtualSegment, real_address: Address) {
self.bucket.set(bucket);
self.real_address.set(real_address);
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -950,4 +1044,62 @@ mod test {
let expected_read = include_bytes!("memory_manager/stability_read.golden");
assert!(expected_read.as_slice() == read.as_slice());
}

#[test]
fn bucket_cache() {
let bucket_cache = BucketCache::new();

// No match, nothing has been stored.
assert_eq!(
bucket_cache.get(VirtualSegment {
address: Address::from(0),
length: Bytes::from(1u64)
}),
None
);

bucket_cache.store(
VirtualSegment {
address: Address::from(0),
length: Bytes::from(335u64),
},
Address::from(983),
);

// Match at the beginning
assert_eq!(
bucket_cache.get(VirtualSegment {
address: Address::from(1),
length: Bytes::from(2u64)
}),
Some(Address::from(984))
);

// Match at the end
assert_eq!(
bucket_cache.get(VirtualSegment {
address: Address::from(334),
length: Bytes::from(1u64)
}),
Some(Address::from(1317))
);

// Match entire segment
assert_eq!(
bucket_cache.get(VirtualSegment {
address: Address::from(0),
length: Bytes::from(335u64),
}),
Some(Address::from(983))
);

// No match - outside cached segment
assert_eq!(
bucket_cache.get(VirtualSegment {
address: Address::from(1),
length: Bytes::from(335u64)
}),
None
);
}
}
8 changes: 8 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ impl Add<Bytes> for Address {
}
}

impl Sub<Address> for Address {
type Output = Bytes;

fn sub(self, address: Address) -> Self::Output {
Bytes(self.0 - address.0)
}
}

impl Sub<Bytes> for Address {
type Output = Self;

Expand Down
Loading