Skip to content

Commit

Permalink
feat: Adds list constructor to Expression and SQL APIs (#3737)
Browse files Browse the repository at this point in the history
This PR adds the ability to "merge" multiple series into a series of
lists which enables construction of lists via _daft expressions_ rather
than _python expressions_ via lit.
  • Loading branch information
rchowell authored Jan 31, 2025
1 parent 47a2ece commit 63ffd5e
Show file tree
Hide file tree
Showing 20 changed files with 498 additions and 74 deletions.
3 changes: 2 additions & 1 deletion daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def refresh_logger() -> None:
from daft.dataframe import DataFrame
from daft.logical.schema import Schema
from daft.datatype import DataType, TimeUnit
from daft.expressions import Expression, col, lit, interval, coalesce
from daft.expressions import Expression, col, list_, lit, interval, coalesce
from daft.io import (
DataCatalogTable,
DataCatalogType,
Expand Down Expand Up @@ -116,6 +116,7 @@ def refresh_logger() -> None:
"from_pylist",
"from_ray_dataset",
"interval",
"list_",
"lit",
"planning_config_ctx",
"read_csv",
Expand Down
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ class PyExpr:
def eq(expr1: PyExpr, expr2: PyExpr) -> bool: ...
def col(name: str) -> PyExpr: ...
def lit(item: Any) -> PyExpr: ...
def list_(items: list[PyExpr]) -> PyExpr: ...
def date_lit(item: int) -> PyExpr: ...
def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ...
def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
Expand Down
4 changes: 2 additions & 2 deletions daft/expressions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

from .expressions import Expression, ExpressionsProjection, col, lit, interval, coalesce
from .expressions import Expression, ExpressionsProjection, col, list_, lit, interval, coalesce

__all__ = ["Expression", "ExpressionsProjection", "coalesce", "col", "interval", "lit"]
__all__ = ["Expression", "ExpressionsProjection", "coalesce", "col", "interval", "list_", "lit"]
32 changes: 32 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,38 @@ def col(name: str) -> Expression:
return Expression._from_pyexpr(_col(name))


def list_(*items: Expression | str):
"""Constructs a list from the item expressions.
Example:
>>> import daft
>>> df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
>>> df = df.select(daft.list_("x", "y").alias("fwd"), daft.list_("y", "x").alias("rev"))
>>> df.show()
╭─────────────┬─────────────╮
│ fwd ┆ rev │
│ --- ┆ --- │
│ List[Int64] ┆ List[Int64] │
╞═════════════╪═════════════╡
│ [1, 4] ┆ [4, 1] │
├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [2, 5] ┆ [5, 2] │
├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ [3, 6] ┆ [6, 3] │
╰─────────────┴─────────────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
Args:
*items (Union[Expression, str]): item expressions to construct the list
Returns:
Expression: Expression representing the constructed list
"""
assert len(items) > 0, "List constructor requires at least one item"
return Expression._from_pyexpr(native.list_([col(i)._expr if isinstance(i, str) else i._expr for i in items]))


def interval(
years: int | None = None,
months: int | None = None,
Expand Down
1 change: 1 addition & 0 deletions docs/sphinx/source/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Constructors

col
lit
list_

Generic
#######
Expand Down
5 changes: 5 additions & 0 deletions src/daft-core/src/array/growable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ pub trait Growable {
/// Extends this [`Growable`] with null elements
fn add_nulls(&mut self, additional: usize);

/// Extends this [`Growable`] with null elements (same as add_nulls with arrow naming convention).
fn extend_nulls(&mut self, len: usize) {
self.add_nulls(len);
}

/// Builds an array from the [`Growable`]
fn build(&mut self) -> DaftResult<Series>;
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub mod take;
pub mod time;
mod trigonometry;
pub mod utf8;
pub mod zip;

pub fn cast_series_to_supertype(series: &[&Series]) -> DaftResult<Vec<Series>> {
let supertype = series
Expand Down
90 changes: 90 additions & 0 deletions src/daft-core/src/series/ops/zip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use std::cmp::{max, min};

use arrow2::offset::Offsets;
use common_error::{DaftError, DaftResult};
use daft_schema::{dtype::DataType, field::Field};

use crate::{
array::{growable::make_growable, ListArray},
series::{IntoSeries, Series},
};

impl Series {
/// Zips series into a single series of lists.
/// ex:
/// ```text
/// A: Series := ( a_0, a_1, .. , a_n )
/// B: Series := ( b_0, b_1, .. , b_n )
/// C: Series := Zip(A, B) <-> ( [a_0, b_0], [a_1, b_1], [a_2, b_2] )
/// ```
pub fn zip(field: Field, series: &[&Self]) -> DaftResult<Self> {
// err if no series to zip
if series.is_empty() {
return Err(DaftError::ValueError(
"Need at least 1 series to perform zip".to_string(),
));
}

// homogeneity checks naturally happen in make_growable's downcast.
let dtype = match &field.dtype {
DataType::List(dtype) => dtype.as_ref(),
DataType::FixedSizeList(..) => {
return Err(DaftError::ValueError(
"Fixed size list constructor is currently not supported".to_string(),
));
}
_ => {
return Err(DaftError::ValueError(
"Cannot zip field with non-list type".to_string(),
));
}
};

// 0 -> index of child in 'arrays' vector
// 1 -> last index of child
type Child = (usize, usize);

// build a null series mask so we can skip making full_nulls and avoid downcast "Null to T" errors.
let mut mask: Vec<Option<Child>> = vec![];
let mut rows = 0;
let mut capacity = 0;
let mut arrays = vec![];

for s in series {
let len = s.len();
if is_null(s) {
mask.push(None);
} else {
mask.push(Some((arrays.len(), len - 1)));
arrays.push(*s);
}
rows = max(rows, len);
capacity += len;
}

// initialize a growable child
let mut offsets = Offsets::<i64>::with_capacity(capacity);
let mut child = make_growable("list", dtype, arrays, true, capacity);
let sublist_len = series.len() as i64;

// merge each series based upon the mask
for row in 0..rows {
for i in &mask {
if let Some((i, end)) = *i {
child.extend(i, min(row, end), 1);
} else {
child.extend_nulls(1);
}
}
offsets.try_push(sublist_len)?;
}

// create the outer array with offsets
Ok(ListArray::new(field, child.build()?, offsets.into(), None).into_series())
}
}

/// Same null check logic as in Series::concat, but may need an audit since there are other is_null impls.
fn is_null(series: &&Series) -> bool {
series.data_type() == &DataType::Null
}
53 changes: 53 additions & 0 deletions src/daft-dsl/src/expr/display.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use std::fmt::Write;

use itertools::Itertools;

use super::{Expr, ExprRef, Operator};

/// Display for Expr::BinaryOp
pub fn expr_binary_op_display_without_formatter(
op: &Operator,
left: &ExprRef,
right: &ExprRef,
) -> std::result::Result<String, std::fmt::Error> {
let mut f = String::default();
let write_out_expr = |f: &mut String, input: &Expr| match input {
Expr::Alias(e, _) => write!(f, "{e}"),
Expr::BinaryOp { .. } => write!(f, "[{input}]"),
_ => write!(f, "{input}"),
};
write_out_expr(&mut f, left)?;
write!(&mut f, " {op} ")?;
write_out_expr(&mut f, right)?;
Ok(f)
}

/// Display for Expr::IsIn
pub fn expr_is_in_display_without_formatter(
expr: &ExprRef,
inputs: &[ExprRef],
) -> std::result::Result<String, std::fmt::Error> {
let mut f = String::default();
write!(&mut f, "{expr} IN (")?;
for (i, input) in inputs.iter().enumerate() {
if i != 0 {
write!(&mut f, ", ")?;
}
write!(&mut f, "{input}")?;
}
write!(&mut f, ")")?;
Ok(f)
}

/// Display for Expr::List
pub fn expr_list_display_without_formatter(
items: &[ExprRef],
) -> std::result::Result<String, std::fmt::Error> {
let mut f = String::default();
write!(
&mut f,
"list({})",
items.iter().map(|x| x.to_string()).join(", ")
)?;
Ok(f)
}
Loading

0 comments on commit 63ffd5e

Please sign in to comment.