diff --git a/Cargo.lock b/Cargo.lock index 9decb8d10bc47..575739747ad89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -733,6 +733,18 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "arrow-udf-js" +version = "0.1.1" +source = "git+https://github.com/risingwavelabs/arrow-udf?rev=6c32f71#6c32f710b5948147f8214797fc334a4a3cadef0d" +dependencies = [ + "anyhow", + "arrow-array 50.0.0", + "arrow-buffer 50.0.0", + "arrow-schema 50.0.0", + "rquickjs", +] + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -4353,6 +4365,7 @@ dependencies = [ "arrow-flight", "arrow-ipc 50.0.0", "arrow-schema 50.0.0", + "arrow-udf-js", "async-backtrace", "async-channel 1.9.0", "async-stream", @@ -8241,7 +8254,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2caa5afb8bf9f3a2652760ce7d4f62d21c4d5a423e68466fca30df82f2330164" dependencies = [ "cfg-if", - "windows-targets 0.52.4", + "windows-targets 0.48.5", ] [[package]] @@ -11289,6 +11302,33 @@ dependencies = [ "serde", ] +[[package]] +name = "rquickjs" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad7f63201fa6f2ff8173e4758ea552549d687d8f63003361a8b5c50f7c446ded" +dependencies = [ + "rquickjs-core", +] + +[[package]] +name = "rquickjs-core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cad00eeddc0f88af54ee202c8385fb214fe0423897c056a7df8369fb482e3695" +dependencies = [ + "rquickjs-sys", +] + +[[package]] +name = "rquickjs-sys" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120dbbc3296de9b96de8890091635d46f3506cd38b4e8f21800c386c035d64fa" +dependencies = [ + "cc", +] + [[package]] name = "rsa" version = "0.7.2" diff --git a/Cargo.toml b/Cargo.toml index d1221316f2bfb..f9d870424d8ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,7 @@ async-trait = { version = "0.1.77", package = "async-trait-fn" } bincode = { version = "2.0.0-rc.3", features = ["serde", "std", "alloc"] } borsh = { version = "1.2.1", features = ["derive"] } bytes = "1.5.0" +hashbrown = { version = "0.14.3", default-features = false } byteorder = "1.4.3" chrono = { version = "0.4.31", features = ["serde"] } chrono-tz = { version = "0.8", features = ["serde"] } diff --git a/src/meta/app/src/principal/mod.rs b/src/meta/app/src/principal/mod.rs index d0bc54cc8fb50..1e3344c2c99d8 100644 --- a/src/meta/app/src/principal/mod.rs +++ b/src/meta/app/src/principal/mod.rs @@ -48,6 +48,7 @@ pub use user_auth::PasswordHashMethod; pub use user_defined_file_format::UserDefinedFileFormat; pub use user_defined_function::LambdaUDF; pub use user_defined_function::UDFDefinition; +pub use user_defined_function::UDFScript; pub use user_defined_function::UDFServer; pub use user_defined_function::UdfName; pub use user_defined_function::UserDefinedFunction; diff --git a/src/meta/app/src/principal/user_defined_function.rs b/src/meta/app/src/principal/user_defined_function.rs index 8da8ecf2d95ae..3d613f6ec6b1f 100644 --- a/src/meta/app/src/principal/user_defined_function.rs +++ b/src/meta/app/src/principal/user_defined_function.rs @@ -54,10 +54,21 @@ pub struct UDFServer { pub return_type: DataType, } +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UDFScript { + pub code: String, + pub handler: String, + pub language: String, + pub arg_types: Vec, + pub return_type: DataType, + pub runtime_version: String, +} + #[derive(Clone, Debug, Eq, PartialEq)] pub enum UDFDefinition { LambdaUDF(LambdaUDF), UDFServer(UDFServer), + UDFScript(UDFScript), } #[derive(Clone, Debug, Eq, PartialEq)] @@ -108,6 +119,31 @@ impl UserDefinedFunction { created_on: Utc::now(), } } + + pub fn create_udf_script( + name: &str, + code: &str, + handler: &str, + language: &str, + arg_types: Vec, + return_type: DataType, + runtime_version: &str, + description: &str, + ) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + definition: UDFDefinition::UDFScript(UDFScript { + code: code.to_string(), + handler: handler.to_string(), + language: language.to_string(), + arg_types, + return_type, + runtime_version: runtime_version.to_string(), + }), + created_on: Utc::now(), + } + } } impl Display for UDFDefinition { @@ -144,6 +180,26 @@ impl Display for UDFDefinition { ") RETURNS {return_type} LANGUAGE {language} HANDLER = {handler} ADDRESS = {address}" )?; } + + UDFDefinition::UDFScript(UDFScript { + code, + arg_types, + return_type, + handler, + language, + runtime_version, + }) => { + for (i, item) in arg_types.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{item}")?; + } + write!( + f, + ") RETURNS {return_type} LANGUAGE {language} RUNTIME_VERSION = {runtime_version} HANDLER = {handler} AS $${code}$$" + )?; + } } Ok(()) } diff --git a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs index d0383da7b1a6e..8b40c1379832d 100644 --- a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs +++ b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs @@ -106,6 +106,64 @@ impl FromToProto for mt::UDFServer { } } +impl FromToProto for mt::UDFScript { + type PB = pb::UdfScript; + fn get_pb_ver(p: &Self::PB) -> u64 { + p.ver + } + fn from_pb(p: pb::UdfScript) -> Result { + reader_check_msg(p.ver, p.min_reader_ver)?; + + let mut arg_types = Vec::with_capacity(p.arg_types.len()); + for arg_type in p.arg_types { + let arg_type = DataType::from(&TableDataType::from_pb(arg_type)?); + arg_types.push(arg_type); + } + let return_type = DataType::from(&TableDataType::from_pb(p.return_type.ok_or_else( + || Incompatible { + reason: "UDFScript.return_type can not be None".to_string(), + }, + )?)?); + + Ok(mt::UDFScript { + code: p.code, + arg_types, + return_type, + handler: p.handler, + language: p.language, + runtime_version: p.runtime_version, + }) + } + + fn to_pb(&self) -> Result { + let mut arg_types = Vec::with_capacity(self.arg_types.len()); + for arg_type in self.arg_types.iter() { + let arg_type = infer_schema_type(arg_type) + .map_err(|e| Incompatible { + reason: format!("Convert DataType to TableDataType failed: {}", e.message()), + })? + .to_pb()?; + arg_types.push(arg_type); + } + let return_type = infer_schema_type(&self.return_type) + .map_err(|e| Incompatible { + reason: format!("Convert DataType to TableDataType failed: {}", e.message()), + })? + .to_pb()?; + + Ok(pb::UdfScript { + ver: VER, + min_reader_ver: MIN_READER_VER, + code: self.code.clone(), + handler: self.handler.clone(), + language: self.language.clone(), + arg_types, + return_type: Some(return_type), + runtime_version: self.runtime_version.clone(), + }) + } +} + impl FromToProto for mt::UserDefinedFunction { type PB = pb::UserDefinedFunction; fn get_pb_ver(p: &Self::PB) -> u64 { @@ -120,6 +178,9 @@ impl FromToProto for mt::UserDefinedFunction { Some(pb::user_defined_function::Definition::UdfServer(udf_server)) => { mt::UDFDefinition::UDFServer(mt::UDFServer::from_pb(udf_server)?) } + Some(pb::user_defined_function::Definition::UdfScript(udf_script)) => { + mt::UDFDefinition::UDFScript(mt::UDFScript::from_pb(udf_script)?) + } None => { return Err(Incompatible { reason: "UserDefinedFunction.definition cannot be None".to_string(), @@ -146,6 +207,9 @@ impl FromToProto for mt::UserDefinedFunction { mt::UDFDefinition::UDFServer(udf_server) => { pb::user_defined_function::Definition::UdfServer(udf_server.to_pb()?) } + mt::UDFDefinition::UDFScript(udf_script) => { + pb::user_defined_function::Definition::UdfScript(udf_script.to_pb()?) + } }; Ok(pb::UserDefinedFunction { diff --git a/src/meta/proto-conv/src/util.rs b/src/meta/proto-conv/src/util.rs index 3ecda3beaffe0..24fe5a21492d1 100644 --- a/src/meta/proto-conv/src/util.rs +++ b/src/meta/proto-conv/src/util.rs @@ -109,7 +109,8 @@ const META_CHANGE_LOG: &[(u64, &str)] = &[ (77, "2024-01-22: Remove: allow_anonymous in S3 Config", ), (78, "2024-01-29: Refactor: GrantEntry::UserPrivilegeType and ShareGrantEntry::ShareGrantObjectPrivilege use from_bits_truncate deserialize", ), (79, "2024-01-31: Add: udf.proto/UserDefinedFunction add created_on field", ), - (80, "2024-02-01: Add: Add: datatype.proto/DataType Geometry type") + (80, "2024-02-01: Add: Add: datatype.proto/DataType Geometry type"), + (81, "2024-03-94: Add: Add: udf.udf_script") // Dear developer: // If you're gonna add a new metadata version, you'll have to add a test for it. // You could just copy an existing test file(e.g., `../tests/it/v024_table_meta.rs`) diff --git a/src/meta/proto-conv/tests/it/main.rs b/src/meta/proto-conv/tests/it/main.rs index 05dfbb3b40f0e..cc608bae8223e 100644 --- a/src/meta/proto-conv/tests/it/main.rs +++ b/src/meta/proto-conv/tests/it/main.rs @@ -83,4 +83,4 @@ mod v076_role_ownership_info; mod v077_s3_remove_allow_anonymous; mod v078_grantentry; mod v079_udf_created_on; -mod v080_geometry_datatype; +mod v081_udf_script; diff --git a/src/meta/proto-conv/tests/it/v081_udf_script.rs b/src/meta/proto-conv/tests/it/v081_udf_script.rs new file mode 100644 index 0000000000000..b56cafc5d5cc2 --- /dev/null +++ b/src/meta/proto-conv/tests/it/v081_udf_script.rs @@ -0,0 +1,92 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::DateTime; +use chrono::Utc; +use databend_common_expression::types::DataType; +use databend_common_expression::types::NumberDataType; +use databend_common_meta_app::principal::LambdaUDF; +use databend_common_meta_app::principal::UDFDefinition; +use databend_common_meta_app::principal::UDFServer; +use databend_common_meta_app::principal::UserDefinedFunction; +use minitrace::func_name; + +use crate::common; + +// These bytes are built when a new version in introduced, +// and are kept for backward compatibility test. +// +// ************************************************************* +// * These messages should never be updated, * +// * only be added when a new version is added, * +// * or be removed when an old version is no longer supported. * +// ************************************************************* +// +// The message bytes are built from the output of `test_pb_from_to()` +#[test] +fn test_decode_v81_udf_python() -> anyhow::Result<()> { + let bytes = vec![ + 10, 8, 112, 108, 117, 115, 95, 105, 110, 116, 18, 21, 84, 104, 105, 115, 32, 105, 115, 32, + 97, 32, 100, 101, 115, 99, 114, 105, 112, 116, 105, 111, 110, 34, 107, 10, 21, 104, 116, + 116, 112, 58, 47, 47, 108, 111, 99, 97, 108, 104, 111, 115, 116, 58, 56, 56, 56, 56, 18, + 11, 112, 108, 117, 115, 95, 105, 110, 116, 95, 112, 121, 26, 6, 112, 121, 116, 104, 111, + 110, 34, 17, 154, 2, 8, 58, 0, 160, 6, 81, 168, 6, 24, 160, 6, 81, 168, 6, 24, 34, 17, 154, + 2, 8, 58, 0, 160, 6, 81, 168, 6, 24, 160, 6, 81, 168, 6, 24, 42, 17, 154, 2, 8, 66, 0, 160, + 6, 81, 168, 6, 24, 160, 6, 81, 168, 6, 24, 160, 6, 81, 168, 6, 24, 42, 23, 50, 48, 50, 51, + 45, 49, 50, 45, 49, 53, 32, 48, 49, 58, 50, 54, 58, 48, 57, 32, 85, 84, 67, 160, 6, 81, + 168, 6, 24, + ]; + + let want = || UserDefinedFunction { + name: "plus_int".to_string(), + description: "This is a description".to_string(), + definition: UDFDefinition::UDFServer(UDFServer { + address: "http://localhost:8888".to_string(), + handler: "plus_int_py".to_string(), + language: "python".to_string(), + arg_types: vec![ + DataType::Number(NumberDataType::Int32), + DataType::Number(NumberDataType::Int32), + ], + return_type: DataType::Number(NumberDataType::Int64), + }), + created_on: DateTime::::from_timestamp(1702603569, 0).unwrap(), + }; + + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old(func_name!(), bytes.as_slice(), 81, want()) +} + +#[test] +fn test_decode_v81_udf_sql() -> anyhow::Result<()> { + let bytes = vec![ + 10, 10, 105, 115, 110, 111, 116, 101, 109, 112, 116, 121, 18, 21, 84, 104, 105, 115, 32, + 105, 115, 32, 97, 32, 100, 101, 115, 99, 114, 105, 112, 116, 105, 111, 110, 26, 34, 10, 1, + 112, 18, 23, 40, 112, 41, 32, 45, 62, 32, 40, 78, 79, 84, 32, 105, 115, 95, 110, 117, 108, + 108, 40, 112, 41, 41, 160, 6, 81, 168, 6, 24, 42, 23, 49, 57, 55, 53, 45, 48, 53, 45, 50, + 53, 32, 49, 54, 58, 51, 57, 58, 52, 52, 32, 85, 84, 67, 160, 6, 81, 168, 6, 24, + ]; + let want = || UserDefinedFunction { + name: "isnotempty".to_string(), + description: "This is a description".to_string(), + definition: UDFDefinition::LambdaUDF(LambdaUDF { + parameters: vec!["p".to_string()], + definition: "(p) -> (NOT is_null(p))".to_string(), + }), + created_on: DateTime::::from_timestamp(170267984, 0).unwrap(), + }; + + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old(func_name!(), bytes.as_slice(), 81, want()) +} diff --git a/src/meta/protos/proto/udf.proto b/src/meta/protos/proto/udf.proto index 214d8c5b3daf0..3ed23b3c4fbbb 100644 --- a/src/meta/protos/proto/udf.proto +++ b/src/meta/protos/proto/udf.proto @@ -37,6 +37,19 @@ message UDFServer { DataType return_type = 5; } +message UDFScript { + uint64 ver = 100; + uint64 min_reader_ver = 101; + + string code = 1; + string handler = 2; + string language = 3; + repeated DataType arg_types = 4; + DataType return_type = 5; + string runtime_version = 6; +} + + message UserDefinedFunction { uint64 ver = 100; uint64 min_reader_ver = 101; @@ -46,6 +59,7 @@ message UserDefinedFunction { oneof definition { LambdaUDF lambda_udf = 3; UDFServer udf_server = 4; + UDFScript udf_script = 6; } // The time udf created. optional string created_on = 5; diff --git a/src/query/ast/src/ast/format/ast_format.rs b/src/query/ast/src/ast/format/ast_format.rs index 0bc54230946e5..6a6a13bc572bd 100644 --- a/src/query/ast/src/ast/format/ast_format.rs +++ b/src/query/ast/src/ast/format/ast_format.rs @@ -2189,6 +2189,47 @@ impl<'ast> Visitor<'ast> for AstFormatVisitor { AstFormatContext::new(format!("UdfServerAddress {address}")); children.push(FormatTreeNode::new(address_format_ctx)); } + UDFDefinition::UDFScript { + arg_types, + return_type, + code, + handler, + language, + runtime_version, + } => { + if !arg_types.is_empty() { + let mut arg_types_children = Vec::with_capacity(arg_types.len()); + for arg_type in arg_types.iter() { + let type_format_ctx = AstFormatContext::new(format!("DataType {arg_type}")); + arg_types_children.push(FormatTreeNode::new(type_format_ctx)); + } + let arg_format_ctx = AstFormatContext::with_children( + "UdfArgTypes".to_string(), + arg_types_children.len(), + ); + children.push(FormatTreeNode::with_children( + arg_format_ctx, + arg_types_children, + )); + } + + let return_type_format_ctx = + AstFormatContext::new(format!("UdfReturnType {return_type}")); + children.push(FormatTreeNode::new(return_type_format_ctx)); + + let handler_format_ctx = AstFormatContext::new(format!("UdfHandler {handler}")); + children.push(FormatTreeNode::new(handler_format_ctx)); + + let language_format_ctx = AstFormatContext::new(format!("UdfLanguage {language}")); + children.push(FormatTreeNode::new(language_format_ctx)); + + let code_format_ctx = AstFormatContext::new(format!("UdfCode {code}")); + children.push(FormatTreeNode::new(code_format_ctx)); + + let runtime_format_ctx: AstFormatContext = + AstFormatContext::new(format!("RuntimeVersion {runtime_version}")); + children.push(FormatTreeNode::new(runtime_format_ctx)); + } } if let Some(description) = &stmt.description { @@ -2285,6 +2326,47 @@ impl<'ast> Visitor<'ast> for AstFormatVisitor { AstFormatContext::new(format!("UdfServerAddress {address}")); children.push(FormatTreeNode::new(address_format_ctx)); } + + UDFDefinition::UDFScript { + arg_types, + return_type, + code, + handler, + language, + runtime_version, + } => { + if !arg_types.is_empty() { + let mut arg_types_children = Vec::with_capacity(arg_types.len()); + for arg_type in arg_types.iter() { + let type_format_ctx = AstFormatContext::new(format!("DataType {arg_type}")); + arg_types_children.push(FormatTreeNode::new(type_format_ctx)); + } + let arg_format_ctx = AstFormatContext::with_children( + "UdfArgTypes".to_string(), + arg_types_children.len(), + ); + children.push(FormatTreeNode::with_children( + arg_format_ctx, + arg_types_children, + )); + } + + let return_type_format_ctx = + AstFormatContext::new(format!("UdfReturnType {return_type}")); + children.push(FormatTreeNode::new(return_type_format_ctx)); + + let handler_format_ctx = AstFormatContext::new(format!("UdfHandler {handler}")); + children.push(FormatTreeNode::new(handler_format_ctx)); + + let language_format_ctx = AstFormatContext::new(format!("UdfLanguage {language}")); + children.push(FormatTreeNode::new(language_format_ctx)); + + let code_format_ctx = AstFormatContext::new(format!("UdfCode {code}")); + children.push(FormatTreeNode::new(code_format_ctx)); + + let c = AstFormatContext::new(format!("RuntimeVersion {runtime_version}")); + children.push(FormatTreeNode::new(c)); + } } if let Some(description) = &stmt.description { diff --git a/src/query/ast/src/ast/statements/udf.rs b/src/query/ast/src/ast/statements/udf.rs index 5ada9ced2bf70..f249b7184763d 100644 --- a/src/query/ast/src/ast/statements/udf.rs +++ b/src/query/ast/src/ast/statements/udf.rs @@ -35,6 +35,15 @@ pub enum UDFDefinition { handler: String, language: String, }, + + UDFScript { + arg_types: Vec, + return_type: TypeName, + code: String, + handler: String, + language: String, + runtime_version: String, + }, } #[derive(Debug, Clone, PartialEq)] @@ -77,6 +86,21 @@ impl Display for UDFDefinition { ") RETURNS {return_type} LANGUAGE {language} HANDLER = {handler} ADDRESS = {address}" )?; } + UDFDefinition::UDFScript { + arg_types, + return_type, + code, + handler, + language, + runtime_version, + } => { + write!(f, "(")?; + write_comma_separated_list(f, arg_types)?; + write!( + f, + ") RETURNS {return_type} LANGUAGE {language} runtime_version = {runtime_version} HANDLER = {handler} AS $${code}$$" + )?; + } } Ok(()) } diff --git a/src/query/ast/src/parser/expr.rs b/src/query/ast/src/parser/expr.rs index 2a54cdc647d9f..bcc3878cd538b 100644 --- a/src/query/ast/src/parser/expr.rs +++ b/src/query/ast/src/parser/expr.rs @@ -1292,11 +1292,13 @@ pub fn json_op(i: Input) -> IResult { pub fn literal(i: Input) -> IResult { let string = map(literal_string, Literal::String); + let code_string = map(code_string, Literal::String); let boolean = map(literal_bool, Literal::Boolean); let null = value(Literal::Null, rule! { NULL }); rule!( #string + | #code_string | #boolean | #literal_number | #null @@ -1417,6 +1419,13 @@ pub fn at_string(i: Input) -> IResult { })(i) } +pub fn code_string(i: Input) -> IResult { + map_res(rule! { CodeString }, |token| { + let path = &token.text()[2..token.text().len() - 2]; + Ok(path.to_string()) + })(i) +} + pub fn nullable(i: Input) -> IResult { alt(( value(true, rule! { NULL }), diff --git a/src/query/ast/src/parser/statement.rs b/src/query/ast/src/parser/statement.rs index 469eb98d16b1f..e2e970f6e9b9b 100644 --- a/src/query/ast/src/parser/statement.rs +++ b/src/query/ast/src/parser/statement.rs @@ -3559,9 +3559,32 @@ pub fn udf_definition(i: Input) -> IResult { }, ); + let udf_script = map( + rule! { + "(" ~ #comma_separated_list0(udf_arg_type) ~ ")" + ~ RETURNS ~ #udf_arg_type + ~ LANGUAGE ~ #ident + ~ HANDLER ~ ^"=" ~ ^#literal_string + ~ AS ~ ^#code_string + }, + |(_, arg_types, _, _, return_type, _, language, _, _, handler, _, code)| { + UDFDefinition::UDFScript { + arg_types, + return_type, + code, + handler, + language: language.to_string(), + // TODO inject runtime_version by user + // Now we use fixed runtime version + runtime_version: "".to_string(), + } + }, + ); + rule!( #udf_server: "(, ...) RETURNS LANGUAGE HANDLER= ADDRESS=" | #lambda_udf: "AS (, ...) -> " + | #udf_script: "(, ...) RETURNS LANGUAGE HANDLER= AS " )(i) } diff --git a/src/query/ast/src/parser/token.rs b/src/query/ast/src/parser/token.rs index 233d3b6e01e42..49c02459be19c 100644 --- a/src/query/ast/src/parser/token.rs +++ b/src/query/ast/src/parser/token.rs @@ -151,6 +151,9 @@ pub enum TokenKind { #[regex(r#"'([^'\\]|\\.|'')*'"#)] QuotedString, + #[regex(r#"\$\$([^\$]|(\$[^\$]))*\$\$"#)] + CodeString, + #[regex(r#"@([^\s`;'"()]|\\\s|\\'|\\"|\\\\)+"#)] AtString, @@ -1185,7 +1188,12 @@ impl TokenKind { pub fn is_literal(&self) -> bool { matches!( self, - LiteralInteger | LiteralFloat | QuotedString | PGLiteralHex | MySQLLiteralHex + LiteralInteger + | LiteralFloat + | QuotedString + | CodeString + | PGLiteralHex + | MySQLLiteralHex ) } @@ -1194,6 +1202,7 @@ impl TokenKind { self, Ident | QuotedString + | CodeString | PGLiteralHex | MySQLLiteralHex | LiteralInteger diff --git a/src/query/ast/tests/it/parser.rs b/src/query/ast/tests/it/parser.rs index 6d2e4eeee490e..016cd2f6a6f94 100644 --- a/src/query/ast/tests/it/parser.rs +++ b/src/query/ast/tests/it/parser.rs @@ -573,6 +573,15 @@ fn test_statement() { "CREATE OR REPLACE FUNCTION isnotempty_test_replace AS(p) -> not(is_null(p)) DESC = 'This is a description';", "CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';", "CREATE OR REPLACE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';", + r#"create or replace function addone(int) +returns int +language python +handler = 'addone_py' +as +$$ +def addone_py(i): + return i+1 +$$;"#, "DROP FUNCTION binary_reverse;", "DROP FUNCTION isnotempty;", ]; @@ -779,6 +788,7 @@ fn test_expr() { r#"char(0xD0, 0xBF, 0xD1)"#, r#"[42, 3.5, 4., .001, 5e2, 1.925e-3, .38e+7, 1.e-01, 0xfff, x'deedbeef']"#, r#"123456789012345678901234567890"#, + r#"$$ab123c$$"#, r#"x'123456789012345678901234567890'"#, r#"1e100000000000000"#, r#"100_100_000"#, diff --git a/src/query/ast/tests/it/testdata/expr-error.txt b/src/query/ast/tests/it/testdata/expr-error.txt index 3b895619304e8..8c7e3424a1d15 100644 --- a/src/query/ast/tests/it/testdata/expr-error.txt +++ b/src/query/ast/tests/it/testdata/expr-error.txt @@ -53,7 +53,7 @@ error: --> SQL:1:10 | 1 | CAST(col1) - | ---- ^ unexpected `)`, expecting `AS`, `,`, `(`, `IS`, `NOT`, `IN`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `LIKE`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `DATE_ADD`, `DATE_SUB`, `DATE_TRUNC`, `DATE`, or 29 more ... + | ---- ^ unexpected `)`, expecting `AS`, `,`, `(`, `IS`, `NOT`, `IN`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `LIKE`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `DATE_ADD`, `DATE_SUB`, `DATE_TRUNC`, `DATE`, or 30 more ... | | | while parsing `CAST(... AS ...)` | while parsing expression diff --git a/src/query/ast/tests/it/testdata/expr.txt b/src/query/ast/tests/it/testdata/expr.txt index 006c8f6214015..7cb65f90ad668 100644 --- a/src/query/ast/tests/it/testdata/expr.txt +++ b/src/query/ast/tests/it/testdata/expr.txt @@ -229,6 +229,21 @@ Literal { } +---------- Input ---------- +$$ab123c$$ +---------- Output --------- +'ab123c' +---------- AST ------------ +Literal { + span: Some( + 0..10, + ), + lit: String( + "ab123c", + ), +} + + ---------- Input ---------- x'123456789012345678901234567890' ---------- Output --------- diff --git a/src/query/ast/tests/it/testdata/lexer.txt b/src/query/ast/tests/it/testdata/lexer.txt index 2dc23930406c4..75ba7259149ee 100644 --- a/src/query/ast/tests/it/testdata/lexer.txt +++ b/src/query/ast/tests/it/testdata/lexer.txt @@ -4,6 +4,12 @@ [(EOI, "", 0..0)] +---------- Input ---------- +$$ab$cd$$ $$ab$$ +---------- Output --------- +[(CodeString, "$$ab$cd$$", 0..9), (CodeString, "$$ab$$", 11..17), (EOI, "", 17..17)] + + ---------- Input ---------- x'deadbeef' -- a hex string\n 'a string literal\n escape quote by '' or \\\'. ' ---------- Output --------- diff --git a/src/query/ast/tests/it/testdata/statement-error.txt b/src/query/ast/tests/it/testdata/statement-error.txt index 61d2e28e9e556..07a40d07f2612 100644 --- a/src/query/ast/tests/it/testdata/statement-error.txt +++ b/src/query/ast/tests/it/testdata/statement-error.txt @@ -414,7 +414,7 @@ error: --> SQL:1:41 | 1 | SELECT * FROM t GROUP BY GROUPING SETS () - | ------ ^ unexpected `)`, expecting `(`, `IS`, `IN`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `LIKE`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `DATE_ADD`, `DATE_SUB`, `DATE_TRUNC`, `DATE`, `TIMESTAMP`, `INTERVAL`, or 27 more ... + | ------ ^ unexpected `)`, expecting `(`, `IS`, `IN`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `LIKE`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `DATE_ADD`, `DATE_SUB`, `DATE_TRUNC`, `DATE`, `TIMESTAMP`, `INTERVAL`, or 28 more ... | | | while parsing `SELECT ...` @@ -830,7 +830,7 @@ error: --> SQL:1:65 | 1 | CREATE FUNCTION IF NOT EXISTS isnotempty AS(p) -> not(is_null(p) - | ------ -- ---- ^ unexpected end of line, expecting `)`, `OVER`, `(`, `IS`, `NOT`, `IN`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `LIKE`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `DATE_ADD`, `DATE_SUB`, `DATE_TRUNC`, `DATE`, or 30 more ... + | ------ -- ---- ^ unexpected end of line, expecting `)`, `OVER`, `(`, `IS`, `NOT`, `IN`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `LIKE`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `DATE_ADD`, `DATE_SUB`, `DATE_TRUNC`, `DATE`, or 31 more ... | | | | | | | | | while parsing `( [, ...])` | | | while parsing expression diff --git a/src/query/ast/tests/it/testdata/statement.txt b/src/query/ast/tests/it/testdata/statement.txt index 965a97a120118..bce94c981b376 100644 --- a/src/query/ast/tests/it/testdata/statement.txt +++ b/src/query/ast/tests/it/testdata/statement.txt @@ -16598,6 +16598,51 @@ CreateUDF( ) +---------- Input ---------- +create or replace function addone(int) +returns int +language python +handler = 'addone_py' +as +$$ +def addone_py(i): + return i+1 +$$; +---------- Output --------- +CREATE OR REPLACE FUNCTION addone (Int32 NULL) RETURNS Int32 NULL LANGUAGE python runtime_version = HANDLER = addone_py AS $$ +def addone_py(i): + return i+1 +$$ +---------- AST ------------ +CreateUDF( + CreateUDFStmt { + create_option: CreateOrReplace, + udf_name: Identifier { + span: Some( + 27..33, + ), + name: "addone", + quote: None, + }, + description: None, + definition: UDFScript { + arg_types: [ + Nullable( + Int32, + ), + ], + return_type: Nullable( + Int32, + ), + code: "\ndef addone_py(i):\n return i+1\n", + handler: "addone_py", + language: "python", + runtime_version: "", + }, + }, +) + + ---------- Input ---------- DROP FUNCTION binary_reverse; ---------- Output --------- diff --git a/src/query/ast/tests/it/token.rs b/src/query/ast/tests/it/token.rs index e7c5cfb830a53..b2c65d6cb9867 100644 --- a/src/query/ast/tests/it/token.rs +++ b/src/query/ast/tests/it/token.rs @@ -51,6 +51,7 @@ fn test_lexer() { let cases = vec![ r#""#, + r#"$$ab$cd$$ $$ab$$"#, r#"x'deadbeef' -- a hex string\n 'a string literal\n escape quote by '' or \\\'. '"#, r#"'中文' '日本語'"#, r#"@abc 123"#, diff --git a/src/query/management/tests/it/udf.rs b/src/query/management/tests/it/udf.rs index 6ce2fa706b25a..80943fb1f4cb7 100644 --- a/src/query/management/tests/it/udf.rs +++ b/src/query/management/tests/it/udf.rs @@ -77,6 +77,29 @@ async fn test_add_udf() -> Result<()> { catch => panic!("GetKVActionReply{:?}", catch), } + // udf script + let udf = create_test_udf_script(); + + udf_api.add_udf(udf.clone(), &CreateOption::None).await??; + + let value = kv_api + .get_kv(format!("__fd_udfs/admin/{}", udf.name).as_str()) + .await?; + + match value { + Some(SeqV { + seq: 3, + meta: _, + data: value, + }) => { + assert_eq!( + value, + serialize_struct(&udf, ErrorCode::IllegalUDFFormat, || "")? + ); + } + catch => panic!("GetKVActionReply{:?}", catch), + } + Ok(()) } @@ -186,6 +209,19 @@ fn create_test_udf_server() -> UserDefinedFunction { ) } +fn create_test_udf_script() -> UserDefinedFunction { + UserDefinedFunction::create_udf_script( + "strlen2", + "testcode", + "strlen_py", + "javascript", + vec![DataType::String], + DataType::Number(NumberDataType::Int64), + "3.12.0", + "This is a description", + ) +} + async fn new_udf_api() -> Result<(Arc, UdfMgr)> { let test_api = Arc::new(MetaEmbedded::new_temp().await?); let mgr = UdfMgr::create(test_api.clone(), NonEmptyStr::new("admin").unwrap()); diff --git a/src/query/service/Cargo.toml b/src/query/service/Cargo.toml index 63947db229686..b6e4a636de48c 100644 --- a/src/query/service/Cargo.toml +++ b/src/query/service/Cargo.toml @@ -98,6 +98,10 @@ jsonb = { workspace = true } # GitHub dependencies # Crates.io dependencies +## TODO add python support +# arrow-udf-python = {package = "arrow-udf-python", git = "https://github.com/risingwavelabs/arrow-udf", rev = "6c32f71" } +arrow-udf-js = { package = "arrow-udf-js", git = "https://github.com/risingwavelabs/arrow-udf", rev = "6c32f71" } + arrow-array = { workspace = true } arrow-flight = { workspace = true } arrow-ipc = { workspace = true } diff --git a/src/query/service/src/interpreters/interpreter_delete.rs b/src/query/service/src/interpreters/interpreter_delete.rs index ec605f0d61b08..a7c2970e3f0bf 100644 --- a/src/query/service/src/interpreters/interpreter_delete.rs +++ b/src/query/service/src/interpreters/interpreter_delete.rs @@ -414,7 +414,7 @@ fn do_replace_subquery( } } } - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { for arg in &mut udf.arguments { if !do_replace_subquery(filters, arg)? { replace_selection_with_filter = Some(filters.pop_back().unwrap()); @@ -422,6 +422,7 @@ fn do_replace_subquery( } } } + ScalarExpr::SubqueryExpr { .. } => { if data_type == DataType::Nullable(Box::new(DataType::Boolean)) { let filter = filters.pop_back().unwrap(); diff --git a/src/query/service/src/interpreters/interpreter_factory.rs b/src/query/service/src/interpreters/interpreter_factory.rs index 6aedd7187b673..89ecb555924ea 100644 --- a/src/query/service/src/interpreters/interpreter_factory.rs +++ b/src/query/service/src/interpreters/interpreter_factory.rs @@ -378,15 +378,15 @@ impl InterpreterFactory { ctx, *revoke_role.clone(), )?)), - Plan::CreateUDF(create_user_udf) => Ok(Arc::new(CreateUserUDFInterpreter::try_create( + Plan::CreateUDF(create_user_udf) => Ok(Arc::new(CreateUserUDFScript::try_create( ctx, *create_user_udf.clone(), )?)), - Plan::AlterUDF(alter_udf) => Ok(Arc::new(AlterUserUDFInterpreter::try_create( + Plan::AlterUDF(alter_udf) => Ok(Arc::new(AlterUserUDFScript::try_create( ctx, *alter_udf.clone(), )?)), - Plan::DropUDF(drop_udf) => Ok(Arc::new(DropUserUDFInterpreter::try_create( + Plan::DropUDF(drop_udf) => Ok(Arc::new(DropUserUDFScript::try_create( ctx, *drop_udf.clone(), )?)), diff --git a/src/query/service/src/interpreters/interpreter_user_udf_alter.rs b/src/query/service/src/interpreters/interpreter_user_udf_alter.rs index 7eae63fe6fe90..614923dafbdcd 100644 --- a/src/query/service/src/interpreters/interpreter_user_udf_alter.rs +++ b/src/query/service/src/interpreters/interpreter_user_udf_alter.rs @@ -25,21 +25,21 @@ use crate::sessions::QueryContext; use crate::sessions::TableContext; #[derive(Debug)] -pub struct AlterUserUDFInterpreter { +pub struct AlterUserUDFScript { ctx: Arc, plan: AlterUDFPlan, } -impl AlterUserUDFInterpreter { +impl AlterUserUDFScript { pub fn try_create(ctx: Arc, plan: AlterUDFPlan) -> Result { - Ok(AlterUserUDFInterpreter { ctx, plan }) + Ok(AlterUserUDFScript { ctx, plan }) } } #[async_trait::async_trait] -impl Interpreter for AlterUserUDFInterpreter { +impl Interpreter for AlterUserUDFScript { fn name(&self) -> &str { - "AlterUserUDFInterpreter" + "AlterUserUDFScript" } fn is_ddl(&self) -> bool { diff --git a/src/query/service/src/interpreters/interpreter_user_udf_create.rs b/src/query/service/src/interpreters/interpreter_user_udf_create.rs index ce93d4f843741..6164ffea18cef 100644 --- a/src/query/service/src/interpreters/interpreter_user_udf_create.rs +++ b/src/query/service/src/interpreters/interpreter_user_udf_create.rs @@ -28,21 +28,21 @@ use crate::sessions::QueryContext; use crate::sessions::TableContext; #[derive(Debug)] -pub struct CreateUserUDFInterpreter { +pub struct CreateUserUDFScript { ctx: Arc, plan: CreateUDFPlan, } -impl CreateUserUDFInterpreter { +impl CreateUserUDFScript { pub fn try_create(ctx: Arc, plan: CreateUDFPlan) -> Result { - Ok(CreateUserUDFInterpreter { ctx, plan }) + Ok(CreateUserUDFScript { ctx, plan }) } } #[async_trait::async_trait] -impl Interpreter for CreateUserUDFInterpreter { +impl Interpreter for CreateUserUDFScript { fn name(&self) -> &str { - "CreateUserUDFInterpreter" + "CreateUserUDFScript" } fn is_ddl(&self) -> bool { diff --git a/src/query/service/src/interpreters/interpreter_user_udf_drop.rs b/src/query/service/src/interpreters/interpreter_user_udf_drop.rs index ad3ea544b7dd9..46899d649f78b 100644 --- a/src/query/service/src/interpreters/interpreter_user_udf_drop.rs +++ b/src/query/service/src/interpreters/interpreter_user_udf_drop.rs @@ -28,21 +28,21 @@ use crate::sessions::QueryContext; use crate::sessions::TableContext; #[derive(Debug)] -pub struct DropUserUDFInterpreter { +pub struct DropUserUDFScript { ctx: Arc, plan: DropUDFPlan, } -impl DropUserUDFInterpreter { +impl DropUserUDFScript { pub fn try_create(ctx: Arc, plan: DropUDFPlan) -> Result { - Ok(DropUserUDFInterpreter { ctx, plan }) + Ok(DropUserUDFScript { ctx, plan }) } } #[async_trait::async_trait] -impl Interpreter for DropUserUDFInterpreter { +impl Interpreter for DropUserUDFScript { fn name(&self) -> &str { - "DropUserUDFInterpreter" + "DropUserUDFScript" } fn is_ddl(&self) -> bool { diff --git a/src/query/service/src/interpreters/mod.rs b/src/query/service/src/interpreters/mod.rs index d14be6912adc2..9eaa6e5f2c4d8 100644 --- a/src/query/service/src/interpreters/mod.rs +++ b/src/query/service/src/interpreters/mod.rs @@ -217,9 +217,9 @@ pub use interpreter_user_drop::DropUserInterpreter; pub use interpreter_user_stage_create::CreateUserStageInterpreter; pub use interpreter_user_stage_drop::DropUserStageInterpreter; pub use interpreter_user_stage_remove::RemoveUserStageInterpreter; -pub use interpreter_user_udf_alter::AlterUserUDFInterpreter; -pub use interpreter_user_udf_create::CreateUserUDFInterpreter; -pub use interpreter_user_udf_drop::DropUserUDFInterpreter; +pub use interpreter_user_udf_alter::AlterUserUDFScript; +pub use interpreter_user_udf_create::CreateUserUDFScript; +pub use interpreter_user_udf_drop::DropUserUDFScript; pub use interpreter_vacuum_drop_tables::VacuumDropTablesInterpreter; pub use interpreter_vacuum_temporary_files::VacuumTemporaryFilesInterpreter; pub use interpreter_view_alter::AlterViewInterpreter; diff --git a/src/query/service/src/pipelines/builders/builder_udf.rs b/src/query/service/src/pipelines/builders/builder_udf.rs index 9511ca35b6877..9f30e714ac1be 100644 --- a/src/query/service/src/pipelines/builders/builder_udf.rs +++ b/src/query/service/src/pipelines/builders/builder_udf.rs @@ -16,20 +16,32 @@ use databend_common_exception::Result; use databend_common_pipeline_core::processors::ProcessorPtr; use databend_common_sql::executor::physical_plans::Udf; -use crate::pipelines::processors::transforms::TransformUdf; +use crate::pipelines::processors::transforms::TransformUdfScript; +use crate::pipelines::processors::transforms::TransformUdfServer; use crate::pipelines::PipelineBuilder; impl PipelineBuilder { pub(crate) fn build_udf(&mut self, udf: &Udf) -> Result<()> { self.build_pipeline(&udf.input)?; - self.main_pipeline.add_transform(|input, output| { - Ok(ProcessorPtr::create(TransformUdf::try_create( - self.func_ctx.clone(), - udf.udf_funcs.clone(), - input, - output, - )?)) - }) + if udf.script_udf { + self.main_pipeline.add_transform(|input, output| { + Ok(ProcessorPtr::create(TransformUdfScript::try_create( + self.func_ctx.clone(), + udf.udf_funcs.clone(), + input, + output, + )?)) + }) + } else { + self.main_pipeline.add_transform(|input, output| { + Ok(ProcessorPtr::create(TransformUdfServer::try_create( + self.func_ctx.clone(), + udf.udf_funcs.clone(), + input, + output, + )?)) + }) + } } } diff --git a/src/query/service/src/pipelines/processors/transforms/mod.rs b/src/query/service/src/pipelines/processors/transforms/mod.rs index 8c22f7060a7fd..a4ffd249560c2 100644 --- a/src/query/service/src/pipelines/processors/transforms/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/mod.rs @@ -34,7 +34,8 @@ mod transform_resort_addon_without_source_schema; mod transform_runtime_cast_schema; mod transform_sort_spill; mod transform_srf; -mod transform_udf; +mod transform_udf_script; +mod transform_udf_server; mod window; pub use hash_join::*; @@ -58,7 +59,8 @@ pub use transform_resort_addon_without_source_schema::TransformResortAddOnWithou pub use transform_runtime_cast_schema::TransformRuntimeCastSchema; pub use transform_sort_spill::create_transform_sort_spill; pub use transform_srf::TransformSRF; -pub use transform_udf::TransformUdf; +pub use transform_udf_script::TransformUdfScript; +pub use transform_udf_server::TransformUdfServer; pub use window::FrameBound; pub use window::TransformWindow; pub use window::WindowFunctionInfo; diff --git a/src/query/service/src/pipelines/processors/transforms/transform_udf_script.rs b/src/query/service/src/pipelines/processors/transforms/transform_udf_script.rs new file mode 100644 index 0000000000000..cd855597444a1 --- /dev/null +++ b/src/query/service/src/pipelines/processors/transforms/transform_udf_script.rs @@ -0,0 +1,139 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow_schema::Schema; +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use databend_common_expression::variant_transform::contains_variant; +use databend_common_expression::variant_transform::transform_variant; +use databend_common_expression::BlockEntry; +use databend_common_expression::DataBlock; +use databend_common_expression::DataField; +use databend_common_expression::DataSchema; +use databend_common_expression::FunctionContext; +use databend_common_pipeline_transforms::processors::Transform; +use databend_common_pipeline_transforms::processors::Transformer; +use databend_common_sql::executor::physical_plans::UdfFunctionDesc; + +use crate::pipelines::processors::InputPort; +use crate::pipelines::processors::OutputPort; +use crate::pipelines::processors::Processor; +pub struct TransformUdfScript { + funcs: Vec, + js_runtime: Arc, + // TODO: + // py_runtime: Arc, +} + +unsafe impl Send for TransformUdfScript {} + +impl TransformUdfScript { + pub fn try_create( + _func_ctx: FunctionContext, + funcs: Vec, + input: Arc, + output: Arc, + ) -> Result> { + let mut js_runtime = arrow_udf_js::Runtime::new() + .map_err(|err| ErrorCode::UDFDataError(format!("Cannot create js runtime: {err}")))?; + + for func in funcs.iter() { + let tmp_schema = + DataSchema::new(vec![DataField::new("tmp", func.data_type.as_ref().clone())]); + let arrow_schema = Schema::from(&tmp_schema); + + let (_, _, code) = func.udf_type.as_script().unwrap(); + js_runtime + .add_function_with_handler( + &func.name, + arrow_schema.field(0).data_type().clone(), + arrow_udf_js::CallMode::ReturnNullOnNullInput, + code, + &func.func_name, + ) + .map_err(|err| ErrorCode::UDFDataError(format!("Cannot add js function: {err}")))?; + } + + Ok(Transformer::create(input, output, Self { + funcs, + js_runtime: Arc::new(js_runtime), + })) + } +} + +impl Transform for TransformUdfScript { + const NAME: &'static str = "UDFScriptTransform"; + + fn transform(&mut self, mut data_block: DataBlock) -> Result { + let num_rows = data_block.num_rows(); + for func in &self.funcs { + // construct input record_batch + let block_entries = func + .arg_indices + .iter() + .map(|i| { + let arg = data_block.get_by_offset(*i).clone(); + if contains_variant(&arg.data_type) { + let new_arg = BlockEntry::new( + arg.data_type.clone(), + transform_variant(&arg.value, true)?, + ); + Ok(new_arg) + } else { + Ok(arg) + } + }) + .collect::>>()?; + + let fields = block_entries + .iter() + .enumerate() + .map(|(idx, arg)| DataField::new(&format!("arg{}", idx + 1), arg.data_type.clone())) + .collect::>(); + let data_schema = DataSchema::new(fields); + + let input_batch = DataBlock::new(block_entries, num_rows) + .to_record_batch_with_dataschema(&data_schema) + .map_err(|err| ErrorCode::from_string(format!("{err}")))?; + + let result_batch = self + .js_runtime + .call(&func.name, &input_batch) + .map_err(|err| ErrorCode::from_string(format!("{err}")))?; + + let schema = DataSchema::try_from(&(*result_batch.schema()))?; + let (result_block, _result_schema) = + DataBlock::from_record_batch(&schema, &result_batch).map_err(|err| { + ErrorCode::UDFDataError(format!( + "Cannot convert arrow record batch to data block: {err}" + )) + })?; + + let col = if contains_variant(&func.data_type) { + let value = transform_variant(&result_block.get_by_offset(0).value, false)?; + BlockEntry { + data_type: func.data_type.as_ref().clone(), + value, + } + } else { + result_block.get_by_offset(0).clone() + }; + + data_block.add_column(col); + } + Ok(data_block) + } +} diff --git a/src/query/service/src/pipelines/processors/transforms/transform_udf.rs b/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs similarity index 95% rename from src/query/service/src/pipelines/processors/transforms/transform_udf.rs rename to src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs index a09080013fe7d..b314ed14b678b 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_udf.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs @@ -32,12 +32,12 @@ use crate::pipelines::processors::InputPort; use crate::pipelines::processors::OutputPort; use crate::pipelines::processors::Processor; -pub struct TransformUdf { +pub struct TransformUdfServer { func_ctx: FunctionContext, funcs: Vec, } -impl TransformUdf { +impl TransformUdfServer { pub fn try_create( func_ctx: FunctionContext, funcs: Vec, @@ -52,7 +52,7 @@ impl TransformUdf { } #[async_trait::async_trait] -impl AsyncTransform for TransformUdf { +impl AsyncTransform for TransformUdfServer { const NAME: &'static str = "UdfTransform"; #[async_backtrace::framed] @@ -60,6 +60,7 @@ impl AsyncTransform for TransformUdf { let connect_timeout = self.func_ctx.external_server_connect_timeout_secs; let request_timeout = self.func_ctx.external_server_request_timeout_secs; for func in &self.funcs { + let server_addr = func.udf_type.as_server().unwrap(); // construct input record_batch let num_rows = data_block.num_rows(); let block_entries = func @@ -91,8 +92,7 @@ impl AsyncTransform for TransformUdf { .map_err(|err| ErrorCode::from_string(format!("{err}")))?; let mut client = - UDFFlightClient::connect(&func.server_addr, connect_timeout, request_timeout) - .await?; + UDFFlightClient::connect(server_addr, connect_timeout, request_timeout).await?; let result_batch = client.do_exchange(&func.func_name, input_batch).await?; let schema = DataSchema::try_from(&(*result_batch.schema()))?; diff --git a/src/query/sql/src/executor/physical_plan_visitor.rs b/src/query/sql/src/executor/physical_plan_visitor.rs index 6fe636fcbd854..e1a258e50fa61 100644 --- a/src/query/sql/src/executor/physical_plan_visitor.rs +++ b/src/query/sql/src/executor/physical_plan_visitor.rs @@ -495,6 +495,7 @@ pub trait PhysicalPlanReplacer { input: Box::new(input), udf_funcs: plan.udf_funcs.clone(), stat_info: plan.stat_info.clone(), + script_udf: plan.script_udf, })) } } diff --git a/src/query/sql/src/executor/physical_plans/physical_udf.rs b/src/query/sql/src/executor/physical_plans/physical_udf.rs index cb251f9365632..8513166094298 100644 --- a/src/query/sql/src/executor/physical_plans/physical_udf.rs +++ b/src/query/sql/src/executor/physical_plans/physical_udf.rs @@ -26,6 +26,7 @@ use crate::executor::explain::PlanStatsInfo; use crate::executor::PhysicalPlan; use crate::executor::PhysicalPlanBuilder; use crate::optimizer::SExpr; +use crate::plans::UDFType; use crate::ColumnSet; use crate::IndexType; use crate::ScalarExpr; @@ -36,7 +37,7 @@ pub struct Udf { pub plan_id: u32, pub input: Box, pub udf_funcs: Vec, - + pub script_udf: bool, // Only used for explain pub stat_info: Option, } @@ -56,12 +57,14 @@ impl Udf { #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub struct UdfFunctionDesc { + pub name: String, pub func_name: String, - pub server_addr: String, pub output_column: IndexType, pub arg_indices: Vec, pub arg_exprs: Vec, pub data_type: Box, + + pub udf_type: UDFType, } impl PhysicalPlanBuilder { @@ -86,7 +89,10 @@ impl PhysicalPlanBuilder { if used.is_empty() { return self.build(s_expr.child(0)?, required).await; } - let udf = crate::plans::Udf { items: used }; + let udf = crate::plans::Udf { + items: used, + script_udf: udf.script_udf, + }; let input = self.build(s_expr.child(0)?, required).await?; let input_schema = input.output_schema()?; let mut index = input_schema.num_fields(); @@ -95,9 +101,10 @@ impl PhysicalPlanBuilder { .items .iter() .map(|item| { - if let ScalarExpr::UDFServerCall(func) = &item.scalar { - let arg_indices = func - .arguments + if let ScalarExpr::UDFCall(func) = &item.scalar { + let (arguments, display_name) = (&func.arguments, &func.display_name); + + let arg_indices = arguments .iter() .map(|arg| match arg { ScalarExpr::BoundColumnRef(col) => { @@ -119,11 +126,10 @@ impl PhysicalPlanBuilder { }) .collect::>>()?; - udf_index_map.insert(func.display_name.clone(), index); + udf_index_map.insert(display_name.clone(), index); index += 1; - let arg_exprs = func - .arguments + let arg_exprs = arguments .iter() .map(|arg| { let expr = arg.as_expr()?; @@ -133,12 +139,13 @@ impl PhysicalPlanBuilder { .collect::>>()?; let udf_func = UdfFunctionDesc { + name: func.name.clone(), func_name: func.func_name.clone(), - server_addr: func.server_addr.clone(), output_column: item.index, arg_indices, arg_exprs, data_type: func.return_type.clone(), + udf_type: func.udf_type.clone(), }; Ok(udf_func) } else { @@ -151,6 +158,7 @@ impl PhysicalPlanBuilder { plan_id: 0, input: Box::new(input), udf_funcs, + script_udf: udf.script_udf, stat_info: Some(stat_info), })) } diff --git a/src/query/sql/src/planner/binder/binder.rs b/src/query/sql/src/planner/binder/binder.rs index b875f66ab9d62..c503e554dcb2f 100644 --- a/src/query/sql/src/planner/binder/binder.rs +++ b/src/query/sql/src/planner/binder/binder.rs @@ -683,7 +683,7 @@ impl<'a> Binder { scalar, ScalarExpr::WindowFunction(_) | ScalarExpr::AggregateFunction(_) - | ScalarExpr::UDFServerCall(_) + | ScalarExpr::UDFCall(_) | ScalarExpr::SubqueryExpr(_) ) }; @@ -694,7 +694,7 @@ impl<'a> Binder { // add check for SExpr to disable invalid source for copy/insert/merge/replace pub(crate) fn check_sexpr_top(&self, s_expr: &SExpr) -> Result { - let f = |scalar: &ScalarExpr| matches!(scalar, ScalarExpr::UDFServerCall(_)); + let f = |scalar: &ScalarExpr| matches!(scalar, ScalarExpr::UDFCall(_)); let mut finder = Finder::new(&f); Self::check_sexpr(s_expr, &mut finder) } @@ -806,7 +806,7 @@ impl<'a> Binder { scalar, ScalarExpr::WindowFunction(_) | ScalarExpr::AggregateFunction(_) - | ScalarExpr::UDFServerCall(_) + | ScalarExpr::UDFCall(_) ) }; let mut finder = Finder::new(&f); diff --git a/src/query/sql/src/planner/binder/scalar_common.rs b/src/query/sql/src/planner/binder/scalar_common.rs index 04bc62046a7f1..6bc1c2f3b4148 100644 --- a/src/query/sql/src/planner/binder/scalar_common.rs +++ b/src/query/sql/src/planner/binder/scalar_common.rs @@ -180,7 +180,7 @@ pub fn contain_subquery(scalar: &ScalarExpr) -> bool { } ScalarExpr::FunctionCall(func) => func.arguments.iter().any(contain_subquery), ScalarExpr::CastExpr(CastExpr { argument, .. }) => contain_subquery(argument), - ScalarExpr::UDFServerCall(udf) => udf.arguments.iter().any(contain_subquery), + ScalarExpr::UDFCall(udf) => udf.arguments.iter().any(contain_subquery), _ => false, } } diff --git a/src/query/sql/src/planner/binder/select.rs b/src/query/sql/src/planner/binder/select.rs index d9313ebc31730..c7229403fc76b 100644 --- a/src/query/sql/src/planner/binder/select.rs +++ b/src/query/sql/src/planner/binder/select.rs @@ -294,8 +294,12 @@ impl Binder { s_expr = self.bind_projection(&mut from_context, &projections, &scalar_items, s_expr)?; - // rewrite udf - let mut udf_rewriter = UdfRewriter::new(self.metadata.clone()); + // rewrite udf for interpreter udf + let mut udf_rewriter = UdfRewriter::new(self.metadata.clone(), true); + s_expr = udf_rewriter.rewrite(&s_expr)?; + + // rewrite udf for server udf + let mut udf_rewriter = UdfRewriter::new(self.metadata.clone(), false); s_expr = udf_rewriter.rewrite(&s_expr)?; // rewrite variant inner fields as virtual columns diff --git a/src/query/sql/src/planner/binder/sort.rs b/src/query/sql/src/planner/binder/sort.rs index 80e0b62220ddb..0c60b216c576e 100644 --- a/src/query/sql/src/planner/binder/sort.rs +++ b/src/query/sql/src/planner/binder/sort.rs @@ -40,7 +40,7 @@ use crate::plans::ScalarExpr; use crate::plans::ScalarItem; use crate::plans::Sort; use crate::plans::SortItem; -use crate::plans::UDFServerCall; +use crate::plans::UDFCall; use crate::plans::VisitorMut as _; use crate::BindContext; use crate::IndexType; @@ -382,7 +382,7 @@ impl Binder { target_type: target_type.clone(), })) } - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { let new_args = udf .arguments .iter() @@ -390,12 +390,12 @@ impl Binder { self.rewrite_scalar_with_replacement(bind_context, arg, replacement_fn) }) .collect::>>()?; - Ok(UDFServerCall { + Ok(UDFCall { span: udf.span, name: udf.name.clone(), func_name: udf.func_name.clone(), display_name: udf.display_name.clone(), - server_addr: udf.server_addr.clone(), + udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), return_type: udf.return_type.clone(), arguments: new_args, diff --git a/src/query/sql/src/planner/binder/udf.rs b/src/query/sql/src/planner/binder/udf.rs index 26e9217a1e5db..ee122142a27cb 100644 --- a/src/query/sql/src/planner/binder/udf.rs +++ b/src/query/sql/src/planner/binder/udf.rs @@ -24,6 +24,7 @@ use databend_common_expression::types::DataType; use databend_common_expression::udf_client::UDFFlightClient; use databend_common_meta_app::principal::LambdaUDF; use databend_common_meta_app::principal::UDFDefinition as PlanUDFDefinition; +use databend_common_meta_app::principal::UDFScript; use databend_common_meta_app::principal::UDFServer; use databend_common_meta_app::principal::UserDefinedFunction; @@ -118,6 +119,45 @@ impl Binder { created_on: Utc::now(), }) } + UDFDefinition::UDFScript { + arg_types, + return_type, + code, + handler, + language, + runtime_version, + } => { + let mut arg_datatypes = Vec::with_capacity(arg_types.len()); + for arg_type in arg_types { + arg_datatypes.push(DataType::from(&resolve_type_name(arg_type, true)?)); + } + let return_type = DataType::from(&resolve_type_name(return_type, true)?); + + if !["python", "javascript"].contains(&language.to_lowercase().as_str()) { + return Err(ErrorCode::InvalidArgument(format!( + "Unallowed UDF language '{language}', must be python or javascript" + ))); + } + + let mut runtime_version = runtime_version.to_string(); + if runtime_version.is_empty() && language.to_lowercase() == "python" { + runtime_version = "3.12.0".to_string(); + } + + Ok(UserDefinedFunction { + name: udf_name.to_string(), + description: udf_description.clone().unwrap_or_default(), + definition: PlanUDFDefinition::UDFScript(UDFScript { + code: code.clone(), + arg_types: arg_datatypes, + return_type, + handler: handler.clone(), + language: language.clone(), + runtime_version, + }), + created_on: Utc::now(), + }) + } } } diff --git a/src/query/sql/src/planner/format/display_rel_operator.rs b/src/query/sql/src/planner/format/display_rel_operator.rs index a6b0a27e8ca90..3c8500c52ec5d 100644 --- a/src/query/sql/src/planner/format/display_rel_operator.rs +++ b/src/query/sql/src/planner/format/display_rel_operator.rs @@ -96,7 +96,7 @@ pub fn format_scalar(scalar: &ScalarExpr) -> String { ) } ScalarExpr::SubqueryExpr(_) => "SUBQUERY".to_string(), - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { format!( "{}({})", &udf.func_name, diff --git a/src/query/sql/src/planner/optimizer/decorrelate/flatten_scalar.rs b/src/query/sql/src/planner/optimizer/decorrelate/flatten_scalar.rs index 5c46624b58534..82098a8c95088 100644 --- a/src/query/sql/src/planner/optimizer/decorrelate/flatten_scalar.rs +++ b/src/query/sql/src/planner/optimizer/decorrelate/flatten_scalar.rs @@ -23,7 +23,7 @@ use crate::plans::BoundColumnRef; use crate::plans::CastExpr; use crate::plans::FunctionCall; use crate::plans::ScalarExpr; -use crate::plans::UDFServerCall; +use crate::plans::UDFCall; impl SubqueryRewriter { pub(crate) fn flatten_scalar( @@ -88,18 +88,18 @@ impl SubqueryRewriter { target_type: cast_expr.target_type.clone(), })) } - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { let arguments = udf .arguments .iter() .map(|arg| self.flatten_scalar(arg, correlated_columns)) .collect::>>()?; - Ok(ScalarExpr::UDFServerCall(UDFServerCall { + Ok(ScalarExpr::UDFCall(UDFCall { span: udf.span, name: udf.name.clone(), func_name: udf.func_name.clone(), display_name: udf.display_name.clone(), - server_addr: udf.server_addr.clone(), + udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), return_type: udf.return_type.clone(), arguments, diff --git a/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs b/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs index a95261ce831c8..2759b56c51a96 100644 --- a/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs +++ b/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs @@ -48,8 +48,8 @@ use crate::plans::ScalarExpr; use crate::plans::ScalarItem; use crate::plans::SubqueryExpr; use crate::plans::SubqueryType; +use crate::plans::UDFCall; use crate::plans::UDFLambdaCall; -use crate::plans::UDFServerCall; use crate::plans::WindowFuncType; use crate::IndexType; use crate::MetadataRef; @@ -346,7 +346,7 @@ impl SubqueryRewriter { self.derived_columns.clear(); Ok((scalar, s_expr)) } - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { let mut args = vec![]; let mut s_expr = s_expr.clone(); for arg in udf.arguments.iter() { @@ -355,12 +355,12 @@ impl SubqueryRewriter { args.push(res.0); } - let expr: ScalarExpr = UDFServerCall { + let expr: ScalarExpr = UDFCall { span: udf.span, name: udf.name.clone(), func_name: udf.func_name.clone(), display_name: udf.display_name.clone(), - server_addr: udf.server_addr.clone(), + udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), return_type: udf.return_type.clone(), arguments: args, @@ -369,6 +369,7 @@ impl SubqueryRewriter { Ok((expr, s_expr)) } + ScalarExpr::UDFLambdaCall(udf) => { let mut s_expr = s_expr.clone(); let res = self.try_rewrite_subquery(&udf.scalar, &s_expr, false)?; diff --git a/src/query/sql/src/planner/optimizer/filter/pull_up_filter.rs b/src/query/sql/src/planner/optimizer/filter/pull_up_filter.rs index 7872841c57dee..857b035c2488a 100644 --- a/src/query/sql/src/planner/optimizer/filter/pull_up_filter.rs +++ b/src/query/sql/src/planner/optimizer/filter/pull_up_filter.rs @@ -233,7 +233,7 @@ impl PullUpFilterOptimizer { ScalarExpr::CastExpr(cast) => { Self::replace_predicate(&mut cast.argument, items, metadata)?; } - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { for arg in udf.arguments.iter_mut() { Self::replace_predicate(arg, items, metadata)?; } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs index 49283f2e1ac1f..6c4ddea151c40 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs @@ -39,7 +39,7 @@ use crate::plans::FunctionCall; use crate::plans::RelOperator; use crate::plans::ScalarItem; use crate::plans::SortItem; -use crate::plans::UDFServerCall; +use crate::plans::UDFCall; use crate::plans::VisitorMut; use crate::ColumnEntry; use crate::ColumnSet; @@ -823,7 +823,7 @@ impl RewriteInfomartion<'_> { .join(", ") ) } - ScalarExpr::UDFServerCall(udf) => format!( + ScalarExpr::UDFCall(udf) => format!( "{}({})", &udf.func_name, udf.arguments @@ -832,6 +832,7 @@ impl RewriteInfomartion<'_> { .collect::>() .join(", ") ), + _ => unreachable!(), // Window function and subquery will not appear in index. } } @@ -1124,19 +1125,19 @@ fn rewrite_query_item( .into(), ) } - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { let mut new_args = Vec::with_capacity(udf.arguments.len()); for arg in udf.arguments.iter() { let new_arg = rewrite_by_selection(query_info, arg, index_selection)?; new_args.push(new_arg); } Some( - UDFServerCall { + UDFCall { span: udf.span, name: udf.name.clone(), func_name: udf.func_name.clone(), display_name: udf.display_name.clone(), - server_addr: udf.server_addr.clone(), + udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), return_type: udf.return_type.clone(), arguments: new_args, @@ -1144,6 +1145,7 @@ fn rewrite_query_item( .into(), ) } + // TODO UDF interpreter ScalarExpr::AggregateFunction(_) => None, /* Aggregate function must appear in index selection. */ _ => unreachable!(), // Window function and subquery will not appear in index. } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs index 431803d0a88df..cc5c21a36b52f 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs @@ -33,7 +33,7 @@ use crate::plans::NthValueFunction; use crate::plans::RelOp; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; -use crate::plans::UDFServerCall; +use crate::plans::UDFCall; use crate::plans::WindowFunc; use crate::plans::WindowFuncType; use crate::plans::WindowOrderBy; @@ -202,19 +202,19 @@ impl RulePushDownFilterEvalScalar { target_type: cast.target_type.clone(), })) } - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { let arguments = udf .arguments .iter() .map(|arg| Self::replace_predicate(arg, items)) .collect::>>()?; - Ok(ScalarExpr::UDFServerCall(UDFServerCall { + Ok(ScalarExpr::UDFCall(UDFCall { span: udf.span, name: udf.name.clone(), func_name: udf.func_name.clone(), display_name: udf.display_name.clone(), - server_addr: udf.server_addr.clone(), + udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), return_type: udf.return_type.clone(), arguments, diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs index 9ca0e1d6505d0..46fc822c8dff8 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs @@ -32,7 +32,7 @@ use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::RelOp; use crate::plans::Scan; -use crate::plans::UDFServerCall; +use crate::plans::UDFCall; use crate::plans::WindowFunc; use crate::plans::WindowFuncType; use crate::plans::WindowOrderBy; @@ -305,7 +305,7 @@ impl RulePushDownFilterScan { target_type: cast.target_type.clone(), })) } - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { let arguments = udf .arguments .iter() @@ -319,17 +319,18 @@ impl RulePushDownFilterScan { }) .collect::>>()?; - Ok(ScalarExpr::UDFServerCall(UDFServerCall { + Ok(ScalarExpr::UDFCall(UDFCall { span: udf.span, name: udf.name.clone(), func_name: udf.func_name.clone(), display_name: udf.display_name.clone(), - server_addr: udf.server_addr.clone(), + udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), return_type: udf.return_type.clone(), arguments, })) } + _ => Ok(predicate.clone()), } } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_prewhere.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_prewhere.rs index 00eb07d224450..0da98729dccf8 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_prewhere.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_prewhere.rs @@ -82,7 +82,7 @@ impl RulePushDownPrewhere { Self::collect_columns_impl(table_index, schema, cast.argument.as_ref(), columns)?; } ScalarExpr::ConstantExpr(_) => {} - ScalarExpr::UDFServerCall(udf) => { + ScalarExpr::UDFCall(udf) => { for arg in udf.arguments.iter() { Self::collect_columns_impl(table_index, schema, arg, columns)?; } diff --git a/src/query/sql/src/planner/optimizer/s_expr.rs b/src/query/sql/src/planner/optimizer/s_expr.rs index 260bc6f846393..0bdf8ef7d441b 100644 --- a/src/query/sql/src/planner/optimizer/s_expr.rs +++ b/src/query/sql/src/planner/optimizer/s_expr.rs @@ -28,8 +28,8 @@ use crate::plans::Exchange; use crate::plans::RelOperator; use crate::plans::Scan; use crate::plans::SubqueryExpr; +use crate::plans::UDFCall; use crate::plans::UDFLambdaCall; -use crate::plans::UDFServerCall; use crate::plans::Visitor; use crate::plans::WindowFuncType; use crate::IndexType; @@ -466,7 +466,7 @@ pub fn get_udf_names(scalar: &ScalarExpr) -> Result> { } impl<'a> Visitor<'a> for FindUdfNamesVisitor<'a> { - fn visit_udf_server_call(&mut self, udf: &'a UDFServerCall) -> Result<()> { + fn visit_udf_call(&mut self, udf: &'a UDFCall) -> Result<()> { for expr in &udf.arguments { self.visit(expr)?; } diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index bf10b58901ac5..837e07075f7ce 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -24,6 +24,7 @@ use databend_common_expression::types::DataType; use databend_common_expression::RemoteExpr; use databend_common_expression::Scalar; use educe::Educe; +use enum_as_inner::EnumAsInner; use itertools::Itertools; use super::WindowFuncFrame; @@ -43,7 +44,7 @@ pub enum ScalarExpr { FunctionCall(FunctionCall), CastExpr(CastExpr), SubqueryExpr(SubqueryExpr), - UDFServerCall(UDFServerCall), + UDFCall(UDFCall), UDFLambdaCall(UDFLambdaCall), } @@ -117,7 +118,7 @@ impl ScalarExpr { }), ScalarExpr::CastExpr(expr) => expr.span.or(expr.argument.span()), ScalarExpr::SubqueryExpr(expr) => expr.span, - ScalarExpr::UDFServerCall(expr) => expr.span, + ScalarExpr::UDFCall(expr) => expr.span, ScalarExpr::UDFLambdaCall(expr) => expr.span, _ => None, } @@ -142,7 +143,7 @@ impl ScalarExpr { self.evaluable = false; Ok(()) } - fn visit_udf_server_call(&mut self, _: &'a UDFServerCall) -> Result<()> { + fn visit_udf_call(&mut self, _: &'a UDFCall) -> Result<()> { self.evaluable = false; Ok(()) } @@ -326,21 +327,19 @@ impl TryFrom for SubqueryExpr { } } -impl From for ScalarExpr { - fn from(v: UDFServerCall) -> Self { - Self::UDFServerCall(v) +impl From for ScalarExpr { + fn from(v: UDFCall) -> Self { + Self::UDFCall(v) } } -impl TryFrom for UDFServerCall { +impl TryFrom for UDFCall { type Error = ErrorCode; fn try_from(value: ScalarExpr) -> Result { - if let ScalarExpr::UDFServerCall(value) = value { + if let ScalarExpr::UDFCall(value) = value { Ok(value) } else { - Err(ErrorCode::Internal( - "Cannot downcast Scalar to UDFServerCall", - )) + Err(ErrorCode::Internal("Cannot downcast Scalar to UDFCall")) } } } @@ -586,9 +585,10 @@ fn hash_column_set(columns: &ColumnSet, state: &mut H) { columns.iter().for_each(|c| c.hash(state)); } +/// UDFCall includes server & lambda call #[derive(Clone, Debug, Educe)] #[educe(PartialEq, Eq, Hash)] -pub struct UDFServerCall { +pub struct UDFCall { #[educe(Hash(ignore), PartialEq(ignore), Eq(ignore))] pub span: Span, // name in meta @@ -596,10 +596,25 @@ pub struct UDFServerCall { // name in handler pub func_name: String, pub display_name: String, - pub server_addr: String, pub arg_types: Vec, pub return_type: Box, pub arguments: Vec, + pub udf_type: UDFType, +} + +#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize, EnumAsInner)] +pub enum UDFType { + Server(String), // server_addr + Script((String, String, String)), // Lang, Version, Code +} + +impl UDFType { + pub fn match_type(&self, is_script: bool) -> bool { + match self { + UDFType::Server(_) => !is_script, + UDFType::Script(_) => is_script, + } + } } #[derive(Clone, Debug, Educe)] @@ -653,7 +668,7 @@ pub trait Visitor<'a>: Sized { } Ok(()) } - fn visit_udf_server_call(&mut self, udf: &'a UDFServerCall) -> Result<()> { + fn visit_udf_call(&mut self, udf: &'a UDFCall) -> Result<()> { for expr in &udf.arguments { self.visit(expr)?; } @@ -795,11 +810,11 @@ pub trait VisitorWithParent<'a>: Sized { Ok(()) } - fn visit_udf_server_call( + fn visit_udf_call( &mut self, _parent: Option<&'a ScalarExpr>, current: &'a ScalarExpr, - udf: &'a UDFServerCall, + udf: &'a UDFCall, ) -> Result<()> { for expr in &udf.arguments { self.visit_with_parent(Some(current), expr)?; @@ -837,7 +852,7 @@ pub fn walk_expr_with_parent<'a, V: VisitorWithParent<'a>>( ScalarExpr::FunctionCall(func) => visitor.visit_function_call(parent, current, func), ScalarExpr::CastExpr(cast_expr) => visitor.visit_cast(parent, current, cast_expr), ScalarExpr::SubqueryExpr(subquery) => visitor.visit_subquery(parent, current, subquery), - ScalarExpr::UDFServerCall(udf) => visitor.visit_udf_server_call(parent, current, udf), + ScalarExpr::UDFCall(udf) => visitor.visit_udf_call(parent, current, udf), ScalarExpr::UDFLambdaCall(udf) => visitor.visit_udf_lambda_call(parent, current, udf), } } @@ -852,7 +867,7 @@ pub fn walk_expr<'a, V: Visitor<'a>>(visitor: &mut V, expr: &'a ScalarExpr) -> R ScalarExpr::FunctionCall(expr) => visitor.visit_function_call(expr), ScalarExpr::CastExpr(expr) => visitor.visit_cast(expr), ScalarExpr::SubqueryExpr(expr) => visitor.visit_subquery(expr), - ScalarExpr::UDFServerCall(expr) => visitor.visit_udf_server_call(expr), + ScalarExpr::UDFCall(expr) => visitor.visit_udf_call(expr), ScalarExpr::UDFLambdaCall(expr) => visitor.visit_udf_lambda_call(expr), } } @@ -924,7 +939,7 @@ pub trait VisitorMut<'a>: Sized { } Ok(()) } - fn visit_udf_server_call(&mut self, udf: &'a mut UDFServerCall) -> Result<()> { + fn visit_udf_call(&mut self, udf: &'a mut UDFCall) -> Result<()> { for expr in &mut udf.arguments { self.visit(expr)?; } @@ -949,7 +964,7 @@ pub fn walk_expr_mut<'a, V: VisitorMut<'a>>( ScalarExpr::FunctionCall(expr) => visitor.visit_function_call(expr), ScalarExpr::CastExpr(expr) => visitor.visit_cast_expr(expr), ScalarExpr::SubqueryExpr(expr) => visitor.visit_subquery_expr(expr), - ScalarExpr::UDFServerCall(expr) => visitor.visit_udf_server_call(expr), + ScalarExpr::UDFCall(expr) => visitor.visit_udf_call(expr), ScalarExpr::UDFLambdaCall(expr) => visitor.visit_udf_lambda_call(expr), } } diff --git a/src/query/sql/src/planner/plans/udf.rs b/src/query/sql/src/planner/plans/udf.rs index a963eee37b0c4..6f89a53fb4543 100644 --- a/src/query/sql/src/planner/plans/udf.rs +++ b/src/query/sql/src/planner/plans/udf.rs @@ -31,6 +31,7 @@ use crate::plans::ScalarItem; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Udf { pub items: Vec, + pub script_udf: bool, } impl Udf { diff --git a/src/query/sql/src/planner/semantic/lowering.rs b/src/query/sql/src/planner/semantic/lowering.rs index dfaad7b135128..5033229a61e66 100644 --- a/src/query/sql/src/planner/semantic/lowering.rs +++ b/src/query/sql/src/planner/semantic/lowering.rs @@ -244,7 +244,7 @@ impl ScalarExpr { data_type: subquery.data_type(), display_name: "DUMMY".to_string(), }, - ScalarExpr::UDFServerCall(udf) => RawExpr::ColumnRef { + ScalarExpr::UDFCall(udf) => RawExpr::ColumnRef { span: None, id: ColumnBindingBuilder::new( udf.display_name.clone(), @@ -256,6 +256,7 @@ impl ScalarExpr { data_type: (*udf.return_type).clone(), display_name: udf.display_name.clone(), }, + ScalarExpr::UDFLambdaCall(udf) => { let scalar = &udf.scalar; scalar.as_raw_expr() diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 91ce6e9551ff8..729789929f330 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -77,6 +77,7 @@ use databend_common_functions::GENERAL_LAMBDA_FUNCTIONS; use databend_common_functions::GENERAL_WINDOW_FUNCTIONS; use databend_common_meta_app::principal::LambdaUDF; use databend_common_meta_app::principal::UDFDefinition; +use databend_common_meta_app::principal::UDFScript; use databend_common_meta_app::principal::UDFServer; use databend_common_users::UserApiProvider; use indexmap::IndexMap; @@ -114,8 +115,9 @@ use crate::plans::ScalarExpr; use crate::plans::ScalarItem; use crate::plans::SubqueryExpr; use crate::plans::SubqueryType; +use crate::plans::UDFCall; use crate::plans::UDFLambdaCall; -use crate::plans::UDFServerCall; +use crate::plans::UDFType; use crate::plans::WindowFunc; use crate::plans::WindowFuncFrame; use crate::plans::WindowFuncFrameBound; @@ -2972,6 +2974,10 @@ impl<'a> TypeChecker<'a> { self.resolve_udf_server(span, name, arguments, udf_def) .await?, )), + UDFDefinition::UDFScript(udf_def) => Ok(Some( + self.resolve_udf_script(span, name, arguments, udf_def) + .await?, + )), } } @@ -3025,14 +3031,57 @@ impl<'a> TypeChecker<'a> { self.ctx.set_cacheable(false); Ok(Box::new(( - UDFServerCall { + UDFCall { + span, + name, + func_name: udf_definition.handler, + display_name, + udf_type: UDFType::Server(address.clone()), + arg_types: udf_definition.arg_types, + return_type: Box::new(udf_definition.return_type.clone()), + arguments: args, + } + .into(), + udf_definition.return_type.clone(), + ))) + } + + #[async_recursion::async_recursion] + #[async_backtrace::framed] + async fn resolve_udf_script( + &mut self, + span: Span, + name: String, + arguments: &[Expr], + udf_definition: UDFScript, + ) -> Result> { + let mut args = Vec::with_capacity(arguments.len()); + for (argument, dest_type) in arguments.iter().zip(udf_definition.arg_types.iter()) { + let box (arg, ty) = self.resolve(argument).await?; + if ty != *dest_type { + args.push(wrap_cast(&arg, dest_type)); + } else { + args.push(arg); + } + } + + let arg_names = arguments.iter().map(|arg| format!("{}", arg)).join(", "); + let display_name = format!("{}({})", udf_definition.handler, arg_names); + + self.ctx.set_cacheable(false); + Ok(Box::new(( + UDFCall { span, name, func_name: udf_definition.handler, display_name, - server_addr: udf_definition.address, arg_types: udf_definition.arg_types, return_type: Box::new(udf_definition.return_type.clone()), + udf_type: UDFType::Script(( + udf_definition.language, + udf_definition.runtime_version, + udf_definition.code, + )), arguments: args, } .into(), diff --git a/src/query/sql/src/planner/semantic/udf_rewriter.rs b/src/query/sql/src/planner/semantic/udf_rewriter.rs index ca3aacfc17825..1549b1a9ff2f3 100644 --- a/src/query/sql/src/planner/semantic/udf_rewriter.rs +++ b/src/query/sql/src/planner/semantic/udf_rewriter.rs @@ -26,7 +26,7 @@ use crate::plans::EvalScalar; use crate::plans::RelOperator; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; -use crate::plans::UDFServerCall; +use crate::plans::UDFCall; use crate::plans::Udf; use crate::plans::VisitorMut; use crate::ColumnBindingBuilder; @@ -46,16 +46,18 @@ pub(crate) struct UdfRewriter { /// Mapping: (udf function display name) -> (derived index) /// This is used to reuse already generated derived columns udf_functions_index_map: HashMap, + script_udf: bool, } impl UdfRewriter { - pub(crate) fn new(metadata: MetadataRef) -> Self { + pub(crate) fn new(metadata: MetadataRef, script_udf: bool) -> Self { Self { metadata, udf_arguments: Default::default(), udf_functions: Default::default(), udf_functions_map: Default::default(), udf_functions_index_map: Default::default(), + script_udf, } } @@ -74,7 +76,7 @@ impl UdfRewriter { RelOperator::EvalScalar(mut plan) => { for item in &plan.items { // The index of Udf item can be reused. - if let ScalarExpr::UDFServerCall(udf) = &item.scalar { + if let ScalarExpr::UDFCall(udf) = &item.scalar { self.udf_functions_index_map .insert(udf.display_name.clone(), item.index); } @@ -115,6 +117,7 @@ impl UdfRewriter { let udf_plan = Udf { items: mem::take(&mut self.udf_functions), + script_udf: self.script_udf, }; Arc::new(SExpr::create_unary(Arc::new(udf_plan.into()), child_expr)) } else { @@ -127,19 +130,23 @@ impl<'a> VisitorMut<'a> for UdfRewriter { fn visit(&mut self, expr: &'a mut ScalarExpr) -> Result<()> { walk_expr_mut(self, expr)?; // replace udf with derived column - if let ScalarExpr::UDFServerCall(udf) = expr { + if let ScalarExpr::UDFCall(udf) = expr { if let Some(column_ref) = self.udf_functions_map.get(&udf.display_name) { *expr = ScalarExpr::BoundColumnRef(column_ref.clone()); - } else { + } else if udf.udf_type.match_type(self.script_udf) { return Err(ErrorCode::Internal("Rewrite udf function failed")); } } Ok(()) } - fn visit_udf_server_call(&mut self, udf: &'a mut UDFServerCall) -> Result<()> { + fn visit_udf_call(&mut self, udf: &'a mut UDFCall) -> Result<()> { + if !udf.udf_type.match_type(self.script_udf) { + return Ok(()); + } + for (i, arg) in udf.arguments.iter_mut().enumerate() { - if let ScalarExpr::UDFServerCall(_) = arg { + if let ScalarExpr::UDFCall(_) = arg { return Err(ErrorCode::InvalidArgument( "the argument of UDF server call can't be a UDF server call", )); diff --git a/src/query/storages/system/src/user_functions_table.rs b/src/query/storages/system/src/user_functions_table.rs index 51ac6f1947514..e191c5dfb5d34 100644 --- a/src/query/storages/system/src/user_functions_table.rs +++ b/src/query/storages/system/src/user_functions_table.rs @@ -52,6 +52,11 @@ fn encode_arguments(udf_definition: &UDFDefinition) -> jsonb::Value { "return_type": &x.return_type.to_string(), })) .into(), + UDFDefinition::UDFScript(x) => (&json!({ + "arg_types": &x.arg_types.clone().into_iter().map(|dt| dt.to_string()).collect::>(), + "return_type": &x.return_type.to_string(), + })) + .into(), } } @@ -98,6 +103,7 @@ impl AsyncSystemTable for UserFunctionsTable { udfs.get(i).map_or("", |udf| match &udf.definition { UDFDefinition::LambdaUDF(_) => "SQL", UDFDefinition::UDFServer(x) => &x.language, + UDFDefinition::UDFScript(x) => &x.language, }) }) .collect(); diff --git a/tests/sqllogictests/suites/base/03_common/03_0013_select_udf.test b/tests/sqllogictests/suites/base/03_common/03_0013_select_udf.test index 5c53c1363d757..d405d6e94a5c7 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0013_select_udf.test +++ b/tests/sqllogictests/suites/base/03_common/03_0013_select_udf.test @@ -29,3 +29,27 @@ DROP FUNCTION cal statement ok DROP FUNCTION notnull + +## test js udf +statement ok +CREATE FUNCTION gcd (INT, INT) RETURNS BIGINT LANGUAGE python HANDLER = 'gcd_js' AS $$ + export function gcd_js(a, b) { + while (b != 0) { + let t = b; + b = a % b; + a = t; + } + return a; + } +$$ + +query I +select number, gcd(number * 3, number * 6) from numbers(5) where number > 0 order by 1; +---- +1 3 +2 6 +3 9 +4 12 + +statement ok +DROP FUNCTION gcd