From 9a543bf551468d42275768fa48f0faacaab1e263 Mon Sep 17 00:00:00 2001 From: "R. Tyler Croy" Date: Sat, 8 Feb 2025 19:07:21 +0000 Subject: [PATCH] feat: Use LazyTableProvider by default for write_to_deltalake for memory efficiency This defaults write_to_deltalake in Python to attempt to use the LazytableProvider for a more stream-like execution. It's currently opted out for schewma evolution since that's not supported by default. Some improvements in schema mismatch detection inside of the operations::write module are required as well Signed-off-by: R. Tyler Croy --- crates/core/src/delta_datafusion/mod.rs | 14 +----- crates/core/src/operations/write.rs | 19 +++++-- python/src/lib.rs | 37 ++++++++++++-- python/src/merge.rs | 43 +--------------- python/src/write.rs | 66 +++++++++++++++++++++++++ python/tests/test_writer.py | 14 ------ 6 files changed, 116 insertions(+), 77 deletions(-) create mode 100644 python/src/write.rs diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index dfa79c09b4..fcaea89465 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -876,14 +876,6 @@ impl TableProvider for LazyTableProvider { TableType::Base } - fn get_table_definition(&self) -> Option<&str> { - None - } - - fn get_logical_plan(&self) -> Option> { - None - } - async fn scan( &self, _session: &dyn Session, @@ -909,7 +901,7 @@ impl TableProvider for LazyTableProvider { if projection != ¤t_projection { let execution_props = &ExecutionProps::new(); let fields: DeltaResult, String)>> = projection - .into_iter() + .iter() .map(|i| { let (table_ref, field) = df_schema.qualified_field(*i); create_physical_expr( @@ -941,10 +933,6 @@ impl TableProvider for LazyTableProvider { .map(|_| TableProviderFilterPushDown::Inexact) .collect()) } - - fn statistics(&self) -> Option { - None - } } // TODO: this will likely also need to perform column mapping later when we support reader protocol v2 diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 58effb1b51..3a49f3397c 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -834,11 +834,22 @@ impl std::future::IntoFuture for WriteBuilder { .unwrap_or_default(); let mut schema_drift = false; let mut df = if let Some(plan) = this.input { - if this.schema_mode == Some(SchemaMode::Merge) { - return Err(DeltaTableError::Generic( - "Schema merge not supported yet for Datafusion".to_string(), - )); + match this.schema_mode { + Some(SchemaMode::Merge) => { + return Err(DeltaTableError::Generic( + "Schema merge not supported yet for Datafusion".to_string(), + )); + } + Some(SchemaMode::Overwrite) => {} + None => { + if let Some(snapshot) = &this.snapshot { + let table_schema = snapshot.input_schema()?; + let plan_schema = plan.schema().as_arrow(); + try_cast_batch(table_schema.fields(), plan_schema.fields())? + } + } } + Ok(DataFrame::new(state.clone(), plan.as_ref().clone())) } else if let Some(batches) = this.batches { if batches.is_empty() { diff --git a/python/src/lib.rs b/python/src/lib.rs index 248ad3d2fc..629de3ec25 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -5,6 +5,7 @@ mod merge; mod query; mod schema; mod utils; +mod write; use std::collections::{HashMap, HashSet}; use std::ffi::CString; @@ -49,6 +50,7 @@ use deltalake::operations::transaction::{ }; use deltalake::operations::update::UpdateBuilder; use deltalake::operations::vacuum::VacuumBuilder; +use deltalake::operations::write::{SchemaMode, WriteBuilder}; use deltalake::operations::{collect_sendable_stream, CustomExecuteHandler}; use deltalake::parquet::basic::Compression; use deltalake::parquet::errors::ParquetError; @@ -71,6 +73,7 @@ use crate::merge::PyMergeBuilder; use crate::query::PyQueryBuilder; use crate::schema::{schema_to_pyobject, Field}; use crate::utils::rt; +use crate::write::ArrowStreamBatchGenerator; use deltalake::operations::update_field_metadata::UpdateFieldMetadataBuilder; use deltalake::protocol::DeltaOperation::UpdateFieldMetadata; use pyo3::exceptions::{PyRuntimeError, PyValueError}; @@ -2099,7 +2102,6 @@ fn write_to_deltalake( post_commithook_properties: Option, ) -> PyResult<()> { py.allow_threads(|| { - let batches = data.0.map(|batch| batch.unwrap()).collect::>(); let save_mode = mode.parse().map_err(PythonError::from)?; let options = storage_options.clone().unwrap_or_default(); @@ -2112,10 +2114,37 @@ fn write_to_deltalake( .map_err(PythonError::from)? }; - let mut builder = table.write(batches).with_save_mode(save_mode); - if let Some(schema_mode) = schema_mode { - builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?); + let dont_be_so_lazy = match table.0.state.as_ref() { + Some(state) => state.table_config().enable_change_data_feed(), + // You don't have state somehow, so I guess it's okay to be lazy. + _ => false, + }; + + let mut builder = + WriteBuilder::new(table.0.log_store(), table.0.state).with_save_mode(save_mode); + + if let Some(ref schema_mode) = schema_mode { + let schema_mode = schema_mode.parse().map_err(PythonError::from)?; + builder = builder.with_schema_mode(schema_mode); } + + if (schema_mode == Some("merge".into())) || dont_be_so_lazy { + debug!( + "write_to_deltalake() is not able to lazily perform a write, collecting batches" + ); + builder = builder.with_input_batches(data.0.map(|batch| batch.unwrap())); + } else { + use deltalake::datafusion::datasource::provider_as_source; + use deltalake::datafusion::logical_expr::LogicalPlanBuilder; + let table_provider = crate::write::to_lazy_table(data.0).map_err(PythonError::from)?; + + let plan = LogicalPlanBuilder::scan("source", provider_as_source(table_provider), None) + .map_err(PythonError::from)? + .build() + .map_err(PythonError::from)?; + builder = builder.with_input_execution_plan(Arc::new(plan)); + } + if let Some(partition_columns) = partition_by { builder = builder.with_partition_columns(partition_columns); } diff --git a/python/src/merge.rs b/python/src/merge.rs index 8da1aceb6f..7a44eee759 100644 --- a/python/src/merge.rs +++ b/python/src/merge.rs @@ -15,12 +15,12 @@ use deltalake::{DeltaResult, DeltaTable}; use parking_lot::RwLock; use pyo3::prelude::*; use std::collections::HashMap; -use std::fmt::{self}; use std::future::IntoFuture; use std::sync::{Arc, Mutex}; use crate::error::PythonError; use crate::utils::rt; +use crate::write::ArrowStreamBatchGenerator; use crate::{ maybe_create_commit_properties, set_writer_properties, PyCommitProperties, PyPostCommitHookProperties, PyWriterProperties, @@ -37,47 +37,6 @@ pub(crate) struct PyMergeBuilder { merge_schema: bool, arrow_schema: Arc, } -#[derive(Debug)] -struct ArrowStreamBatchGenerator { - pub array_stream: Arc>, -} - -impl fmt::Display for ArrowStreamBatchGenerator { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "ArrowStreamBatchGenerator {{ array_stream: {:?} }}", - self.array_stream - ) - } -} - -impl ArrowStreamBatchGenerator { - fn new(array_stream: Arc>) -> Self { - Self { array_stream } - } -} - -impl LazyBatchGenerator for ArrowStreamBatchGenerator { - fn generate_next_batch( - &mut self, - ) -> deltalake::datafusion::error::Result> { - let mut stream_reader = self.array_stream.lock().map_err(|_| { - deltalake::datafusion::error::DataFusionError::Execution( - "Failed to lock the ArrowArrayStreamReader".to_string(), - ) - })?; - - match stream_reader.next() { - Some(Ok(record_batch)) => Ok(Some(record_batch)), - Some(Err(err)) => Err(deltalake::datafusion::error::DataFusionError::ArrowError( - err, None, - )), - None => Ok(None), // End of stream - } - } -} - impl PyMergeBuilder { #[allow(clippy::too_many_arguments)] pub fn new( diff --git a/python/src/write.rs b/python/src/write.rs new file mode 100644 index 0000000000..b10b208d50 --- /dev/null +++ b/python/src/write.rs @@ -0,0 +1,66 @@ +//! The write module contains shared code used for writes by the write_to_deltalake function and +//! the merge cod + +use deltalake::arrow::ffi_stream::ArrowArrayStreamReader; +use deltalake::datafusion::catalog::TableProvider; +use deltalake::datafusion::physical_plan::memory::LazyBatchGenerator; +use deltalake::delta_datafusion::LazyTableProvider; +use deltalake::DeltaResult; +use parking_lot::RwLock; +use std::fmt::{self}; +use std::sync::{Arc, Mutex}; + +/// Convert an [ArrowArrayStreamReader] into a [LazyTableProvider] +pub(crate) fn to_lazy_table(source: ArrowArrayStreamReader) -> DeltaResult> { + use deltalake::arrow::array::RecordBatchReader; + let schema = source.schema(); + let arrow_stream: Arc> = Arc::new(Mutex::new(source)); + let arrow_stream_batch_generator: Arc> = + Arc::new(RwLock::new(ArrowStreamBatchGenerator::new(arrow_stream))); + + Ok(Arc::new(LazyTableProvider::try_new( + schema.clone(), + vec![arrow_stream_batch_generator], + )?)) +} + +#[derive(Debug)] +pub(crate) struct ArrowStreamBatchGenerator { + pub array_stream: Arc>, +} + +impl fmt::Display for ArrowStreamBatchGenerator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ArrowStreamBatchGenerator {{ array_stream: {:?} }}", + self.array_stream + ) + } +} + +impl ArrowStreamBatchGenerator { + pub fn new(array_stream: Arc>) -> Self { + Self { array_stream } + } +} + +impl LazyBatchGenerator for ArrowStreamBatchGenerator { + fn generate_next_batch( + &mut self, + ) -> deltalake::datafusion::error::Result> { + let mut stream_reader = self.array_stream.lock().map_err(|_| { + deltalake::datafusion::error::DataFusionError::Execution( + "Failed to lock the ArrowArrayStreamReader".to_string(), + ) + })?; + + match stream_reader.next() { + Some(Ok(record_batch)) => Ok(Some(record_batch)), + Some(Err(err)) => Err(deltalake::datafusion::error::DataFusionError::ArrowError( + err, None, + )), + None => Ok(None), // End of stream + } + } +} diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index cbf40dbfd1..f83c7fc164 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1589,13 +1589,6 @@ def test_schema_cols_diff_order(tmp_path: pathlib.Path, engine): assert dt.to_pyarrow_table(columns=["baz", "bar", "foo"]) == expected -def test_empty(existing_table: DeltaTable): - schema = existing_table.schema().to_pyarrow() - empty_table = pa.Table.from_pylist([], schema=schema) - with pytest.raises(DeltaError, match="No data source supplied to write command"): - write_deltalake(existing_table, empty_table, mode="append", engine="rust") - - def test_rust_decimal_cast(tmp_path: pathlib.Path): import re @@ -1813,13 +1806,6 @@ def test_roundtrip_cdc_evolution(tmp_path: pathlib.Path): assert os.path.isdir(os.path.join(tmp_path, "_change_data")) -def test_empty_dataset_write(tmp_path: pathlib.Path, sample_data: pa.Table): - empty_arrow_table = sample_data.schema.empty_table() - empty_dataset = dataset(empty_arrow_table) - with pytest.raises(DeltaError, match="No data source supplied to write command"): - write_deltalake(tmp_path, empty_dataset, mode="append") - - @pytest.mark.pandas def test_predicate_out_of_bounds(tmp_path: pathlib.Path): """See """