Skip to content

Commit

Permalink
Desugaring of positional arguments with defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
c42f committed Nov 22, 2024
1 parent ceb3eac commit 760de66
Show file tree
Hide file tree
Showing 4 changed files with 612 additions and 57 deletions.
250 changes: 193 additions & 57 deletions src/desugaring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,36 @@ end

# Return true when `x` and `y` are "the same identifier", but also works with
# bindings (and hence ssa vars). See also `is_identifier_like()`
function is_same_identifier_like(x, y)
return (kind(x) == K"Identifier" && kind(y) == K"Identifier" && NameKey(x) == NameKey(y)) ||
(kind(x) == K"BindingId" && kind(y) == K"BindingId" && x.var_id == y.var_id)
function is_same_identifier_like(ex::SyntaxTree, y::SyntaxTree)
return (kind(ex) == K"Identifier" && kind(y) == K"Identifier" && NameKey(ex) == NameKey(y)) ||
(kind(ex) == K"BindingId" && kind(y) == K"BindingId" && ex.var_id == y.var_id)
end

function is_same_identifier_like(x, name::AbstractString)
return kind(x) == K"Identifier" && x.name_val == name
function is_same_identifier_like(ex::SyntaxTree, name::AbstractString)
return kind(ex) == K"Identifier" && ex.name_val == name
end

function contains_identifier(ex, idents...)
return any(is_same_identifier_like(ex, id) for id in idents) ||
(!is_leaf(ex) && any(contains_identifier(e, idents...) for e in children(ex)))
function contains_identifier(ex::SyntaxTree, idents::AbstractVector{<:SyntaxTree})
contains_unquoted(ex) do e
any(is_same_identifier_like(e, id) for id in idents)
end
end

function contains_identifier(ex::SyntaxTree, idents...)
contains_unquoted(ex) do e
any(is_same_identifier_like(e, id) for id in idents)
end
end

# Return true if `f(e)` is true for any unquoted child of `ex`, recursively.
function contains_unquoted(f::Function, ex::SyntaxTree)
if f(ex)
return true
elseif !is_leaf(ex) && !(kind(ex) in KSet"quote inert meta")
return any(contains_unquoted(f, e) for e in children(ex))
else
return false
end
end

# Identify some expressions that are safe to repeat
Expand Down Expand Up @@ -526,7 +544,7 @@ function expand_unionall_def(ctx, srcref, lhs, rhs)
name = lhs[1]
@ast ctx srcref [K"block"
[K"const_if_global" name]
unionall_type = expand_forms_2(ctx, [K"where" rhs lhs[2:end]...])
unionall_type := expand_forms_2(ctx, [K"where" rhs lhs[2:end]...])
expand_forms_2([K"=" name unionall_type])
]
end
Expand Down Expand Up @@ -1238,7 +1256,9 @@ function match_function_arg(full_ex)
is_nospecialize = true
ex = ex[2]
elseif k == K"="
@chk full_ex isnothing(default) && !is_slurp
if !isnothing(default)
throw(full_ex, "multiple defaults provided with `=` in function argument")
end
default = ex[2]
ex = ex[1]
else
Expand All @@ -1252,27 +1272,140 @@ function match_function_arg(full_ex)
is_nospecialize=is_nospecialize)
end

# Expand `where` clause(s) of a function into (typevar_names, typevar_stmts) where
# - `typevar_names` are the names of the type's type parameters
# - `typevar_stmts` are a list of statements to define a `TypeVar` for each parameter
# name in `typevar_names`, to be emitted prior to uses of `typevar_names`.
# There is exactly one statement from each typevar.
function _split_wheres!(ctx, typevar_names, typevar_stmts, ex)
if kind(ex) == K"where" && numchildren(ex) == 2
vars_kind = kind(ex[2])
if vars_kind == K"_typevars"
append!(typevar_names, children(ex[2]))
else
params = vars_kind == K"braces" ? ex[2][1:end] : ex[2:2]
for param in params
bounds = analyze_typevar(ctx, param)
n = bounds[1]
push!(typevar_names, n)
push!(typevar_stmts, @ast ctx param [K"local" n])
push!(typevar_stmts, @ast ctx param [K"=" n bounds_to_TypeVar(ctx, param, bounds)])
end
expand_typevars!(ctx, typevar_names, typevar_stmts, params)
end
_split_wheres!(ctx, typevar_names, typevar_stmts, ex[1])
else
ex
end
end

