diff --git a/src/daft-connect/src/functions.rs b/src/daft-connect/src/functions.rs index 444d67faae..d0351e3054 100644 --- a/src/daft-connect/src/functions.rs +++ b/src/daft-connect/src/functions.rs @@ -5,12 +5,13 @@ use daft_dsl::{ ExprRef, }; use once_cell::sync::Lazy; -use partition_transform::PartitionTransformFunctions; use spark_connect::Expression; use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer}; mod aggregate; mod core; +mod datetime; + mod math; mod partition_transform; mod string; @@ -19,8 +20,9 @@ pub(crate) static CONNECT_FUNCTIONS: Lazy = Lazy::new(|| { let mut functions = SparkFunctions::new(); functions.register::(); functions.register::(); + functions.register::(); functions.register::(); - functions.register::(); + functions.register::(); functions.register::(); functions }); @@ -104,8 +106,10 @@ impl SparkFunction for UnaryFunction { } } -struct Todo; -impl SparkFunction for Todo { +#[allow(non_camel_case_types)] +struct TODO_FUNCTION; + +impl SparkFunction for TODO_FUNCTION { fn to_expr( &self, _args: &[Expression], diff --git a/src/daft-connect/src/functions/core.rs b/src/daft-connect/src/functions/core.rs index 6140f54cea..6d4d38d614 100644 --- a/src/daft-connect/src/functions/core.rs +++ b/src/daft-connect/src/functions/core.rs @@ -3,7 +3,7 @@ use daft_functions::{coalesce::Coalesce, float::IsNan}; use daft_sql::sql_expr; use spark_connect::Expression; -use super::{FunctionModule, SparkFunction, Todo, UnaryFunction}; +use super::{FunctionModule, SparkFunction, UnaryFunction, TODO_FUNCTION}; use crate::{ error::{ConnectError, ConnectResult}, invalid_argument_err, @@ -32,27 +32,26 @@ impl FunctionModule for CoreFunctions { parent.add_fn("^", BinaryOpFunction(Operator::Xor)); parent.add_fn("<<", BinaryOpFunction(Operator::ShiftLeft)); parent.add_fn(">>", BinaryOpFunction(Operator::ShiftRight)); - // Normal Functions // https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#normal-functions parent.add_fn("coalesce", Coalesce {}); - parent.add_fn("input_file_name", Todo); + parent.add_fn("input_file_name", TODO_FUNCTION); parent.add_fn("isnan", IsNan {}); parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null())); - parent.add_fn("monotically_increasing_id", Todo); - parent.add_fn("named_struct", Todo); - parent.add_fn("nanvl", Todo); - parent.add_fn("rand", Todo); - parent.add_fn("randn", Todo); - parent.add_fn("spark_partition_id", Todo); - parent.add_fn("when", Todo); - parent.add_fn("bitwise_not", Todo); - parent.add_fn("bitwiseNOT", Todo); + parent.add_fn("monotically_increasing_id", TODO_FUNCTION); + parent.add_fn("named_struct", TODO_FUNCTION); + parent.add_fn("nanvl", TODO_FUNCTION); + parent.add_fn("rand", TODO_FUNCTION); + parent.add_fn("randn", TODO_FUNCTION); + parent.add_fn("spark_partition_id", TODO_FUNCTION); + parent.add_fn("when", TODO_FUNCTION); + parent.add_fn("bitwise_not", TODO_FUNCTION); + parent.add_fn("bitwiseNOT", TODO_FUNCTION); parent.add_fn("expr", SqlExpr); - parent.add_fn("greatest", Todo); - parent.add_fn("least", Todo); + parent.add_fn("greatest", TODO_FUNCTION); + parent.add_fn("least", TODO_FUNCTION); // parent.add_fn("isnan", UnaryFunction(|arg| arg.is_nan())); diff --git a/src/daft-connect/src/functions/datetime.rs b/src/daft-connect/src/functions/datetime.rs new file mode 100644 index 0000000000..a3c73a83c3 --- /dev/null +++ b/src/daft-connect/src/functions/datetime.rs @@ -0,0 +1,78 @@ +use daft_core::datatypes::TimeUnit; +use daft_functions::temporal::{Day, DayOfWeek, Hour, Minute, Month, Second, Year}; +use daft_schema::dtype::DataType; + +use super::{FunctionModule, UnaryFunction, TODO_FUNCTION}; + +/// https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#datetime-functions +pub struct DatetimeFunctions; + +impl FunctionModule for DatetimeFunctions { + fn register(parent: &mut super::SparkFunctions) { + parent.add_fn("add_months", TODO_FUNCTION); + parent.add_fn("convert_timezone", TODO_FUNCTION); + parent.add_fn("curdate", TODO_FUNCTION); + parent.add_fn("current_date", TODO_FUNCTION); + parent.add_fn("current_timestamp", TODO_FUNCTION); + parent.add_fn("current_timezone", TODO_FUNCTION); + parent.add_fn("date_add", TODO_FUNCTION); + parent.add_fn("date_diff", TODO_FUNCTION); + parent.add_fn("date_format", TODO_FUNCTION); + parent.add_fn("date_from_unix_date", TODO_FUNCTION); + parent.add_fn("date_part", TODO_FUNCTION); + parent.add_fn("date_sub", TODO_FUNCTION); + parent.add_fn("date_trunc", TODO_FUNCTION); + parent.add_fn("dateadd", TODO_FUNCTION); + parent.add_fn("datediff", TODO_FUNCTION); + parent.add_fn("datepart", TODO_FUNCTION); + parent.add_fn("day", Day); + parent.add_fn("dayofmonth", TODO_FUNCTION); + parent.add_fn("dayofweek", DayOfWeek); + parent.add_fn("dayofyear", TODO_FUNCTION); + parent.add_fn("extract", TODO_FUNCTION); + parent.add_fn("from_unixtime", TODO_FUNCTION); + parent.add_fn("from_utc_timestamp", TODO_FUNCTION); + parent.add_fn("hour", Hour); + parent.add_fn("last_day", TODO_FUNCTION); + parent.add_fn("localtimestamp", TODO_FUNCTION); + parent.add_fn("make_date", TODO_FUNCTION); + parent.add_fn("make_dt_interval", TODO_FUNCTION); + parent.add_fn("make_interval", TODO_FUNCTION); + parent.add_fn("make_timestamp", TODO_FUNCTION); + parent.add_fn("make_timestamp_ltz", TODO_FUNCTION); + parent.add_fn("make_timestamp_ntz", TODO_FUNCTION); + parent.add_fn("make_ym_interval", TODO_FUNCTION); + parent.add_fn("minute", Minute); + parent.add_fn("month", Month); + parent.add_fn("months_between", TODO_FUNCTION); + parent.add_fn("next_day", TODO_FUNCTION); + parent.add_fn("now", TODO_FUNCTION); + parent.add_fn("quarter", TODO_FUNCTION); + parent.add_fn("second", Second); + parent.add_fn("session_window", TODO_FUNCTION); + parent.add_fn("timestamp_micros", TODO_FUNCTION); + parent.add_fn("timestamp_millis", TODO_FUNCTION); + parent.add_fn("timestamp_seconds", TODO_FUNCTION); + parent.add_fn("to_date", UnaryFunction(|arg| arg.cast(&DataType::Date))); + parent.add_fn( + "to_timestamp", + UnaryFunction(|arg| arg.cast(&DataType::Timestamp(TimeUnit::Milliseconds, None))), + ); + parent.add_fn("to_timestamp_ltz", TODO_FUNCTION); + parent.add_fn("to_timestamp_ntz", TODO_FUNCTION); + parent.add_fn("to_unix_timestamp", TODO_FUNCTION); + parent.add_fn("to_utc_timestamp", TODO_FUNCTION); + parent.add_fn("trunc", TODO_FUNCTION); + parent.add_fn("try_to_timestamp", TODO_FUNCTION); + parent.add_fn("unix_date", TODO_FUNCTION); + parent.add_fn("unix_micros", TODO_FUNCTION); + parent.add_fn("unix_millis", TODO_FUNCTION); + parent.add_fn("unix_seconds", TODO_FUNCTION); + parent.add_fn("unix_timestamp", TODO_FUNCTION); + parent.add_fn("weekday", TODO_FUNCTION); + parent.add_fn("weekofyear", TODO_FUNCTION); + parent.add_fn("window", TODO_FUNCTION); + parent.add_fn("window_time", TODO_FUNCTION); + parent.add_fn("year", Year); + } +} diff --git a/src/daft-connect/src/functions/math.rs b/src/daft-connect/src/functions/math.rs index 503391def0..ab934a8cb2 100644 --- a/src/daft-connect/src/functions/math.rs +++ b/src/daft-connect/src/functions/math.rs @@ -15,7 +15,7 @@ use daft_functions::numeric::{ }; use spark_connect::Expression; -use super::{FunctionModule, SparkFunction, Todo}; +use super::{FunctionModule, SparkFunction, TODO_FUNCTION}; use crate::{ error::{ConnectError, ConnectResult}, invalid_argument_err, @@ -36,60 +36,60 @@ impl FunctionModule for MathFunctions { parent.add_fn("atan", ArcTan); parent.add_fn("atanh", ArcTanh); parent.add_fn("atan2", Atan2 {}); - parent.add_fn("bin", Todo); + parent.add_fn("bin", TODO_FUNCTION); parent.add_fn("cbrt", Cbrt {}); parent.add_fn("ceil", Ceil {}); parent.add_fn("ceiling", Ceil {}); - parent.add_fn("conv", Todo); + parent.add_fn("conv", TODO_FUNCTION); parent.add_fn("cos", Cos {}); - parent.add_fn("cosh", Todo); + parent.add_fn("cosh", TODO_FUNCTION); parent.add_fn("cot", Cot {}); - parent.add_fn("csc", Todo); - parent.add_fn("e", Todo); + parent.add_fn("csc", TODO_FUNCTION); + parent.add_fn("e", TODO_FUNCTION); parent.add_fn("exp", Exp {}); - parent.add_fn("expm1", Todo); - parent.add_fn("factorial", Todo); + parent.add_fn("expm1", TODO_FUNCTION); + parent.add_fn("factorial", TODO_FUNCTION); parent.add_fn("floor", Floor {}); - parent.add_fn("hex", Todo); - parent.add_fn("unhex", Todo); - parent.add_fn("hypot", Todo); + parent.add_fn("hex", TODO_FUNCTION); + parent.add_fn("unhex", TODO_FUNCTION); + parent.add_fn("hypot", TODO_FUNCTION); parent.add_fn("ln", Ln {}); parent.add_fn("log", LogFunction); parent.add_fn("log10", Log10 {}); - parent.add_fn("log1p", Todo); + parent.add_fn("log1p", TODO_FUNCTION); parent.add_fn("log2", Log2 {}); - parent.add_fn("negate", Todo); - parent.add_fn("negative", Todo); - parent.add_fn("pi", Todo); - parent.add_fn("pmod", Todo); - parent.add_fn("positive", Todo); - parent.add_fn("pow", Todo); - parent.add_fn("power", Todo); - parent.add_fn("rint", Todo); + parent.add_fn("negate", TODO_FUNCTION); + parent.add_fn("negative", TODO_FUNCTION); + parent.add_fn("pi", TODO_FUNCTION); + parent.add_fn("pmod", TODO_FUNCTION); + parent.add_fn("positive", TODO_FUNCTION); + parent.add_fn("pow", TODO_FUNCTION); + parent.add_fn("power", TODO_FUNCTION); + parent.add_fn("rint", TODO_FUNCTION); parent.add_fn("round", RoundFunction); - parent.add_fn("bround", Todo); - parent.add_fn("sec", Todo); - parent.add_fn("shiftleft", Todo); - parent.add_fn("shiftright", Todo); - parent.add_fn("sign", Todo); - parent.add_fn("signum", Todo); + parent.add_fn("bround", TODO_FUNCTION); + parent.add_fn("sec", TODO_FUNCTION); + parent.add_fn("shiftleft", TODO_FUNCTION); + parent.add_fn("shiftright", TODO_FUNCTION); + parent.add_fn("sign", TODO_FUNCTION); + parent.add_fn("signum", TODO_FUNCTION); parent.add_fn("sin", Sin {}); - parent.add_fn("sinh", Todo); + parent.add_fn("sinh", TODO_FUNCTION); parent.add_fn("tan", Tan {}); - parent.add_fn("tanh", Todo); - parent.add_fn("toDegrees", Todo); - parent.add_fn("try_add", Todo); - parent.add_fn("try_avg", Todo); - parent.add_fn("try_divide", Todo); - parent.add_fn("try_multiply", Todo); - parent.add_fn("try_subtract", Todo); - parent.add_fn("try_sum", Todo); - parent.add_fn("try_to_binary", Todo); - parent.add_fn("try_to_number", Todo); + parent.add_fn("tanh", TODO_FUNCTION); + parent.add_fn("toDegrees", TODO_FUNCTION); + parent.add_fn("try_add", TODO_FUNCTION); + parent.add_fn("try_avg", TODO_FUNCTION); + parent.add_fn("try_divide", TODO_FUNCTION); + parent.add_fn("try_multiply", TODO_FUNCTION); + parent.add_fn("try_subtract", TODO_FUNCTION); + parent.add_fn("try_sum", TODO_FUNCTION); + parent.add_fn("try_to_binary", TODO_FUNCTION); + parent.add_fn("try_to_number", TODO_FUNCTION); parent.add_fn("degrees", Degrees {}); - parent.add_fn("toRadians", Todo); + parent.add_fn("toRadians", TODO_FUNCTION); parent.add_fn("radians", Radians {}); - parent.add_fn("width_bucket", Todo); + parent.add_fn("width_bucket", TODO_FUNCTION); // } } diff --git a/src/daft-connect/src/functions/partition_transform.rs b/src/daft-connect/src/functions/partition_transform.rs index d78e421794..e045a40880 100644 --- a/src/daft-connect/src/functions/partition_transform.rs +++ b/src/daft-connect/src/functions/partition_transform.rs @@ -8,6 +8,7 @@ use crate::{ spark_analyzer::SparkAnalyzer, }; +// https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#partition-transformation-functions pub struct PartitionTransformFunctions; impl FunctionModule for PartitionTransformFunctions { diff --git a/src/daft-connect/src/functions/string.rs b/src/daft-connect/src/functions/string.rs index 064b70fb96..010ff485a7 100644 --- a/src/daft-connect/src/functions/string.rs +++ b/src/daft-connect/src/functions/string.rs @@ -6,7 +6,7 @@ use daft_functions::utf8::{ }; use spark_connect::Expression; -use super::{FunctionModule, SparkFunction, Todo}; +use super::{FunctionModule, SparkFunction, TODO_FUNCTION}; use crate::{error::ConnectResult, invalid_argument_err, spark_analyzer::SparkAnalyzer}; // see https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html#string-functions @@ -14,73 +14,73 @@ pub struct StringFunctions; impl FunctionModule for StringFunctions { fn register(parent: &mut super::SparkFunctions) { - parent.add_fn("ascii", Todo); - parent.add_fn("base64", Todo); - parent.add_fn("bit_length", Todo); - parent.add_fn("btrim", Todo); - parent.add_fn("char", Todo); + parent.add_fn("ascii", TODO_FUNCTION); + parent.add_fn("base64", TODO_FUNCTION); + parent.add_fn("bit_length", TODO_FUNCTION); + parent.add_fn("btrim", TODO_FUNCTION); + parent.add_fn("char", TODO_FUNCTION); parent.add_fn("character_length", Utf8Length {}); parent.add_fn("char_length", Utf8Length {}); - parent.add_fn("concat_ws", Todo); + parent.add_fn("concat_ws", TODO_FUNCTION); parent.add_fn("contains", daft_functions::utf8::Utf8Contains {}); - parent.add_fn("decode", Todo); - parent.add_fn("elt", Todo); + parent.add_fn("decode", TODO_FUNCTION); + parent.add_fn("elt", TODO_FUNCTION); parent.add_fn("encode", Utf8Endswith {}); - parent.add_fn("endswith", Todo); - parent.add_fn("find_in_set", Todo); - parent.add_fn("format_number", Todo); - parent.add_fn("format_string", Todo); + parent.add_fn("endswith", TODO_FUNCTION); + parent.add_fn("find_in_set", TODO_FUNCTION); + parent.add_fn("format_number", TODO_FUNCTION); + parent.add_fn("format_string", TODO_FUNCTION); parent.add_fn("ilike", Utf8Ilike {}); - parent.add_fn("initcap", Todo); - parent.add_fn("instr", Todo); - parent.add_fn("lcase", Todo); + parent.add_fn("initcap", TODO_FUNCTION); + parent.add_fn("instr", TODO_FUNCTION); + parent.add_fn("lcase", TODO_FUNCTION); parent.add_fn("length", Utf8LengthBytes {}); parent.add_fn("like", Utf8Like {}); parent.add_fn("lower", Utf8Lower {}); parent.add_fn("left", Utf8Left {}); - parent.add_fn("levenshtein", Todo); - parent.add_fn("locate", Todo); + parent.add_fn("levenshtein", TODO_FUNCTION); + parent.add_fn("locate", TODO_FUNCTION); parent.add_fn("lpad", Utf8Lpad {}); - parent.add_fn("ltrim", Todo); - parent.add_fn("mask", Todo); - parent.add_fn("octet_length", Todo); - parent.add_fn("parse_url", Todo); - parent.add_fn("position", Todo); - parent.add_fn("printf", Todo); - parent.add_fn("rlike", Todo); - parent.add_fn("regexp", Todo); - parent.add_fn("regexp_like", Todo); - parent.add_fn("regexp_count", Todo); + parent.add_fn("ltrim", TODO_FUNCTION); + parent.add_fn("mask", TODO_FUNCTION); + parent.add_fn("octet_length", TODO_FUNCTION); + parent.add_fn("parse_url", TODO_FUNCTION); + parent.add_fn("position", TODO_FUNCTION); + parent.add_fn("printf", TODO_FUNCTION); + parent.add_fn("rlike", TODO_FUNCTION); + parent.add_fn("regexp", TODO_FUNCTION); + parent.add_fn("regexp_like", TODO_FUNCTION); + parent.add_fn("regexp_count", TODO_FUNCTION); parent.add_fn("regexp_extract", RegexpExtract); parent.add_fn("regexp_extract_all", RegexpExtractAll); parent.add_fn("regexp_replace", Utf8Replace { regex: true }); - parent.add_fn("regexp_substr", Todo); - parent.add_fn("regexp_instr", Todo); + parent.add_fn("regexp_substr", TODO_FUNCTION); + parent.add_fn("regexp_instr", TODO_FUNCTION); parent.add_fn("replace", Utf8Replace { regex: false }); parent.add_fn("right", Utf8Right {}); - parent.add_fn("ucase", Todo); - parent.add_fn("unbase64", Todo); + parent.add_fn("ucase", TODO_FUNCTION); + parent.add_fn("unbase64", TODO_FUNCTION); parent.add_fn("rpad", Utf8Rpad {}); - parent.add_fn("repeat", Todo); - parent.add_fn("rtrim", Todo); - parent.add_fn("soundex", Todo); + parent.add_fn("repeat", TODO_FUNCTION); + parent.add_fn("rtrim", TODO_FUNCTION); + parent.add_fn("soundex", TODO_FUNCTION); parent.add_fn("split", Utf8Split { regex: false }); - parent.add_fn("split_part", Todo); + parent.add_fn("split_part", TODO_FUNCTION); parent.add_fn("startswith", Utf8Startswith {}); parent.add_fn("substr", Utf8Substr {}); parent.add_fn("substring", Utf8Substr {}); - parent.add_fn("substring_index", Todo); - parent.add_fn("overlay", Todo); - parent.add_fn("sentences", Todo); - parent.add_fn("to_binary", Todo); - parent.add_fn("to_char", Todo); - parent.add_fn("to_number", Todo); - parent.add_fn("to_varchar", Todo); - parent.add_fn("translate", Todo); - parent.add_fn("trim", Todo); + parent.add_fn("substring_index", TODO_FUNCTION); + parent.add_fn("overlay", TODO_FUNCTION); + parent.add_fn("sentences", TODO_FUNCTION); + parent.add_fn("to_binary", TODO_FUNCTION); + parent.add_fn("to_char", TODO_FUNCTION); + parent.add_fn("to_number", TODO_FUNCTION); + parent.add_fn("to_varchar", TODO_FUNCTION); + parent.add_fn("translate", TODO_FUNCTION); + parent.add_fn("trim", TODO_FUNCTION); parent.add_fn("upper", Utf8Upper {}); - parent.add_fn("url_decode", Todo); - parent.add_fn("url_encode", Todo); + parent.add_fn("url_decode", TODO_FUNCTION); + parent.add_fn("url_encode", TODO_FUNCTION); } }