Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify foreigncall handling #554

Merged
merged 1 commit into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CodeTracking = "0.5.9, 1"
julia = "1.6"

[extras]
CassetteOverlay = "d78b62d4-37fa-4a6f-acd8-2f19986eb9ee"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
Expand All @@ -29,4 +30,4 @@ Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DataFrames", "Dates", "DeepDiffs", "Distributed", "FunctionWrappers", "HTTP", "LinearAlgebra", "Logging", "Mmap", "PyCall", "SHA", "SparseArrays", "Tensors", "Test"]
test = ["CassetteOverlay", "DataFrames", "Dates", "DeepDiffs", "Distributed", "FunctionWrappers", "HTTP", "LinearAlgebra", "Logging", "Mmap", "PyCall", "SHA", "SparseArrays", "Tensors", "Test"]
165 changes: 59 additions & 106 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ end

function lookup_getproperties(a::Expr)
if a.head === :call && length(a.args) == 3 &&
a.args[1] isa QuoteNode && a.args[1].value === Base.getproperty &&
a.args[1] isa QuoteNode && a.args[1].value === Base.getproperty &&
a.args[2] isa QuoteNode && a.args[2].value isa Module &&
a.args[3] isa QuoteNode && a.args[3].value isa Symbol
return lookup_global_ref(Core.GlobalRef(a.args[2].value, a.args[3].value))
Expand Down Expand Up @@ -179,7 +179,6 @@ function optimize!(code::CodeInfo, scope)
# Replace :llvmcall and :foreigncall with compiled variants. See
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
foreigncalls_idx = Int[]
delete_idxs = Int[]
for (idx, stmt) in enumerate(code.code)
# Foregincalls can be rhs of assignments
if isexpr(stmt, :(=))
Expand All @@ -190,36 +189,18 @@ function optimize!(code::CodeInfo, scope)
# Check for :llvmcall
arg1 = stmt.args[1]
if (arg1 === :llvmcall || lookup_stmt(code.code, arg1) === Base.llvmcall) && isempty(sparams) && scope isa Method
nargs = length(stmt.args)-4
# Call via `invokelatest` to avoid compiling it until we need it
delete_idx = Base.invokelatest(build_compiled_call!, stmt, Base.llvmcall, code, idx, nargs, sparams, evalmod)
delete_idx === nothing && error("llvmcall must be compiled, but exited early from build_compiled_call!")
Base.invokelatest(build_compiled_llvmcall!, stmt, code, idx, evalmod)
push!(foreigncalls_idx, idx)
append!(delete_idxs, delete_idx)
end
elseif stmt.head === :foreigncall && scope isa Method
nargs = length(stmt.args[3]::SimpleVector)
# Call via `invokelatest` to avoid compiling it until we need it
delete_idx = Base.invokelatest(build_compiled_call!, stmt, :ccall, code, idx, nargs, sparams, evalmod)
if delete_idx !== nothing
push!(foreigncalls_idx, idx)
append!(delete_idxs, delete_idx)
end
Base.invokelatest(build_compiled_foreigncall!, stmt, code, sparams, evalmod)
push!(foreigncalls_idx, idx)
end
end
end

if !isempty(delete_idxs)
ssalookup = compute_ssa_mapping_delete_statements!(code, delete_idxs)
let lkup = ssalookup
foreigncalls_idx = map(x -> lkup[x], foreigncalls_idx)
end
deleteat!(codelocs(code), delete_idxs)
deleteat!(code.code, delete_idxs)
code.ssavaluetypes = length(code.code)
renumber_ssa!(code.code, ssalookup)
end

## Un-nest :call expressions (so that there will be only one :call per line)
# This will allow us to re-use args-buffers rather than having to allocate new ones each time.
old_code, old_codelocs = code.code, codelocs(code)
Expand Down Expand Up @@ -273,66 +254,52 @@ function parametric_type_to_expr(@nospecialize(t::Type))
return t
end

# Handle :llvmcall & :foreigncall (issue #28)
function build_compiled_call!(stmt::Expr, fcall, code, idx, nargs::Int, sparams::Vector{Symbol}, evalmod)
TVal = evalmod == Core.Compiler ? Core.Compiler.Val : Val
delete_idx = Int[]
if fcall === :ccall
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]::SimpleVector
# delete cconvert and unsafe_convert calls and forward the original values, since
# the same conversions will be applied within the generated compiled variant of this :foreigncall anyway
args = []
for (atype, arg) in zip(ArgType, stmt.args[6:6+nargs-1])
if atype === Any
push!(args, arg)
else
if arg isa SSAValue
unsafe_convert_expr = code.code[arg.id]::Expr
push!(delete_idx, arg.id) # delete the unsafe_convert
cconvert_val = unsafe_convert_expr.args[3]
if isa(cconvert_val, SSAValue)
push!(delete_idx, cconvert_val.id) # delete the cconvert
newarg = (code.code[cconvert_val.id]::Expr).args[3]
push!(args, newarg)
else
@assert isa(cconvert_val, SlotNumber)
push!(args, cconvert_val)
end
elseif arg isa SlotNumber
idx = findfirst(code.code) do expr
Meta.isexpr(expr, :(=)) || return false
lhs = expr.args[1]
return lhs isa SlotNumber && lhs.id === arg.id
end::Int
unsafe_convert_expr = code.code[idx]::Expr
push!(delete_idx, idx) # delete the unsafe_convert
push!(args, unsafe_convert_expr.args[2])
else
error("unexpected foreigncall argument type encountered: $(typeof(arg))")
end
end
end
else
# Run a mini-interpreter to extract the types
framecode = FrameCode(CompiledCalls, code; optimize=false)
frame = Frame(framecode, prepare_framedata(framecode, []))
idxstart = idx
for i = 2:4
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
end
frame.pc = idxstart
if idxstart < idx
while true
pc = step_expr!(Compiled(), frame)
pc === idx && break
pc === nothing && error("this should never happen")
end
function build_compiled_llvmcall!(stmt::Expr, code, idx, evalmod)
# Run a mini-interpreter to extract the types
framecode = FrameCode(CompiledCalls, code; optimize=false)
frame = Frame(framecode, prepare_framedata(framecode, []))
idxstart = idx
for i = 2:4
idxstart = smallest_ref(code.code, stmt.args[i], idxstart)
end
frame.pc = idxstart
if idxstart < idx
while true
pc = step_expr!(Compiled(), frame)
pc === idx && break
pc === nothing && error("this should never happen")
end
cfunc, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])::DataType
args = stmt.args[5:end]
end
llvmir, RetType, ArgType = @lookup(frame, stmt.args[2]), @lookup(frame, stmt.args[3]), @lookup(frame, stmt.args[4])::DataType
args = stmt.args[5:end]
argnames = Any[Symbol(:arg, i) for i = 1:length(args)]
cc_key = (llvmir, RetType, ArgType, evalmod) # compiled call key
f = get(compiled_calls, cc_key, nothing)
if f === nothing
methname = gensym("compiled_llvmcall")
def = :(
function $methname($(argnames...))
return $(Base.llvmcall)($llvmir, $RetType, $ArgType, $(argnames...))
end)
f = Core.eval(evalmod, def)
compiled_calls[cc_key] = f
end

