diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index d0f020cbbf..acac2c2c1f 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -225,6 +225,13 @@ class RawDeltaTable: allow_out_of_range: bool = False, ) -> pyarrow.RecordBatchReader: ... def transaction_versions(self) -> Dict[str, Transaction]: ... + def set_column_metadata( + self, + column: str, + metadata: dict[str, str], + commit_properties: Optional[CommitProperties], + post_commithook_properties: Optional[PostCommitHookProperties], + ) -> None: ... def __datafusion_table_provider__(self) -> Any: ... def rust_core_version() -> str: ... diff --git a/python/deltalake/table.py b/python/deltalake/table.py index a8a8c72805..92ec979cd7 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -2167,6 +2167,28 @@ def set_table_properties( commit_properties, ) + def set_column_metadata( + self, + column: str, + metadata: dict[str, str], + commit_properties: Optional[CommitProperties] = None, + post_commithook_properties: Optional[PostCommitHookProperties] = None, + ) -> None: + """ + Update a field's metadata in a schema. If the metadata key does not exist, the entry is inserted. + + If the column name doesn't exist in the schema - an error is raised. + + :param column: name of the column to update metadata for. + :param metadata: the metadata to be added or modified on the column. + :param commit_properties: properties of the transaction commit. If None, default values are used. + :param post_commithook_properties: properties for the post commit hook. If None, default values are used. + :return: + """ + self.table._table.set_column_metadata( + column, metadata, commit_properties, post_commithook_properties + ) + class TableOptimizer: """API for various table optimization commands.""" diff --git a/python/src/lib.rs b/python/src/lib.rs index fa01e28fbf..3c36da9ddf 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -18,7 +18,7 @@ use arrow::pyarrow::PyArrowType; use chrono::{DateTime, Duration, FixedOffset, Utc}; use datafusion_ffi::table_provider::FFI_TableProvider; use delta_kernel::expressions::Scalar; -use delta_kernel::schema::StructField; +use delta_kernel::schema::{MetadataValue, StructField}; use deltalake::arrow::compute::concat_batches; use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator}; @@ -63,13 +63,6 @@ use error::DeltaError; use futures::future::join_all; use tracing::log::*; -use pyo3::exceptions::{PyRuntimeError, PyValueError}; -use pyo3::prelude::*; -use pyo3::pybacked::PyBackedStr; -use pyo3::types::{PyCapsule, PyDict, PyFrozenSet}; -use serde_json::{Map, Value}; -use uuid::Uuid; - use crate::error::DeltaProtocolError; use crate::error::PythonError; use crate::features::TableFeatures; @@ -78,6 +71,14 @@ use crate::merge::PyMergeBuilder; use crate::query::PyQueryBuilder; use crate::schema::{schema_to_pyobject, Field}; use crate::utils::rt; +use deltalake::operations::update_field_metadata::UpdateFieldMetadataBuilder; +use deltalake::protocol::DeltaOperation::UpdateFieldMetadata; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; +use pyo3::types::{PyCapsule, PyDict, PyFrozenSet}; +use serde_json::{Map, Value}; +use uuid::Uuid; #[cfg(all(target_family = "unix", not(target_os = "emscripten")))] use jemallocator::Jemalloc; @@ -1521,6 +1522,43 @@ impl RawDeltaTable { } } + #[pyo3(signature = (field_name, metadata, commit_properties=None, post_commithook_properties=None))] + pub fn set_column_metadata( + &self, + py: Python, + field_name: &str, + metadata: HashMap, + commit_properties: Option, + post_commithook_properties: Option, + ) -> PyResult<()> { + let table = py.allow_threads(|| { + let mut cmd = UpdateFieldMetadataBuilder::new(self.log_store()?, self.cloned_state()?); + + cmd = cmd.with_field_name(field_name).with_metadata( + metadata + .iter() + .map(|(k, v)| (k.clone(), MetadataValue::String(v.clone()))) + .collect(), + ); + + if let Some(commit_properties) = + maybe_create_commit_properties(commit_properties, post_commithook_properties) + { + cmd = cmd.with_commit_properties(commit_properties) + } + + if self.log_store()?.name() == "LakeFSLogStore" { + cmd = cmd.with_custom_execute_handler(Arc::new(LakeFSCustomExecuteHandler {})) + } + + rt().block_on(cmd.into_future()) + .map_err(PythonError::from) + .map_err(PyErr::from) + })?; + self.set_state(table.state)?; + Ok(()) + } + fn __datafusion_table_provider__<'py>( &self, py: Python<'py>, diff --git a/python/tests/test_alter.py b/python/tests/test_alter.py index acc37db822..7e244c4e83 100644 --- a/python/tests/test_alter.py +++ b/python/tests/test_alter.py @@ -457,3 +457,18 @@ def test_add_feautres(existing_sample_table: DeltaTable): "v2Checkpoint", ] ) # type: ignore + + +def test_set_column_metadata(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table) + + dt = DeltaTable(tmp_path) + + dt.alter.set_column_metadata("price", {"comment": "my comment"}) + + fields_by_name = {field.name: field for field in dt.schema().fields} + assert fields_by_name["price"].metadata == {"comment": "my comment"} + + with pytest.raises(DeltaError): + # Can't set metadata for non existing column. + dt.alter.set_column_metadata("non_existing_column", {"comment": "my comment"})