Skip to content

Commit dca745b

Browse files
feat: support add all null column as metadata-only operation via sql (#3504)
Adds support for adding all-null column via SQL. If the user passes: ```rs dataset.add_column(NewColumnTransform::SqlExpressions(vec!["new_col", "CAST(NULL AS int)"]); ``` We'll discover that the intention is to to create an all null column, and optimize the transform to: ```rs dataset.add_column(NewColumnTransform::AllNull(Arc::new( Schema::new(vec![ Field::new("new_col", DataType:Int32, true), ]) ) ``` The motivation here is to be able to expose the capability to add the all null column as a metadata-only operation through the LanceDB SDKs. Currently these methods only support passing SQL expressions. A different option would have been to modify the arguments to the python table.add_column & typescript table.addColumn, but that seemed like more work so I wanted to propose this solution first.
1 parent a144028 commit dca745b

File tree

3 files changed

+307
-12
lines changed

3 files changed

+307
-12
lines changed

python/python/tests/test_schema_evolution.py

+28
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,31 @@ def some_udf(batch):
512512

513513
with pytest.raises(ValueError, match="A checkpoint file cannot be used"):
514514
frag.merge_columns(some_udf, columns=["a"])
515+
516+
517+
def test_add_cols_all_null_with_sql(tmp_path: Path):
518+
tab = pa.table(
519+
{
520+
"a": range(100),
521+
}
522+
)
523+
dataset = lance.write_dataset(
524+
tab, tmp_path, max_rows_per_file=50, data_storage_version="stable"
525+
)
526+
fragments_before = dataset.get_fragments()
527+
dataset.add_columns({"b": "CAST(NULL AS INT)"})
528+
fragments_after = dataset.get_fragments()
529+
530+
# assert this was a metadata only operation and no data was written
531+
assert len(fragments_before) == len(fragments_after)
532+
for frag_before, frag_after in zip(fragments_before, fragments_after):
533+
assert frag_before.fragment_id == frag_after.fragment_id
534+
assert frag_before.data_files() == frag_after.data_files()
535+
536+
# assert the schema is as expected
537+
assert dataset.schema == pa.schema(
538+
{
539+
"a": pa.int64(),
540+
"b": pa.int32(),
541+
}
542+
)

rust/lance/src/dataset/schema_evolution.rs

+126-12
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use futures::stream::{StreamExt, TryStreamExt};
1313
use lance_arrow::SchemaExt;
1414
use lance_core::datatypes::{Field, Schema};
1515
use lance_datafusion::utils::StreamingWriteSource;
16-
use lance_encoding::version::LanceFileVersion;
1716
use lance_table::format::Fragment;
1817
use snafu::location;
1918

@@ -23,6 +22,12 @@ use super::{
2322
Dataset,
2423
};
2524

25+
mod optimize;
26+
27+
use optimize::{
28+
ChainedNewColumnTransformOptimizer, NewColumnTransformOptimizer, SqlToAllNullsOptimizer,
29+
};
30+
2631
#[derive(Debug, Clone, PartialEq)]
2732
pub struct BatchInfo {
2833
pub fragment_id: u32,
@@ -149,6 +154,14 @@ pub(super) async fn add_columns_to_fragments(
149154
Ok(())
150155
};
151156

157+
// Optimize the transforms
158+
let mut optimizer = ChainedNewColumnTransformOptimizer::new(vec![]);
159+
// ALlNull transform can not performed on legacy files
160+
if !dataset.is_legacy_storage() {
161+
optimizer.add_optimizer(Box::new(SqlToAllNullsOptimizer::new()));
162+
}
163+
let transforms = optimizer.optimize(dataset, transforms)?;
164+
152165
let (output_schema, fragments) = match transforms {
153166
NewColumnTransform::BatchUDF(udf) => {
154167
check_names(udf.output_schema.as_ref())?;
@@ -262,17 +275,7 @@ pub(super) async fn add_columns_to_fragments(
262275
// can't add all-null columns as a metadata-only operation. The reason is because we
263276
// use the NullReader for fragments that have missing columns and we can't mix legacy
264277
// and non-legacy readers when reading the fragment.
265-
if fragments.iter().any(|fragment| {
266-
fragment.files.iter().any(|file| {
267-
matches!(
268-
LanceFileVersion::try_from_major_minor(
269-
file.file_major_version,
270-
file.file_minor_version
271-
),
272-
Ok(LanceFileVersion::Legacy)
273-
)
274-
})
275-
}) {
278+
if dataset.is_legacy_storage() {
276279
return Err(Error::NotSupported {
277280
source: "Cannot add all-null columns to legacy dataset version.".into(),
278281
location: location!(),
@@ -1744,4 +1747,115 @@ mod test {
17441747

17451748
Ok(())
17461749
}
1750+
1751+
#[tokio::test]
1752+
async fn test_new_column_sql_to_all_nulls_transform_optimizer() {
1753+
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
1754+
"a",
1755+
DataType::Int32,
1756+
false,
1757+
)]));
1758+
1759+
let batch = RecordBatch::try_new(
1760+
schema.clone(),
1761+
vec![Arc::new(Int32Array::from_iter(0..100))],
1762+
)
1763+
.unwrap();
1764+
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
1765+
let test_dir = tempfile::tempdir().unwrap();
1766+
let test_uri = test_dir.path().to_str().unwrap();
1767+
let mut dataset = Dataset::write(
1768+
reader,
1769+
test_uri,
1770+
Some(WriteParams {
1771+
max_rows_per_file: 50,
1772+
max_rows_per_group: 25,
1773+
data_storage_version: Some(LanceFileVersion::Stable),
1774+
..Default::default()
1775+
}),
1776+
)
1777+
.await
1778+
.unwrap();
1779+
dataset.validate().await.unwrap();
1780+
1781+
let manifest_before = dataset.manifest.clone();
1782+
1783+
// Add all null column
1784+
dataset
1785+
.add_columns(
1786+
NewColumnTransform::SqlExpressions(vec![(
1787+
"b".to_string(),
1788+
"CAST(NULL AS int)".to_string(),
1789+
)]),
1790+
None,
1791+
None,
1792+
)
1793+
.await
1794+
.unwrap();
1795+
let manifest_after = dataset.manifest.clone();
1796+
1797+
// Check that this is a metadata-only operation (the fragments don't change)
1798+
assert_eq!(&manifest_before.fragments, &manifest_after.fragments);
1799+
1800+
// check that the new field was added to the schema
1801+
let expected_schema = ArrowSchema::new(vec![
1802+
ArrowField::new("a", DataType::Int32, false),
1803+
ArrowField::new("b", DataType::Int32, true),
1804+
]);
1805+
assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema);
1806+
}
1807+
1808+
#[tokio::test]
1809+
async fn test_new_column_sql_to_all_nulls_transform_optimizer_legacy() {
1810+
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
1811+
"a",
1812+
DataType::Int32,
1813+
false,
1814+
)]));
1815+
1816+
let batch = RecordBatch::try_new(
1817+
schema.clone(),
1818+
vec![Arc::new(Int32Array::from_iter(0..100))],
1819+
)
1820+
.unwrap();
1821+
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
1822+
let test_dir = tempfile::tempdir().unwrap();
1823+
let test_uri = test_dir.path().to_str().unwrap();
1824+
let mut dataset = Dataset::write(
1825+
reader,
1826+
test_uri,
1827+
Some(WriteParams {
1828+
max_rows_per_file: 50,
1829+
max_rows_per_group: 25,
1830+
data_storage_version: Some(LanceFileVersion::Legacy),
1831+
..Default::default()
1832+
}),
1833+
)
1834+
.await
1835+
.unwrap();
1836+
dataset.validate().await.unwrap();
1837+
1838+
// Add all null column ...
1839+
// This is basically a smoke test to ensure we don't try to use the all-nulls
1840+
// transform optimizer where it's not supported, and then blow up when we try
1841+
// to apply the transform
1842+
dataset
1843+
.add_columns(
1844+
NewColumnTransform::SqlExpressions(vec![(
1845+
"b".to_string(),
1846+
"CAST(NULL AS int)".to_string(),
1847+
)]),
1848+
None,
1849+
None,
1850+
)
1851+
.await
1852+
.unwrap();
1853+
1854+
// check that the new field was added to the schema
1855+
let expected_schema = ArrowSchema::new(vec![
1856+
ArrowField::new("a", DataType::Int32, false),
1857+
ArrowField::new("b", DataType::Int32, true),
1858+
]);
1859+
assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema);
1860+
}
17471861
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright The Lance Authors
3+
4+
use std::sync::Arc;
5+
6+
use arrow_schema::{DataType, Field, Schema};
7+
use datafusion::prelude::Expr;
8+
use datafusion::scalar::ScalarValue;
9+
use lance_datafusion::planner::Planner;
10+
11+
use crate::error::Result;
12+
use crate::Dataset;
13+
14+
use super::NewColumnTransform;
15+
16+
/// Optimizes a `NewColumnTransform` into
17+
pub(super) trait NewColumnTransformOptimizer: Send + Sync {
18+
/// Optimize the passed `NewColumnTransform` to a more efficient form.
19+
fn optimize(
20+
&self,
21+
dataset: &Dataset,
22+
transform: NewColumnTransform,
23+
) -> Result<NewColumnTransform>;
24+
}
25+
26+
/// A `NewColumnTransformOptimizer` that chains multiple `NewColumnTransformOptimizer`s together.
27+
pub(super) struct ChainedNewColumnTransformOptimizer {
28+
optimizers: Vec<Box<dyn NewColumnTransformOptimizer>>,
29+
}
30+
31+
impl ChainedNewColumnTransformOptimizer {
32+
pub(super) fn new(optimizers: Vec<Box<dyn NewColumnTransformOptimizer>>) -> Self {
33+
Self { optimizers }
34+
}
35+
36+
pub(super) fn add_optimizer(&mut self, optimizer: Box<dyn NewColumnTransformOptimizer>) {
37+
self.optimizers.push(optimizer);
38+
}
39+
}
40+
41+
/// A `NewColumnTransformOptimizer` that chains multiple `NewColumnTransformOptimizer`s together.
42+
impl NewColumnTransformOptimizer for ChainedNewColumnTransformOptimizer {
43+
fn optimize(
44+
&self,
45+
dataset: &Dataset,
46+
transform: NewColumnTransform,
47+
) -> Result<NewColumnTransform> {
48+
let mut transform = transform;
49+
for optimizer in &self.optimizers {
50+
transform = optimizer.optimize(dataset, transform)?;
51+
}
52+
Ok(transform)
53+
}
54+
}
55+
56+
/// Optimizes a `NewColumnTransform` that is a SQL expression to a `NewColumnTransform::AllNulls` if
57+
/// the SQL expression is "NULL". For example
58+
/// `NewColumnTransform::SqlExpression(vec![("new_col", "CAST(NULL AS int)"])`
59+
/// would be optimized to
60+
/// `NewColumnTransform::AllNulls(Schema::new(vec![Field::new("new_col", DataType::Int)]))`.
61+
///
62+
pub(super) struct SqlToAllNullsOptimizer;
63+
64+
impl SqlToAllNullsOptimizer {
65+
pub(super) fn new() -> Self {
66+
Self
67+
}
68+
69+
fn is_all_null(&self, expr: &Expr) -> AllNullsResult {
70+
match expr {
71+
Expr::Cast(cast) => {
72+
if matches!(cast.expr.as_ref(), Expr::Literal(ScalarValue::Null)) {
73+
let data_type = cast.data_type.clone();
74+
AllNullsResult::AllNulls(data_type)
75+
} else {
76+
AllNullsResult::NotAllNulls
77+
}
78+
}
79+
_ => AllNullsResult::NotAllNulls,
80+
}
81+
}
82+
}
83+
84+
enum AllNullsResult {
85+
AllNulls(DataType),
86+
NotAllNulls,
87+
}
88+
89+
impl NewColumnTransformOptimizer for SqlToAllNullsOptimizer {
90+
fn optimize(
91+
&self,
92+
dataset: &Dataset,
93+
transform: NewColumnTransform,
94+
) -> Result<NewColumnTransform> {
95+
match &transform {
96+
NewColumnTransform::SqlExpressions(expressions) => {
97+
let arrow_schema = Arc::new(Schema::from(dataset.schema()));
98+
let planner = Planner::new(arrow_schema);
99+
let mut all_null_schema_fields = vec![];
100+
for (name, expr) in expressions {
101+
let expr = planner.parse_expr(expr)?;
102+
if let AllNullsResult::AllNulls(data_type) = self.is_all_null(&expr) {
103+
let field = Field::new(name, data_type, true);
104+
all_null_schema_fields.push(field);
105+
} else {
106+
return Ok(transform);
107+
}
108+
}
109+
110+
let all_null_schema = Schema::new(all_null_schema_fields);
111+
Ok(NewColumnTransform::AllNulls(Arc::new(all_null_schema)))
112+
}
113+
_ => Ok(transform),
114+
}
115+
}
116+
}
117+
118+
#[cfg(test)]
119+
mod test {
120+
use super::*;
121+
122+
use arrow_array::RecordBatchIterator;
123+
124+
#[tokio::test]
125+
async fn test_sql_to_all_null_transform() {
126+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
127+
let empty_reader = RecordBatchIterator::new(vec![], schema.clone());
128+
let dataset = Arc::new(
129+
Dataset::write(empty_reader, "memory://", None)
130+
.await
131+
.unwrap(),
132+
);
133+
134+
let original = NewColumnTransform::SqlExpressions(vec![
135+
("new_col1".to_string(), "CAST(NULL AS int)".to_string()),
136+
("new_col2".to_string(), "CAST(NULL AS bigint)".to_string()),
137+
]);
138+
139+
let optimizer = SqlToAllNullsOptimizer::new();
140+
let result = optimizer.optimize(&dataset, original).unwrap();
141+
142+
assert!(matches!(result, NewColumnTransform::AllNulls(_)));
143+
if let NewColumnTransform::AllNulls(schema) = result {
144+
assert_eq!(schema.fields().len(), 2);
145+
assert_eq!(schema.field(0).name(), "new_col1");
146+
assert_eq!(schema.field(0).data_type(), &DataType::Int32);
147+
assert!(schema.field(0).is_nullable());
148+
assert_eq!(schema.field(1).name(), "new_col2");
149+
assert_eq!(schema.field(1).data_type(), &DataType::Int64);
150+
assert!(schema.field(1).is_nullable());
151+
}
152+
}
153+
}

0 commit comments

Comments
 (0)