Skip to content

Commit

Permalink
Record captured locals per lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
c42f committed Dec 5, 2024
1 parent afce50e commit 2a8eb6e
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/closure_conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct ClosureConversionCtx{GraphType} <: AbstractLoweringContext
end

function add_lambda_local!(ctx::ClosureConversionCtx, id)
push!(ctx.lambda_bindings.locals, id)
init_lambda_binding(ctx.lambda_bindings, id)
end

# Convert `ex` to `type` by calling `convert(type, ex)` when necessary.
Expand Down
15 changes: 9 additions & 6 deletions src/linear_ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ function compile_conditional(ctx, ex, false_label)
end

function add_lambda_local!(ctx::LinearIRContext, id)
push!(ctx.lambda_bindings.locals, id)
init_lambda_binding(ctx.lambda_bindings, id)
end

# Lowering of exception handling must ensure that
Expand Down Expand Up @@ -941,11 +941,14 @@ function compile_lambda(outer_ctx, ex)
end
end
# Sorting the lambda locals is required to remove dependence on Dict iteration order.
for id in sort(collect(ex.lambda_bindings.locals))
info = lookup_binding(ctx.bindings, id)
@assert info.kind == :local
push!(slots, Slot(info.name, :local, false))
slot_rewrites[id] = length(slots)
for (id, lbinfo) in sort(collect(pairs(ex.lambda_bindings.bindings)), by=first)
if !lbinfo.is_captured
info = lookup_binding(ctx.bindings, id)
if info.kind == :local
push!(slots, Slot(info.name, :local, false))
slot_rewrites[id] = length(slots)
end
end
end
for (i,arg) in enumerate(children(static_parameters))
@assert kind(arg) == K"BindingId"
Expand Down
141 changes: 115 additions & 26 deletions src/scope_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,49 @@ function NameKey(ex::SyntaxTree)
NameKey(ex.name_val, get(ex, :scope_layer, _lowering_internal_layer))
end

struct CaptureInfo
# Metadata about how a binding is used within some enclosing lambda
struct LambdaBindingInfo
is_captured::Bool
is_read::Bool
is_assigned::Bool
is_called::Bool
end

LambdaBindingInfo() = LambdaBindingInfo(false, false, false, false)

function LambdaBindingInfo(parent::LambdaBindingInfo;
is_captured = nothing,
is_read = nothing,
is_assigned = nothing,
is_called = nothing)
LambdaBindingInfo(
isnothing(is_captured) ? parent.is_captured : is_captured,
isnothing(is_read) ? parent.is_read : is_read,
isnothing(is_assigned) ? parent.is_assigned : is_assigned,
isnothing(is_called) ? parent.is_called : is_called,
)
end

struct LambdaBindings
# Local bindings within the lambda
locals::Set{IdTag}
captures::Dict{IdTag,CaptureInfo}
# Bindings used within the lambda
bindings::Dict{IdTag,LambdaBindingInfo}
end

LambdaBindings() = LambdaBindings(Set{IdTag}(), Dict{IdTag,CaptureInfo}())
function init_lambda_binding(binds::LambdaBindings, id; kws...)
@assert !haskey(binds.bindings, id)
binds.bindings[id] = LambdaBindingInfo(LambdaBindingInfo(); kws...)
end

function update_lambda_binding!(binds::LambdaBindings, id; kws...)
binfo = binds.bindings[id]
binds.bindings[id] = LambdaBindingInfo(binfo; kws...)
end

function update_lambda_binding!(ctx::AbstractLoweringContext, id; kws...)
update_lambda_binding!(last(ctx.scope_stack).lambda_bindings, id; kws...)
end

LambdaBindings() = LambdaBindings(Dict{IdTag,LambdaBindings}())


struct ScopeInfo
Expand Down Expand Up @@ -230,17 +260,19 @@ function add_lambda_args(ctx, var_ids, args, args_kind)
"static parameter name not distinct from function argument"
throw(LoweringError(arg, msg))
end
var_ids[varkey] = init_binding(ctx, varkey, args_kind;
is_nospecialize=getmeta(arg, :nospecialize, false))
id = init_binding(ctx, varkey, args_kind;
is_nospecialize=getmeta(arg, :nospecialize, false))
var_ids[varkey] = id
elseif ka != K"BindingId" && ka != K"Placeholder"
throw(LoweringError(arg, "Unexpected lambda arg kind"))
end
end
end

