diff --git a/crates/polars-io/src/ipc/ipc_file.rs b/crates/polars-io/src/ipc/ipc_file.rs index 804910627dbc..198e75ab3afe 100644 --- a/crates/polars-io/src/ipc/ipc_file.rs +++ b/crates/polars-io/src/ipc/ipc_file.rs @@ -33,6 +33,7 @@ //! assert!(df.equals(&df_read)); //! ``` use std::io::{Read, Seek}; +use std::path::PathBuf; use arrow::datatypes::ArrowSchemaRef; use arrow::io::ipc::read; @@ -79,7 +80,8 @@ pub struct IpcReader { pub(super) projection: Option>, pub(crate) columns: Option>, pub(super) row_index: Option, - memory_map: bool, + // Stores the as key semaphore to make sure we don't write to the memory mapped file. + pub(super) memory_map: Option, metadata: Option, schema: Option, } @@ -138,8 +140,9 @@ impl IpcReader { } /// Set if the file is to be memory_mapped. Only works with uncompressed files. - pub fn memory_mapped(mut self, toggle: bool) -> Self { - self.memory_map = toggle; + /// The file name must be passed to register the memory mapped file. + pub fn memory_mapped(mut self, path_buf: Option) -> Self { + self.memory_map = path_buf; self } @@ -150,7 +153,7 @@ impl IpcReader { predicate: Option>, verbose: bool, ) -> PolarsResult { - if self.memory_map && self.reader.to_file().is_some() { + if self.memory_map.is_some() && self.reader.to_file().is_some() { if verbose { eprintln!("memory map ipc file") } @@ -199,7 +202,7 @@ impl SerReader for IpcReader { columns: None, projection: None, row_index: None, - memory_map: true, + memory_map: None, metadata: None, schema: None, } @@ -211,7 +214,7 @@ impl SerReader for IpcReader { } fn finish(mut self) -> PolarsResult { - if self.memory_map && self.reader.to_file().is_some() { + if self.memory_map.is_some() && self.reader.to_file().is_some() { match self.finish_memmapped(None) { Ok(df) => return Ok(df), Err(err) => check_mmap_err(err)?, diff --git a/crates/polars-io/src/ipc/mmap.rs b/crates/polars-io/src/ipc/mmap.rs index a8282894787e..dd164baed88e 100644 --- a/crates/polars-io/src/ipc/mmap.rs +++ b/crates/polars-io/src/ipc/mmap.rs @@ -2,11 +2,10 @@ use arrow::io::ipc::read; use arrow::io::ipc::read::{Dictionaries, FileMetadata}; use arrow::mmap::{mmap_dictionaries_unchecked, mmap_unchecked}; use arrow::record_batch::RecordBatch; -use memmap::Mmap; use polars_core::prelude::*; use super::ipc_file::IpcReader; -use crate::mmap::MmapBytesReader; +use crate::mmap::{MMapSemaphore, MmapBytesReader}; use crate::predicates::PhysicalIoExpr; use crate::shared::{finish_reader, ArrowReader}; use crate::utils::{apply_projection, columns_to_projection}; @@ -19,7 +18,10 @@ impl IpcReader { match self.reader.to_file() { Some(file) => { let mmap = unsafe { memmap::Mmap::map(file).unwrap() }; - let metadata = read::read_file_metadata(&mut std::io::Cursor::new(mmap.as_ref()))?; + let mmap_key = self.memory_map.take().unwrap(); + let semaphore = MMapSemaphore::new(mmap_key, mmap); + let metadata = + read::read_file_metadata(&mut std::io::Cursor::new(semaphore.as_ref()))?; if let Some(columns) = &self.columns { let schema = &metadata.schema; @@ -33,7 +35,7 @@ impl IpcReader { metadata.schema.clone() }; - let reader = MMapChunkIter::new(mmap, metadata, &self.projection)?; + let reader = MMapChunkIter::new(Arc::new(semaphore), metadata, &self.projection)?; finish_reader( reader, @@ -53,7 +55,7 @@ impl IpcReader { struct MMapChunkIter<'a> { dictionaries: Dictionaries, metadata: FileMetadata, - mmap: Arc, + mmap: Arc, idx: usize, end: usize, projection: &'a Option>, @@ -61,12 +63,10 @@ struct MMapChunkIter<'a> { impl<'a> MMapChunkIter<'a> { fn new( - mmap: Mmap, + mmap: Arc, metadata: FileMetadata, projection: &'a Option>, ) -> PolarsResult { - let mmap = Arc::new(mmap); - let end = metadata.blocks.len(); // mmap the dictionaries let dictionaries = unsafe { mmap_dictionaries_unchecked(&metadata, mmap.clone())? }; diff --git a/crates/polars-io/src/mmap.rs b/crates/polars-io/src/mmap.rs index bf082b8798cf..cf281a4d358b 100644 --- a/crates/polars-io/src/mmap.rs +++ b/crates/polars-io/src/mmap.rs @@ -1,5 +1,63 @@ +use std::collections::btree_map::Entry; +use std::collections::BTreeMap; use std::fs::File; use std::io::{BufReader, Cursor, Read, Seek}; +use std::path::{Path, PathBuf}; +use std::sync::Mutex; + +use memmap::Mmap; +use once_cell::sync::Lazy; +use polars_error::{polars_bail, PolarsResult}; +use polars_utils::create_file; + +// Keep track of memory mapped files so we don't write to them while reading +// Use a btree as it uses less memory than a hashmap and this thing never shrinks. +static MEMORY_MAPPED_FILES: Lazy>> = + Lazy::new(|| Mutex::new(Default::default())); + +pub(crate) struct MMapSemaphore { + path: PathBuf, + mmap: Mmap, +} + +impl MMapSemaphore { + pub(super) fn new(path: PathBuf, mmap: Mmap) -> Self { + let mut guard = MEMORY_MAPPED_FILES.lock().unwrap(); + guard.insert(path.clone(), 1); + Self { path, mmap } + } +} + +impl AsRef<[u8]> for MMapSemaphore { + #[inline] + fn as_ref(&self) -> &[u8] { + self.mmap.as_ref() + } +} + +impl Drop for MMapSemaphore { + fn drop(&mut self) { + let mut guard = MEMORY_MAPPED_FILES.lock().unwrap(); + if let Entry::Occupied(mut e) = guard.entry(std::mem::take(&mut self.path)) { + let v = e.get_mut(); + *v -= 1; + + if *v == 0 { + e.remove_entry(); + } + } + } +} + +/// Open a file to get write access. This will check if the file is currently registered as memory mapped. +pub fn try_create_file(path: &Path) -> PolarsResult { + let guard = MEMORY_MAPPED_FILES.lock().unwrap(); + if guard.contains_key(path) { + polars_bail!(ComputeError: "cannot write to file: already memory mapped") + } + drop(guard); + create_file(path) +} /// Trait used to get a hold to file handler or to the underlying bytes /// without performing a Read. diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs index fb0b5c7206aa..5b8d20e511e0 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs @@ -95,6 +95,12 @@ impl IpcExec { let file = std::fs::File::open(path)?; + let memory_mapped = if self.options.memory_map { + Some(path.clone()) + } else { + None + }; + let df = IpcReader::new(file) .with_n_rows( // NOTE: If there is any file that by itself exceeds the @@ -108,7 +114,7 @@ impl IpcExec { ) .with_row_index(self.file_options.row_index.clone()) .with_projection(projection.clone()) - .memory_mapped(self.options.memory_map) + .memory_mapped(memory_mapped) .finish()?; row_counter diff --git a/crates/polars-utils/src/io.rs b/crates/polars-utils/src/io.rs index a6c3be1e745e..a943f9e5cbf5 100644 --- a/crates/polars-utils/src/io.rs +++ b/crates/polars-utils/src/io.rs @@ -1,21 +1,30 @@ use std::fs::File; -use std::io::Error; +use std::io; use std::path::Path; use polars_error::*; +fn map_err(path: &Path, err: io::Error) -> PolarsError { + let path = path.to_string_lossy(); + let msg = if path.len() > 88 { + let truncated_path: String = path.chars().skip(path.len() - 88).collect(); + format!("{err}: ...{truncated_path}") + } else { + format!("{err}: {path}") + }; + io::Error::new(err.kind(), msg).into() +} + pub fn open_file

