Skip to content

Commit

Permalink
allow Module.SubModule.func and similar (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkrumbiegel authored Mar 8, 2021
1 parent ebb7223 commit 738584d
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 12 deletions.
70 changes: 58 additions & 12 deletions src/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,81 @@ is_aside(x::Expr) = x.head == :macrocall && x.args[1] == Symbol("@aside")


insert_first_arg(symbol::Symbol, firstarg) = Expr(:call, symbol, firstarg)
insert_first_arg(any, firstarg) = error("Can't insert an argument to $any. Needs to be a Symbol or a call expression")
insert_first_arg(any, firstarg) = insertionerror(any)

function insertionerror(expr)
error(
"""Can't insert a first argument into:
$expr.
First argument insertion works with expressions like these, where [Module.SubModule.] is optional:
[Module.SubModule.]func
[Module.SubModule.]func(args...)
[Module.SubModule.]func(args...; kwargs...)
[Module.SubModule.]@macro
[Module.SubModule.]@macro(args...)
@. [Module.SubModule.]func
"""
)
end

is_moduled_symbol(x) = false
function is_moduled_symbol(e::Expr)
e.head == :. &&
length(e.args) == 2 &&
(e.args[1] isa Symbol || is_moduled_symbol(e.args[1])) &&
e.args[2] isa QuoteNode &&
e.args[2].value isa Symbol
end

function insert_first_arg(e::Expr, firstarg)
head = e.head
args = e.args

# f(a, b) --> f(firstarg, a, b)
if head == :call && length(args) > 0
# Module.SubModule.symbol
if is_moduled_symbol(e)
Expr(:call, e, firstarg)

# f(args...) --> f(firstarg, args...)
elseif head == :call && length(args) > 0
if length(args) 2 && Meta.isexpr(args[2], :parameters)
Expr(head, args[1:2]..., firstarg, args[3:end]...)
else
Expr(head, args[1], firstarg, args[2:end]...)
end
# f.(a, b) --> f.(firstarg, a, b)
elseif head == :. && length(args) > 1 &&
args[1] isa Symbol && args[2] isa Expr && args[2].head == :tuple

# f.(args...) --> f.(firstarg, args...)
elseif head == :. &&
length(args) > 1 &&
args[1] isa Symbol &&
args[2] isa Expr &&
args[2].head == :tuple

Expr(head, args[1], Expr(args[2].head, firstarg, args[2].args...))
# @. somesymbol --> somesymbol.(firstarg)
elseif head == :macrocall && length(args) == 3 && args[1] == Symbol("@__dot__") &&
args[2] isa LineNumberNode && args[3] isa Symbol

# @. [Module.SubModule.]somesymbol --> somesymbol.(firstarg)
elseif head == :macrocall &&
length(args) == 3 &&
args[1] == Symbol("@__dot__") &&
args[2] isa LineNumberNode &&
(is_moduled_symbol(args[3]) || args[3] isa Symbol)

Expr(:., args[3], Expr(:tuple, firstarg))

# @macro(a, b) --> @macro(firstarg, a, b)
elseif head == :macrocall && args[1] isa Symbol && args[2] isa LineNumberNode
# @macro(args...) --> @macro(firstarg, args...)
elseif head == :macrocall &&
(is_moduled_symbol(args[1]) || args[1] isa Symbol) &&
args[2] isa LineNumberNode

if args[1] == Symbol("@__dot__")
error("You can only use the @. macro and automatic first argument insertion if what follows is of the form `[Module.SubModule.]func`")
end

Expr(head, args[1], args[2], firstarg, args[3:end]...)

else
error("Can't prepend first arg to expression $e that isn't a call.")
insertionerror(e)
end
end

Expand Down
87 changes: 87 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,90 @@ end
end
@test y == 3
end

module LocalModule
function square(xs)
xs .^ 2
end

function power(xs, pow)
xs .^ pow
end

add_one(x) = x + 1

macro sin(exp)
:(sin($(esc(exp))))
end

macro broadcastminus(exp1, exp2)
:(broadcast(-, $(esc(exp1)), $(esc(exp2))))
end

module SubModule
function square(xs)
xs .^ 2
end

function power(xs, pow)
xs .^ pow
end

add_one(x) = x + 1

macro sin(exp)
:(sin($(esc(exp))))
end

macro broadcastminus(exp1, exp2)
:(broadcast(-, $(esc(exp1)), $(esc(exp2))))
end
end
end

@testset "Module qualification" begin

using .LocalModule

xs = [1, 2, 3]
pow = 4
y = @chain xs begin
LocalModule.square
LocalModule.power(pow)
Base.sum
end
@test y == sum(LocalModule.power(LocalModule.square(xs), pow))

y2 = @chain xs begin
LocalModule.SubModule.square
LocalModule.SubModule.power(pow)
Base.sum
end
@test y == sum(LocalModule.SubModule.power(LocalModule.SubModule.square(xs), pow))

y3 = @chain xs begin
@. LocalModule.add_one
@. LocalModule.SubModule.add_one
end
@test y3 == LocalModule.SubModule.add_one.(LocalModule.add_one.(xs))

y4 = @chain xs begin
LocalModule.@broadcastminus(2.5)
end
@test y4 == LocalModule.@broadcastminus(xs, 2.5)

y5 = @chain xs begin
LocalModule.SubModule.@broadcastminus(2.5)
end
@test y5 == LocalModule.SubModule.@broadcastminus(xs, 2.5)

y6 = @chain 3 begin
LocalModule.@sin
end
@test y6 == LocalModule.@sin(3)

y7 = @chain 3 begin
LocalModule.SubModule.@sin
end
@test y7 == LocalModule.SubModule.@sin(3)
end

0 comments on commit 738584d

Please sign in to comment.