From 246bf21d257bb0b32ccef3b02ec92ee0a2005162 Mon Sep 17 00:00:00 2001 From: Viren6 <94880762+Viren6@users.noreply.github.com> Date: Wed, 8 Jan 2025 08:44:29 +0000 Subject: [PATCH] Mmap embed binary (#88) Writes decompressed net files to a folder called Monty in the temp directory, so they can be mmap-ed (shared between seperate Monty instances). This also eliminates the decompression time after the first load. Old nets get automatically cleaned up so there is only ever one copy of the policy/value net in the folder at any point. No functional change. Bench: 1733801 --- Cargo.lock | 1 + Cargo.toml | 2 + src/main.rs | 197 ++++++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 165 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e04faa4..dea500d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -190,6 +190,7 @@ version = "1.0.0" dependencies = [ "chrono", "memmap2", + "once_cell", "sha2", "zstd", ] diff --git a/Cargo.toml b/Cargo.toml index 4f02618..07a470a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,8 @@ resolver = "2" [dependencies] memmap2 = "0.9.5" zstd = "0.13.2" +once_cell = "1.20.2" +sha2 = "0.10.8" [build-dependencies] sha2 = "0.10.8" diff --git a/src/main.rs b/src/main.rs index 8da7a02..44aad80 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,72 +8,199 @@ fn main() { #[cfg(feature = "embed")] mod net { + use memmap2::Mmap; use monty::{uci, ChessState, MctsParams, PolicyNetwork, ValueNetwork}; - use std::io::Cursor; - use std::mem::MaybeUninit; - use std::sync::LazyLock; + use once_cell::sync::Lazy; + use sha2::{Digest, Sha256}; + use std::fs::{self, File}; + use std::io::{Cursor, Write}; + use std::path::{Path, PathBuf}; use zstd::stream::decode_all; // Embed compressed byte arrays static COMPRESSED_VALUE: &[u8] = include_bytes!("../value.network.zst"); static COMPRESSED_POLICY: &[u8] = include_bytes!("../policy.network.zst"); - /// Helper function to safely decompress and initialize a Boxed structure. - fn decompress_into_boxed(data: &[u8]) -> Box { - // Ensure the decompressed data size matches the target structure size + /// Compute the first 12 hexadecimal characters of the SHA-256 hash of the data. + fn compute_short_sha(data: &[u8]) -> String { + let mut hasher = Sha256::new(); + hasher.update(data); + let result = hasher.finalize(); + // Convert the hash to a hexadecimal string and take the first 12 characters + format!("{:x}", result)[..12].to_string() + } + + /// Get the full path in the OS's temporary directory for the given data. + /// The filename format is "nn-.network" + fn get_network_path(data: &[u8]) -> PathBuf { + let mut temp_dir = std::env::temp_dir(); + temp_dir.push("Monty"); + fs::create_dir_all(&temp_dir) + .expect("Failed to create 'Monty' directory in the temp folder"); + let hash_prefix = compute_short_sha(data); + temp_dir.join(format!("nn-{}.network", hash_prefix)) + } + + /// Extract the first 12 characters of the SHA-256 prefix from the filename. + /// Assumes the filename format is "nn-.network" + fn extract_sha_prefix(file_name: &str) -> String { + // Ensure the filename starts with "nn-" and ends with ".network" + if file_name.starts_with("nn-") && file_name.ends_with(".network") { + // Extract the hash prefix + let start = 3; // Length of "nn-" + let end = file_name.len() - ".network".len(); + let hash_prefix = &file_name[start..end]; + if hash_prefix.len() == 12 { + return hash_prefix.to_string(); + } + } + panic!("Invalid file name format: {}", file_name); + } + + /// Cleanup old decompressed network files that do not match the current hash prefixes. + fn cleanup_old_files(current_hash_prefixes: &[&str]) -> std::io::Result<()> { + let mut temp_dir = std::env::temp_dir(); + temp_dir.push("Monty"); + fs::create_dir_all(&temp_dir) + .expect("Failed to create 'Monty' directory in the temp folder"); + for entry in fs::read_dir(&temp_dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_file() { + if let Some(fname) = path.file_name().and_then(|s| s.to_str()) { + // Check if the file matches the naming pattern + if fname.starts_with("nn-") && fname.ends_with(".network") { + // Extract the hash prefix from the filename + let extracted_hash = extract_sha_prefix(fname); + // If the extracted hash is not in the current hash prefixes, remove the file + if !current_hash_prefixes.contains(&extracted_hash.as_str()) { + // Attempt to remove the file; ignore errors for now + let _ = fs::remove_file(&path); + } + } + } + } + } + Ok(()) + } + + /// Decompress the data and write it to the specified file path. + /// If the file already exists and its hash prefix matches, do nothing. + /// Otherwise, decompress and write the file. + fn decompress_and_write( + _network_type: &str, + compressed_data: &[u8], + file_path: &Path, + ) -> std::io::Result<()> { + // Compute expected hash prefix + let expected_hash_prefix = compute_short_sha(compressed_data); + + // Note: Removed cleanup_old_files from here to prevent deleting other network files + + // Check if a file with the expected hash prefix already exists + if file_path.exists() { + // Extract the existing file's hash prefix + let existing_file_name = file_path.file_name().unwrap().to_str().unwrap(); + let existing_hash_prefix = extract_sha_prefix(existing_file_name); + + if existing_hash_prefix == expected_hash_prefix { + // Hash prefix matches; no need to overwrite + return Ok(()); + } else { + // Hash prefix mismatch; remove the old file + fs::remove_file(file_path)?; + } + } + + // Decompress the data + let decompressed_data = decode_all(Cursor::new(compressed_data)).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("Decompression failed: {}", e), + ) + })?; + + // Write the decompressed data to a temporary file first + let temp_file_path = file_path.with_extension("tmp"); + { + let mut temp_file = File::create(&temp_file_path)?; + temp_file.write_all(&decompressed_data)?; + } + + // Atomically rename the temporary file to the target path + fs::rename(&temp_file_path, file_path)?; + + Ok(()) + } + + /// Unsafe helper function to interpret the memory-mapped data as the target structure. + /// Ensure that the data layout matches exactly. + unsafe fn read_into_struct_unchecked(mmap: &Mmap) -> &T { assert_eq!( - data.len(), + mmap.len(), std::mem::size_of::(), - "Decompressed data size does not match the target structure size." + "Mapped file size does not match the target structure size." ); + &*(mmap.as_ptr() as *const T) + } - // Create an uninitialized Box - let mut boxed = Box::new(MaybeUninit::::uninit()); + // Initialize and memory-map both policy and value networks together + static NETWORKS: Lazy<(Mmap, Mmap)> = Lazy::new(|| { + // Compute hash prefixes based on compressed data + let policy_hash_prefix = compute_short_sha(COMPRESSED_POLICY); + let value_hash_prefix = compute_short_sha(COMPRESSED_VALUE); - unsafe { - // Copy the decompressed data into the Box's memory - std::ptr::copy_nonoverlapping(data.as_ptr(), boxed.as_mut_ptr() as *mut u8, data.len()); + // Current hash prefixes + let current_hash_prefixes = [policy_hash_prefix.as_str(), value_hash_prefix.as_str()]; - // Assume the Box is now initialized - boxed.assume_init() - } - } + // Cleanup old network files not matching current hash prefixes + cleanup_old_files(¤t_hash_prefixes).expect("Failed to cleanup old network files"); - // Lazy initialization for VALUE using LazyLock to ensure heap allocation - static VALUE: LazyLock> = LazyLock::new(|| { - // Decompress the value network - let decompressed_data = - decode_all(Cursor::new(COMPRESSED_VALUE)).expect("Failed to decompress value network"); + // Get file paths in the temporary directory + let policy_path = get_network_path(COMPRESSED_POLICY); + let value_path = get_network_path(COMPRESSED_VALUE); - // Initialize the Box with the decompressed data - decompress_into_boxed::(&decompressed_data) - }); + // Decompress and write network files + decompress_and_write("policy", COMPRESSED_POLICY, &policy_path) + .expect("Failed to decompress/write policy network"); - // Lazy initialization for POLICY using LazyLock to ensure heap allocation - static POLICY: LazyLock> = LazyLock::new(|| { - // Decompress the policy network - let decompressed_data = decode_all(Cursor::new(COMPRESSED_POLICY)) - .expect("Failed to decompress policy network"); + decompress_and_write("value", COMPRESSED_VALUE, &value_path) + .expect("Failed to decompress/write value network"); - // Initialize the Box with the decompressed data - decompress_into_boxed::(&decompressed_data) + // Memory-map the policy network file + let policy_file = + File::open(&policy_path).expect("Failed to open policy network file for mmap"); + let policy_mmap = + unsafe { Mmap::map(&policy_file).expect("Failed to memory-map policy network file") }; + + // Memory-map the value network file + let value_file = + File::open(&value_path).expect("Failed to open value network file for mmap"); + let value_mmap = + unsafe { Mmap::map(&value_file).expect("Failed to memory-map value network file") }; + + (policy_mmap, value_mmap) }); pub fn run() { let mut args = std::env::args(); let arg1 = args.nth(1); + // Interpret the memory-mapped data as network structures + let policy: &PolicyNetwork = unsafe { read_into_struct_unchecked(&NETWORKS.0) }; + let value: &ValueNetwork = unsafe { read_into_struct_unchecked(&NETWORKS.1) }; + if let Some("bench") = arg1.as_deref() { uci::bench( ChessState::BENCH_DEPTH, - &*POLICY, // Dereference the Box to get &PolicyNetwork - &*VALUE, // Dereference the Box to get &ValueNetwork + policy, + value, &MctsParams::default(), ); return; } - uci::run(&*POLICY, &*VALUE); + uci::run(policy, value); } }