From 738584dfbaa1a51844b12c048a9bd2e56c6baca3 Mon Sep 17 00:00:00 2001 From: jkrumbiegel <22495855+jkrumbiegel@users.noreply.github.com> Date: Mon, 8 Mar 2021 11:32:20 +0100 Subject: [PATCH] allow Module.SubModule.func and similar (#21) --- src/Chain.jl | 70 +++++++++++++++++++++++++++++++------- test/runtests.jl | 87 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 12 deletions(-) diff --git a/src/Chain.jl b/src/Chain.jl index f39036a..7cd6139 100644 --- a/src/Chain.jl +++ b/src/Chain.jl @@ -7,35 +7,81 @@ is_aside(x::Expr) = x.head == :macrocall && x.args[1] == Symbol("@aside") insert_first_arg(symbol::Symbol, firstarg) = Expr(:call, symbol, firstarg) -insert_first_arg(any, firstarg) = error("Can't insert an argument to $any. Needs to be a Symbol or a call expression") +insert_first_arg(any, firstarg) = insertionerror(any) + +function insertionerror(expr) + error( + """Can't insert a first argument into: + $expr. + + First argument insertion works with expressions like these, where [Module.SubModule.] is optional: + + [Module.SubModule.]func + [Module.SubModule.]func(args...) + [Module.SubModule.]func(args...; kwargs...) + [Module.SubModule.]@macro + [Module.SubModule.]@macro(args...) + @. [Module.SubModule.]func + """ + ) +end + +is_moduled_symbol(x) = false +function is_moduled_symbol(e::Expr) + e.head == :. && + length(e.args) == 2 && + (e.args[1] isa Symbol || is_moduled_symbol(e.args[1])) && + e.args[2] isa QuoteNode && + e.args[2].value isa Symbol +end function insert_first_arg(e::Expr, firstarg) head = e.head args = e.args - # f(a, b) --> f(firstarg, a, b) - if head == :call && length(args) > 0 + # Module.SubModule.symbol + if is_moduled_symbol(e) + Expr(:call, e, firstarg) + + # f(args...) --> f(firstarg, args...) + elseif head == :call && length(args) > 0 if length(args) ≥ 2 && Meta.isexpr(args[2], :parameters) Expr(head, args[1:2]..., firstarg, args[3:end]...) else Expr(head, args[1], firstarg, args[2:end]...) end - # f.(a, b) --> f.(firstarg, a, b) - elseif head == :. && length(args) > 1 && - args[1] isa Symbol && args[2] isa Expr && args[2].head == :tuple + + # f.(args...) --> f.(firstarg, args...) + elseif head == :. && + length(args) > 1 && + args[1] isa Symbol && + args[2] isa Expr && + args[2].head == :tuple Expr(head, args[1], Expr(args[2].head, firstarg, args[2].args...)) - # @. somesymbol --> somesymbol.(firstarg) - elseif head == :macrocall && length(args) == 3 && args[1] == Symbol("@__dot__") && - args[2] isa LineNumberNode && args[3] isa Symbol + + # @. [Module.SubModule.]somesymbol --> somesymbol.(firstarg) + elseif head == :macrocall && + length(args) == 3 && + args[1] == Symbol("@__dot__") && + args[2] isa LineNumberNode && + (is_moduled_symbol(args[3]) || args[3] isa Symbol) + Expr(:., args[3], Expr(:tuple, firstarg)) - # @macro(a, b) --> @macro(firstarg, a, b) - elseif head == :macrocall && args[1] isa Symbol && args[2] isa LineNumberNode + # @macro(args...) --> @macro(firstarg, args...) + elseif head == :macrocall && + (is_moduled_symbol(args[1]) || args[1] isa Symbol) && + args[2] isa LineNumberNode + + if args[1] == Symbol("@__dot__") + error("You can only use the @. macro and automatic first argument insertion if what follows is of the form `[Module.SubModule.]func`") + end + Expr(head, args[1], args[2], firstarg, args[3:end]...) else - error("Can't prepend first arg to expression $e that isn't a call.") + insertionerror(e) end end diff --git a/test/runtests.jl b/test/runtests.jl index c2a20f0..7f44162 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -260,3 +260,90 @@ end end @test y == 3 end + +module LocalModule + function square(xs) + xs .^ 2 + end + + function power(xs, pow) + xs .^ pow + end + + add_one(x) = x + 1 + + macro sin(exp) + :(sin($(esc(exp)))) + end + + macro broadcastminus(exp1, exp2) + :(broadcast(-, $(esc(exp1)), $(esc(exp2)))) + end + + module SubModule + function square(xs) + xs .^ 2 + end + + function power(xs, pow) + xs .^ pow + end + + add_one(x) = x + 1 + + macro sin(exp) + :(sin($(esc(exp)))) + end + + macro broadcastminus(exp1, exp2) + :(broadcast(-, $(esc(exp1)), $(esc(exp2)))) + end + end +end + +@testset "Module qualification" begin + + using .LocalModule + + xs = [1, 2, 3] + pow = 4 + y = @chain xs begin + LocalModule.square + LocalModule.power(pow) + Base.sum + end + @test y == sum(LocalModule.power(LocalModule.square(xs), pow)) + + y2 = @chain xs begin + LocalModule.SubModule.square + LocalModule.SubModule.power(pow) + Base.sum + end + @test y == sum(LocalModule.SubModule.power(LocalModule.SubModule.square(xs), pow)) + + y3 = @chain xs begin + @. LocalModule.add_one + @. LocalModule.SubModule.add_one + end + @test y3 == LocalModule.SubModule.add_one.(LocalModule.add_one.(xs)) + + y4 = @chain xs begin + LocalModule.@broadcastminus(2.5) + end + @test y4 == LocalModule.@broadcastminus(xs, 2.5) + + y5 = @chain xs begin + LocalModule.SubModule.@broadcastminus(2.5) + end + @test y5 == LocalModule.SubModule.@broadcastminus(xs, 2.5) + + y6 = @chain 3 begin + LocalModule.@sin + end + @test y6 == LocalModule.@sin(3) + + y7 = @chain 3 begin + LocalModule.SubModule.@sin + end + @test y7 == LocalModule.SubModule.@sin(3) +end \ No newline at end of file