function _method_def_expr(ctx, srcref, callex, func_self, method_table,
docs, typevar_names, arg_names, arg_types, ret_var, body)
# metadata contains svec(types, sparms, location)
@ast ctx srcref [K"block"
method_metadata := [K"call"(callex)
"svec" ::K"core"
[K"call"
"svec" ::K"core"
arg_types...
]
[K"call"
"svec" ::K"core"
typevar_names...
]
QuoteNode(source_location(LineNumberNode, callex))::K"Value"
]
[K"method"
method_table
method_metadata
[K"lambda"(body, is_toplevel_thunk=false)
[K"block" arg_names...]
[K"block" typevar_names...]
body
ret_var # might be `nothing` and hence removed
]
]
if !isnothing(docs)
[K"call"(docs)
bind_docs!::K"Value"
func_self
docs[1]
method_metadata
]
end
]
end

function trim_used_typevars(ctx, arg_types, typevar_names, typevar_stmts)
n_typevars = length(typevar_names)
@assert n_typevars == length(typevar_stmts)
# Filter typevar names down to those which are directly used in the arg list
typevar_used = [contains_identifier(tn, arg_types) for tn in typevar_names]
# _Or_ used transitively via other typevars. The following code
# computes this by incrementally coloring the graph of dependencies
# between type vars.
found_used = true
while found_used
found_used = false
for (i,tn) in enumerate(typevar_names)
if typevar_used[i]
continue
end
for j = i+1:n_typevars
if typevar_used[j] && contains_identifier(typevar_stmts[j], tn)
found_used = true
typevar_used[i] = true
break
end
end
end
end
trimmed_typevar_names = SyntaxList(ctx)
for (used,tn) in zip(typevar_used, typevar_names)
if used
push!(trimmed_typevar_names, tn)
end
end
return trimmed_typevar_names
end

# Generate a method for every number of allowed optional arguments
# For example for `f(x, y=1, z=2)` we generate two additional methods
# f(x) = f(x, 1, 2)
# f(x, y) = f(x, y, 2)
function _optional_positional_defs!(ctx, method_stmts, srcref, callex, func_self,
method_table, typevar_names, typevar_stmts,
arg_names, arg_types, first_default, arg_defaults, ret_var)
# Replace placeholder arguments with variables - we need to pass them to
# the inner method for dispatch even when unused in the inner method body
def_arg_names = map(arg_names) do arg
kind(arg) == K"Placeholder" ?
new_mutable_var(ctx, arg, arg.name_val; kind=:argument) :
arg
end
for def_idx = 1:length(arg_defaults)
first_omitted = first_default + def_idx - 1
trimmed_arg_names = def_arg_names[1:first_omitted-1]
# Call the full method directly if no arguments are reused in
# subsequent defaults. Otherwise conservatively call the function with
# only one additional default argument supplied and let the chain of
# function calls eventually lead to the full method.
any_args_in_trailing_defaults =
any(arg_defaults[def_idx+1:end]) do defaultval
contains_identifier(defaultval, def_arg_names[first_omitted:end])
end
last_used_default = any_args_in_trailing_defaults ?
def_idx : lastindex(arg_defaults)
body = @ast ctx callex [K"block"
[K"call"
trimmed_arg_names...
arg_defaults[def_idx:last_used_default]...
]
]
trimmed_arg_types = arg_types[1:first_omitted-1]
trimmed_typevar_names = trim_used_typevars(ctx, trimmed_arg_types,
typevar_names, typevar_stmts)
# TODO: Ensure we preserve @nospecialize metadata in args
push!(method_stmts,
_method_def_expr(ctx, srcref, callex, func_self, method_table, nothing,
trimmed_typevar_names, trimmed_arg_names, trimmed_arg_types,
ret_var, body))
end
end

