Skip to content
This repository has been archived by the owner on Dec 14, 2023. It is now read-only.

Commit

Permalink
Non-branching function calls
Browse files Browse the repository at this point in the history
When a global function pointer's value is taken,
it is now substituted by a wrapper function that
takes a context argument which is simply dropped.
This allows us to treat all stored functions as
having the additional context argument, without
forcing all functions to have that argument.

We may want to force all functions to later have
the argument (i.e. no wrappers), but for now this
is the simpler/more expedient approach.
  • Loading branch information
axw committed Dec 27, 2013
1 parent dbee235 commit bea1627
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 80 deletions.
88 changes: 18 additions & 70 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,79 +27,27 @@ func (c *compiler) createCall(fn *LLVMValue, argValues []*LLVMValue) *LLVMValue
resultType = results
}

var fnptr llvm.Value
// Builtins are represented as a raw function pointer.
fnval := fn.LLVMValue()
if fnval.Type().TypeKind() == llvm.PointerTypeKind {
fnptr = fnval
} else {
fnptr = c.builder.CreateExtractValue(fnval, 0, "")
context := c.builder.CreateExtractValue(fnval, 1, "")
llfntyp := fnptr.Type().ElementType()
paramTypes := llfntyp.ParamTypes()

// If the context is not a constant null, and we're not
// dealing with a method (where we don't care about the value
// of the receiver), then we must conditionally call the
// function with the additional receiver/closure.
if !context.IsNull() && fntyp.Recv() == nil {
// Store the blocks for referencing in the Phi below;
// note that we update the block after each createCall,
// since createCall may create new blocks and we want
// the predecessors to the Phi.
var nullctxblock llvm.BasicBlock
var nonnullctxblock llvm.BasicBlock
var endblock llvm.BasicBlock
var nullctxresult llvm.Value

// len(paramTypes) == len(args) iff function is not a method.
if !context.IsConstant() && len(paramTypes) == len(args) {
currblock := c.builder.GetInsertBlock()
endblock = llvm.AddBasicBlock(currblock.Parent(), "")
endblock.MoveAfter(currblock)
nonnullctxblock = llvm.InsertBasicBlock(endblock, "")
nullctxblock = llvm.InsertBasicBlock(nonnullctxblock, "")
nullctx := c.builder.CreateIsNull(context, "")
c.builder.CreateCondBr(nullctx, nullctxblock, nonnullctxblock)

// null context case.
c.builder.SetInsertPointAtEnd(nullctxblock)
nullctxresult = c.builder.CreateCall(fnptr, args, "")
nullctxblock = c.builder.GetInsertBlock()
c.builder.CreateBr(endblock)
c.builder.SetInsertPointAtEnd(nonnullctxblock)
}

// non-null context case.
var result llvm.Value
args := append([]llvm.Value{context}, args...)
if len(paramTypes) < len(args) {
returnType := llfntyp.ReturnType()
ctxType := context.Type()
paramTypes := append([]llvm.Type{ctxType}, paramTypes...)
vararg := llfntyp.IsFunctionVarArg()
llfntyp := llvm.FunctionType(returnType, paramTypes, vararg)
fnptrtyp := llvm.PointerType(llfntyp, 0)
fnptr = c.builder.CreateBitCast(fnptr, fnptrtyp, "")
}
result = c.builder.CreateCall(fnptr, args, "")
return c.NewValue(c.builder.CreateCall(fnval, args, ""), resultType)
}

// If the return type is not void, create a
// PHI node to select which value to return.
if !nullctxresult.IsNil() {
nonnullctxblock = c.builder.GetInsertBlock()
c.builder.CreateBr(endblock)
c.builder.SetInsertPointAtEnd(endblock)
if result.Type().TypeKind() != llvm.VoidTypeKind {
phiresult := c.builder.CreatePHI(result.Type(), "")
values := []llvm.Value{nullctxresult, result}
blocks := []llvm.BasicBlock{nullctxblock, nonnullctxblock}
phiresult.AddIncoming(values, blocks)
result = phiresult
}
}
return c.NewValue(result, resultType)
}
// If context is constant null, then the function does
// not need a context argument.
fnptr := c.builder.CreateExtractValue(fnval, 0, "")
context := c.builder.CreateExtractValue(fnval, 1, "")
llfntyp := fnptr.Type().ElementType()
paramTypes := llfntyp.ParamTypes()
if context.IsNull() {
return c.NewValue(c.builder.CreateCall(fnptr, args, ""), resultType)
}
result := c.builder.CreateCall(fnptr, args, "")
llfntyp = llvm.FunctionType(
llfntyp.ReturnType(),
append([]llvm.Type{context.Type()}, paramTypes...),
llfntyp.IsFunctionVarArg(),
)
fnptr = c.builder.CreateBitCast(fnptr, llvm.PointerType(llfntyp, 0), "")
result := c.builder.CreateCall(fnptr, append([]llvm.Value{context}, args...), "")
return c.NewValue(result, resultType)
}
3 changes: 2 additions & 1 deletion convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ func (v *LLVMValue) convertI2E() *LLVMValue {
c := v.compiler
f := c.runtime.convertI2E.LLVMValue()
args := []llvm.Value{c.coerce(v.LLVMValue(), c.runtime.iface.llvm)}
return c.NewValue(c.builder.CreateCall(f, args, ""), types.NewInterface(nil, nil))
typ := types.NewInterface(nil, nil)
return c.NewValue(c.coerce(c.builder.CreateCall(f, args, ""), c.llvmtypes.ToLLVM(typ)), typ)
}