# Analyze identifier usage within a scope, adding all newly discovered
# identifiers to ctx.bindings and returning a lookup table from identifier
# names to their variable IDs
# Analyze identifier usage within a scope
# * Allocate a new binding for each identifier which the scope introduces.
# * Record the identifier=>binding mapping in a lookup table
# * Return a `ScopeInfo` with the mapping plus additional scope metadata
function analyze_scope(ctx, ex, scope_type, is_toplevel_global_scope=false,
lambda_args=nothing, lambda_static_parameters=nothing)
parentscope = isempty(ctx.scope_stack) ? nothing : ctx.scope_stack[end]
Expand All @@ -251,8 +283,13 @@ function analyze_scope(ctx, ex, scope_type, is_toplevel_global_scope=false,
assignments, locals, destructured_args, globals,
used, used_bindings, alias_bindings = find_scope_vars(ex)

# Create new lookup table for variables in this scope which differ from the
# parent scope.
# Construct a mapping from identifiers to bindings
#
# This will contain a binding ID for each variable which is introduced by
# the scope, including
# * Explicit locals
# * Explicit globals
# * Implicit locals created by assignment
var_ids = Dict{NameKey,IdTag}()

if !isnothing(lambda_args)
Expand All @@ -272,8 +309,9 @@ function analyze_scope(ctx, ex, scope_type, is_toplevel_global_scope=false,
end
elseif var_kind(ctx, varkey) === :static_parameter
throw(LoweringError(e, "local variable name `$(varkey.name)` conflicts with a static parameter"))
else
var_ids[varkey] = init_binding(ctx, varkey, :local)
end
var_ids[varkey] = init_binding(ctx, varkey, :local)
end

# Add explicit globals
Expand Down Expand Up @@ -353,24 +391,67 @@ function analyze_scope(ctx, ex, scope_type, is_toplevel_global_scope=false,
end
end

for varkey in used
if lookup_var(ctx, varkey) === nothing
# Add other newly discovered identifiers as globals
init_binding(ctx, varkey, :global)
end
end

#--------------------------------------------------
# At this point we've discovered all the bindings defined in this scope and
# added them to `var_ids`.
#
# Next we record information about how the new bindings relate to the
# enclosing lambda
# * All non-globals are recorded (kind :local and :argument will later be turned into slots)
# * Captured variables are detected and recorded
lambda_bindings = is_outer_lambda_scope ? LambdaBindings() : parentscope.lambda_bindings

for id in values(var_ids)
vk = var_kind(ctx, id)
if vk === :local
push!(lambda_bindings.locals, id)
binfo = lookup_binding(ctx, id)
if !binfo.is_ssa && binfo.kind !== :global
init_lambda_binding(lambda_bindings, id)
end
end

# FIXME: This assumes used bindings are internal to the lambda and cannot
# be from the environment, and also assumes they are assigned. That's
# correct for now but in general we should go by the same code path that
# identifiers do.
for id in used_bindings
info = lookup_binding(ctx, id)
if !info.is_ssa && info.kind == :local
push!(lambda_bindings.locals, id)
binfo = lookup_binding(ctx, id)
if !binfo.is_ssa && binfo.kind !== :global
if !haskey(lambda_bindings.bindings, id)
init_lambda_binding(lambda_bindings, id, is_read=true, is_assigned=true)
end
end
end

for varkey in used
id = haskey(var_ids, varkey) ? var_ids[varkey] : lookup_var(ctx, varkey)
if id === nothing
# Identifiers which are used but not defined in some scope are
# newly discovered global bindings
init_binding(ctx, varkey, :global)
elseif !in_toplevel_thunk
binfo = lookup_binding(ctx, id)
if binfo.kind !== :global
if !haskey(lambda_bindings.bindings, id)
# Used vars from a scope *outside* the current lambda are captured
init_lambda_binding(lambda_bindings, id, is_captured=true, is_read=true)
else
update_lambda_binding!(lambda_bindings, id, is_read=true)
end
end
end
end

if !in_toplevel_thunk
for (varkey,_) in assignments
id = haskey(var_ids, varkey) ? var_ids[varkey] : lookup_var(ctx, varkey)
binfo = lookup_binding(ctx, id)
if binfo.kind !== :global
if !haskey(lambda_bindings.bindings, id)
# Assigned vars from a scope *outside* the current lambda are captured
init_lambda_binding(lambda_bindings, id, is_captured=true, is_assigned=true)
else
update_lambda_binding!(lambda_bindings, id, is_assigned=true)
end
end
end
end

Expand Down Expand Up @@ -410,6 +491,14 @@ function maybe_update_bindings!(ctx, ex)
throw(LoweringError(ex, "unsupported `const` declaration on local variable"))
end
update_binding!(ctx, id; is_const=true)
elseif k == K"call"
name = ex[1]
if kind(name) == K"BindingId"
id = name.var_id
if haskey(last(ctx.scope_stack).lambda_bindings.bindings, id)
update_lambda_binding!(ctx, id, is_called=true)
end
end
end
nothing
end
Expand Down

0 comments on commit 2a8eb6e

Please sign in to comment.