Skip to content

Commit

Permalink
GDV-20: [Java] support varlen types in gandiva (apache#61)
Browse files Browse the repository at this point in the history
- added java bindings for varlen types/literals
- minor cleanups in llvm generator and engine
  (reported by clang-tidy)
  • Loading branch information
pravindra authored Jul 6, 2018
1 parent 7cbfe80 commit a766748
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 102 deletions.
4 changes: 2 additions & 2 deletions include/gandiva/tree_expr_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ class TreeExprBuilder {

/// \brief create a node with a function.
/// returns null if return_type is null
static NodePtr MakeFunction(const std::string &name, const NodeVector &children,
static NodePtr MakeFunction(const std::string &name, const NodeVector &params,
DataTypePtr return_type);

/// \brief create a node with an if-else expression.
/// returns null if any of the inputs is null.
static NodePtr MakeIf(NodePtr condition, NodePtr this_node, NodePtr else_node,
static NodePtr MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node,
DataTypePtr result_type);

/// \brief create a node with a boolean AND expression.
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/annotator.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Annotator {
int AddLocalBitMap() { return local_bitmap_count_++; }

/// Prepare an eval batch for the incoming record batch.
EvalBatchPtr PrepareEvalBatch(const arrow::RecordBatch &batch,
EvalBatchPtr PrepareEvalBatch(const arrow::RecordBatch &record_batch,
const ArrayDataVector &out_vector);

private:
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/bitmap_accumulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class BitMapAccumulator : public DexDefaultVisitor {
}

/// Compute the dst_bmap based on the contents and type of the accumulated bitmap dex.
void ComputeResult(uint8_t *dst_bmap);
void ComputeResult(uint8_t *dst_bitmap);

/// Compute the intersection of the accumulated bitmaps and save the result in
/// dst_bmap.
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::once_flag init_once_flag;

// One-time initializations.
void Engine::InitOnce() {
assert(!init_once_done_);
DCHECK_EQ(init_once_done_, false);

llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
Expand Down Expand Up @@ -110,7 +110,7 @@ Status Engine::LoadPreCompiledIRFiles(const std::string &byte_code_file_path) {
std::unique_ptr<llvm::Module> ir_module = move(module_or_error.get());

/// Verify the IR module
if (llvm::verifyModule(*ir_module.get(), &llvm::errs())) {
if (llvm::verifyModule(*ir_module, &llvm::errs())) {
return Status::CodeGenError("verify of IR Module failed");
}

Expand Down
2 changes: 1 addition & 1 deletion src/codegen/expr_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class ExprDecomposer : public NodeVisitor {
void PopThenEntry(const IfNode &node);

// push 'else entry' into stack.
void PushElseEntry(const IfNode &node, int local_bmap_idx);
void PushElseEntry(const IfNode &node, int local_bitmap_idx);

// pop 'else entry' from stack. returns 'true' if this is a terminal else condition
// i.e no nested if condition below this node.
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/expr_validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Status ExprValidator::Validate(const ExpressionPtr &expr) {
}

Status ExprValidator::Visit(const FieldNode &node) {
auto llvm_type = types_->IRType(node.return_type()->id());
auto llvm_type = types_.IRType(node.return_type()->id());
if (llvm_type == nullptr) {
std::stringstream ss;
ss << "Field " << node.field()->name() << " has unsupported data type "
Expand Down Expand Up @@ -117,7 +117,7 @@ Status ExprValidator::Visit(const IfNode &node) {
}

Status ExprValidator::Visit(const LiteralNode &node) {
auto llvm_type = types_->IRType(node.return_type()->id());
auto llvm_type = types_.IRType(node.return_type()->id());
if (llvm_type == nullptr) {
std::stringstream ss;
ss << "Value " << node.holder() << " has unsupported data type "
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/expr_validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class FunctionRegistry;
/// data types, signatures and return types
class ExprValidator : public NodeVisitor {
public:
explicit ExprValidator(LLVMTypes *types, SchemaPtr schema)
explicit ExprValidator(LLVMTypes &types, SchemaPtr schema)
: types_(types), schema_(schema) {
for (auto &field : schema_->fields()) {
field_map_[field->name()] = field;
Expand All @@ -59,7 +59,7 @@ class ExprValidator : public NodeVisitor {

FunctionRegistry registry_;

LLVMTypes *types_;
LLVMTypes &types_;

SchemaPtr schema_;

Expand Down
58 changes: 28 additions & 30 deletions src/codegen/llvm_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,11 @@ Status LLVMGenerator::Make(std::shared_ptr<Configuration> config,
std::unique_ptr<LLVMGenerator> llvmgen_obj(new LLVMGenerator());
Status status = Engine::Make(config, &(llvmgen_obj->engine_));
GANDIVA_RETURN_NOT_OK(status);
llvmgen_obj->types_ = new LLVMTypes(*(llvmgen_obj->engine_)->context());
llvmgen_obj->types_.reset(new LLVMTypes(*(llvmgen_obj->engine_)->context()));
*llvm_generator = std::move(llvmgen_obj);
return Status::OK();
}

LLVMGenerator::~LLVMGenerator() {
for (auto it = compiled_exprs_.begin(); it != compiled_exprs_.end(); ++it) {
delete *it;
}
delete types_;
}

Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr output) {
int idx = compiled_exprs_.size();

Expand All @@ -68,8 +61,9 @@ Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr out
CodeGenExprValue(value_validity->value_expr(), output, idx, &ir_function);
GANDIVA_RETURN_NOT_OK(status);

CompiledExpr *compiled_expr = new CompiledExpr(value_validity, output, ir_function);
compiled_exprs_.push_back(compiled_expr);
std::unique_ptr<CompiledExpr> compiled_expr(
new CompiledExpr(value_validity, output, ir_function));
compiled_exprs_.push_back(std::move(compiled_expr));
return Status::OK();
}

Expand All @@ -85,7 +79,7 @@ Status LLVMGenerator::Build(const ExpressionVector &exprs) {
GANDIVA_RETURN_NOT_OK(result);

// setup the jit functions for each expression.
for (auto compiled_expr : compiled_exprs_) {
for (auto &compiled_expr : compiled_exprs_) {
llvm::Function *ir_func = compiled_expr->ir_function();
EvalFunc fn = reinterpret_cast<EvalFunc>(engine_->CompiledFunction(ir_func));
compiled_expr->set_jit_function(fn);
Expand All @@ -102,7 +96,7 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch &record_batch,
DCHECK_GT(eval_batch->num_buffers(), 0);

// generate bitmap vectors, by doing an intersection.
for (auto compiled_expr : compiled_exprs_) {
for (auto &compiled_expr : compiled_exprs_) {
// generate data/offset vectors.
EvalFunc jit_function = compiled_expr->jit_function();
jit_function(eval_batch->buffers(), eval_batch->local_bitmaps(),
Expand Down Expand Up @@ -136,12 +130,14 @@ llvm::Value *LLVMGenerator::GetDataReference(llvm::Value *arg_addrs, int idx,
const std::string &name = field->name();
llvm::Value *load = LoadVectorAtIndex(arg_addrs, idx, name);
llvm::Type *base_type = types_->DataVecType(field->type());
llvm::Value *ret;
if (base_type->isPointerTy()) {
return ir_builder().CreateIntToPtr(load, base_type, name + "_darray");
ret = ir_builder().CreateIntToPtr(load, base_type, name + "_darray");
} else {
llvm::Type *pointer_type = types_->ptr_type(base_type);
return ir_builder().CreateIntToPtr(load, pointer_type, name + "_darray");
ret = ir_builder().CreateIntToPtr(load, pointer_type, name + "_darray");
}
return ret;
}

/// Get reference to offsets array at specified index in the args list.
Expand Down Expand Up @@ -355,20 +351,22 @@ llvm::Value *LLVMGenerator::AddFunctionCall(const std::string &full_name,
llvm::Function *fn = module()->getFunction(full_name);
DCHECK(fn != NULL);

if (enable_ir_traces_ && full_name.compare("printf") && full_name.compare("printff")) {
if (enable_ir_traces_ && full_name.compare("printf") != 0 &&
full_name.compare("printff") != 0) {
// Trace for debugging
ADD_TRACE("invoke native fn " + full_name);
}

// build a call to the llvm function.
llvm::Value *value;
if (ret_type->isVoidTy()) {
// void functions can't have a name for the call.
return ir_builder().CreateCall(fn, args);
value = ir_builder().CreateCall(fn, args);
} else {
llvm::Value *value = ir_builder().CreateCall(fn, args, full_name);
value = ir_builder().CreateCall(fn, args, full_name);
DCHECK(value->getType() == ret_type);
return value;
}
return value;
}

#define ADD_VISITOR_TRACE(...) \
Expand Down Expand Up @@ -466,7 +464,7 @@ void LLVMGenerator::Visitor::Visit(const FalseDex &dex) {
}

void LLVMGenerator::Visitor::Visit(const LiteralDex &dex) {
LLVMTypes *types = generator_->types_;
LLVMTypes *types = generator_->types_.get();
llvm::Value *value = nullptr;
llvm::Value *len = nullptr;

Expand Down Expand Up @@ -535,7 +533,7 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex &dex) {
void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex &dex) {
ADD_VISITOR_TRACE("visit NonNullableFunc base function " +
dex.func_descriptor()->name());
LLVMTypes *types = generator_->types_;
LLVMTypes *types = generator_->types_.get();

// build the function params (ignore validity).
auto params = BuildParams(dex.args(), false);
Expand All @@ -549,7 +547,7 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex &dex) {

void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex &dex) {
ADD_VISITOR_TRACE("visit NullableNever base function " + dex.func_descriptor()->name());
LLVMTypes *types = generator_->types_;
LLVMTypes *types = generator_->types_.get();

// build function params along with validity.
auto params = BuildParams(dex.args(), true);
Expand All @@ -565,7 +563,7 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex &dex) {
ADD_VISITOR_TRACE("visit NullableInternal base function " +
dex.func_descriptor()->name());
llvm::IRBuilder<> &builder = ir_builder();
LLVMTypes *types = generator_->types_;
LLVMTypes *types = generator_->types_.get();

// build function params along with validity.
auto params = BuildParams(dex.args(), true);
Expand Down Expand Up @@ -593,7 +591,7 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex &dex) {
void LLVMGenerator::Visitor::Visit(const IfDex &dex) {
ADD_VISITOR_TRACE("visit IfExpression");
llvm::IRBuilder<> &builder = ir_builder();
LLVMTypes *types = generator_->types_;
LLVMTypes *types = generator_->types_.get();

// Evaluate condition.
LValuePtr if_condition = BuildValueAndValidity(dex.condition_vv());
Expand Down Expand Up @@ -677,7 +675,7 @@ void LLVMGenerator::Visitor::Visit(const IfDex &dex) {
void LLVMGenerator::Visitor::Visit(const BooleanAndDex &dex) {
ADD_VISITOR_TRACE("visit BooleanAndExpression");
llvm::IRBuilder<> &builder = ir_builder();
LLVMTypes *types = generator_->types_;
LLVMTypes *types = generator_->types_.get();
llvm::LLVMContext &context = generator_->context();

// Create blocks for short-circuit.
Expand Down Expand Up @@ -744,7 +742,7 @@ void LLVMGenerator::Visitor::Visit(const BooleanAndDex &dex) {
void LLVMGenerator::Visitor::Visit(const BooleanOrDex &dex) {
ADD_VISITOR_TRACE("visit BooleanOrExpression");
llvm::IRBuilder<> &builder = ir_builder();
LLVMTypes *types = generator_->types_;
LLVMTypes *types = generator_->types_.get();
llvm::LLVMContext &context = generator_->context();

// Create blocks for short-circuit.
Expand Down Expand Up @@ -842,7 +840,7 @@ std::vector<llvm::Value *> LLVMGenerator::Visitor::BuildParams(
*/
llvm::Value *LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector &validities) {
llvm::IRBuilder<> &builder = ir_builder();
LLVMTypes *types = generator_->types_;
LLVMTypes *types = generator_->types_.get();

llvm::Value *isValid = types->true_constant();
for (auto &dex : validities) {
Expand Down Expand Up @@ -913,7 +911,7 @@ std::string LLVMGenerator::ReplaceFormatInTrace(const std::string &in_msg,
std::string msg = in_msg;
std::size_t pos = msg.find("%T");
if (pos == std::string::npos) {
assert(0);
DCHECK(0);
return msg;
}

Expand All @@ -934,7 +932,7 @@ std::string LLVMGenerator::ReplaceFormatInTrace(const std::string &in_msg,
fmt = "%lf";
*print_fn = "print_double";
} else {
assert(0);
DCHECK(0);
}
msg.replace(pos, 2, fmt);
return msg;
Expand All @@ -947,7 +945,7 @@ void LLVMGenerator::AddTrace(const std::string &msg, llvm::Value *value) {

std::string dmsg = "IR_TRACE:: " + msg + "\n";
std::string print_fn_name = "printf";
if (value) {
if (value != nullptr) {
dmsg = ReplaceFormatInTrace(dmsg, value, &print_fn_name);
}
trace_strings_.push_back(dmsg);
Expand All @@ -960,7 +958,7 @@ void LLVMGenerator::AddTrace(const std::string &msg, llvm::Value *value) {

std::vector<llvm::Value *> args;
args.push_back(str_ptr_cast);
if (value) {
if (value != nullptr) {
args.push_back(value);
}
AddFunctionCall(print_fn_name, types_->i32_type(), args);
Expand Down
16 changes: 7 additions & 9 deletions src/codegen/llvm_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ namespace gandiva {
/// Builds an LLVM module and generates code for the specified set of expressions.
class LLVMGenerator {
public:
~LLVMGenerator();

/// \brief Factory method to initialize the generator.
static Status Make(std::shared_ptr<Configuration> config,
std::unique_ptr<LLVMGenerator> *llvm_generator);
Expand All @@ -51,7 +49,7 @@ class LLVMGenerator {
Status Execute(const arrow::RecordBatch &record_batch,
const ArrayDataVector &output_vector);

LLVMTypes *types() { return types_; }
LLVMTypes &types() { return *types_; }
llvm::Module *module() { return engine_->module(); }

private:
Expand Down Expand Up @@ -143,17 +141,17 @@ class LLVMGenerator {
llvm::Function **fn);

/// Generate code to load the local bitmap specified index and cast it as bitmap.
llvm::Value *GetLocalBitMapReference(llvm::Value *arg_local_bitmaps, int idx);
llvm::Value *GetLocalBitMapReference(llvm::Value *arg_bitmaps, int idx);

/// Generate code to get the bit value at 'position' in the bitmap.
llvm::Value *GetPackedBitValue(llvm::Value *bitMap, llvm::Value *position);
llvm::Value *GetPackedBitValue(llvm::Value *bitmap, llvm::Value *position);

/// Generate code to set the bit value at 'position' in the bitmap to 'value'.
void SetPackedBitValue(llvm::Value *bitMap, llvm::Value *position, llvm::Value *value);
void SetPackedBitValue(llvm::Value *bitmap, llvm::Value *position, llvm::Value *value);

/// Generate code to clear the bit value at 'position' in the bitmap if 'value'
/// is false.
void ClearPackedBitValueIfFalse(llvm::Value *bitMap, llvm::Value *position,
void ClearPackedBitValueIfFalse(llvm::Value *bitmap, llvm::Value *position,
llvm::Value *value);

/// Generate code to make a function call (to a pre-compiled IR function) which takes
Expand All @@ -178,8 +176,8 @@ class LLVMGenerator {
void AddTrace(const std::string &msg, llvm::Value *value = nullptr);

std::unique_ptr<Engine> engine_;
std::vector<CompiledExpr *> compiled_exprs_;
LLVMTypes *types_;
std::vector<std::unique_ptr<CompiledExpr>> compiled_exprs_;
std::unique_ptr<LLVMTypes> types_;
FunctionRegistry function_registry_;
Annotator annotator_;

Expand Down
8 changes: 4 additions & 4 deletions src/codegen/tree_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ NodePtr TreeExprBuilder::MakeBinaryLiteral(const std::string &value) {
}

NodePtr TreeExprBuilder::MakeNull(DataTypePtr data_type) {
static const std::string empty = "";
static const std::string empty;

if (data_type == nullptr) {
return nullptr;
Expand Down Expand Up @@ -88,11 +88,11 @@ NodePtr TreeExprBuilder::MakeField(FieldPtr field) {
}

NodePtr TreeExprBuilder::MakeFunction(const std::string &name, const NodeVector &params,
DataTypePtr result) {
if (result == nullptr) {
DataTypePtr result_type) {
if (result_type == nullptr) {
return nullptr;
}
return FunctionNode::MakeFunction(name, params, result);
return FunctionNode::MakeFunction(name, params, result_type);
}

NodePtr TreeExprBuilder::MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node,
Expand Down
Loading

0 comments on commit a766748

Please sign in to comment.