diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9fa42fb --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,90 @@ +name: CI + +on: + push: + branches: + - main + tags: + - '**' + pull_request: {} + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + + - id: cache-rust + uses: Swatinem/rust-cache@v2 + + - uses: pre-commit/action@v3.0.0 + with: + extra_args: --all-files --verbose + env: + PRE_COMMIT_COLOR: always + SKIP: test + + test: + name: test rust-${{ matrix.rust-version }} + strategy: + fail-fast: false + matrix: + rust-version: [stable, nightly] + + runs-on: ubuntu-latest + + env: + RUST_VERSION: ${{ matrix.rust-version }} + + steps: + - uses: actions/checkout@v3 + + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust-version }} + + - id: cache-rust + uses: Swatinem/rust-cache@v2 + + - run: cargo test --all-features +# - uses: taiki-e/install-action@cargo-llvm-cov +# +# - run: cargo llvm-cov --all-features --codecov --output-path codecov.json +# +# - uses: codecov/codecov-action@v3 +# with: +# files: codecov.json +# env_vars: RUST_VERSION + + # https://github.com/marketplace/actions/alls-green#why used for branch protection checks + check: + if: always() + needs: [test, lint] + runs-on: ubuntu-latest + steps: + - name: Decide whether the needed jobs succeeded or failed + uses: re-actors/alls-green@release/v1 + with: + jobs: ${{ toJSON(needs) }} + + release: + needs: [check] + if: "success() && startsWith(github.ref, 'refs/tags/')" + runs-on: ubuntu-latest + environment: release + + steps: + - uses: actions/checkout@v2 + + - name: install rust stable + uses: dtolnay/rust-toolchain@stable + + - uses: Swatinem/rust-cache@v2 + + - run: cargo publish + env: + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..ec099d4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,32 @@ +fail_fast: true + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-yaml + - id: check-toml + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-added-large-files + +- repo: local + hooks: + - id: format-check + name: Format Check + entry: cargo fmt + types: [rust] + language: system + pass_filenames: false + - id: clippy + name: Clippy + entry: cargo clippy + types: [rust] + language: system + pass_filenames: false + - id: test + name: Test + entry: cargo test + types: [rust] + language: system + pass_filenames: false diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..7530651 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1 @@ +max_width = 120 diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8438382 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "datafusion-functions-json" +version = "0.1.0" +edition = "2021" + +[dependencies] +arrow-schema = "51.0.0" +datafusion-common = "37.0.0" +datafusion-expr = "37.0.0" +jiter = { git = "https://github.com/pydantic/jiter.git", branch = "next_skip" } +paste = "1.0.14" +log = "0.4.21" +datafusion-execution = "37.0.0" + +[dev-dependencies] +arrow = "51.0.0" +datafusion = "37.0.0" +tokio = { version = "1.37.0", features = ["full"] } + +[lints.clippy] +dbg_macro = "warn" +print_stdout = "warn" + +# in general we lint against the pedantic group, but we will whitelist +# certain lints which we don't want to enforce (for now) +pedantic = { level = "warn", priority = -1 } +missing_errors_doc = "allow" diff --git a/README.md b/README.md new file mode 100644 index 0000000..d6cce16 --- /dev/null +++ b/README.md @@ -0,0 +1,18 @@ +# datafusion-functions-json + +methods to implement: + +* [x] `json_obj_contains(json: str, key: str) -> bool` - true if a JSON object has a specific key +* [ ] `json_obj_contains_all(json: str, keys: list[str]) -> bool` - true if a JSON object has all of a list of keys +* [ ] `json_obj_contains_any(json: str, keys: list[str]) -> bool` - true if a JSON object has all of a list of keys +* [ ] `json_obj_keys(json: str) -> list[str]` - get the keys of a JSON object +* [ ] `json_obj_values(json: str) -> list[Any]` - get the values of a JSON object +* [ ] `json_is_obj(json: str) -> bool` - true if the JSON is an object +* [ ] `json_array_contains(json: str, key: Any) -> bool` - true if a JSON array has a specific value +* [ ] `json_array_items(json: str) -> list[Any]` - get the items of a JSON array +* [ ] `json_is_array(json: str) -> bool` - true if the JSON is an array +* [ ] `json_get(json: str, key: str | int) -> Any` - get the value of a key in a JSON object or array +* [ ] `json_get_path(json: str, key: list[str | int]) -> Any` - is this possible? +* [ ] `json_length(json: str) -> int` - get the length of a JSON object or array +* [ ] `json_valid(json: str) -> bool` - true if the JSON is valid +* [ ] `json_cast(json: str) -> Any` - cast the JSON to a native type??? diff --git a/src/json_obj_contains.rs b/src/json_obj_contains.rs new file mode 100644 index 0000000..aa9d619 --- /dev/null +++ b/src/json_obj_contains.rs @@ -0,0 +1,103 @@ +use crate::macros::make_udf_function; +use arrow_schema::DataType; +use arrow_schema::DataType::{LargeUtf8, Utf8}; +use datafusion_common::arrow::array::{as_string_array, ArrayRef, BooleanArray}; +use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::Jiter; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + JsonObjContains, + json_obj_contains, + json_data key, // arg name + "Does the string exist as a top-level key within the JSON value?", // doc + json_obj_contains_udf // internal function name +); + +#[derive(Debug)] +pub(super) struct JsonObjContains { + signature: Signature, + aliases: Vec, +} + +impl JsonObjContains { + pub fn new() -> Self { + Self { + signature: Signature::uniform(2, vec![Utf8, LargeUtf8], Volatility::Immutable), + aliases: vec!["json_obj_contains".to_string(), "json_object_contains".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonObjContains { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_obj_contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + Utf8 | LargeUtf8 => Ok(DataType::Boolean), + _ => { + plan_err!("The json_obj_contains function can only accept Utf8 or LargeUtf8.") + } + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let json_haystack = match &args[0] { + ColumnarValue::Array(array) => as_string_array(array), + ColumnarValue::Scalar(_) => { + return exec_err!("json_obj_contains first argument: unexpected argument type, expected string array") + } + }; + + let needle = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.clone(), + _ => return exec_err!("json_obj_contains second argument: unexpected argument type, expected string"), + }; + + let array = json_haystack + .iter() + .map(|opt_json| opt_json.map(|json| jiter_json_contains(json.as_bytes(), &needle))) + .collect::(); + + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_contains(json_data: &[u8], expected_key: &str) -> bool { + let mut jiter = Jiter::new(json_data, false); + let Ok(Some(first_key)) = jiter.next_object() else { + return false; + }; + + if first_key == expected_key { + return true; + } + if jiter.next_skip().is_err() { + return false; + } + while let Ok(Some(key)) = jiter.next_key() { + if key == expected_key { + return true; + } + if jiter.next_skip().is_err() { + return false; + } + } + false +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..09be0f9 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,26 @@ +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use datafusion_expr::ScalarUDF; +use log::debug; +use std::sync::Arc; + +mod json_obj_contains; +mod macros; + +pub mod functions { + pub use crate::json_obj_contains::json_obj_contains; +} + +/// Register all JSON UDFs +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let functions: Vec> = vec![json_obj_contains::json_obj_contains_udf()]; + functions.into_iter().try_for_each(|udf| { + let existing_udf = registry.register_udf(udf)?; + if let Some(existing_udf) = existing_udf { + debug!("Overwrite existing UDF: {}", existing_udf.name()); + } + Ok(()) as Result<()> + })?; + + Ok(()) +} diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 0000000..240c0da --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,98 @@ +#[allow(clippy::doc_markdown)] +/// Currently copied verbatim, can hopefully be replaced or simplified +/// https://github.com/apache/datafusion/blob/19356b26f515149f96f9b6296975a77ac7260149/datafusion/functions-array/src/macros.rs +/// +/// Creates external API functions for an array UDF. Specifically, creates +/// +/// 1. Single `ScalarUDF` instance +/// +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a +/// function named `$NAME` which returns that function named $NAME. +/// +/// This is used to ensure creating the list of `ScalarUDF` only happens once. +/// +/// # 2. `expr_fn` style function +/// +/// These are functions that create an `Expr` that invokes the UDF, used +/// primarily to programmatically create expressions. +/// +/// For example: +/// ```text +/// pub fn array_to_string(delimiter: Expr) -> Expr { +/// ... +/// } +/// ``` +/// # Arguments +/// * `UDF`: name of the [`ScalarUDFImpl`] +/// * `EXPR_FN`: name of the `expr_fn` function to be created +/// * `arg`: 0 or more named arguments for the function +/// * `DOC`: documentation string for the function +/// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` +/// * `GNAME`: name for the single static instance of the `ScalarUDF` +/// +/// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl +macro_rules! make_udf_function { + ($UDF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr , $SCALAR_UDF_FN:ident) => { + paste::paste! { + // "fluent expr_fn" style function + #[doc = $DOC] + #[must_use] pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( + $SCALAR_UDF_FN(), + vec![$($arg),*], + )) + } + + /// Singleton instance of [`$UDF`], ensures the UDF is only created once + /// named STATIC_$(UDF). For example `STATIC_ArrayToString` + #[allow(non_upper_case_globals)] + static [< STATIC_ $UDF >]: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] + /// + /// [`ScalarUDF`]: datafusion_expr::ScalarUDF + pub fn $SCALAR_UDF_FN() -> std::sync::Arc { + [< STATIC_ $UDF >] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( + <$UDF>::new(), + )) + }) + .clone() + } + } + }; + ($UDF:ty, $EXPR_FN:ident, $DOC:expr , $SCALAR_UDF_FN:ident) => { + paste::paste! { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { + datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( + $SCALAR_UDF_FN(), + arg, + )) + } + + /// Singleton instance of [`$UDF`], ensures the UDF is only created once + /// named STATIC_$(UDF). For example `STATIC_ArrayToString` + #[allow(non_upper_case_globals)] + static [< STATIC_ $UDF >]: std::sync::OnceLock> = + std::sync::OnceLock::new(); + /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] + /// + /// [`ScalarUDF`]: datafusion_expr::ScalarUDF + pub fn $SCALAR_UDF_FN() -> std::sync::Arc { + [< STATIC_ $UDF >] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( + <$UDF>::new(), + )) + }) + .clone() + } + } + }; +} + +pub(crate) use make_udf_function; diff --git a/tests/test_json_obj_contains.rs b/tests/test_json_obj_contains.rs new file mode 100644 index 0000000..20c704c --- /dev/null +++ b/tests/test_json_obj_contains.rs @@ -0,0 +1,69 @@ +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::{array::StringArray, record_batch::RecordBatch}; +use std::sync::Arc; + +use datafusion::assert_batches_eq; +use datafusion::error::Result; +use datafusion::execution::context::SessionContext; +use datafusion_functions_json::register_all; + +async fn create_test_table() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("json_data", DataType::Utf8, false), + ])); + + let data = [ + ("object_foo", r#" {"foo": 123} "#), + ("object_bar", r#" {"bar": true} "#), + ("list_foo", r#" ["foo"] "#), + ("invalid_json", "is not json"), + ]; + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(StringArray::from( + data.iter().map(|(name, _)| *name).collect::>(), + )), + Arc::new(StringArray::from( + data.iter().map(|(_, json)| *json).collect::>(), + )), + ], + )?; + + let mut ctx = SessionContext::new(); + register_all(&mut ctx)?; + ctx.register_batch("test", batch)?; + Ok(ctx) +} + +/// Executes an expression on the test dataframe as a select. +/// Compares formatted output of a record batch with an expected +/// vector of strings, using the `assert_batch_eq`! macro +macro_rules! query { + ($sql:expr, $expected: expr) => { + let ctx = create_test_table().await?; + let df = ctx.sql($sql).await?; + let batches = df.collect().await?; + + assert_batches_eq!($expected, &batches); + }; +} + +#[tokio::test] +async fn test_json_obj_contains() -> Result<()> { + let expected = [ + "+--------------+-----------------------------------------------+", + "| name | json_obj_contains(test.json_data,Utf8(\"foo\")) |", + "+--------------+-----------------------------------------------+", + "| object_foo | true |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+--------------+-----------------------------------------------+", + ]; + + query!("select name, json_obj_contains(json_data, 'foo') from test", expected); + Ok(()) +}