Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use LazyTableProvider by default for write_to_deltalake for memory efficiency #3196

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -876,14 +876,6 @@ impl TableProvider for LazyTableProvider {
TableType::Base
}

fn get_table_definition(&self) -> Option<&str> {
None
}

fn get_logical_plan(&self) -> Option<Cow<'_, LogicalPlan>> {
None
}

async fn scan(
&self,
_session: &dyn Session,
Expand All @@ -909,7 +901,7 @@ impl TableProvider for LazyTableProvider {
if projection != &current_projection {
let execution_props = &ExecutionProps::new();
let fields: DeltaResult<Vec<(Arc<dyn PhysicalExpr>, String)>> = projection
.into_iter()
.iter()
.map(|i| {
let (table_ref, field) = df_schema.qualified_field(*i);
create_physical_expr(
Expand Down Expand Up @@ -941,10 +933,6 @@ impl TableProvider for LazyTableProvider {
.map(|_| TableProviderFilterPushDown::Inexact)
.collect())
}

fn statistics(&self) -> Option<Statistics> {
None
}
}

// TODO: this will likely also need to perform column mapping later when we support reader protocol v2
Expand Down
19 changes: 15 additions & 4 deletions crates/core/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,11 +834,22 @@ impl std::future::IntoFuture for WriteBuilder {
.unwrap_or_default();
let mut schema_drift = false;
let mut df = if let Some(plan) = this.input {
if this.schema_mode == Some(SchemaMode::Merge) {
return Err(DeltaTableError::Generic(
"Schema merge not supported yet for Datafusion".to_string(),
));
match this.schema_mode {
Some(SchemaMode::Merge) => {
return Err(DeltaTableError::Generic(
"Schema merge not supported yet for Datafusion".to_string(),
));
}
Some(SchemaMode::Overwrite) => {}
None => {
if let Some(snapshot) = &this.snapshot {
let table_schema = snapshot.input_schema()?;
let plan_schema = plan.schema().as_arrow();
try_cast_batch(table_schema.fields(), plan_schema.fields())?
}
}
}

Ok(DataFrame::new(state.clone(), plan.as_ref().clone()))
} else if let Some(batches) = this.batches {
if batches.is_empty() {
Expand Down
27 changes: 23 additions & 4 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod merge;
mod query;
mod schema;
mod utils;
mod write;

use std::collections::{HashMap, HashSet};
use std::ffi::CString;
Expand Down Expand Up @@ -49,6 +50,7 @@ use deltalake::operations::transaction::{
};
use deltalake::operations::update::UpdateBuilder;
use deltalake::operations::vacuum::VacuumBuilder;
use deltalake::operations::write::{SchemaMode, WriteBuilder};
use deltalake::operations::{collect_sendable_stream, CustomExecuteHandler};
use deltalake::parquet::basic::Compression;
use deltalake::parquet::errors::ParquetError;
Expand All @@ -71,6 +73,7 @@ use crate::merge::PyMergeBuilder;
use crate::query::PyQueryBuilder;
use crate::schema::{schema_to_pyobject, Field};
use crate::utils::rt;
use crate::write::ArrowStreamBatchGenerator;
use deltalake::operations::update_field_metadata::UpdateFieldMetadataBuilder;
use deltalake::protocol::DeltaOperation::UpdateFieldMetadata;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
Expand Down Expand Up @@ -2099,7 +2102,6 @@ fn write_to_deltalake(
post_commithook_properties: Option<PyPostCommitHookProperties>,
) -> PyResult<()> {
py.allow_threads(|| {
let batches = data.0.map(|batch| batch.unwrap()).collect::<Vec<_>>();
let save_mode = mode.parse().map_err(PythonError::from)?;

let options = storage_options.clone().unwrap_or_default();
Expand All @@ -2111,11 +2113,28 @@ fn write_to_deltalake(
))
.map_err(PythonError::from)?
};
let mut builder =
WriteBuilder::new(table.0.log_store(), table.0.state).with_save_mode(save_mode);

let mut builder = table.write(batches).with_save_mode(save_mode);
if let Some(schema_mode) = schema_mode {
builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?);
if let Some(ref schema_mode) = schema_mode {
let schema_mode = schema_mode.parse().map_err(PythonError::from)?;
builder = builder.with_schema_mode(schema_mode);
}

if schema_mode == Some("merge".into()) {
builder = builder.with_input_batches(data.0.map(|batch| batch.unwrap()));
} else {
use deltalake::datafusion::datasource::provider_as_source;
use deltalake::datafusion::logical_expr::LogicalPlanBuilder;
let table_provider = crate::write::to_lazy_table(data.0).map_err(PythonError::from)?;

let plan = LogicalPlanBuilder::scan("source", provider_as_source(table_provider), None)
.expect("XXX: Failed to make scan")
.build()
.expect("XXX Failed to build logical plan");
builder = builder.with_input_execution_plan(Arc::new(plan));
}

if let Some(partition_columns) = partition_by {
builder = builder.with_partition_columns(partition_columns);
}
Expand Down
43 changes: 1 addition & 42 deletions python/src/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use deltalake::{DeltaResult, DeltaTable};
use parking_lot::RwLock;
use pyo3::prelude::*;
use std::collections::HashMap;
use std::fmt::{self};
use std::future::IntoFuture;
use std::sync::{Arc, Mutex};

use crate::error::PythonError;
use crate::utils::rt;
use crate::write::ArrowStreamBatchGenerator;
use crate::{
maybe_create_commit_properties, set_writer_properties, PyCommitProperties,
PyPostCommitHookProperties, PyWriterProperties,
Expand All @@ -37,47 +37,6 @@ pub(crate) struct PyMergeBuilder {
merge_schema: bool,
arrow_schema: Arc<ArrowSchema>,
}
#[derive(Debug)]
struct ArrowStreamBatchGenerator {
pub array_stream: Arc<Mutex<ArrowArrayStreamReader>>,
}

impl fmt::Display for ArrowStreamBatchGenerator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ArrowStreamBatchGenerator {{ array_stream: {:?} }}",
self.array_stream
)
}
}

impl ArrowStreamBatchGenerator {
fn new(array_stream: Arc<Mutex<ArrowArrayStreamReader>>) -> Self {
Self { array_stream }
}
}

impl LazyBatchGenerator for ArrowStreamBatchGenerator {
fn generate_next_batch(
&mut self,
) -> deltalake::datafusion::error::Result<Option<deltalake::arrow::array::RecordBatch>> {
let mut stream_reader = self.array_stream.lock().map_err(|_| {
deltalake::datafusion::error::DataFusionError::Execution(
"Failed to lock the ArrowArrayStreamReader".to_string(),
)
})?;

match stream_reader.next() {
Some(Ok(record_batch)) => Ok(Some(record_batch)),
Some(Err(err)) => Err(deltalake::datafusion::error::DataFusionError::ArrowError(
err, None,
)),
None => Ok(None), // End of stream
}
}
}

impl PyMergeBuilder {
#[allow(clippy::too_many_arguments)]
pub fn new(
Expand Down
66 changes: 66 additions & 0 deletions python/src/write.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//! The write module contains shared code used for writes by the write_to_deltalake function and
//! the merge cod
use deltalake::arrow::ffi_stream::ArrowArrayStreamReader;
use deltalake::datafusion::catalog::TableProvider;
use deltalake::datafusion::physical_plan::memory::LazyBatchGenerator;
use deltalake::delta_datafusion::LazyTableProvider;
use deltalake::DeltaResult;
use parking_lot::RwLock;
use std::fmt::{self};
use std::sync::{Arc, Mutex};

/// Convert an [ArrowArrayStreamReader] into a [LazyTableProvider]
pub(crate) fn to_lazy_table(source: ArrowArrayStreamReader) -> DeltaResult<Arc<dyn TableProvider>> {
use deltalake::arrow::array::RecordBatchReader;
let schema = source.schema();
let arrow_stream: Arc<Mutex<ArrowArrayStreamReader>> = Arc::new(Mutex::new(source));
let arrow_stream_batch_generator: Arc<RwLock<dyn LazyBatchGenerator>> =
Arc::new(RwLock::new(ArrowStreamBatchGenerator::new(arrow_stream)));

Ok(Arc::new(LazyTableProvider::try_new(
schema.clone(),
vec![arrow_stream_batch_generator],
)?))
}

#[derive(Debug)]
pub(crate) struct ArrowStreamBatchGenerator {
pub array_stream: Arc<Mutex<ArrowArrayStreamReader>>,
}

impl fmt::Display for ArrowStreamBatchGenerator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ArrowStreamBatchGenerator {{ array_stream: {:?} }}",
self.array_stream
)
}
}

impl ArrowStreamBatchGenerator {
pub fn new(array_stream: Arc<Mutex<ArrowArrayStreamReader>>) -> Self {
Self { array_stream }
}
}

impl LazyBatchGenerator for ArrowStreamBatchGenerator {
fn generate_next_batch(
&mut self,
) -> deltalake::datafusion::error::Result<Option<deltalake::arrow::array::RecordBatch>> {
let mut stream_reader = self.array_stream.lock().map_err(|_| {
deltalake::datafusion::error::DataFusionError::Execution(
"Failed to lock the ArrowArrayStreamReader".to_string(),
)
})?;

match stream_reader.next() {
Some(Ok(record_batch)) => Ok(Some(record_batch)),
Some(Err(err)) => Err(deltalake::datafusion::error::DataFusionError::ArrowError(
err, None,
)),
None => Ok(None), // End of stream
}
}
}
14 changes: 0 additions & 14 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,13 +1589,6 @@ def test_schema_cols_diff_order(tmp_path: pathlib.Path, engine):
assert dt.to_pyarrow_table(columns=["baz", "bar", "foo"]) == expected


def test_empty(existing_table: DeltaTable):
schema = existing_table.schema().to_pyarrow()
empty_table = pa.Table.from_pylist([], schema=schema)
with pytest.raises(DeltaError, match="No data source supplied to write command"):
write_deltalake(existing_table, empty_table, mode="append", engine="rust")


def test_rust_decimal_cast(tmp_path: pathlib.Path):
import re

Expand Down Expand Up @@ -1813,13 +1806,6 @@ def test_roundtrip_cdc_evolution(tmp_path: pathlib.Path):
assert os.path.isdir(os.path.join(tmp_path, "_change_data"))


def test_empty_dataset_write(tmp_path: pathlib.Path, sample_data: pa.Table):
empty_arrow_table = sample_data.schema.empty_table()
empty_dataset = dataset(empty_arrow_table)
with pytest.raises(DeltaError, match="No data source supplied to write command"):
write_deltalake(tmp_path, empty_dataset, mode="append")


@pytest.mark.pandas
def test_predicate_out_of_bounds(tmp_path: pathlib.Path):
"""See <https://github.com/delta-io/delta-rs/issues/2867>"""
Expand Down
Loading