diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 57dacdb580..46e3366aae 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2977,6 +2977,45 @@ class DataReplacement(BaseOperation): replacements: List[LanceOperation.DataReplacementGroup] + @dataclass + class Project(BaseOperation): + """ + Operation that project columns. + Use this operator for drop column or rename/swap column. + + Attributes + ---------- + schema: LanceSchema + The lance schema of the new dataset. + + Examples + -------- + Use the projece operator to swap column: + + >>> import lance + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> from lance.schema import LanceSchema + >>> table = pa.table({"a": [1, 2], "b": ["a", "b"], "b1": ["c", "d"]}) + >>> dataset = lance.write_dataset(table, "example") + >>> dataset.to_table().to_pandas() + a b b1 + 0 1 a c + 1 2 b d + >>> + >>> ## rename column `b` into `b0` and rename b1 into `b` + >>> table = pa.table({"a": [3, 4], "b0": ["a", "b"], "b": ["c", "d"]}) + >>> lance_schema = LanceSchema.from_pyarrow(table.schema) + >>> operation = lance.LanceOperation.Project(lance_schema) + >>> dataset = lance.LanceDataset.commit("example", operation, read_version=1) + >>> dataset.to_table().to_pandas() + a b0 b + 0 1 a c + 1 2 b d + """ + + schema: LanceSchema + class ScannerBuilder: def __init__(self, ds: LanceDataset): diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index d9ba40ce6a..8da516eec7 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -30,6 +30,7 @@ from lance._dataset.sharded_batch_iterator import ShardedBatchIterator from lance.commit import CommitConflictError from lance.debug import format_fragment +from lance.schema import LanceSchema from lance.util import validate_vector_index # Various valid inputs for write_dataset @@ -3023,6 +3024,74 @@ def test_data_replacement(tmp_path: Path): assert tbl == expected +def test_schema_project_drop_column(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100, 200), "b": range(300, 400)}) + base_dir = tmp_path / "test" + + dataset = lance.write_dataset(table, base_dir) + + schema = pa.Table.from_pydict({"a": range(1)}).schema + lance_schema = LanceSchema.from_pyarrow(schema) + + project = lance.LanceOperation.Project(lance_schema) + dataset = lance.LanceDataset.commit(dataset, project, read_version=1) + + tbl = dataset.to_table() + + expected = pa.Table.from_pydict( + { + "a": list(range(100, 200)), + } + ) + assert tbl == expected + + +def test_schema_project_rename_column(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100, 200), "b": range(300, 400)}) + base_dir = tmp_path / "test" + + dataset = lance.write_dataset(table, base_dir) + + schema = pa.Table.from_pydict({"c": range(1), "d": range(1)}).schema + lance_schema = LanceSchema.from_pyarrow(schema) + + project = lance.LanceOperation.Project(lance_schema) + dataset = lance.LanceDataset.commit(dataset, project, read_version=1) + + tbl = dataset.to_table() + + expected = pa.Table.from_pydict( + { + "c": list(range(100, 200)), + "d": list(range(300, 400)), + } + ) + assert tbl == expected + + +def test_schema_project_swap_column(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100, 200), "b": range(300, 400)}) + base_dir = tmp_path / "test" + + dataset = lance.write_dataset(table, base_dir) + + schema = pa.Table.from_pydict({"b": range(1), "a": range(1)}).schema + lance_schema = LanceSchema.from_pyarrow(schema) + + project = lance.LanceOperation.Project(lance_schema) + dataset = lance.LanceDataset.commit(dataset, project, read_version=1) + + tbl = dataset.to_table() + + expected = pa.Table.from_pydict( + { + "b": list(range(100, 200)), + "a": list(range(300, 400)), + } + ) + assert tbl == expected + + def test_empty_structs(tmp_path): schema = pa.schema([pa.field("id", pa.int32()), pa.field("empties", pa.struct([]))]) table = pa.table({"id": [0, 1, 2], "empties": [{}] * 3}, schema=schema) diff --git a/python/src/transaction.rs b/python/src/transaction.rs index 33ed60eaa5..b09087bd06 100644 --- a/python/src/transaction.rs +++ b/python/src/transaction.rs @@ -157,6 +157,12 @@ impl FromPyObject<'_> for PyLance { Ok(Self(op)) } + "Project" => { + let schema = extract_schema(&ob.getattr("schema")?)?; + + let op = Operation::Project { schema }; + Ok(Self(op)) + } unsupported => Err(PyValueError::new_err(format!( "Unsupported operation: {unsupported}", ))),