(path: P) -> PolarsResult where P: AsRef, { - std::fs::File::open(&path).map_err(|err| { - let path = path.as_ref().to_string_lossy(); - let msg = if path.len() > 88 { - let truncated_path: String = path.chars().skip(path.len() - 88).collect(); - format!("{err}: ...{truncated_path}") - } else { - format!("{err}: {path}") - }; - Error::new(err.kind(), msg).into() - }) + File::open(&path).map_err(|err| map_err(path.as_ref(), err)) +} + +pub fn create_file

(path: P) -> PolarsResult +where + P: AsRef, +{ + File::create(&path).map_err(|err| map_err(path.as_ref(), err)) } diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index 842ea031d32f..9bf269785e6b 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -41,4 +41,4 @@ pub mod ord; pub mod partitioned; pub use index::{IdxSize, NullableIdxSize}; -pub use io::open_file; +pub use io::*; diff --git a/py-polars/src/dataframe/io.rs b/py-polars/src/dataframe/io.rs index d462259d9be7..1f31e616f412 100644 --- a/py-polars/src/dataframe/io.rs +++ b/py-polars/src/dataframe/io.rs @@ -4,7 +4,7 @@ use std::ops::Deref; #[cfg(feature = "avro")] use polars::io::avro::AvroCompression; -use polars::io::mmap::ReaderBytes; +use polars::io::mmap::{try_create_file, ReaderBytes}; use polars::io::RowIndex; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -14,7 +14,8 @@ use super::*; use crate::conversion::parse_parquet_compression; use crate::conversion::Wrap; use crate::file::{ - get_either_file, get_file_like, get_mmap_bytes_reader, read_if_bytesio, EitherRustPythonFile, + get_either_file, get_file_like, get_mmap_bytes_reader, get_mmap_bytes_reader_and_path, + read_if_bytesio, EitherRustPythonFile, }; #[pymethods] @@ -279,14 +280,16 @@ impl PyDataFrame { offset, }); py_f = read_if_bytesio(py_f); - let mmap_bytes_r = get_mmap_bytes_reader(&py_f)?; + let (mmap_bytes_r, mmap_path) = get_mmap_bytes_reader_and_path(&py_f)?; + + let mmap_path = if memory_map { mmap_path } else { None }; let df = py.allow_threads(move || { IpcReader::new(mmap_bytes_r) .with_projection(projection) .with_columns(columns) .with_n_rows(n_rows) .with_row_index(row_index) - .memory_mapped(memory_map) + .memory_mapped(mmap_path) .finish() .map_err(PyPolarsErr::from) })?; @@ -488,7 +491,9 @@ impl PyDataFrame { future: bool, ) -> PyResult<()> { if let Ok(s) = py_f.extract::(py) { - let f = std::fs::File::create(&*s)?; + let s: &str = s.as_ref(); + let path = std::path::Path::new(s); + let f = try_create_file(path).map_err(PyPolarsErr::from)?; py.allow_threads(|| { IpcWriter::new(f) .with_compression(compression.0) diff --git a/py-polars/src/file.rs b/py-polars/src/file.rs index c361e9f3392c..b05cfd260fd3 100644 --- a/py-polars/src/file.rs +++ b/py-polars/src/file.rs @@ -1,6 +1,7 @@ use std::fs::File; use std::io; use std::io::{BufReader, Cursor, Read, Seek, SeekFrom, Write}; +use std::path::PathBuf; use polars::io::mmap::MmapBytesReader; use polars_error::polars_warn; @@ -218,17 +219,23 @@ pub fn read_if_bytesio(py_f: Bound) -> Bound { pub fn get_mmap_bytes_reader<'a>( py_f: &'a Bound<'a, PyAny>, ) -> PyResult> { + get_mmap_bytes_reader_and_path(py_f).map(|t| t.0) +} + +pub fn get_mmap_bytes_reader_and_path<'a>( + py_f: &'a Bound<'a, PyAny>, +) -> PyResult<(Box, Option)> { // bytes object if let Ok(bytes) = py_f.downcast::() { - Ok(Box::new(Cursor::new(bytes.as_bytes()))) + Ok((Box::new(Cursor::new(bytes.as_bytes())), None)) } // string so read file else if let Ok(pstring) = py_f.downcast::() { let s = pstring.to_cow()?; let p = std::path::Path::new(&*s); - let p = resolve_homedir(p); - let f = polars_utils::open_file(p).map_err(PyPolarsErr::from)?; - Ok(Box::new(f)) + let p_resolved = resolve_homedir(p); + let f = polars_utils::open_file(p_resolved).map_err(PyPolarsErr::from)?; + Ok((Box::new(f), Some(p.to_path_buf()))) } // hopefully a normal python file: with open(...) as f:. else { @@ -242,6 +249,6 @@ pub fn get_mmap_bytes_reader<'a>( let f = Python::with_gil(|py| { PyFileLikeObject::with_requirements(py_f.to_object(py), true, false, true) })?; - Ok(Box::new(f)) + Ok((Box::new(f), None)) } } diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py index cd273ca1cd94..be9473a050c8 100644 --- a/py-polars/tests/unit/io/test_ipc.py +++ b/py-polars/tests/unit/io/test_ipc.py @@ -337,3 +337,19 @@ def test_ipc_decimal_15920( path = f"{tmp_path}/data" df.write_ipc(path) assert_frame_equal(pl.read_ipc(path), df) + + +@pytest.mark.write_disk() +def test_ipc_raise_on_writing_mmap(tmp_path: Path) -> None: + p = tmp_path / "foo.ipc" + df = pl.DataFrame({"foo": [1, 2, 3]}) + # first write is allowed + df.write_ipc(p) + + # now open as memory mapped + df = pl.read_ipc(p, memory_map=True) + + with pytest.raises( + pl.ComputeError, match="cannot write to file: already memory mapped" + ): + df.write_ipc(p)