diff --git a/cpp/src/gandiva/codegen/CMakeLists.txt b/cpp/src/gandiva/codegen/CMakeLists.txt index 5083b9c5e1bf2..7ebf08618cd9c 100644 --- a/cpp/src/gandiva/codegen/CMakeLists.txt +++ b/cpp/src/gandiva/codegen/CMakeLists.txt @@ -35,8 +35,10 @@ set(SRC_FILES annotator.cc function_signature.cc llvm_generator.cc llvm_types.cc + like_holder.cc projector.cc selection_vector.cc + regex_util.cc status.cc tree_expr_builder.cc ${BC_FILE_PATH_CC}) @@ -84,17 +86,16 @@ install( #args: label test-file src-files add_gandiva_unit_test(bitmap_accumulator_test.cc bitmap_accumulator.cc) -add_gandiva_unit_test(dex_llvm_test.cc) add_gandiva_unit_test(engine_llvm_test.cc engine.cc llvm_types.cc status.cc configuration.cc ${BC_FILE_PATH_CC}) add_gandiva_unit_test(function_signature_test.cc function_signature.cc) add_gandiva_unit_test(function_registry_test.cc function_registry.cc function_signature.cc) add_gandiva_unit_test(llvm_types_test.cc llvm_types.cc) -add_gandiva_unit_test(llvm_generator_test.cc llvm_generator.cc engine.cc llvm_types.cc expr_decomposer.cc function_registry.cc annotator.cc status.cc bitmap_accumulator.cc configuration.cc function_signature.cc ${BC_FILE_PATH_CC}) +add_gandiva_unit_test(llvm_generator_test.cc llvm_generator.cc regex_util.cc engine.cc llvm_types.cc expr_decomposer.cc function_registry.cc annotator.cc status.cc bitmap_accumulator.cc configuration.cc function_signature.cc like_holder.cc regex_util.cc ${BC_FILE_PATH_CC}) add_gandiva_unit_test(annotator_test.cc annotator.cc function_signature.cc) -add_gandiva_unit_test(tree_expr_test.cc tree_expr_builder.cc expr_decomposer.cc annotator.cc function_registry.cc function_signature.cc) -add_gandiva_unit_test(expr_decomposer_test.cc expr_decomposer.cc tree_expr_builder.cc annotator.cc function_registry.cc function_signature.cc) +add_gandiva_unit_test(tree_expr_test.cc tree_expr_builder.cc expr_decomposer.cc annotator.cc function_registry.cc function_signature.cc like_holder.cc regex_util.cc status.cc) +add_gandiva_unit_test(expr_decomposer_test.cc expr_decomposer.cc tree_expr_builder.cc annotator.cc function_registry.cc function_signature.cc like_holder.cc regex_util.cc status.cc) add_gandiva_unit_test(status_test.cc status.cc) add_gandiva_unit_test(expression_registry_test.cc llvm_types.cc expression_registry.cc function_signature.cc function_registry.cc) add_gandiva_unit_test(selection_vector_test.cc selection_vector.cc status.cc) add_gandiva_unit_test(lru_cache_test.cc) - +add_gandiva_unit_test(like_holder_test.cc like_holder.cc regex_util.cc status.cc) diff --git a/cpp/src/gandiva/codegen/dex.h b/cpp/src/gandiva/codegen/dex.h index fef004460eab7..4484d37db165c 100644 --- a/cpp/src/gandiva/codegen/dex.h +++ b/cpp/src/gandiva/codegen/dex.h @@ -21,6 +21,7 @@ #include "codegen/dex_visitor.h" #include "codegen/field_descriptor.h" #include "codegen/func_descriptor.h" +#include "codegen/function_holder.h" #include "codegen/literal_holder.h" #include "codegen/native_function.h" #include "codegen/value_validity_pair.h" @@ -104,20 +105,24 @@ class LocalBitMapValidityDex : public Dex { class FuncDex : public Dex { public: FuncDex(FuncDescriptorPtr func_descriptor, const NativeFunction *native_function, - const ValueValidityPairVector &args) + FunctionHolderPtr function_holder, const ValueValidityPairVector &args) : func_descriptor_(func_descriptor), native_function_(native_function), + function_holder_(function_holder), args_(args) {} FuncDescriptorPtr func_descriptor() const { return func_descriptor_; } const NativeFunction *native_function() const { return native_function_; } + FunctionHolderPtr function_holder() const { return function_holder_; } + const ValueValidityPairVector &args() const { return args_; } private: FuncDescriptorPtr func_descriptor_; const NativeFunction *native_function_; + FunctionHolderPtr function_holder_; ValueValidityPairVector args_; }; @@ -127,8 +132,9 @@ class NonNullableFuncDex : public FuncDex { public: NonNullableFuncDex(FuncDescriptorPtr func_descriptor, const NativeFunction *native_function, + FunctionHolderPtr function_holder, const ValueValidityPairVector &args) - : FuncDex(func_descriptor, native_function, args) {} + : FuncDex(func_descriptor, native_function, function_holder, args) {} void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } }; @@ -139,8 +145,9 @@ class NullableNeverFuncDex : public FuncDex { public: NullableNeverFuncDex(FuncDescriptorPtr func_descriptor, const NativeFunction *native_function, + FunctionHolderPtr function_holder, const ValueValidityPairVector &args) - : FuncDex(func_descriptor, native_function, args) {} + : FuncDex(func_descriptor, native_function, function_holder, args) {} void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } }; @@ -151,8 +158,9 @@ class NullableInternalFuncDex : public FuncDex { public: NullableInternalFuncDex(FuncDescriptorPtr func_descriptor, const NativeFunction *native_function, + FunctionHolderPtr function_holder, const ValueValidityPairVector &args, int local_bitmap_idx) - : FuncDex(func_descriptor, native_function, args), + : FuncDex(func_descriptor, native_function, function_holder, args), local_bitmap_idx_(local_bitmap_idx) {} void Accept(DexVisitor &visitor) override { visitor.Visit(*this); } diff --git a/cpp/src/gandiva/codegen/dex_llvm_test.cc b/cpp/src/gandiva/codegen/dex_llvm_test.cc deleted file mode 100644 index 78278967b81e1..0000000000000 --- a/cpp/src/gandiva/codegen/dex_llvm_test.cc +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright (C) 2017-2018 Dremio Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "codegen/dex.h" - -#include -#include - -#include - -namespace gandiva { - -class TestDex : public ::testing::Test { - protected: - void SetUp() { - name_map_[&typeid(VectorReadValidityDex)] = "VectorReadValidityDex"; - name_map_[&typeid(VectorReadFixedLenValueDex)] = "VectorReadFixedLenValueDex"; - name_map_[&typeid(VectorReadVarLenValueDex)] = "VectorReadVarLenValueDex"; - name_map_[&typeid(LocalBitMapValidityDex)] = "LocalBitMapValidityDex"; - name_map_[&typeid(NonNullableFuncDex)] = "NonNullableFuncDex"; - name_map_[&typeid(NullableNeverFuncDex)] = "NullableNeverFuncDex"; - name_map_[&typeid(NullableInternalFuncDex)] = "NullableInternalFuncDex"; - name_map_[&typeid(TrueDex)] = "TrueDex"; - name_map_[&typeid(FalseDex)] = "FalseDex"; - name_map_[&typeid(LiteralDex)] = "LiteralDex"; - name_map_[&typeid(IfDex)] = "IfDex"; - name_map_[&typeid(BooleanAndDex)] = "BooleanAndDex"; - name_map_[&typeid(BooleanOrDex)] = "BooleanOrDex"; - } - - std::map name_map_; -}; - -TEST_F(TestDex, TestVisitor) { - class TestVisitor : public DexVisitor { - public: - TestVisitor(std::map *map, std::string *result) - : map_(map), result_(result) {} - - void Visit(const VectorReadValidityDex &dex) override { - *result_ = (*map_)[&typeid(dex)]; - } - - void Visit(const VectorReadFixedLenValueDex &dex) override { - *result_ = (*map_)[&typeid(dex)]; - } - - void Visit(const VectorReadVarLenValueDex &dex) override { - *result_ = (*map_)[&typeid(dex)]; - } - - void Visit(const LocalBitMapValidityDex &dex) override { - *result_ = (*map_)[&typeid(dex)]; - } - - void Visit(const TrueDex &dex) override { *result_ = (*map_)[&typeid(dex)]; } - - void Visit(const FalseDex &dex) override { *result_ = (*map_)[&typeid(dex)]; } - - void Visit(const LiteralDex &dex) override { *result_ = (*map_)[&typeid(dex)]; } - - void Visit(const NonNullableFuncDex &dex) override { - *result_ = (*map_)[&typeid(dex)]; - } - - void Visit(const NullableNeverFuncDex &dex) override { - *result_ = (*map_)[&typeid(dex)]; - } - - void Visit(const NullableInternalFuncDex &dex) override { - *result_ = (*map_)[&typeid(dex)]; - } - - void Visit(const IfDex &dex) override { *result_ = (*map_)[&typeid(dex)]; } - - void Visit(const BooleanAndDex &dex) override { *result_ = (*map_)[&typeid(dex)]; } - - void Visit(const BooleanOrDex &dex) override { *result_ = (*map_)[&typeid(dex)]; } - - private: - std::map *map_; - std::string *result_; - }; - - std::string desc; - TestVisitor visitor(&name_map_, &desc); - - FieldPtr field = arrow::field("abc", arrow::int32()); - FieldDescriptorPtr field_desc = std::make_shared(field, 0, 1, 2); - VectorReadValidityDex vv_dex(field_desc); - vv_dex.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(VectorReadValidityDex)]); - - VectorReadFixedLenValueDex vd_dex(field_desc); - vd_dex.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(VectorReadFixedLenValueDex)]); - - LocalBitMapValidityDex local_bitmap_dex(0); - local_bitmap_dex.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(LocalBitMapValidityDex)]); - - std::vector params{arrow::int32()}; - FuncDescriptorPtr my_func = - std::make_shared("abc", params, arrow::boolean()); - - NonNullableFuncDex non_nullable_func(my_func, nullptr, {nullptr}); - non_nullable_func.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(NonNullableFuncDex)]); - - NullableNeverFuncDex nullable_func(my_func, nullptr, {nullptr}); - nullable_func.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(NullableNeverFuncDex)]); - - NullableInternalFuncDex nullable_internal_func(my_func, nullptr, {nullptr}, 0); - nullable_internal_func.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(NullableInternalFuncDex)]); - - IfDex if_dex(nullptr, nullptr, nullptr, arrow::int32(), 0, false); - if_dex.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(IfDex)]); - - BooleanAndDex and_dex({nullptr}, 0); - and_dex.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(BooleanAndDex)]); - - BooleanOrDex or_dex({nullptr}, 0); - or_dex.Accept(visitor); - EXPECT_EQ(desc, name_map_[&typeid(BooleanOrDex)]); -} - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} - -} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/expr_decomposer.cc b/cpp/src/gandiva/codegen/expr_decomposer.cc index 905fef43dd5d0..d477407c76e2e 100644 --- a/cpp/src/gandiva/codegen/expr_decomposer.cc +++ b/cpp/src/gandiva/codegen/expr_decomposer.cc @@ -21,6 +21,7 @@ #include "codegen/annotator.h" #include "codegen/dex.h" +#include "codegen/function_holder_registry.h" #include "codegen/function_registry.h" #include "codegen/node.h" #include "gandiva/function_signature.h" @@ -53,10 +54,19 @@ Status ExprDecomposer::Visit(const FunctionNode &node) { // decompose the children. std::vector args; for (auto &child : node.children()) { - child->Accept(*this); + auto status = child->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + args.push_back(result()); } + // Make a function holder, if required. + std::shared_ptr holder; + if (native_function->needs_holder()) { + auto status = FunctionHolderRegistry::Make(desc->name(), node, &holder); + GANDIVA_RETURN_NOT_OK(status); + } + if (native_function->result_nullable_type() == RESULT_NULL_IF_NULL) { // These functions are decomposable, merge the validity bits of the children. @@ -68,11 +78,13 @@ Status ExprDecomposer::Visit(const FunctionNode &node) { decomposed->validity_exprs().end()); } - auto value_dex = std::make_shared(desc, native_function, args); + auto value_dex = + std::make_shared(desc, native_function, holder, args); result_ = std::make_shared(merged_validity, value_dex); } else if (native_function->result_nullable_type() == RESULT_NULL_NEVER) { // These functions always output valid results. So, no validity dex. - auto value_dex = std::make_shared(desc, native_function, args); + auto value_dex = + std::make_shared(desc, native_function, holder, args); result_ = std::make_shared(value_dex); } else { DCHECK(native_function->result_nullable_type() == RESULT_NULL_INTERNAL); @@ -81,8 +93,8 @@ Status ExprDecomposer::Visit(const FunctionNode &node) { int local_bitmap_idx = annotator_.AddLocalBitMap(); auto validity_dex = std::make_shared(local_bitmap_idx); - auto value_dex = std::make_shared(desc, native_function, - args, local_bitmap_idx); + auto value_dex = std::make_shared( + desc, native_function, holder, args, local_bitmap_idx); result_ = std::make_shared(validity_dex, value_dex); } return Status::OK(); @@ -91,16 +103,19 @@ Status ExprDecomposer::Visit(const FunctionNode &node) { // Decompose an IfNode Status ExprDecomposer::Visit(const IfNode &node) { // Add a local bitmap to track the output validity. - node.condition()->Accept(*this); + auto status = node.condition()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); auto condition_vv = result(); int local_bitmap_idx = PushThenEntry(node); - node.then_node()->Accept(*this); + status = node.then_node()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); auto then_vv = result(); PopThenEntry(node); PushElseEntry(node, local_bitmap_idx); - node.else_node()->Accept(*this); + status = node.else_node()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); auto else_vv = result(); bool is_terminal_else = PopElseEntry(node); @@ -118,7 +133,9 @@ Status ExprDecomposer::Visit(const BooleanNode &node) { // decompose the children. std::vector args; for (auto &child : node.children()) { - child->Accept(*this); + auto status = child->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + args.push_back(result()); } diff --git a/cpp/src/gandiva/codegen/expr_decomposer.h b/cpp/src/gandiva/codegen/expr_decomposer.h index b9cc889a76657..12fcd9e3baf00 100644 --- a/cpp/src/gandiva/codegen/expr_decomposer.h +++ b/cpp/src/gandiva/codegen/expr_decomposer.h @@ -36,9 +36,12 @@ class ExprDecomposer : public NodeVisitor { explicit ExprDecomposer(const FunctionRegistry ®istry, Annotator &annotator) : registry_(registry), annotator_(annotator) {} - ValueValidityPairPtr Decompose(const Node &root) { - root.Accept(*this); - return result(); + Status Decompose(const Node &root, ValueValidityPairPtr *out) { + auto status = root.Accept(*this); + if (status.ok()) { + *out = std::move(result_); + } + return status; } private: diff --git a/cpp/src/gandiva/codegen/expression_registry.cc b/cpp/src/gandiva/codegen/expression_registry.cc index 8791170476a95..0a5875b5b9341 100644 --- a/cpp/src/gandiva/codegen/expression_registry.cc +++ b/cpp/src/gandiva/codegen/expression_registry.cc @@ -29,27 +29,26 @@ ExpressionRegistry::~ExpressionRegistry() {} const ExpressionRegistry::FunctionSignatureIterator ExpressionRegistry::function_signature_begin() { - return FunctionSignatureIterator(function_registry_->begin(), - function_registry_->end()); + return FunctionSignatureIterator(function_registry_->begin()); } const ExpressionRegistry::FunctionSignatureIterator ExpressionRegistry::function_signature_end() const { - return FunctionSignatureIterator(function_registry_->end(), function_registry_->end()); + return FunctionSignatureIterator(function_registry_->end()); } bool ExpressionRegistry::FunctionSignatureIterator::operator!=( const FunctionSignatureIterator &func_sign_it) { - return func_sign_it.it != this->it; + return func_sign_it.it_ != this->it_; } FunctionSignature ExpressionRegistry::FunctionSignatureIterator::operator*() { - return (*it).signature(); + return (*it_).signature(); } ExpressionRegistry::iterator ExpressionRegistry::FunctionSignatureIterator::operator++( int increment) { - return it++; + return it_++; } DataTypeVector ExpressionRegistry::supported_types_ = diff --git a/cpp/src/gandiva/codegen/expression_registry.h b/cpp/src/gandiva/codegen/expression_registry.h index 3de870f602046..dba698a117d58 100644 --- a/cpp/src/gandiva/codegen/expression_registry.h +++ b/cpp/src/gandiva/codegen/expression_registry.h @@ -39,7 +39,7 @@ class ExpressionRegistry { static DataTypeVector supported_types() { return supported_types_; } class FunctionSignatureIterator { public: - FunctionSignatureIterator(iterator begin, iterator end) : it(begin), end(end) {} + FunctionSignatureIterator(iterator it) : it_(it) {} bool operator!=(const FunctionSignatureIterator &func_sign_it); @@ -48,8 +48,7 @@ class ExpressionRegistry { iterator operator++(int); private: - iterator it; - iterator end; + iterator it_; }; const FunctionSignatureIterator function_signature_begin(); const FunctionSignatureIterator function_signature_end() const; diff --git a/cpp/src/gandiva/codegen/function_holder.h b/cpp/src/gandiva/codegen/function_holder.h new file mode 100644 index 0000000000000..d5f9c4ee42543 --- /dev/null +++ b/cpp/src/gandiva/codegen/function_holder.h @@ -0,0 +1,30 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FUNCTION_HOLDER_H +#define GANDIVA_FUNCTION_HOLDER_H + +namespace gandiva { + +/// Holder for a function that can be invoked from LLVM. +class FunctionHolder { + public: + virtual ~FunctionHolder() = default; +}; + +using FunctionHolderPtr = std::shared_ptr; + +} // namespace gandiva + +#endif // GANDIVA_FUNCTION_HOLDER_H diff --git a/cpp/src/gandiva/codegen/function_holder_registry.h b/cpp/src/gandiva/codegen/function_holder_registry.h new file mode 100644 index 0000000000000..876bfee3ecf30 --- /dev/null +++ b/cpp/src/gandiva/codegen/function_holder_registry.h @@ -0,0 +1,62 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_FUNCTION_HOLDER_REGISTRY_H +#define GANDIVA_FUNCTION_HOLDER_REGISTRY_H + +#include "codegen/function_holder.h" +#include "codegen/like_holder.h" +#include "codegen/node.h" +#include "gandiva/status.h" + +namespace gandiva { + +#define LAMBDA_MAKER(derived) \ + [](const FunctionNode &node, FunctionHolderPtr *holder) { \ + std::shared_ptr derived_instance; \ + auto status = derived::Make(node, &derived_instance); \ + if (status.ok()) { \ + *holder = derived_instance; \ + } \ + return status; \ + } + +/// Static registry of function holders. +class FunctionHolderRegistry { + public: + using maker_type = std::function; + using map_type = std::unordered_map; + + static Status Make(const std::string &name, const FunctionNode &node, + FunctionHolderPtr *holder) { + auto found = makers().find(name); + if (found == makers().end()) { + return Status::Invalid("function holder not registered for function " + name); + } + + return found->second(node, holder); + } + + private: + static map_type &makers() { + static map_type maker_map = { + {"like", LAMBDA_MAKER(LikeHolder)}, + }; + return maker_map; + } +}; + +} // namespace gandiva + +#endif // GANDIVA_FUNCTION_HOLDER_REGISTRY_H diff --git a/cpp/src/gandiva/codegen/function_registry.cc b/cpp/src/gandiva/codegen/function_registry.cc index cd9aac225bbe6..d9cfcbc689767 100644 --- a/cpp/src/gandiva/codegen/function_registry.cc +++ b/cpp/src/gandiva/codegen/function_registry.cc @@ -339,8 +339,11 @@ NativeFunction FunctionRegistry::pc_registry_[] = { VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, greater_than), VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, greater_than_or_equal_to), + NativeFunction("like", DataTypeVector{utf8(), utf8()}, boolean(), true /*null_safe*/, + RESULT_NULL_IF_NULL, "like_utf8_utf8", true /*needs_holder*/), + // Null internal (sample) - NativeFunction("half_or_null", DataTypeVector{int32()}, int32(), true, + NativeFunction("half_or_null", DataTypeVector{int32()}, int32(), true /*null_safe*/, RESULT_NULL_INTERNAL, "half_or_null_int32"), }; // namespace gandiva diff --git a/cpp/src/gandiva/codegen/like_holder.cc b/cpp/src/gandiva/codegen/like_holder.cc new file mode 100644 index 0000000000000..34bdc61722c71 --- /dev/null +++ b/cpp/src/gandiva/codegen/like_holder.cc @@ -0,0 +1,59 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/like_holder.h" + +#include +#include "codegen/node.h" +#include "codegen/regex_util.h" + +namespace gandiva { + +Status LikeHolder::Make(const FunctionNode &node, std::shared_ptr *holder) { + if (node.children().size() != 2) { + return Status::Invalid("'like' function requires two parameters"); + } + + auto literal = dynamic_cast(node.children().at(1).get()); + if (literal == nullptr) { + return Status::Invalid("'like' function requires a literal as the second parameter"); + } + + auto literal_type = literal->return_type()->id(); + if (literal_type != arrow::Type::STRING && literal_type != arrow::Type::BINARY) { + return Status::Invalid( + "'like' function requires a string literal as the second parameter"); + } + auto pattern = boost::get(literal->holder()); + return Make(pattern, holder); +} + +Status LikeHolder::Make(const std::string &sql_pattern, + std::shared_ptr *holder) { + std::string posix_pattern; + auto status = RegexUtil::SqlLikePatternToPosix(sql_pattern, posix_pattern); + GANDIVA_RETURN_NOT_OK(status); + + *holder = std::shared_ptr(new LikeHolder(posix_pattern)); + return Status::OK(); +} + +// Wrapper C functions for "like" to be invoked from LLVM. +extern "C" bool like_utf8_utf8(int64_t ptr, const char *data, int data_len, + const char *pattern, int pattern_len) { + LikeHolder *holder = reinterpret_cast(ptr); + return (*holder)(std::string(data, data_len)); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/like_holder.h b/cpp/src/gandiva/codegen/like_holder.h new file mode 100644 index 0000000000000..5875a3bcf64d4 --- /dev/null +++ b/cpp/src/gandiva/codegen/like_holder.h @@ -0,0 +1,46 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_LIKE_HOLDER_H +#define GANDIVA_LIKE_HOLDER_H + +#include +#include "codegen/function_holder.h" +#include "codegen/node.h" +#include "gandiva/status.h" + +namespace gandiva { + +/// Function Holder for SQL 'like' +class LikeHolder : public FunctionHolder { + public: + ~LikeHolder() override = default; + + static Status Make(const FunctionNode &node, std::shared_ptr *holder); + + static Status Make(const std::string &sql_pattern, std::shared_ptr *holder); + + /// Return true if the data matches the pattern. + bool operator()(const std::string &data) { return std::regex_match(data, regex_); } + + private: + LikeHolder(const std::string &pattern) : pattern_(pattern), regex_(pattern) {} + + std::string pattern_; // posix pattern string, to help debugging + std::regex regex_; // compiled regex for the pattern +}; + +} // namespace gandiva + +#endif // GANDIVA_LIKE_HOLDER_H diff --git a/cpp/src/gandiva/codegen/like_holder_test.cc b/cpp/src/gandiva/codegen/like_holder_test.cc new file mode 100644 index 0000000000000..d349e4f0726d8 --- /dev/null +++ b/cpp/src/gandiva/codegen/like_holder_test.cc @@ -0,0 +1,81 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/like_holder.h" +#include "codegen/regex_util.h" + +#include +#include + +#include + +namespace gandiva { + +class TestLikeHolder : public ::testing::Test {}; + +TEST_F(TestLikeHolder, TestMatchAny) { + std::shared_ptr like_holder; + + auto status = LikeHolder::Make("ab%", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto like = *like_holder; + EXPECT_TRUE(like("ab")); + EXPECT_TRUE(like("abc")); + EXPECT_TRUE(like("abcd")); + + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("cab")); +} + +TEST_F(TestLikeHolder, TestMatchOne) { + std::shared_ptr like_holder; + + auto status = LikeHolder::Make("ab_", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto like = *like_holder; + EXPECT_TRUE(like("abc")); + EXPECT_TRUE(like("abd")); + + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("abcd")); + EXPECT_FALSE(like("dabc")); +} + +TEST_F(TestLikeHolder, TestPosixSpecial) { + std::shared_ptr like_holder; + + auto status = LikeHolder::Make(".*ab_", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto like = *like_holder; + EXPECT_TRUE(like(".*abc")); // . and * aren't special in sql regex + EXPECT_FALSE(like("xxabc")); +} + +TEST_F(TestLikeHolder, TestRegexEscape) { + std::string res; + auto status = RegexUtil::SqlLikePatternToPosix("#%hello#_abc_def##", '#', res); + EXPECT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(res, "%hello_abc.def#"); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/llvm_generator.cc b/cpp/src/gandiva/codegen/llvm_generator.cc index 8206b229a5b83..6e92a8420e8cd 100644 --- a/cpp/src/gandiva/codegen/llvm_generator.cc +++ b/cpp/src/gandiva/codegen/llvm_generator.cc @@ -52,13 +52,13 @@ Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr out // decompose the expression to separate out value and validities. ExprDecomposer decomposer(function_registry_, annotator_); - ValueValidityPairPtr value_validity = decomposer.Decompose(*expr->root()); + ValueValidityPairPtr value_validity; + auto status = decomposer.Decompose(*expr->root(), &value_validity); + GANDIVA_RETURN_NOT_OK(status); // Generate the IR function for the decomposed expression. llvm::Function *ir_function = nullptr; - - Status status = - CodeGenExprValue(value_validity->value_expr(), output, idx, &ir_function); + status = CodeGenExprValue(value_validity->value_expr(), output, idx, &ir_function); GANDIVA_RETURN_NOT_OK(status); std::unique_ptr compiled_expr( @@ -69,14 +69,17 @@ Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr out /// Build and optimise module for projection expression. Status LLVMGenerator::Build(const ExpressionVector &exprs) { + Status status; + for (auto &expr : exprs) { auto output = annotator_.AddOutputFieldDescriptor(expr->result()); - Add(expr, output); + status = Add(expr, output); + GANDIVA_RETURN_NOT_OK(status); } // optimise, compile and finalize the module - Status result = engine_->FinalizeModule(optimise_ir_, dump_ir_); - GANDIVA_RETURN_NOT_OK(result); + status = engine_->FinalizeModule(optimise_ir_, dump_ir_); + GANDIVA_RETURN_NOT_OK(status); // setup the jit functions for each expression. for (auto &compiled_expr : compiled_exprs_) { @@ -340,18 +343,45 @@ void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr &compiled_expr, accumulator.ComputeResult(dst_bitmap); } +void LLVMGenerator::CheckAndAddPrototype(const std::string &full_name, + llvm::Type *ret_type, + const std::vector &args) { + auto fn = module()->getFunction(full_name); + if (fn != nullptr) { + // prototype already added to module. + return; + } + + // Create fn prototype for evaluation + std::vector arg_types; + for (auto &value : args) { + arg_types.push_back(value->getType()); + } + llvm::FunctionType *prototype = + llvm::FunctionType::get(ret_type, arg_types, false /*isVarArg*/); + + fn = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, full_name, + module()); + DCHECK_NE(fn, nullptr) << " cpp function " << full_name << " does not exist"; +} + llvm::Value *LLVMGenerator::AddFunctionCall(const std::string &full_name, llvm::Type *ret_type, - const std::vector &args) { - // add to list of functions that need to be compiled - engine_->AddFunctionToCompile(full_name); + const std::vector &args, + bool has_holder) { + if (has_holder) { + CheckAndAddPrototype(full_name, ret_type, args); + } else { + // add to list of functions that need to be compiled + engine_->AddFunctionToCompile(full_name); + } // find the llvm function. llvm::Function *fn = module()->getFunction(full_name); - DCHECK(fn != NULL); + DCHECK_NE(fn, nullptr) << "missing function " + full_name; - if (enable_ir_traces_ && full_name.compare("printf") != 0 && - full_name.compare("printff") != 0) { + if (enable_ir_traces_ && !full_name.compare("printf") && + !full_name.compare("printff")) { // Trace for debugging ADD_TRACE("invoke native fn " + full_name); } @@ -551,12 +581,13 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex &dex) { LLVMTypes *types = generator_->types_.get(); // build the function params (ignore validity). - auto params = BuildParams(dex.args(), false); + auto params = BuildParams(dex.function_holder().get(), dex.args(), false); const NativeFunction *native_function = dex.native_function(); llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); - llvm::Value *value = - generator_->AddFunctionCall(native_function->pc_name(), ret_type, params); + + llvm::Value *value = generator_->AddFunctionCall( + native_function->pc_name(), ret_type, params, native_function->needs_holder()); result_.reset(new LValue(value)); } @@ -565,12 +596,12 @@ void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex &dex) { LLVMTypes *types = generator_->types_.get(); // build function params along with validity. - auto params = BuildParams(dex.args(), true); + auto params = BuildParams(dex.function_holder().get(), dex.args(), true); const NativeFunction *native_function = dex.native_function(); llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); - llvm::Value *value = - generator_->AddFunctionCall(native_function->pc_name(), ret_type, params); + llvm::Value *value = generator_->AddFunctionCall( + native_function->pc_name(), ret_type, params, native_function->needs_holder()); result_.reset(new LValue(value)); } @@ -581,7 +612,7 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex &dex) { LLVMTypes *types = generator_->types_.get(); // build function params along with validity. - auto params = BuildParams(dex.args(), true); + auto params = BuildParams(dex.function_holder().get(), dex.args(), true); // add an extra arg for validity (alloced on stack). llvm::AllocaInst *result_valid_ptr = @@ -590,8 +621,8 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex &dex) { const NativeFunction *native_function = dex.native_function(); llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); - llvm::Value *value = - generator_->AddFunctionCall(native_function->pc_name(), ret_type, params); + llvm::Value *value = generator_->AddFunctionCall( + native_function->pc_name(), ret_type, params, native_function->needs_holder()); // load the result validity and truncate to i1. llvm::Value *result_valid_i8 = builder.CreateLoad(result_valid_ptr); @@ -826,9 +857,18 @@ LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair } std::vector LLVMGenerator::Visitor::BuildParams( - const ValueValidityPairVector &args, bool with_validity) { - // build the function params, along with the validities. + FunctionHolder *holder, const ValueValidityPairVector &args, bool with_validity) { + LLVMTypes *types = generator_->types_.get(); std::vector params; + + // if the function has holder, add the holder pointer first. + if (holder != nullptr) { + llvm::Constant *ptr_int_cast = types->i64_constant((int64_t)holder); + auto ptr = llvm::ConstantExpr::getIntToPtr(ptr_int_cast, types->i8_ptr_type()); + params.push_back(ptr); + } + + // build the function params, along with the validities. for (auto &pair : args) { // build value. DexPtr value_expr = pair->value_expr(); @@ -850,9 +890,7 @@ std::vector LLVMGenerator::Visitor::BuildParams( return params; } -/* - * Bitwise-AND of a vector of bits to get the combined validity. - */ +// Bitwise-AND of a vector of bits to get the combined validity. llvm::Value *LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector &validities) { llvm::IRBuilder<> &builder = ir_builder(); LLVMTypes *types = generator_->types_.get(); diff --git a/cpp/src/gandiva/codegen/llvm_generator.h b/cpp/src/gandiva/codegen/llvm_generator.h index 3ba6c9a5adf5d..b5d1292027866 100644 --- a/cpp/src/gandiva/codegen/llvm_generator.h +++ b/cpp/src/gandiva/codegen/llvm_generator.h @@ -34,6 +34,8 @@ namespace gandiva { +class FunctionHolder; + /// Builds an LLVM module and generates code for the specified set of expressions. class LLVMGenerator { public: @@ -98,7 +100,8 @@ class LLVMGenerator { LValuePtr BuildValueAndValidity(const ValueValidityPair &pair); // Generate code to build the params. - std::vector BuildParams(const ValueValidityPairVector &args, + std::vector BuildParams(FunctionHolder *holder, + const ValueValidityPairVector &args, bool with_validity); // Switch to the entry_block and get reference of the validity/value/offsets buffer @@ -154,10 +157,15 @@ class LLVMGenerator { void ClearPackedBitValueIfFalse(llvm::Value *bitmap, llvm::Value *position, llvm::Value *value); + /// For non-IR functions, add prototype to the module on first encounter. + void CheckAndAddPrototype(const std::string &full_name, llvm::Type *ret_type, + const std::vector &args); + /// Generate code to make a function call (to a pre-compiled IR function) which takes /// 'args' and has a return type 'ret_type'. llvm::Value *AddFunctionCall(const std::string &full_name, llvm::Type *ret_type, - const std::vector &args); + const std::vector &args, + bool has_holder = false); /// Compute the result bitmap for the expression. /// diff --git a/cpp/src/gandiva/codegen/llvm_generator_test.cc b/cpp/src/gandiva/codegen/llvm_generator_test.cc index b02c4ba156cab..e397b250b59c5 100644 --- a/cpp/src/gandiva/codegen/llvm_generator_test.cc +++ b/cpp/src/gandiva/codegen/llvm_generator_test.cc @@ -42,6 +42,11 @@ TEST_F(TestLLVMGenerator, VerifyPCFunctions) { llvm::Module *module = generator->module(); for (auto &iter : registry_) { + if (iter.needs_holder()) { + // TODO : need a way to verify these too. + continue; + } + llvm::Function *fn = module->getFunction(iter.pc_name()); EXPECT_NE(fn, nullptr) << "function " << iter.pc_name() << " missing in precompiled module\n"; @@ -76,7 +81,8 @@ TEST_F(TestLLVMGenerator, TestAdd) { generator->function_registry_.LookupSignature(signature); std::vector pairs{pair0, pair1}; - auto func_dex = std::make_shared(func_desc, native_func, pairs); + auto func_dex = std::make_shared(func_desc, native_func, + FunctionHolderPtr(nullptr), pairs); auto field_sum = std::make_shared("out", arrow::int32()); auto desc_sum = annotator.CheckAndAddInputFieldDescriptor(field_sum); @@ -135,8 +141,8 @@ TEST_F(TestLLVMGenerator, TestNullInternal) { int local_bitmap_idx = annotator.AddLocalBitMap(); std::vector pairs{pair0}; - auto func_dex = std::make_shared(func_desc, native_func, pairs, - local_bitmap_idx); + auto func_dex = std::make_shared( + func_desc, native_func, FunctionHolderPtr(nullptr), pairs, local_bitmap_idx); auto field_result = std::make_shared("out", arrow::int32()); auto desc_result = annotator.CheckAndAddInputFieldDescriptor(field_result); diff --git a/cpp/src/gandiva/codegen/native_function.h b/cpp/src/gandiva/codegen/native_function.h index 6db44d610db16..8e846dad619a7 100644 --- a/cpp/src/gandiva/codegen/native_function.h +++ b/cpp/src/gandiva/codegen/native_function.h @@ -40,14 +40,16 @@ class NativeFunction { std::string pc_name() const { return pc_name_; } ResultNullableType result_nullable_type() const { return result_nullable_type_; } bool param_null_safe() const { return param_null_safe_; } + bool needs_holder() const { return needs_holder_; } private: NativeFunction(const std::string &base_name, const DataTypeVector ¶m_types, DataTypePtr ret_type, bool param_null_safe, const ResultNullableType &result_nullable_type, - const std::string &pc_name) + const std::string &pc_name, bool needs_holder = false) : signature_(base_name, param_types, ret_type), param_null_safe_(param_null_safe), + needs_holder_(needs_holder), result_nullable_type_(result_nullable_type), pc_name_(pc_name) {} @@ -55,6 +57,7 @@ class NativeFunction { /// attributes bool param_null_safe_; + bool needs_holder_; ResultNullableType result_nullable_type_; /// pre-compiled function name. diff --git a/cpp/src/gandiva/codegen/regex_util.cc b/cpp/src/gandiva/codegen/regex_util.cc new file mode 100644 index 0000000000000..8253cf072d4ad --- /dev/null +++ b/cpp/src/gandiva/codegen/regex_util.cc @@ -0,0 +1,64 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/regex_util.h" + +namespace gandiva { + +/// Characters that are considered special by posix regex. These needs to be +/// escaped with '\\'. +const std::set RegexUtil::posix_regex_specials_ = { + '[', ']', '(', ')', '|', '^', '-', '+', '*', '?', '{', '}', '$', '\\'}; + +Status RegexUtil::SqlLikePatternToPosix(const std::string &sql_pattern, char escape_char, + std::string &posix_pattern) { + posix_pattern.clear(); + for (size_t idx = 0; idx < sql_pattern.size(); ++idx) { + auto cur = sql_pattern.at(idx); + + // Escape any char that is special for posix regex + if (posix_regex_specials_.find(cur) != posix_regex_specials_.end()) { + posix_pattern += "\\"; + } + + if (cur == escape_char) { + // escape char must be followed by '_', '%' or the escape char itself. + ++idx; + if (idx == sql_pattern.size()) { + std::stringstream msg; + msg << "unexpected escape char at the end of pattern " << sql_pattern; + return Status::Invalid(msg.str()); + } + + cur = sql_pattern.at(idx); + if (cur == '_' || cur == '%' || cur == escape_char) { + posix_pattern += cur; + } else { + std::stringstream msg; + msg << "invalid escape sequence in pattern " << sql_pattern << " at offset " + << idx; + return Status::Invalid(msg.str()); + } + } else if (cur == '_') { + posix_pattern += '.'; + } else if (cur == '%') { + posix_pattern += ".*"; + } else { + posix_pattern += cur; + } + } + return Status::OK(); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/codegen/regex_util.h b/cpp/src/gandiva/codegen/regex_util.h new file mode 100644 index 0000000000000..f3b405f3e21ab --- /dev/null +++ b/cpp/src/gandiva/codegen/regex_util.h @@ -0,0 +1,42 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_REGEX_UTIL_H +#define GANDIVA_REGEX_UTIL_H + +#include +#include "gandiva/status.h" + +namespace gandiva { + +/// \brief Utility class for converting sql patterns to posix patterns. +class RegexUtil { + public: + // Convert an sql pattern to an std::regex pattern + static Status SqlLikePatternToPosix(const std::string &like_pattern, char escape_char, + std::string &posix_pattern); + + static Status SqlLikePatternToPosix(const std::string &like_pattern, + std::string &posix_pattern) { + return SqlLikePatternToPosix(like_pattern, 0 /*escape_char*/, posix_pattern); + } + + private: + // set of characters that std::regex treats as special. + static const std::set posix_regex_specials_; +}; + +} // namespace gandiva + +#endif // GANDIVA_REGEX_UTIL_H diff --git a/cpp/src/gandiva/codegen/selection_vector.h b/cpp/src/gandiva/codegen/selection_vector.h index 65114b0856fdd..421d9ea3534ef 100644 --- a/cpp/src/gandiva/codegen/selection_vector.h +++ b/cpp/src/gandiva/codegen/selection_vector.h @@ -25,7 +25,7 @@ namespace gandiva { /// backed by an arrow-array. class SelectionVector { public: - ~SelectionVector() = default; + virtual ~SelectionVector() = default; /// Get the value at a given index. virtual uint GetIndex(int index) const = 0; diff --git a/cpp/src/gandiva/codegen/tree_expr_test.cc b/cpp/src/gandiva/codegen/tree_expr_test.cc index e9055281d27ca..295ac0186c2d7 100644 --- a/cpp/src/gandiva/codegen/tree_expr_test.cc +++ b/cpp/src/gandiva/codegen/tree_expr_test.cc @@ -55,7 +55,10 @@ TEST_F(TestExprTree, TestField) { EXPECT_EQ(n1->return_type(), boolean()); ExprDecomposer decomposer(registry_, annotator); - auto pair = decomposer.Decompose(*n1); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n1, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + auto value = pair->value_expr(); auto value_dex = std::dynamic_pointer_cast(value); EXPECT_EQ(value_dex->FieldType(), boolean()); @@ -83,7 +86,10 @@ TEST_F(TestExprTree, TestBinary) { EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32())); ExprDecomposer decomposer(registry_, annotator); - auto pair = decomposer.Decompose(*n); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + auto value = pair->value_expr(); auto null_if_null = std::dynamic_pointer_cast(value); @@ -106,7 +112,10 @@ TEST_F(TestExprTree, TestUnary) { EXPECT_TRUE(sign == FunctionSignature("isnumeric", {int32()}, boolean())); ExprDecomposer decomposer(registry_, annotator); - auto pair = decomposer.Decompose(*n); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + auto value = pair->value_expr(); auto never_null = std::dynamic_pointer_cast(value); @@ -132,7 +141,10 @@ TEST_F(TestExprTree, TestExpression) { EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32())); ExprDecomposer decomposer(registry_, annotator); - auto pair = decomposer.Decompose(*root_node); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*root_node, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + auto value = pair->value_expr(); auto null_if_null = std::dynamic_pointer_cast(value); diff --git a/cpp/src/gandiva/integ/utf8_test.cc b/cpp/src/gandiva/integ/utf8_test.cc index 1a2fe56637d35..e373f0285cf3d 100644 --- a/cpp/src/gandiva/integ/utf8_test.cc +++ b/cpp/src/gandiva/integ/utf8_test.cc @@ -167,4 +167,45 @@ TEST_F(TestUtf8, TestNullLiteral) { EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); } +TEST_F(TestUtf8, TestLike) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // like(literal(s), a) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_s = TreeExprBuilder::MakeStringLiteral("%spark%"); + auto is_like = TreeExprBuilder::MakeFunction("like", {node_a, literal_s}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_like, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayUtf8({"park", "sparkle", "bright spark and fire", "spark"}, + {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayBool({false, true, true, true}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + } // namespace gandiva