function expand_function_def(ctx, ex, docs, rewrite_call=identity, rewrite_body=identity)
@chk numchildren(ex) in (1,2)
name = ex[1]
Expand Down Expand Up @@ -1336,7 +1469,7 @@ function expand_function_def(ctx, ex, docs, rewrite_call=identity, rewrite_body=
if isnothing(info.default)
if !isempty(arg_defaults) && !info.is_slurp
# TODO: Referring to multiple pieces of syntax in one error message is necessary.
# TODO: Poision ASTs with error nodes and continue rather than immediately throwing.
# TODO: Poison ASTs with error nodes and continue rather than immediately throwing.
#
# We should make something like the following kind of thing work!
# arg_defaults[1] = @ast_error ctx arg_defaults[1] """
Expand All @@ -1354,6 +1487,9 @@ function expand_function_def(ctx, ex, docs, rewrite_call=identity, rewrite_body=
end
push!(arg_defaults, info.default)
end
# TODO: Ideally, ensure side effects of evaluating arg_types only
# happen once - we should create an ssavar if there's any following
# defaults. (flisp lowering doesn't ensure this either)
push!(arg_types, atype)
end

Expand Down Expand Up @@ -1387,6 +1523,7 @@ function expand_function_def(ctx, ex, docs, rewrite_call=identity, rewrite_body=
func_self
]
end
# Add self argument
pushfirst!(arg_names, farg_name)
pushfirst!(arg_types, farg_type)

Expand All @@ -1402,42 +1539,33 @@ function expand_function_def(ctx, ex, docs, rewrite_call=identity, rewrite_body=
ret_var = nothing
end

method_table = nothing_(ctx, name) # TODO: method overlays
method_table_val = nothing # TODO: method overlays
method_table = isnothing(method_table_val) ?
@ast(ctx, callex, "nothing"::K"core") :
ssavar(ctx, ex, "method_table")
method_stmts = SyntaxList(ctx)

if !isempty(arg_defaults)
# For self argument added above
first_default += 1
_optional_positional_defs!(ctx, method_stmts, ex, callex, func_self,
method_table, typevar_names, typevar_stmts,
arg_names, arg_types, first_default, arg_defaults, ret_var)
end

# The method with all non-default arguments
push!(method_stmts,
_method_def_expr(ctx, ex, callex, func_self, method_table, docs,
typevar_names, arg_names, arg_types, ret_var, body))

@ast ctx ex [K"scope_block"(scope_type=:hard)
[K"block"
typevar_stmts...
[K"=" func_self func_self_val]
# metadata contains svec(types, sparms, location)
method_metadata := [K"call"(callex)
"svec" ::K"core"
[K"call"
"svec" ::K"core"
arg_types...
]
[K"call"
"svec" ::K"core"
typevar_names...
]
QuoteNode(source_location(LineNumberNode, callex))::K"Value"
]
[K"method"
method_table
method_metadata
[K"lambda"(body, is_toplevel_thunk=false)
[K"block" arg_names...]
[K"block" typevar_names...]
body
ret_var # might be `nothing` and hence removed
]
]
if !isnothing(docs)
[K"call"(docs)
bind_docs!::K"Value"
func_self
docs[1]
method_metadata
]
if !isnothing(method_table_val)
[K"=" method_table method_table_val]
end
[K"=" func_self func_self_val]
method_stmts...
[K"unnecessary" func_self]
]
]
Expand Down Expand Up @@ -1569,19 +1697,27 @@ function analyze_type_sig(ctx, ex)
end

# Expand type_params into (typevar_names, typevar_stmts) where
# - `typevar_names` are the names of the types's type parameters
# - `typevar_names` are the names of the type's type parameters
# - `typevar_stmts` are a list of statements to define a `TypeVar` for each parameter
# name in `typevar_names`, to be emitted prior to uses of `typevar_names`
function expand_typevars(ctx, type_params)
typevar_names = SyntaxList(ctx)
typevar_stmts = SyntaxList(ctx)
# name in `typevar_names`, to be emitted prior to uses of `typevar_names`.
# There is exactly one statement from each typevar.
function expand_typevars!(ctx, typevar_names, typevar_stmts, type_params)
for param in type_params
bounds = analyze_typevar(ctx, param)
n = bounds[1]
push!(typevar_names, n)
push!(typevar_stmts, @ast ctx param [K"local" n])
push!(typevar_stmts, @ast ctx param [K"=" n bounds_to_TypeVar(ctx, param, bounds)])
push!(typevar_stmts, @ast ctx param [K"block"
[K"local" n]
[K"=" n bounds_to_TypeVar(ctx, param, bounds)]
])
end
return nothing
end

function expand_typevars(ctx, type_params)
typevar_names = SyntaxList(ctx)
typevar_stmts = SyntaxList(ctx)
expand_typevars!(ctx, typevar_names, typevar_stmts, type_params)
return (typevar_names, typevar_stmts)
end

Expand Down
Loading

0 comments on commit 760de66

Please sign in to comment.