stmt.args[1] = QuoteNode(f)
stmt.head = :call
deleteat!(stmt.args, 2:length(stmt.args))
append!(stmt.args, args)
end


# Handle :llvmcall & :foreigncall (issue #28)
function build_compiled_foreigncall!(stmt::Expr, code, sparams::Vector{Symbol}, evalmod)
TVal = evalmod == Core.Compiler ? Core.Compiler.Val : Val
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]::SimpleVector

dynamic_ccall = false
if isa(cfunc, Expr) # specification by tuple, e.g., (:clock, "libc")
oldcfunc = nothing
if isa(cfunc, Expr) # specification by tuple, e.g., (:clock, "libc")
cfunc = something(static_eval(cfunc), cfunc)
end
if isa(cfunc, Symbol)
Expand All @@ -348,14 +315,12 @@ function build_compiled_call!(stmt::Expr, fcall, code, idx, nargs::Int, sparams:
@assert length(RetType) == 1
RetType = RetType[1]
end
args = stmt.args[6:end]
# When the ccall is dynamic we pass the pointer as an argument so can reuse the function
cc_key = (dynamic_ccall ? :ptr : cfunc, RetType, ArgType, evalmod, length(sparams)) # compiled call key
cc_key = ((dynamic_ccall ? :ptr : cfunc), RetType, ArgType, evalmod, length(sparams)) # compiled call key
f = get(compiled_calls, cc_key, nothing)
argnames = Any[Symbol(:arg, i) for i = 1:nargs]
if f === nothing
if fcall === :ccall
ArgType = Expr(:tuple, Any[parametric_type_to_expr(t) for t in ArgType::SimpleVector]...)
end
ArgType = Expr(:tuple, Any[parametric_type_to_expr(t) for t in ArgType::SimpleVector]...)
RetType = parametric_type_to_expr(RetType)
# #285: test whether we can evaluate an type constraints on parametric expressions
# this essentially comes down to having the names be available in CompiledCalls,
Expand All @@ -366,31 +331,19 @@ function build_compiled_call!(stmt::Expr, fcall, code, idx, nargs::Int, sparams:
catch
return nothing
end
argnames = Any[Symbol(:arg, i) for i = 1:length(args)]
wrapargs = copy(argnames)
if dynamic_ccall
pushfirst!(wrapargs, cfunc)
end
for sparam in sparams
push!(wrapargs, :(::$TVal{$sparam}))
end
methname = gensym("compiledcall")
calling_convention = stmt.args[5]
if calling_convention === :(:llvmcall)
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $fcall($cfunc, llvmcall, $RetType, $ArgType, $(argnames...))
end)
elseif calling_convention === :(:stdcall)
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $fcall($cfunc, stdcall, $RetType, $ArgType, $(argnames...))
end)
else
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $fcall($cfunc, $RetType, $ArgType, $(argnames...))
end)
if dynamic_ccall
pushfirst!(wrapargs, cfunc)
end
methname = gensym("compiled_ccall")
def = :(
function $methname($(wrapargs...)) where {$(sparams...)}
return $(Expr(:foreigncall, cfunc, RetType, stmt.args[3:5]..., argnames...))
end)
f = Core.eval(evalmod, def)
compiled_calls[cc_key] = f
end
Expand All @@ -404,7 +357,7 @@ function build_compiled_call!(stmt::Expr, fcall, code, idx, nargs::Int, sparams:
for i in 1:length(sparams)
push!(stmt.args, :($TVal($(Expr(:static_parameter, i)))))
end
return delete_idx
return nothing
end

function replace_coretypes!(src; rev::Bool=false)
Expand Down
17 changes: 17 additions & 0 deletions test/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -937,3 +937,20 @@ end
@static if isdefined(Base.Experimental, Symbol("@opaque"))
@test @interpret (Base.Experimental.@opaque x->3*x)(4) == 12
end

# CassetteOverlay, issue #552
@static if VERSION >= v"1.8"
using CassetteOverlay
end

@static if VERSION >= v"1.8"
function foo()
x = IdDict()
x[:foo] = 1
end
@MethodTable SinTable;
@testset "CassetteOverlay" begin
pass = @overlaypass SinTable;
@test (@interpret pass(foo)) == 1
end
end