// convertE2I converts an empty interface value to a non-empty interface.
Expand Down
87 changes: 78 additions & 9 deletions ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@ type unit struct {
*compiler
pkg *ssa.Package
globals map[ssa.Value]*LLVMValue

// funcvals is a map of *ssa.Function to LLVM functions that
// may be stored. Non-receiver functions in this map will have
// an additional context parameter, to enable non-branching
// calls with a pair-of-pointer function representation,
// without forcing the additional parameter on all functions.
funcvals map[*ssa.Function]*LLVMValue
}

func newUnit(c *compiler, pkg *ssa.Package) *unit {
u := &unit{
compiler: c,
pkg: pkg,
globals: make(map[ssa.Value]*LLVMValue),
funcvals: make(map[*ssa.Function]*LLVMValue),
}
return u
}
Expand Down Expand Up @@ -266,12 +274,23 @@ func (fr *frame) value(v ssa.Value) *LLVMValue {
case nil:
return nil
case *ssa.Function:
result, ok := fr.funcvals[v]
if ok {
return result
}
// fr.globals[v] has the function in raw pointer form;
// we must convert it to <f,ctx> form.
// we must convert it to <f,ctx> form. If the function
// does not have a receiver, then create a wrapper
// function that has an additional "context" parameter.
f := fr.resolveFunction(v)
if v.Signature.Recv() == nil && len(v.FreeVars) == 0 {
f = contextFunction(fr.compiler, f)
}
pair := llvm.ConstNull(fr.llvmtypes.ToLLVM(f.Type()))
pair = llvm.ConstInsertValue(pair, f.LLVMValue(), []uint32{0})
return fr.NewValue(pair, f.Type())
result = fr.NewValue(pair, f.Type())
fr.funcvals[v] = result
return result
case *ssa.Const:
return fr.NewConstValue(v.Value, v.Type())
case *ssa.Global:
Expand Down Expand Up @@ -682,9 +701,17 @@ func (fr *frame) prepareCall(instr ssa.CallInstruction) (fn *LLVMValue, args []*
return fn, args, nil
}

switch call.Value.(type) {
switch v := call.Value.(type) {
case *ssa.Builtin:
// handled below
case *ssa.Function:
// Function handled specially; value() will convert
// a function to one with a context argument.
fn = fr.resolveFunction(v)
pair := llvm.ConstNull(fr.llvmtypes.ToLLVM(fn.Type()))
pair = llvm.ConstInsertValue(pair, fn.LLVMValue(), []uint32{0})
fn = fr.NewValue(pair, fn.Type())
return fn, args, nil
default:
fn = fr.value(call.Value)
return fn, args, nil
Expand All @@ -705,20 +732,19 @@ func (fr *frame) prepareCall(instr ssa.CallInstruction) (fn *LLVMValue, args []*
params[i] = types.NewParam(arg.Pos(), nil, arg.Name(), args[i].Type())
}
sig := types.NewSignature(nil, nil, types.NewTuple(params...), nil, false)
fntyp := fr.llvmtypes.ToLLVM(sig).StructElementTypes()[0].ElementType()
llvmfn := llvm.AddFunction(fr.module.Module, "", fntyp)
llfntyp := fr.llvmtypes.ToLLVM(sig)
llfnptr := llvm.AddFunction(fr.module.Module, "", llfntyp.StructElementTypes()[0].ElementType())
currBlock := fr.builder.GetInsertBlock()
entry := llvm.AddBasicBlock(llvmfn, "entry")
entry := llvm.AddBasicBlock(llfnptr, "entry")
fr.builder.SetInsertPointAtEnd(entry)
internalArgs := make([]Value, len(args))
for i, arg := range args {
internalArgs[i] = fr.NewValue(llvmfn.Param(i), arg.Type())
internalArgs[i] = fr.NewValue(llfnptr.Param(i), arg.Type())
}
fr.printValues(builtin.Name() == "println", internalArgs...)
fr.builder.CreateRetVoid()
fr.builder.SetInsertPointAtEnd(currBlock)
fn = fr.NewValue(llvmfn, sig)
return fn, args, nil
return fr.NewValue(llfnptr, sig), args, nil

case "panic":
panic("TODO: panic")
Expand Down Expand Up @@ -777,3 +803,46 @@ func hasDefer(f *ssa.Function) bool {
}
return false
}

// contextFunction creates a wrapper function that
// has the same signature as the specified function,
// but has an additional first parameter that accepts
// and ignores the function context value.
//
// contextFunction must be called with a global function
// pointer.
func contextFunction(c *compiler, f *LLVMValue) *LLVMValue {
defer c.builder.SetInsertPointAtEnd(c.builder.GetInsertBlock())
resultType := c.llvmtypes.ToLLVM(f.Type())
fnptr := f.LLVMValue()
contextType := resultType.StructElementTypes()[1]
llfntyp := fnptr.Type().ElementType()
llfntyp = llvm.FunctionType(
llfntyp.ReturnType(),
append([]llvm.Type{contextType}, llfntyp.ParamTypes()...),
llfntyp.IsFunctionVarArg(),
)
wrapper := llvm.AddFunction(c.module.Module, fnptr.Name()+".ctx", llfntyp)
entry := llvm.AddBasicBlock(wrapper, "entry")
c.builder.SetInsertPointAtEnd(entry)
args := make([]llvm.Value, len(llfntyp.ParamTypes())-1)
for i := range args {
args[i] = wrapper.Param(i + 1)
}
result := c.builder.CreateCall(fnptr, args, "")
switch nresults := f.Type().(*types.Signature).Results().Len(); nresults {
case 0:
c.builder.CreateRetVoid()
case 1:
c.builder.CreateRet(result)
default:
results := make([]llvm.Value, nresults)
for i := range results {
results[i] = c.builder.CreateExtractValue(result, i, "")
}
c.builder.CreateAggregateRet(results)
}
fnptr = c.builder.CreateBitCast(wrapper, fnptr.Type(), "")
fnval := c.builder.CreateInsertValue(llvm.ConstNull(resultType), fnptr, 0, "")
return c.NewValue(fnval, f.Type())
}

0 comments on commit bea1627

Please sign in to comment.