Skip to content

Commit

Permalink
feat: Use LazyTableProvider by default for write_to_deltalake for mem…
Browse files Browse the repository at this point in the history
…ory efficiency

This defaults write_to_deltalake in Python to attempt to use the
LazytableProvider for a more stream-like execution. It's currently opted
out for schewma evolution since that's not supported by default.

Some improvements in schema mismatch detection inside of the
operations::write module are required as well

Signed-off-by: R. Tyler Croy <rtyler@brokenco.de>
  • Loading branch information
rtyler committed Feb 8, 2025
1 parent 4ef9fb3 commit ba2845b
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 77 deletions.
14 changes: 1 addition & 13 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -876,14 +876,6 @@ impl TableProvider for LazyTableProvider {
TableType::Base
}

fn get_table_definition(&self) -> Option<&str> {
None
}

fn get_logical_plan(&self) -> Option<Cow<'_, LogicalPlan>> {
None
}

async fn scan(
&self,
_session: &dyn Session,
Expand All @@ -909,7 +901,7 @@ impl TableProvider for LazyTableProvider {
if projection != &current_projection {
let execution_props = &ExecutionProps::new();
let fields: DeltaResult<Vec<(Arc<dyn PhysicalExpr>, String)>> = projection
.into_iter()
.iter()
.map(|i| {
let (table_ref, field) = df_schema.qualified_field(*i);
create_physical_expr(
Expand Down Expand Up @@ -941,10 +933,6 @@ impl TableProvider for LazyTableProvider {
.map(|_| TableProviderFilterPushDown::Inexact)
.collect())
}

fn statistics(&self) -> Option<Statistics> {
None
}
}

// TODO: this will likely also need to perform column mapping later when we support reader protocol v2
Expand Down
19 changes: 15 additions & 4 deletions crates/core/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,11 +834,22 @@ impl std::future::IntoFuture for WriteBuilder {
.unwrap_or_default();
let mut schema_drift = false;
let mut df = if let Some(plan) = this.input {
if this.schema_mode == Some(SchemaMode::Merge) {
return Err(DeltaTableError::Generic(
"Schema merge not supported yet for Datafusion".to_string(),
));
match this.schema_mode {
Some(SchemaMode::Merge) => {
return Err(DeltaTableError::Generic(
"Schema merge not supported yet for Datafusion".to_string(),
));
}
Some(SchemaMode::Overwrite) => {}
None => {
if let Some(snapshot) = &this.snapshot {
let table_schema = snapshot.input_schema()?;
let plan_schema = plan.schema().as_arrow();
try_cast_batch(table_schema.fields(), plan_schema.fields())?
}
}
}

Ok(DataFrame::new(state.clone(), plan.as_ref().clone()))
} else if let Some(batches) = this.batches {
if batches.is_empty() {
Expand Down
27 changes: 23 additions & 4 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod merge;
mod query;
mod schema;
mod utils;
mod write;

use std::collections::{HashMap, HashSet};
use std::ffi::CString;
Expand Down Expand Up @@ -49,6 +50,7 @@ use deltalake::operations::transaction::{
};
use deltalake::operations::update::UpdateBuilder;
use deltalake::operations::vacuum::VacuumBuilder;
use deltalake::operations::write::{SchemaMode, WriteBuilder};
use deltalake::operations::{collect_sendable_stream, CustomExecuteHandler};
use deltalake::parquet::basic::Compression;
use deltalake::parquet::errors::ParquetError;
Expand All @@ -71,6 +73,7 @@ use crate::merge::PyMergeBuilder;
use crate::query::PyQueryBuilder;
use crate::schema::{schema_to_pyobject, Field};
use crate::utils::rt;
use crate::write::ArrowStreamBatchGenerator;
use deltalake::operations::update_field_metadata::UpdateFieldMetadataBuilder;
use deltalake::protocol::DeltaOperation::UpdateFieldMetadata;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
Expand Down Expand Up @@ -2099,7 +2102,6 @@ fn write_to_deltalake(
post_commithook_properties: Option<PyPostCommitHookProperties>,
) -> PyResult<()> {
py.allow_threads(|| {
let batches = data.0.map(|batch| batch.unwrap()).collect::<Vec<_>>();
let save_mode = mode.parse().map_err(PythonError::from)?;

let options = storage_options.clone().unwrap_or_default();
Expand All @@ -2111,11 +2113,28 @@ fn write_to_deltalake(
))
.map_err(PythonError::from)?
};
let mut builder =
WriteBuilder::new(table.0.log_store(), table.0.state).with_save_mode(save_mode);

let mut builder = table.write(batches).with_save_mode(save_mode);
if let Some(schema_mode) = schema_mode {
builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?);
if let Some(ref schema_mode) = schema_mode {
let schema_mode = schema_mode.parse().map_err(PythonError::from)?;
builder = builder.with_schema_mode(schema_mode);
}

if schema_mode == Some("merge".into()) {
builder = builder.with_input_batches(data.0.map(|batch| batch.unwrap()));
} else {
use deltalake::datafusion::datasource::provider_as_source;
use deltalake::datafusion::logical_expr::LogicalPlanBuilder;
let table_provider = crate::write::to_lazy_table(data.0).map_err(PythonError::from)?;

let plan = LogicalPlanBuilder::scan("source", provider_as_source(table_provider), None)
.expect("XXX: Failed to make scan")
.build()
.expect("XXX Failed to build logical plan");
builder = builder.with_input_execution_plan(Arc::new(plan));
}

if let Some(partition_columns) = partition_by {
builder = builder.with_partition_columns(partition_columns);
}
Expand Down
43 changes: 1 addition & 42 deletions python/src/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use deltalake::{DeltaResult, DeltaTable};
use parking_lot::RwLock;
use pyo3::prelude::*;
use std::collections::HashMap;
use std::fmt::{self};
use std::future::IntoFuture;
use std::sync::{Arc, Mutex};

use crate::error::PythonError;
use crate::utils::rt;
use crate::write::ArrowStreamBatchGenerator;
use crate::{
maybe_create_commit_properties, set_writer_properties, PyCommitProperties,
PyPostCommitHookProperties, PyWriterProperties,
Expand All @@ -37,47 +37,6 @@ pub(crate) struct PyMergeBuilder {
merge_schema: bool,
arrow_schema: Arc<ArrowSchema>,
}
#[derive(Debug)]
struct ArrowStreamBatchGenerator {
pub array_stream: Arc<Mutex<ArrowArrayStreamReader>>,
}

impl fmt::Display for ArrowStreamBatchGenerator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ArrowStreamBatchGenerator {{ array_stream: {:?} }}",
self.array_stream
)
}
}

impl ArrowStreamBatchGenerator {
fn new(array_stream: Arc<Mutex<ArrowArrayStreamReader>>) -> Self {
Self { array_stream }
}
}

impl LazyBatchGenerator for ArrowStreamBatchGenerator {
fn generate_next_batch(
&mut self,
) -> deltalake::datafusion::error::Result<Option<deltalake::arrow::array::RecordBatch>> {
let mut stream_reader = self.array_stream.lock().map_err(|_| {
deltalake::datafusion::error::DataFusionError::Execution(
"Failed to lock the ArrowArrayStreamReader".to_string(),
)
})?;

match stream_reader.next() {
Some(Ok(record_batch)) => Ok(Some(record_batch)),
Some(Err(err)) => Err(deltalake::datafusion::error::DataFusionError::ArrowError(
err, None,
)),
None => Ok(None), // End of stream
}
}
}

impl PyMergeBuilder {
#[allow(clippy::too_many_arguments)]
pub fn new(
Expand Down
66 changes: 66 additions & 0 deletions python/src/write.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//! The write module contains shared code used for writes by the write_to_deltalake function and
//! the merge cod
use deltalake::arrow::ffi_stream::ArrowArrayStreamReader;
use deltalake::datafusion::catalog::TableProvider;
use deltalake::datafusion::physical_plan::memory::LazyBatchGenerator;
use deltalake::delta_datafusion::LazyTableProvider;
use deltalake::DeltaResult;
use parking_lot::RwLock;
use std::fmt::{self};
use std::sync::{Arc, Mutex};

/// Convert an [ArrowArrayStreamReader] into a [LazyTableProvider]
pub(crate) fn to_lazy_table(source: ArrowArrayStreamReader) -> DeltaResult<Arc<dyn TableProvider>> {
use deltalake::arrow::array::RecordBatchReader;
let schema = source.schema();
let arrow_stream: Arc<Mutex<ArrowArrayStreamReader>> = Arc::new(Mutex::new(source));
let arrow_stream_batch_generator: Arc<RwLock<dyn LazyBatchGenerator>> =
Arc::new(RwLock::new(ArrowStreamBatchGenerator::new(arrow_stream)));

Ok(Arc::new(LazyTableProvider::try_new(
schema.clone(),
vec![arrow_stream_batch_generator],
)?))
}

#[derive(Debug)]
pub(crate) struct ArrowStreamBatchGenerator {
pub array_stream: Arc<Mutex<ArrowArrayStreamReader>>,
}

impl fmt::Display for ArrowStreamBatchGenerator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ArrowStreamBatchGenerator {{ array_stream: {:?} }}",
self.array_stream
)
}
}

impl ArrowStreamBatchGenerator {
pub fn new(array_stream: Arc<Mutex<ArrowArrayStreamReader>>) -> Self {
Self { array_stream }
}
}

impl LazyBatchGenerator for ArrowStreamBatchGenerator {
fn generate_next_batch(
&mut self,
) -> deltalake::datafusion::error::Result<Option<deltalake::arrow::array::RecordBatch>> {
let mut stream_reader = self.array_stream.lock().map_err(|_| {
deltalake::datafusion::error::DataFusionError::Execution(
"Failed to lock the ArrowArrayStreamReader".to_string(),
)
})?;

match stream_reader.next() {
Some(Ok(record_batch)) => Ok(Some(record_batch)),
Some(Err(err)) => Err(deltalake::datafusion::error::DataFusionError::ArrowError(
err, None,
)),
None => Ok(None), // End of stream
}
}
}
14 changes: 0 additions & 14 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,13 +1589,6 @@ def test_schema_cols_diff_order(tmp_path: pathlib.Path, engine):
assert dt.to_pyarrow_table(columns=["baz", "bar", "foo"]) == expected


def test_empty(existing_table: DeltaTable):
schema = existing_table.schema().to_pyarrow()
empty_table = pa.Table.from_pylist([], schema=schema)
with pytest.raises(DeltaError, match="No data source supplied to write command"):
write_deltalake(existing_table, empty_table, mode="append", engine="rust")


def test_rust_decimal_cast(tmp_path: pathlib.Path):
import re

Expand Down Expand Up @@ -1813,13 +1806,6 @@ def test_roundtrip_cdc_evolution(tmp_path: pathlib.Path):
assert os.path.isdir(os.path.join(tmp_path, "_change_data"))


def test_empty_dataset_write(tmp_path: pathlib.Path, sample_data: pa.Table):
empty_arrow_table = sample_data.schema.empty_table()
empty_dataset = dataset(empty_arrow_table)
with pytest.raises(DeltaError, match="No data source supplied to write command"):
write_deltalake(tmp_path, empty_dataset, mode="append")


@pytest.mark.pandas
def test_predicate_out_of_bounds(tmp_path: pathlib.Path):
"""See <https://github.com/delta-io/delta-rs/issues/2867>"""
Expand Down

0 comments on commit ba2845b

Please sign in to comment.