diff --git a/Project.toml b/Project.toml index 9f8b4bfc9..bb646a03f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,4 +1,4 @@ -name = "Taped" +name = "Phi" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt and contributors"] version = "0.1.0" @@ -21,7 +21,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] -TapedSpecialFunctionsExt = "SpecialFunctions" +PhiSpecialFunctionsExt = "SpecialFunctions" [compat] BenchmarkTools = "1" diff --git a/bench/Project.toml b/bench/Project.toml index c124eb895..dc1f40799 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -9,7 +9,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Taped = "07d77754-e150-4737-8c94-cd238a1fb45b" +Phi = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" diff --git a/bench/README.md b/bench/README.md index 39a5a07c2..28c28a74a 100644 --- a/bench/README.md +++ b/bench/README.md @@ -1,10 +1,10 @@ # Benchmarking There are two flavours of benchmarks implemented in `run_benchmarks.jl`. -One is a set of pass / fail tests designed to check that large performance regressions are avoided in Taped. -The other is a set of comparisons between a variety of frameworks -- this set is designed to give a rough sense of where Taped stands in comparison to other AD frameworks, and the results should not be thought of as pass / fail tests. +One is a set of pass / fail tests designed to check that large performance regressions are avoided in Phi. +The other is a set of comparisons between a variety of frameworks -- this set is designed to give a rough sense of where Phi stands in comparison to other AD frameworks, and the results should not be thought of as pass / fail tests. -## Taped-Only Benchmarking +## Phi-Only Benchmarking The benchmarking runs as part of CI, and evaluates a sequence of pass / fail tests. @@ -21,9 +21,9 @@ to run the benchmarks which test the performance of AD. This will produce a `Dat containing a run-down of the results. It has the following columns: 1. `tag`: a `String` with an automatically generated name for the test 1. `primal_time`: the time taken to run the original code -1. `taped_time`: the time is takes Taped to AD the code -1. `Taped`: `taped_time / primal_time` -1. `range`: a named tuple with fields `lb` and `ub` specifying the acceptable range of values for `Taped`. +1. `phi_time`: the time is takes Phi to AD the code +1. `Phi`: `phi_time / primal_time` +1. `range`: a named tuple with fields `lb` and `ub` specifying the acceptable range of values for `Phi`. From here you can look at whatever properties of the results you are interested in. @@ -33,7 +33,7 @@ Note that the types of all of the columns are very simple, so it is fine to writ CSV.write("file_name.csv", df) ``` -Additionally, the convenience function `plot_ratio_histogram!` can be used to produce a histogram of `Taped` with formatting which is suited to this field. Call it as follows: +Additionally, the convenience function `plot_ratio_histogram!` can be used to produce a histogram of `Phi` with formatting which is suited to this field. Call it as follows: ```julia derived_results = benchmark_derived_rrules!!(Xoshiro) df = DataFrame(df) @@ -42,7 +42,7 @@ plot_ratio_histogram!(df) ## Inter-framework Benchmarking -This comprises a small suite of functions that we AD using Taped, Zygote, ReverseDiff, and Enzyme. This suite of benchmarks is also run as part of CI, and the output is recorded in two ways: +This comprises a small suite of functions that we AD using Phi, Zygote, ReverseDiff, and Enzyme. This suite of benchmarks is also run as part of CI, and the output is recorded in two ways: 1. a table of results is posted as comment in a PR 1. the table and a corresponding graph are stored as github actions artifacts, and can be retrieved by going to the "Checks" tab of your PR, and clicking on the artifact button. diff --git a/bench/run_benchmarks.jl b/bench/run_benchmarks.jl index 7b5d30297..ebaf75024 100644 --- a/bench/run_benchmarks.jl +++ b/bench/run_benchmarks.jl @@ -13,21 +13,21 @@ using PrettyTables, Random, ReverseDiff, - Taped, + Phi, Test, Turing, Zygote -using Taped: +using Phi: CoDual, generate_hand_written_rrule!!_test_cases, generate_derived_rrule!!_test_cases, InterpretedFunction, TestUtils, - TInterp, + PInterp, _typeof -using Taped.TestUtils: _deepcopy, to_benchmark +using Phi.TestUtils: _deepcopy, to_benchmark function zygote_to_benchmark(ctx, x::Vararg{Any, N}) where {N} out, pb = Zygote._pullback(ctx, x...) @@ -84,7 +84,7 @@ _broadcast_sin_cos_exp(x::AbstractArray{<:Real}) = sum(sin.(cos.(exp.(x)))) # about all of the operations. _simple_mlp(W2, W1, Y, X) = sum(abs2, Y - W2 * map(x -> x * (0 <= x), W1 * X)) -# Only Zygote and Taped can actually handle this. Note that Taped only has rules for BLAS +# Only Zygote and Phi can actually handle this. Note that Phi only has rules for BLAS # and LAPACK stuff, not explicit rules for things like the squared euclidean distance. # Consequently, Zygote is at a major advantage. _gp_lml(x, y, s) = logpdf(GP(SEKernel())(x, s), y) @@ -174,12 +174,12 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo evals=1, ) - # Benchmark AD via Taped. - @info "taped" - rule = Taped.build_rrule(args...) + # Benchmark AD via Phi. + @info "phi" + rule = Phi.build_rrule(args...) coduals = map(x -> x isa CoDual ? x : zero_codual(x), args) to_benchmark(rule, coduals...) - suite["taped"] = @benchmark(to_benchmark($rule, $coduals...)) + suite["phi"] = @benchmark(to_benchmark($rule, $coduals...)) if include_other_frameworks @@ -219,16 +219,16 @@ end function combine_results(result, tag, _range, default_range) d = result[2] primal_time = time(minimum(d["primal"])) - taped_time = time(minimum(d["taped"])) + phi_time = time(minimum(d["phi"])) zygote_time = in("zygote", keys(d)) ? time(minimum(d["zygote"])) : missing rd_time = in("rd", keys(d)) ? time(minimum(d["rd"])) : missing ez_time = in("enzyme", keys(d)) ? time(minimum(d["enzyme"])) : missing - fallback_tag = string((result[1][1], map(Taped._typeof, result[1][2:end])...)) + fallback_tag = string((result[1][1], map(Phi._typeof, result[1][2:end])...)) return ( tag=tag === nothing ? fallback_tag : tag, primal_time=primal_time, - taped_time=taped_time, - Taped=taped_time / primal_time, + phi_time=phi_time, + Phi=phi_time / primal_time, zygote_time=zygote_time, Zygote=zygote_time / primal_time, rd_time=rd_time, @@ -283,7 +283,7 @@ end function flag_concerning_performance(ratios) @testset "detect concerning performance" begin @testset for ratio in ratios - @test ratio.range.lb < ratio.Taped < ratio.range.ub + @test ratio.range.lb < ratio.Phi < ratio.range.ub end end end @@ -291,18 +291,18 @@ end """ plot_ratio_histogram!(df::DataFrame) -Constructs a histogram of the `taped_ratio` field of `df`, with formatting that is +Constructs a histogram of the `phi_ratio` field of `df`, with formatting that is well-suited to the numbers typically found in this field. """ function plot_ratio_histogram!(df::DataFrame) bin = 10.0 .^ (0.0:0.05:6.0) xlim = extrema(bin) - histogram(df.Taped; xscale=:log10, xlim, bin, title="log", label="") + histogram(df.Phi; xscale=:log10, xlim, bin, title="log", label="") end function create_inter_ad_benchmarks() results = benchmark_inter_framework_rules() - tools = [:Taped, :Zygote, :ReverseDiff, :Enzyme] + tools = [:Phi, :Zygote, :ReverseDiff, :Enzyme] df = DataFrame(results)[:, [:tag, tools...]] # Plot graph of results. diff --git a/ext/TapedSpecialFunctionsExt.jl b/ext/PhiSpecialFunctionsExt.jl similarity index 62% rename from ext/TapedSpecialFunctionsExt.jl rename to ext/PhiSpecialFunctionsExt.jl index b39449efc..68e02822e 100644 --- a/ext/TapedSpecialFunctionsExt.jl +++ b/ext/PhiSpecialFunctionsExt.jl @@ -1,8 +1,8 @@ -module TapedSpecialFunctionsExt +module PhiSpecialFunctionsExt - using SpecialFunctions, Taped + using SpecialFunctions, Phi - import Taped: @from_rrule, DefaultCtx + import Phi: @from_rrule, DefaultCtx @from_rrule DefaultCtx Tuple{typeof(airyai), Float64} @from_rrule DefaultCtx Tuple{typeof(airyaix), Float64} diff --git a/src/Taped.jl b/src/Phi.jl similarity index 99% rename from src/Taped.jl rename to src/Phi.jl index 1485685be..0f5b578f6 100644 --- a/src/Taped.jl +++ b/src/Phi.jl @@ -1,4 +1,4 @@ -module Taped +module Phi const CC = Core.Compiler diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 58c4ecbe9..3d62f2624 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -4,7 +4,7 @@ __increment_shim!!(x, y) = increment!!(x, y) """ @from_rrule ctx sig -Creates a `Taped.rrule!!` from a `ChainRulesCore.rrule`. `ctx` is the type of the context in +Creates a `Phi.rrule!!` from a `ChainRulesCore.rrule`. `ctx` is the type of the context in which this rule should apply, and `sig` is the type-tuple which specifies which primal the rule should apply to. @@ -12,7 +12,7 @@ For example, ```julia @from_rrule DefaultCtx Tuple{typeof(sin), Float64} ``` -would define a `Taped.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`. +would define a `Phi.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`. Health warning: Use this function with care. It has only been tested for `Float64` arguments and arguments @@ -29,13 +29,13 @@ macro from_rrule(ctx, sig) arg_type_symbols = sig.args[2:end] arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) - arg_types = map(t -> :(Taped.CoDual{<:$t}), arg_type_symbols) + arg_types = map(t -> :(Phi.CoDual{<:$t}), arg_type_symbols) arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) call_rrule = Expr( :call, - :(Taped.ChainRulesCore.rrule), - map(n -> :(Taped.primal($n)), arg_names)..., + :(Phi.ChainRulesCore.rrule), + map(n -> :(Phi.primal($n)), arg_names)..., ) pb_arg_names = map(n -> Symbol("dx_$(n)"), eachindex(arg_names)) @@ -45,7 +45,7 @@ macro from_rrule(ctx, sig) incrementers = Expr( :tuple, map(pb_arg_names, pb_output_names) do a, b - :(Taped.__increment_shim!!($a, $b)) + :(Phi.__increment_shim!!($a, $b)) end..., ) @@ -62,18 +62,18 @@ macro from_rrule(ctx, sig) rule_expr = ExprTools.combinedef( Dict( :head => :function, - :name => :(Taped.rrule!!), + :name => :(Phi.rrule!!), :args => arg_exprs, :body => quote y, pb = $call_rrule $pb - return Taped.zero_codual(y), pb!! + return Phi.zero_codual(y), pb!! end, ) ) ex = quote - Taped.is_primitive(::Type{$ctx}, ::Type{$sig}) = true + Phi.is_primitive(::Type{$ctx}, ::Type{$sig}) = true $rule_expr end return esc(ex) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index dad50c93b..a8b8df135 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -11,7 +11,7 @@ end TICache() = TICache(IdDict{Core.MethodInstance, Core.CodeInstance}()) -struct TapedInterpreter{C} <: CC.AbstractInterpreter +struct PhiInterpreter{C} <: CC.AbstractInterpreter meta # additional information world::UInt inf_params::CC.InferenceParams @@ -19,7 +19,7 @@ struct TapedInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult} code_cache::TICache oc_cache::Dict{Any, Any} - function TapedInterpreter( + function PhiInterpreter( ::Type{C}; meta=nothing, world::UInt=Base.get_world_counter(), @@ -32,15 +32,15 @@ struct TapedInterpreter{C} <: CC.AbstractInterpreter end end -TapedInterpreter() = TapedInterpreter(DefaultCtx) +PhiInterpreter() = PhiInterpreter(DefaultCtx) -const TInterp = TapedInterpreter +const PInterp = PhiInterpreter -CC.InferenceParams(interp::TInterp) = interp.inf_params -CC.OptimizationParams(interp::TInterp) = interp.opt_params -CC.get_world_counter(interp::TInterp) = interp.world -CC.get_inference_cache(interp::TInterp) = interp.inf_cache -function CC.code_cache(interp::TInterp) +CC.InferenceParams(interp::PInterp) = interp.inf_params +CC.OptimizationParams(interp::PInterp) = interp.opt_params +CC.get_world_counter(interp::PInterp) = interp.world +CC.get_inference_cache(interp::PInterp) = interp.inf_cache +function CC.code_cache(interp::PInterp) return CC.WorldView(interp.code_cache, CC.WorldRange(interp.world)) end function CC.get(wvc::CC.WorldView{TICache}, mi::Core.MethodInstance, default) @@ -62,7 +62,7 @@ _type(x::CC.PartialStruct) = x.typ _type(x::CC.Conditional) = Union{x.thentype, x.elsetype} function CC.inlining_policy( - interp::TapedInterpreter{C}, + interp::PhiInterpreter{C}, @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::UInt8, @@ -85,4 +85,4 @@ function CC.inlining_policy( ) end -context_type(::TInterp{C}) where {C} = C +context_type(::PInterp{C}) where {C} = C diff --git a/src/interpreter/contexts.jl b/src/interpreter/contexts.jl index 9226f8812..edebdd179 100644 --- a/src/interpreter/contexts.jl +++ b/src/interpreter/contexts.jl @@ -51,5 +51,5 @@ is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true You should implemented more complicated method of `is_primitive` in the usual way. """ macro is_primitive(Tctx, sig) - return esc(:(Taped.is_primitive(::Type{$Tctx}, ::Type{<:$sig}) = true)) + return esc(:(Phi.is_primitive(::Type{$Tctx}, ::Type{<:$sig}) = true)) end diff --git a/src/interpreter/interpreted_function.jl b/src/interpreter/interpreted_function.jl index 40fc3eb5a..529bf4fc3 100644 --- a/src/interpreter/interpreted_function.jl +++ b/src/interpreter/interpreted_function.jl @@ -343,7 +343,7 @@ struct InterpretedFunction{sig<:Tuple, C, Treturn, Targ_info<:ArgInfo} bb_starts::Vector{Int} bb_ends::Vector{Int} ir::IRCode - interp::TapedInterpreter + interp::PhiInterpreter spnames::Any end @@ -417,7 +417,7 @@ end Construct a data structure which can be used to execute the instruction specified by `sig`. For example, ```julia -in_f = InterpretedFunction(DefaultCtx(), Tuple{typeof(sin), Float64}, Taped.TInterp()) +in_f = InterpretedFunction(DefaultCtx(), Tuple{typeof(sin), Float64}, Phi.PInterp()) in_f(sin, 5.0) ``` will yield exactly the same result as running `sin(5.0)`. The advantage of this data @@ -458,7 +458,7 @@ and each `Argument` / `SSAValue` in the IR with a (heap-allocated) `AbstractSlot most part, these slots are `Ref`s). While the details of what each kind of `OpaqueClosure` can be found in the corresponding -`Taped.build_inst` method, they generally have the following structure: +`Phi.build_inst` method, they generally have the following structure: - load data from argument / ssa slots, - do computation, - write result to the instruction's ssa slot, @@ -566,7 +566,7 @@ end # `InterpretedFunction`s operate recursively -- if the types associated to the `args` field # of a `:call` expression have not been inferred successfully, then we must wait until # runtime to determine what code to run. The `DelayedInterpretedFunction` does exactly this. -struct DelayedInterpretedFunction{C, Tlocal_cache, T<:TapedInterpreter} +struct DelayedInterpretedFunction{C, Tlocal_cache, T<:PhiInterpreter} ctx::C local_cache::Tlocal_cache interp::T diff --git a/src/interpreter/ir_normalisation.jl b/src/interpreter/ir_normalisation.jl index 0e8715d03..014b04235 100644 --- a/src/interpreter/ir_normalisation.jl +++ b/src/interpreter/ir_normalisation.jl @@ -4,9 +4,9 @@ Apply a sequence of standardising transformations to `ir` which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace 1. `:invoke` `Expr`s with `:call`s, -2. `:foreigncall` `Expr`s with `:call`s to `Taped._foreigncall_`, -3. `:new` `Expr`s with `:call`s to `Taped._new_`, -4. `Core.IntrinsicFunction`s with counterparts from `Taped.IntrinsicWrappers`, +2. `:foreigncall` `Expr`s with `:call`s to `Phi._foreigncall_`, +3. `:new` `Expr`s with `:call`s to `Phi._new_`, +4. `Core.IntrinsicFunction`s with counterparts from `Phi.IntrinsicWrappers`, 5. `getfield(x, 1)` with `lgetfield(x, Val(1))`, and related transformations. `spnames` are the names associated to the static parameters of `ir`. These are needed when @@ -15,7 +15,7 @@ static parameter names have been translated into either types, or `:static_param expressions. Unfortunately, the static parameter names are not retained in `IRCode`, and the `Method` -from which the `IRCode` is derived must be consulted. `Taped.is_vararg_sig_and_sparam_names` +from which the `IRCode` is derived must be consulted. `Phi.is_vararg_sig_and_sparam_names` provides a convenient way to do this. """ function normalise!(ir::IRCode, spnames::Vector{Symbol}) @@ -34,12 +34,12 @@ end foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState}) If `inst` is a `:foreigncall` expression translate it into an equivalent `:call` expression. -If anything else, just return `inst`. See `Taped._foreigncall_` for details. +If anything else, just return `inst`. See `Phi._foreigncall_` for details. `sp_map` maps the names of the static parameters to their values. This function is intended to be called in the context of an `IRCode`, in which case the values of `sp_map` are given by the `sptypes` field of said `IRCode`. The keys should generally be obtained from the -`Method` from which the `IRCode` is derived. See `Taped.normalise!` for more details. +`Method` from which the `IRCode` is derived. See `Phi.normalise!` for more details. """ function foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState}) if Meta.isexpr(inst, :foreigncall) @@ -109,7 +109,7 @@ end """ new_to_call(x) -If instruction `x` is a `:new` expression, replace if with a `:call` to `Taped._new_`. +If instruction `x` is a `:new` expression, replace if with a `:call` to `Phi._new_`. Otherwise, return `x`. """ new_to_call(x) = Meta.isexpr(x, :new) ? Expr(:call, _new_, x.args...) : x @@ -118,7 +118,7 @@ new_to_call(x) = Meta.isexpr(x, :new) ? Expr(:call, _new_, x.args...) : x intrinsic_to_function(inst) If `inst` is a `:call` expression to a `Core.IntrinsicFunction`, replace it with a call to -the corresponding `function` from `Taped.IntrinsicsWrappers`, else return `inst`. +the corresponding `function` from `Phi.IntrinsicsWrappers`, else return `inst`. `cglobal` is a special case -- it requires that its first argument be static in exactly the same way as `:foreigncall`. See `IntrinsicsWrappers.__cglobal` for more info. diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 176949860..2e49525c1 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -22,7 +22,7 @@ function ircode( cfg = CC.compute_basic_blocks(insts) insts = __line_numbers_to_block_numbers!(insts, cfg) stmts = __insts_to_instruction_stream(insts) - linetable = [CC.LineInfoNode(Taped, :ircode, :ir_utils, Int32(1), Int32(0))] + linetable = [CC.LineInfoNode(Phi, :ircode, :ir_utils, Int32(1), Int32(0))] meta = Expr[] return CC.IRCode(stmts, cfg, linetable, argtypes, meta, CC.VarState[]) end @@ -89,7 +89,7 @@ the types in your IR are not being refined, you may wish to check that neither o things are happening. """ function infer_ir!(ir::IRCode) - return __infer_ir!(ir, CC.NativeInterpreter(), __get_toplevel_mi_from_ir(ir, Taped)) + return __infer_ir!(ir, CC.NativeInterpreter(), __get_toplevel_mi_from_ir(ir, Phi)) end # Sometimes types in `IRCode` have been replaced by constants, or partially-completed diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index b4b5c63b8..4d9fb5e64 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -76,7 +76,7 @@ This data structure is used to hold "global" information associated to a particu `build_rrule`. It is used as a means of communication between `make_ad_stmts!` and the codegen which produces the forwards- and reverse-passes. -- `interp`: a `TapedInterpreter`. +- `interp`: a `PhiInterpreter`. - `block_stack_id`: the ID associated to the block stack -- the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit. @@ -98,7 +98,7 @@ codegen which produces the forwards- and reverse-passes. `shared_data_pairs`. =# struct ADInfo - interp::TInterp + interp::PInterp block_stack_id::ID block_stack::Stack{Int32} entry_id::ID @@ -111,7 +111,7 @@ end # The constructor that you should use for ADInfo. function ADInfo( - interp::TInterp, + interp::PInterp, arg_types::Dict{Argument, Any}, ssa_insts::Dict{ID, NewInstruction}, arg_tangent_stacks, @@ -376,12 +376,12 @@ end get_const_primal_value(x::QuoteNode) = x.value get_const_primal_value(x) = x -# Taped does not yet handle `PhiCNode`s. Throw an error if one is encountered. +# Phi does not yet handle `PhiCNode`s. Throw an error if one is encountered. function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo) unhandled_feature("Encountered PhiCNode: $stmt") end -# Taped does not yet handle `UpsilonNode`s. Throw an error if one is encountered. +# Phi does not yet handle `UpsilonNode`s. Throw an error if one is encountered. function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo) unhandled_feature("Encountered UpsilonNode: $stmt") end @@ -667,7 +667,7 @@ end # Compute the concrete type of the rule that will be returned from `build_rrule`. This is # important for performance in dynamic dispatch, and to ensure that recursion works # properly. -function rule_type(interp::TapedInterpreter{C}, ::Type{sig}) where {C, sig} +function rule_type(interp::PhiInterpreter{C}, ::Type{sig}) where {C, sig} is_primitive(C, sig) && return typeof(rrule!!) ir, _ = lookup_ir(interp, sig) @@ -736,16 +736,16 @@ end Helper method. Only uses static information from `args`. """ function build_rrule(args...) - return build_rrule(TapedInterpreter(), _typeof(TestUtils.__get_primals(args))) + return build_rrule(PhiInterpreter(), _typeof(TestUtils.__get_primals(args))) end """ - build_rrule(interp::TInterp{C}, sig::Type{<:Tuple}) where {C} + build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}) where {C} Returns a `DerivedRule` which is an `rrule!!` for `sig` in context `C`. See the docstring for `rrule!!` for more info. """ -function build_rrule(interp::TInterp{C}, sig::Type{<:Tuple}) where {C} +function build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}) where {C} # Reset id count. This ensures that everything in this function is deterministic. seed_id!() @@ -980,7 +980,7 @@ __switch_case(id::Int32, predecessor_id::Int32) = !(id === predecessor_id) #= - DynamicDerivedRule(interp::TapedInterpreter) + DynamicDerivedRule(interp::PhiInterpreter) For internal use only. @@ -994,7 +994,7 @@ struct DynamicDerivedRule{T, V} cache::V end -DynamicDerivedRule(interp::TapedInterpreter) = DynamicDerivedRule(interp, Dict{Any, Any}()) +DynamicDerivedRule(interp::PhiInterpreter) = DynamicDerivedRule(interp, Dict{Any, Any}()) function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N} sig = Tuple{map(_typeof, map(primal, args))...} @@ -1020,7 +1020,7 @@ mutable struct LazyDerivedRule{Trule, T, V} interp::T sig::V rule::Trule - function LazyDerivedRule(interp::T, sig::V) where {T<:TInterp, V<:Type{<:Tuple}} + function LazyDerivedRule(interp::T, sig::V) where {T<:PInterp, V<:Type{<:Tuple}} return new{rule_type(interp, sig), T, V}(interp, sig) end end diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index ba48cba2b..df28db2c7 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -61,7 +61,7 @@ for (fname, elty) in ((:cblas_ddot,:Float64), (:cblas_sdot,:Float32)) end for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) - @eval function Taped.rrule!!( + @eval function Phi.rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$(blas_name(fname))}}, ::CoDual, # return type diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 8a0b77165..c2ec85bb8 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -15,8 +15,8 @@ module IntrinsicsWrappers using Core: Intrinsics -using Taped -import ..Taped: +using Phi +import ..Phi: rrule!!, CoDual, primal, tangent, zero_tangent, NoPullback, tangent_type, increment!!, @is_primitive, MinimalCtx, is_primitive @@ -110,17 +110,17 @@ pointer to, is known statically. In this regard it is like foreigncalls. As a consequence, it requires special handling. The name is converted into a `Val` so that it is available statically, and the function into which `cglobal` calls are converted is -named `Taped.IntrinsicsWrappers.__cglobal`, rather than `Taped.IntrinsicsWrappers.cglobal`. +named `Phi.IntrinsicsWrappers.__cglobal`, rather than `Phi.IntrinsicsWrappers.cglobal`. -If you examine the code associated with `Taped.intrinsic_to_function`, you will see that +If you examine the code associated with `Phi.intrinsic_to_function`, you will see that special handling of `cglobal` is used. =# __cglobal(::Val{s}, x::Vararg{Any, N}) where {s, N} = cglobal(s, x...) translate(::Val{Intrinsics.cglobal}) = __cglobal -Taped.is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(__cglobal), Vararg}}) = true +Phi.is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(__cglobal), Vararg}}) = true function rrule!!(::CoDual{typeof(__cglobal)}, args...) - return Taped.uninit_codual(__cglobal(map(primal, args)...)), NoPullback() + return Phi.uninit_codual(__cglobal(map(primal, args)...)), NoPullback() end @inactive_intrinsic checked_sadd_int diff --git a/src/test_utils.jl b/src/test_utils.jl index 5e3e0a7d0..eb520c3a8 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,7 +11,7 @@ module TestTypes using Base.Iterators: product using Core: svec using ExprTools: combinedef -using ..Taped: NoTangent, tangent_type, _typeof +using ..Phi: NoTangent, tangent_type, _typeof const PRIMALS = Tuple{Bool, Any}[] @@ -84,8 +84,8 @@ interfaces that this package defines have been implemented correctly. """ module TestUtils -using JET, Random, Taped, Test, InteractiveUtils -using Taped: +using JET, Random, Phi, Test, InteractiveUtils +using Phi: CoDual, NoTangent, rrule!!, is_init, zero_codual, DefaultCtx, @is_primitive, val, is_always_fully_initialised, get_tangent_field, set_tangent_field!, MutableTangent, Tangent, _typeof @@ -289,7 +289,7 @@ function test_rrule_interface(f_f̄, x_x̄...; is_primitive, ctx::C, rule) where # Verify that the function to which the rrule applies is considered a primitive. # It is not clear that this really belongs here to be frank. if is_primitive - @test Taped.is_primitive(C, _typeof((f, x...))) + @test Phi.is_primitive(C, _typeof((f, x...))) end # Run the primal programme. Bail out early if this doesn't work. @@ -446,7 +446,7 @@ function test_interpreted_rrule!!(rng::AbstractRNG, x...; interp, kwargs...) end function test_derived_rule(rng::AbstractRNG, x...; interp, kwargs...) - rule = Taped.build_rrule(interp, _typeof(__get_primals(x))) + rule = Phi.build_rrule(interp, _typeof(__get_primals(x))) test_rrule!!(rng, x...; rule, kwargs...) end @@ -621,19 +621,19 @@ end function test_set_tangent_field!_correctness(t1::T, t2::T) where {T<:MutableTangent} Tfields = _typeof(t1.fields) for n in 1:fieldcount(Tfields) - !Taped.is_init(t2.fields[n]) && continue + !Phi.is_init(t2.fields[n]) && continue v = get_tangent_field(t2, n) # Int form. - v′ = Taped.set_tangent_field!(t1, n, v) + v′ = Phi.set_tangent_field!(t1, n, v) @test v′ === v - @test Taped.get_tangent_field(t1, n) === v + @test Phi.get_tangent_field(t1, n) === v # Symbol form. s = fieldname(Tfields, n) - g = Taped.set_tangent_field!(t1, s, v) + g = Phi.set_tangent_field!(t1, s, v) @test g === v - @test Taped.get_tangent_field(t1, n) === v + @test Phi.get_tangent_field(t1, n) === v end end @@ -701,7 +701,7 @@ function test_set_tangent_field!_performance(t1::T, t2::T) where {V, T<:MutableT _set_tangent_field!(t1, Val(n), v) JET.@report_opt _set_tangent_field!(t1, Val(n), v) - if all(n -> !(fieldtype(V, n) <: Taped.PossiblyUninitTangent), 1:fieldcount(V)) + if all(n -> !(fieldtype(V, n) <: Phi.PossiblyUninitTangent), 1:fieldcount(V)) i = Val(n) _set_tangent_field!(t1, i, v) @test count_allocs(_set_tangent_field!, t1, i, v) == 0 @@ -712,7 +712,7 @@ function test_set_tangent_field!_performance(t1::T, t2::T) where {V, T<:MutableT @inferred _set_tangent_field!(t1, s, v) JET.@report_opt _set_tangent_field!(t1, s, v) - if all(n -> !(fieldtype(V, n) <: Taped.PossiblyUninitTangent), 1:fieldcount(V)) + if all(n -> !(fieldtype(V, n) <: Phi.PossiblyUninitTangent), 1:fieldcount(V)) _set_tangent_field!(t1, s, v) @test count_allocs(_set_tangent_field!, t1, s, v) == 0 end @@ -720,7 +720,7 @@ function test_set_tangent_field!_performance(t1::T, t2::T) where {V, T<:MutableT end function test_get_tangent_field_performance(t::Union{MutableTangent, Tangent}) - V = Taped._typeof(t.fields) + V = Phi._typeof(t.fields) for n in 1:fieldcount(V) !is_init(t.fields[n]) && continue @@ -757,7 +757,7 @@ __tangent_generation_should_allocate(::Type{P}) where {P<:Array} = true function __increment_should_allocate(::Type{P}) where {P} return any(eachindex(fieldtypes(P))) do n - Taped.tangent_field_type(P, n) <: PossiblyUninitTangent + Phi.tangent_field_type(P, n) <: PossiblyUninitTangent end end @@ -800,10 +800,10 @@ function test_tangent(rng::AbstractRNG, p::P, z_target::T, x::T, y::T) where {P, @test Tt == _typeof(z) # Check that zero_tangent is deterministic. - @test has_equal_data(z, Taped.zero_tangent(p)) + @test has_equal_data(z, Phi.zero_tangent(p)) # Check that zero_tangent infers. - @test has_equal_data(z, @inferred Taped.zero_tangent(p)) + @test has_equal_data(z, @inferred Phi.zero_tangent(p)) # Verify that the zero tangent is zero via its action. zc = deepcopy(z) @@ -878,15 +878,15 @@ function test_equality_comparison(x) end function run_hand_written_rrule!!_test_cases(rng_ctor, v::Val) - test_cases, memory = Taped.generate_hand_written_rrule!!_test_cases(rng_ctor, v) + test_cases, memory = Phi.generate_hand_written_rrule!!_test_cases(rng_ctor, v) GC.@preserve memory @testset "$f, $(_typeof(x))" for (interface_only, perf_flag, _, f, x...) in test_cases test_rrule!!(rng_ctor(123), f, x...; interface_only, perf_flag) end end function run_derived_rrule!!_test_cases(rng_ctor, v::Val) - interp = Taped.TInterp() - test_cases, memory = Taped.generate_derived_rrule!!_test_cases(rng_ctor, v) + interp = Phi.PInterp() + test_cases, memory = Phi.generate_derived_rrule!!_test_cases(rng_ctor, v) GC.@preserve memory @testset "$f, $(typeof(x))" for (interface_only, perf_flag, _, f, x...) in test_cases test_derived_rule( @@ -906,7 +906,7 @@ function to_benchmark(__rrule!!::R, dx::Vararg{CoDual, N}) where {R, N} end """ - set_up_gradient_problem(fargs...; interp=Taped.TInterp()) + set_up_gradient_problem(fargs...; interp=Phi.PInterp()) Constructs a `rule` and `InterpretedFunction` which can be passed to `value_and_gradient!!`. @@ -914,8 +914,8 @@ For example: ```julia f(x) = sum(abs2, x) x = randn(25) -rule, in_f = Taped.TestUtils.set_up_gradient_problem(f, x) -y, dx = Taped.TestUtils.value_and_gradient!!(rule, in_f, f, x) +rule, in_f = Phi.TestUtils.set_up_gradient_problem(f, x) +y, dx = Phi.TestUtils.value_and_gradient!!(rule, in_f, f, x) ``` will yield the value and associated gradient for `f` and `x`. @@ -924,15 +924,15 @@ with the same `rule` and `in_f` arguments, but with different values of `x`. Optionally, an interpreter may be provided via the `interp` kwarg. -See also: `Taped.TestUtils.value_and_gradient!!`. +See also: `Phi.TestUtils.value_and_gradient!!`. """ -function set_up_gradient_problem(fargs...; interp=Taped.TInterp()) +function set_up_gradient_problem(fargs...; interp=Phi.PInterp()) sig = _typeof(__get_primals(fargs)) - if Taped.is_primitive(DefaultCtx, sig) - return rrule!!, Taped._eval + if Phi.is_primitive(DefaultCtx, sig) + return rrule!!, Phi._eval else - in_f = Taped.InterpretedFunction(DefaultCtx(), sig, interp) - return Taped.build_rrule!!(in_f), in_f + in_f = Phi.InterpretedFunction(DefaultCtx(), sig, interp) + return Phi.build_rrule!!(in_f), in_f end end @@ -949,8 +949,8 @@ AD anything. """ module TestResources -using ..Taped -using ..Taped: +using ..Phi +using ..Phi: CoDual, Tangent, MutableTangent, NoTangent, PossiblyUninitTangent, ircode, @is_primitive, MinimalCtx, val @@ -1357,7 +1357,7 @@ end @noinline edge_case_tester(x::Int) = 10 @noinline edge_case_tester(x::String) = "hi" @is_primitive MinimalCtx Tuple{typeof(edge_case_tester), Float64} -function Taped.rrule!!(::CoDual{typeof(edge_case_tester)}, x::CoDual{Float64}) +function Phi.rrule!!(::CoDual{typeof(edge_case_tester)}, x::CoDual{Float64}) edge_case_tester_pb!!(dy, df, dx) = df, dx + 5 * dy return CoDual(5 * primal(x), 0.0), edge_case_tester_pb!! end @@ -1379,7 +1379,7 @@ sr(n) = Xoshiro(n) return a < b ? a * b : test_self_reference(b, a) + a end -# See https://github.com/withbayes/Taped.jl/pull/84 for info +# See https://github.com/withbayes/Phi.jl/pull/84 for info @noinline function test_recursive_sum(x::Vector{Float64}) isempty(x) && return 0.0 return @inbounds x[1] + test_recursive_sum(x[2:end]) @@ -1589,7 +1589,7 @@ function _setfield!(value::MutableTangent, name, x) return x end -function Taped.rrule!!(::Taped.CoDual{typeof(my_setfield!)}, value, name, x) +function Phi.rrule!!(::Phi.CoDual{typeof(my_setfield!)}, value, name, x) _name = primal(name) old_x = isdefined(primal(value), _name) ? getfield(primal(value), _name) : nothing function setfield!_pullback(dy, df, dvalue, ::NoTangent, dx) @@ -1598,7 +1598,7 @@ function Taped.rrule!!(::Taped.CoDual{typeof(my_setfield!)}, value, name, x) old_x !== nothing && setfield!(primal(value), _name, old_x) return df, dvalue, NoTangent(), new_dx end - y = Taped.CoDual( + y = Phi.CoDual( setfield!(primal(value), _name, primal(x)), _setfield!(tangent(value), _name, tangent(x)), ) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index fea0f10d9..4547a067f 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -4,8 +4,8 @@ function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) end -Taped.@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} +Phi.@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} @testset "chain_rules_macro" begin - Taped.TestUtils.test_rrule!!(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) + Phi.TestUtils.test_rrule!!(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) end diff --git a/test/codual.jl b/test/codual.jl index 1f6a1afcf..20c5f8662 100644 --- a/test/codual.jl +++ b/test/codual.jl @@ -2,7 +2,7 @@ @test CoDual(5.0, 4.0) isa CoDual{Float64, Float64} @test CoDual(Float64, NoTangent()) isa CoDual{Type{Float64}, NoTangent} @test zero_codual(5.0) == CoDual(5.0, 0.0) - @test Taped.uninit_codual(5.0) == CoDual(5.0, 0.0) + @test Phi.uninit_codual(5.0) == CoDual(5.0, 0.0) @test codual_type(Float64) == CoDual{Float64, Float64} @test codual_type(Int) == CoDual{Int, NoTangent} @test codual_type(Real) == CoDual diff --git a/test/front_matter.jl b/test/front_matter.jl index 75436e41b..c37c64185 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -8,7 +8,7 @@ using Random, SpecialFunctions, StableRNGs, - Taped, + Phi, Test import ChainRulesCore @@ -19,7 +19,7 @@ using Core: bitcast, svec, ReturnNode, PhiNode, PiNode, GotoIfNot, GotoNode, SSAValue, Argument using Core.Intrinsics: pointerref, pointerset -using Taped: +using Phi: CC, IntrinsicsWrappers, TestUtils, @@ -81,6 +81,6 @@ const test_group = get(ENV, "TEST_GROUP", "basic") sr(n::Int) = StableRNG(n) # This is annoying and hacky and should be improved. -if isempty(Taped.TestTypes.PRIMALS) - Taped.TestTypes.generate_primals() +if isempty(Phi.TestTypes.PRIMALS) + Phi.TestTypes.generate_primals() end diff --git a/test/integration_testing/array.jl b/test/integration_testing/array.jl index 4eee88cd8..8fc20c436 100644 --- a/test/integration_testing/array.jl +++ b/test/integration_testing/array.jl @@ -502,7 +502,7 @@ _getter() = 5.0 for (interface_only, f, x...) in test_cases f(deepcopy(x)...) end - interp = Taped.TInterp() + interp = Phi.PInterp() @testset for (interface_only, f, x...) in test_cases @info _typeof((f, x...)) TestUtils.test_derived_rule( diff --git a/test/integration_testing/diff_tests.jl b/test/integration_testing/diff_tests.jl index 934877c07..7da695850 100644 --- a/test/integration_testing/diff_tests.jl +++ b/test/integration_testing/diff_tests.jl @@ -1,5 +1,5 @@ @testset "diff_tests" begin - interp = Taped.TInterp() + interp = Phi.PInterp() @testset "$f, $(_typeof(x))" for (n, (interface_only, f, x...)) in enumerate(vcat( TestResources.DIFFTESTS_FUNCTIONS[1:31], # SKIPPING SPARSE_LDIV mat2num_4 and softmax due to `_apply_iterate` handling TestResources.DIFFTESTS_FUNCTIONS[34:66], # SKIPPING SPARSE_LDIV diff --git a/test/integration_testing/distributions.jl b/test/integration_testing/distributions.jl index 34c9aced4..cd9a382d4 100644 --- a/test/integration_testing/distributions.jl +++ b/test/integration_testing/distributions.jl @@ -4,7 +4,7 @@ _sym(A) = A'A _pdmat(A) = PDMat(_sym(A) + 5I) @testset "distributions" begin - interp = Taped.TInterp() + interp = Phi.PInterp() @testset "$(typeof(d))" for (interface_only, d, x) in Any[ # diff --git a/test/integration_testing/gp.jl b/test/integration_testing/gp.jl index 7c0ef1896..a167c0aae 100644 --- a/test/integration_testing/gp.jl +++ b/test/integration_testing/gp.jl @@ -1,7 +1,7 @@ using AbstractGPs, KernelFunctions @testset "gp" begin - interp = Taped.TInterp() + interp = Phi.PInterp() base_kernels = Any[ ZeroKernel(), ConstantKernel(; c=1.0), diff --git a/test/integration_testing/misc.jl b/test/integration_testing/misc.jl index 8ce7362bd..517465e56 100644 --- a/test/integration_testing/misc.jl +++ b/test/integration_testing/misc.jl @@ -1,5 +1,5 @@ @testset "integration_testing" begin - interp = Taped.TInterp() + interp = Phi.PInterp() @testset for (interface_only, f, x...) in vcat( [ (false, getindex, randn(5), 4), diff --git a/test/integration_testing/turing.jl b/test/integration_testing/turing.jl index 17a636c08..0162c0f50 100644 --- a/test/integration_testing/turing.jl +++ b/test/integration_testing/turing.jl @@ -86,7 +86,7 @@ function build_turing_problem(rng, model, example=nothing) end @testset "turing" begin - interp = Taped.TInterp() + interp = Phi.PInterp() @testset "$(typeof(model))" for (interface_only, name, model, ex) in vcat( Any[ (false, "simple_model", simple_model(), nothing), diff --git a/test/interpreter/abstract_interpretation.jl b/test/interpreter/abstract_interpretation.jl index 4229e9f82..88b93c3e8 100644 --- a/test/interpreter/abstract_interpretation.jl +++ b/test/interpreter/abstract_interpretation.jl @@ -1,8 +1,8 @@ a_primitive(x) = sin(x) non_primitive(x) = sin(x) -Taped.is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(a_primitive), Any}}) = true -Taped.is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(non_primitive), Any}}) = false +Phi.is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(a_primitive), Any}}) = true +Phi.is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(non_primitive), Any}}) = false contains_primitive(x) = @inline a_primitive(x) contains_non_primitive(x) = @inline non_primitive(x) @@ -11,7 +11,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @testset "abstract_interpretation" begin # Check that inlining doesn't / does happen as expected. - @testset "TapedInterpreter" begin + @testset "PhiInterpreter" begin @testset "non-primitive continues to be inlined away" begin # A non-primitive is present in the IR for contains_non_primitive. It is @@ -25,7 +25,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert usual_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :sin) # Should continue to inline away under AD compilation. - interp = Taped.TapedInterpreter(DefaultCtx) + interp = Phi.PhiInterpreter(DefaultCtx) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), ad_ir.stmts.inst) @test ad_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :sin) @@ -42,7 +42,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert usual_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :sin) # Should not inline away under AD compilation. - interp = Taped.TapedInterpreter(DefaultCtx) + interp = Phi.PhiInterpreter(DefaultCtx) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), ad_ir.stmts.inst) @test ad_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :a_primitive) @@ -61,7 +61,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert usual_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :sin) # Should not inline away under AD compilation. - interp = Taped.TapedInterpreter(DefaultCtx) + interp = Phi.PhiInterpreter(DefaultCtx) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), ad_ir.stmts.inst) @test ad_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :a_primitive) diff --git a/test/interpreter/bbcode.jl b/test/interpreter/bbcode.jl index 415dfb2a1..b5125874d 100644 --- a/test/interpreter/bbcode.jl +++ b/test/interpreter/bbcode.jl @@ -28,7 +28,7 @@ end bb_copy = copy(bb) @test bb_copy.inst_ids !== bb.inst_ids - @test Taped.terminator(bb) === nothing + @test Phi.terminator(bb) === nothing end @testset "BBCode $f" for (f, P) in [ (TestResources.test_while_loop, Tuple{Float64}), @@ -39,14 +39,14 @@ end bb_code = BBCode(ir) @test bb_code isa BBCode @test length(bb_code.blocks) == length(ir.cfg.blocks) - new_ir = Taped.IRCode(bb_code) + new_ir = Phi.IRCode(bb_code) @test length(new_ir.stmts.inst) == length(ir.stmts.inst) @test all(map(==, ir.stmts.inst, new_ir.stmts.inst)) @test all(map(==, ir.stmts.type, new_ir.stmts.type)) @test all(map(==, ir.stmts.info, new_ir.stmts.info)) @test all(map(==, ir.stmts.line, new_ir.stmts.line)) @test all(map(==, ir.stmts.flag, new_ir.stmts.flag)) - @test length(Taped.collect_stmts(bb_code)) == length(ir.stmts.inst) - @test Taped.id_to_line_map(bb_code) isa Dict{ID, Int} + @test length(Phi.collect_stmts(bb_code)) == length(ir.stmts.inst) + @test Phi.id_to_line_map(bb_code) isa Dict{ID, Int} end end diff --git a/test/interpreter/interpreted_function.jl b/test/interpreter/interpreted_function.jl index 6b1574f7a..1f5491745 100644 --- a/test/interpreter/interpreted_function.jl +++ b/test/interpreter/interpreted_function.jl @@ -32,8 +32,8 @@ Any[Tuple{Float64, Tuple{Int}}, (5.0, 3), true], Any[Tuple{Float64, Tuple{Int, Float64}}, (5.0, 3, 4.0), true], ] - ai = Taped.ArgInfo(Tx, is_va) - @test @inferred Taped.load_args!(ai, x) === nothing + ai = Phi.ArgInfo(Tx, is_va) + @test @inferred Phi.load_args!(ai, x) === nothing end @testset "TypedPhiNode" begin @@ -44,27 +44,27 @@ (1, 2), (ConstSlot(5.0), SlotRef(4.0)), ) - Taped.store_tmp_value!(node, 1) + Phi.store_tmp_value!(node, 1) @test node.tmp_slot[] == 5.0 - Taped.transfer_tmp_value!(node) + Phi.transfer_tmp_value!(node) @test node.ret_slot[] == 5.0 - Taped.store_tmp_value!(node, 2) + Phi.store_tmp_value!(node, 2) @test node.tmp_slot[] == 4.0 @test node.ret_slot[] == 5.0 - Taped.transfer_tmp_value!(node) + Phi.transfer_tmp_value!(node) @test node.ret_slot[] == 4.0 end @testset "phi node with nothing in it" begin node = TypedPhiNode(SlotRef{Union{}}(), SlotRef{Union{}}(), (), ()) - Taped.store_tmp_value!(node, 1) - Taped.transfer_tmp_value!(node) + Phi.store_tmp_value!(node, 1) + Phi.transfer_tmp_value!(node) end @testset "phi node with undefined value" begin node = TypedPhiNode( SlotRef{Float64}(), SlotRef{Float64}(), (1, ), (SlotRef{Float64}(),) ) - Taped.store_tmp_value!(node, 1) - Taped.transfer_tmp_value!(node) + Phi.store_tmp_value!(node, 1) + Phi.transfer_tmp_value!(node) end end @@ -86,7 +86,7 @@ ] val, ret_slot = args oc = build_inst(ReturnNode, ret_slot, val) - @test oc isa Taped.Inst + @test oc isa Phi.Inst output = oc(0) @test output == -1 @test ret_slot[] == val[] @@ -95,7 +95,7 @@ @testset "GotoNode $label" for label in Any[1, 2, 3, 4, 5] oc = build_inst(GotoNode, label) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(3) == label end @@ -108,7 +108,7 @@ TypedGlobalRef(GlobalRef(Main, :__global_bool)), ] oc = build_inst(GotoIfNot, cond, 1, 2) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(5) == (cond[] ? 1 : 2) end @@ -119,7 +119,7 @@ (TypedGlobalRef(GlobalRef(Main, :__global_bool)), ConstSlot(true), 2, 2) ] oc = build_inst(PiNode, input, out, next_blk) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(prev_blk) == next_blk @test out[] == input[] end @@ -130,7 +130,7 @@ (SlotRef{typeof(sin)}(), ConstSlot(sin), 4), ] oc = build_inst(GlobalRef, x, out, next_blk) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(4) == next_blk @test out[] == x[] end @@ -139,7 +139,7 @@ (ConstSlot(5), SlotRef{Int}(), 5), ] oc = build_inst(nothing, x, out, next_blk) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(1) == next_blk @test out[] == x[] end @@ -147,25 +147,25 @@ @testset "Val{:boundscheck}" begin val_ref = SlotRef{Bool}() oc = build_inst(Val(:boundscheck), val_ref, 3) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(5) == 3 @test val_ref[] == true end global __int_output = 5 @testset "Val{:call}" for (arg_slots, evaluator, val_slot, next_blk) in Any[ - ((ConstSlot(sin), SlotRef(5.0)), Taped._eval, SlotRef{Float64}(), 3), - ((ConstSlot(*), SlotRef(4.0), ConstSlot(4.0)), Taped._eval, SlotRef{Any}(), 3), + ((ConstSlot(sin), SlotRef(5.0)), Phi._eval, SlotRef{Float64}(), 3), + ((ConstSlot(*), SlotRef(4.0), ConstSlot(4.0)), Phi._eval, SlotRef{Any}(), 3), ( (ConstSlot(+), ConstSlot(4), ConstSlot(5)), - Taped._eval, + Phi._eval, TypedGlobalRef(Main, :__int_output), 2, ), ( (ConstSlot(getfield), SlotRef((5.0, 5)), ConstSlot(1)), - Taped.get_evaluator( - Taped.MinimalCtx(), + Phi.get_evaluator( + Phi.MinimalCtx(), Tuple{typeof(getfield), Tuple{Float64, Int}, Int}, nothing, false, @@ -175,7 +175,7 @@ ), ] oc = build_inst(Val(:call), arg_slots, evaluator, val_slot, next_blk) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(0) == next_blk f, args... = map(getindex, arg_slots) @test val_slot[] == f(args...) @@ -183,7 +183,7 @@ @testset "Val{:skipped_expression}" begin oc = build_inst(Val(:skipped_expression), 3) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(5) == 3 end @@ -191,19 +191,19 @@ @testset "defined" begin slot_to_check = SlotRef(5.0) oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test oc(0) == 2 end @testset "undefined (non-isbits)" begin slot_to_check = SlotRef{Any}() oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - @test oc isa Taped.Inst + @test oc isa Phi.Inst @test_throws ErrorException oc(3) end @testset "undefined (isbits)" begin slot_to_check = SlotRef{Float64}() oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - @test oc isa Taped.Inst + @test oc isa Phi.Inst # a placeholder for failing to throw an ErrorException when evaluated @test_broken oc(5) == 1 @@ -212,13 +212,13 @@ end # Check that a suite of test cases run and give the correct answer. - interp = Taped.TInterp() + interp = Phi.PInterp() @testset "$(_typeof((f, x...)))" for (a, b, c, f, x...) in TestResources.generate_test_functions() sig = _typeof((f, x...)) @info "$sig" - in_f = Taped.InterpretedFunction(DefaultCtx(), sig, interp) + in_f = Phi.InterpretedFunction(DefaultCtx(), sig, interp) # Verify correctness. @assert f(x...) == f(x...) # check that the primal runs diff --git a/test/interpreter/ir_normalisation.jl b/test/interpreter/ir_normalisation.jl index a4f248ea0..224c6a6a9 100644 --- a/test/interpreter/ir_normalisation.jl +++ b/test/interpreter/ir_normalisation.jl @@ -12,32 +12,32 @@ 0x0000000000000001, ) sp_map = Dict{Symbol, CC.VarState}() - call = Taped.foreigncall_to_call(foreigncall, sp_map) + call = Phi.foreigncall_to_call(foreigncall, sp_map) @test Meta.isexpr(call, :call) - @test call.args[1] == Taped._foreigncall_ + @test call.args[1] == Phi._foreigncall_ end @testset "new_to_call" begin - new_ex = Expr(:new, GlobalRef(Taped, :Foo), SSAValue(1), :hi) - call_ex = Taped.new_to_call(new_ex) + new_ex = Expr(:new, GlobalRef(Phi, :Foo), SSAValue(1), :hi) + call_ex = Phi.new_to_call(new_ex) @test Meta.isexpr(call_ex, :call) - @test call_ex.args[1] == Taped._new_ + @test call_ex.args[1] == Phi._new_ @test call_ex.args[2:end] == new_ex.args end @testset "intrinsic_to_function" begin @testset "GlobalRef" begin intrinsic_ex = Expr(:call, GlobalRef(Core.Intrinsics, :abs_float), SSAValue(1)) - wrapper_ex = Taped.intrinsic_to_function(intrinsic_ex) - @test wrapper_ex.args[1] == Taped.IntrinsicsWrappers.abs_float + wrapper_ex = Phi.intrinsic_to_function(intrinsic_ex) + @test wrapper_ex.args[1] == Phi.IntrinsicsWrappers.abs_float end @testset "IntrinsicFunction" begin intrinsic_ex = Expr(:call, Core.Intrinsics.abs_float, SSAValue(1)) - wrapper_ex = Taped.intrinsic_to_function(intrinsic_ex) - @test wrapper_ex.args[1] == Taped.IntrinsicsWrappers.abs_float + wrapper_ex = Phi.intrinsic_to_function(intrinsic_ex) + @test wrapper_ex.args[1] == Phi.IntrinsicsWrappers.abs_float end @testset "cglobal" begin cglobal_ex = Expr(:call, cglobal, :jl_uv_stdout, Ptr{Cvoid}) - wrapper_ex = Taped.intrinsic_to_function(cglobal_ex) - @test wrapper_ex.args[1] == Taped.IntrinsicsWrappers.__cglobal + wrapper_ex = Phi.intrinsic_to_function(cglobal_ex) + @test wrapper_ex.args[1] == Phi.IntrinsicsWrappers.__cglobal end end @testset "lift_getfield_and_others $ex" for (ex, target) in Any[ @@ -79,6 +79,6 @@ Expr(:call, sin, SSAValue(1)), ), ] - @test Taped.lift_getfield_and_others(ex) == target + @test Phi.lift_getfield_and_others(ex) == target end end diff --git a/test/interpreter/ir_utils.jl b/test/interpreter/ir_utils.jl index c746d3135..4c48b88c6 100644 --- a/test/interpreter/ir_utils.jl +++ b/test/interpreter/ir_utils.jl @@ -13,9 +13,9 @@ end f, args... = fargs insts = only(code_typed(f, _typeof(args)))[1].code - # Use Taped.ircode to build an `IRCode`. + # Use Phi.ircode to build an `IRCode`. argtypes = Any[map(_typeof, fargs)...] - ir = Taped.ircode(insts, argtypes) + ir = Phi.ircode(insts, argtypes) # Check the validity of the `IRCode`, and that an OpaqueClosure constructed using it # gives the same answer as the original function. @@ -25,7 +25,7 @@ end @testset "infer_ir!" begin # Generate IR without any types. - ir = Taped.ircode( + ir = Phi.ircode( Any[ Expr(:call, GlobalRef(Base, :sin), Argument(2)), Expr(:call, cos, SSAValue(1)), @@ -35,7 +35,7 @@ end ) # Run inference and check that the types are as expected. - ir = Taped.infer_ir!(ir) + ir = Phi.infer_ir!(ir) @test ir.stmts.type[1] == Float64 @test ir.stmts.type[2] == Float64 @@ -68,31 +68,31 @@ end (ReturnNode(SSAValue(3)), ReturnNode(SSAValue(3))), (ReturnNode(), ReturnNode()), ] - @test Taped.replace_uses_with(val, SSAValue(1), SSAValue(2)) == target + @test Phi.replace_uses_with(val, SSAValue(1), SSAValue(2)) == target end @testset "PhiNode with undefined" begin vals_with_undef_1 = Vector{Any}(undef, 2) vals_with_undef_1[2] = SSAValue(1) val = PhiNode(Int32[1, 2], vals_with_undef_1) - result = Taped.replace_uses_with(val, SSAValue(1), SSAValue(2)) + result = Phi.replace_uses_with(val, SSAValue(1), SSAValue(2)) @test result.values[2] == SSAValue(2) @test !isassigned(result.values, 1) end end @testset "globalref_type" begin - @test Taped.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_1)) == Any - @test Taped.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_2)) == Float64 - @test Taped.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_3)) == Float64 - @test Taped.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_4)) == Float64 + @test Phi.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_1)) == Any + @test Phi.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_2)) == Float64 + @test Phi.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_3)) == Float64 + @test Phi.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_4)) == Float64 end @testset "unhandled_feature" begin - @test_throws Taped.UnhandledLanguageFeatureException Taped.unhandled_feature("foo") + @test_throws Phi.UnhandledLanguageFeatureException Phi.unhandled_feature("foo") end @testset "inc_args" begin - @test Taped.inc_args(Expr(:call, sin, Argument(4))) == Expr(:call, sin, Argument(5)) - @test Taped.inc_args(ReturnNode(Argument(2))) == ReturnNode(Argument(3)) + @test Phi.inc_args(Expr(:call, sin, Argument(4))) == Expr(:call, sin, Argument(5)) + @test Phi.inc_args(ReturnNode(Argument(2))) == ReturnNode(Argument(3)) id = ID() - @test Taped.inc_args(IDGotoIfNot(Argument(1), id)) == IDGotoIfNot(Argument(2), id) - @test Taped.inc_args(IDGotoNode(id)) == IDGotoNode(id) + @test Phi.inc_args(IDGotoIfNot(Argument(1), id)) == IDGotoIfNot(Argument(2), id) + @test Phi.inc_args(IDGotoNode(id)) == IDGotoNode(id) end end diff --git a/test/interpreter/registers.jl b/test/interpreter/registers.jl index 4ef420927..7c3df1fa6 100644 --- a/test/interpreter/registers.jl +++ b/test/interpreter/registers.jl @@ -1,8 +1,8 @@ @testset "registers" begin - @test Taped.register_type(Float64) <: Taped.AugmentedRegister{CoDual{Float64, Float64}} - @test Taped.register_type(Bool) <: Taped.AugmentedRegister{CoDual{Bool, NoTangent}} - @test Taped.register_type(Any) == Taped.AugmentedRegister - @test Taped.register_type(Real) == Taped.AugmentedRegister - @test ==(Taped.register_type(Union{Float64, Float32}), Taped.AugmentedRegister) - @test Taped.register_type(Union{Float64, Bool}) <: Union{Taped.AugmentedRegister, Bool} + @test Phi.register_type(Float64) <: Phi.AugmentedRegister{CoDual{Float64, Float64}} + @test Phi.register_type(Bool) <: Phi.AugmentedRegister{CoDual{Bool, NoTangent}} + @test Phi.register_type(Any) == Phi.AugmentedRegister + @test Phi.register_type(Real) == Phi.AugmentedRegister + @test ==(Phi.register_type(Union{Float64, Float32}), Phi.AugmentedRegister) + @test Phi.register_type(Union{Float64, Bool}) <: Union{Phi.AugmentedRegister, Bool} end diff --git a/test/interpreter/reverse_mode_ad.jl b/test/interpreter/reverse_mode_ad.jl index 63dda81ab..831426e13 100644 --- a/test/interpreter/reverse_mode_ad.jl +++ b/test/interpreter/reverse_mode_ad.jl @@ -11,13 +11,13 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(ReturnNode, ret, ret_tangent, val) # Test forwards instruction. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(5) == -1 @test ret[] == get_codual(val) @test (@allocations fwds_inst(5)) == 0 # Test backwards instruction. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst ret_tangent[] = 2.0 @test bwds_inst(5) isa Int @test get_tangent_stack(val)[] == 3.0 @@ -30,13 +30,13 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(ReturnNode, ret, ret_tangent, val) # Test forwards instruction. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(5) == -1 @test ret[] == get_codual(val) @test (@allocations fwds_inst(5)) == 0 # Test backwards instruction. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(5) isa Int @test (@allocations bwds_inst(5)) == 0 end @@ -46,12 +46,12 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(GotoNode, dest) # Test forwards instructions. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(1) == dest @test (@allocations fwds_inst(1)) == 0 # Test reverse instructions. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(1) == 1 @test (@allocations bwds_inst(1)) == 0 end @@ -63,7 +63,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(GotoIfNot, dest, next_blk, cond) # Test forwards instructions. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(1) == next_blk @test (@allocations fwds_inst(1)) == 0 cond[] = (zero_codual(false), get_tangent_stack(cond)) @@ -71,7 +71,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} @test (@allocations fwds_inst(1)) == 0 # Test backwards instructions. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(4) == 4 @test (@allocations bwds_inst(1)) == 0 end @@ -82,12 +82,12 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(GotoIfNot, dest, next_blk, cond) # Test forwards instructions. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(1) == next_blk @test (@allocations fwds_inst(1)) == 0 # Test backwards instructions. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(4) == 4 @test (@allocations bwds_inst(1)) == 0 end @@ -122,7 +122,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(Vector{PhiNode}, nodes, next_blk) # Test forwards instructions. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(1) == next_blk @test (@allocations fwds_inst(1)) == 0 @test nodes[1].tmp_slot[] == nodes[1].values[1][] @@ -133,7 +133,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} @test nodes[3].ret_slot[] == nodes[3].tmp_slot[] # Test backwards instructions. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(4) == 4 @test (@allocations bwds_inst(1)) == 0 end @@ -145,7 +145,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(PiNode, Float64, val, ret, next_blk) # Test forwards instruction. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(1) == next_blk @test primal(get_codual(ret)) == primal(get_codual(val)) @test tangent(get_codual(ret)) == tangent(get_codual(val)) @@ -154,11 +154,11 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} # Increment tangent associated to `val`. This is done in order to check that the # tangent to `val` is incremented on the reverse-pass, not replaced. - Taped.increment_ref!(get_tangent_stack(val), 0.1) + Phi.increment_ref!(get_tangent_stack(val), 0.1) # Test backwards instruction. - @test bwds_inst isa Taped.BwdsInst - Taped.increment_ref!(get_tangent_stack(ret), 1.6) + @test bwds_inst isa Phi.BwdsInst + Phi.increment_ref!(get_tangent_stack(ret), 1.6) @test bwds_inst(3) == 3 @test get_tangent_stack(val)[] == 1.6 + 0.1 # check increment has happened. end @@ -180,12 +180,12 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(GlobalRef, P, gref, out, next_blk) # Forwards pass. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(4) == next_blk @test primal(get_codual(out)) == gref[] # Backwards pass. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(10) == 10 end @testset "QuoteNode and literals" for (x, out, next_blk) in Any[ @@ -197,13 +197,13 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} ] fwds_inst, bwds_inst = build_coinsts(nothing, x, out, next_blk) - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(1) == next_blk @test get_codual(out) == x[] @test length(get_tangent_stack(out)) == 1 @test get_tangent_stack(out)[] == tangent(get_codual(out)) - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(10) == 10 end @@ -212,11 +212,11 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} next_blk = 3 fwds_inst, bwds_inst = build_coinsts(Val(:boundscheck), val_ref, next_blk) - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(0) == next_blk @test get_codual(val_ref) == zero_codual(true) @test length(get_tangent_stack(val_ref)) == 1 - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(2) == 2 end @@ -263,20 +263,20 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} ), ] sig = _typeof(map(primal ∘ get_codual, arg_slots)) - interp = Taped.TInterp() - evaluator = Taped.get_evaluator(Taped.MinimalCtx(), sig, interp, true) - __rrule!! = Taped.get_rrule!!_evaluator(evaluator) - pb_stack = Taped.build_pb_stack(__rrule!!, evaluator, arg_slots) + interp = Phi.PInterp() + evaluator = Phi.get_evaluator(Phi.MinimalCtx(), sig, interp, true) + __rrule!! = Phi.get_rrule!!_evaluator(evaluator) + pb_stack = Phi.build_pb_stack(__rrule!!, evaluator, arg_slots) fwds_inst, bwds_inst = build_coinsts( Val(:call), P, out, arg_slots, evaluator, __rrule!!, pb_stack, next_blk ) # Test forwards-pass. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(0) == next_blk # Test reverse-pass. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst @test bwds_inst(5) == 5 end @@ -285,37 +285,37 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} fwds_inst, bwds_inst = build_coinsts(Val(:skipped_expression), next_blk) # Test forwards pass. - @test fwds_inst isa Taped.FwdsInst + @test fwds_inst isa Phi.FwdsInst @test fwds_inst(1) == next_blk # Test backwards pass. - @test bwds_inst isa Taped.BwdsInst + @test bwds_inst isa Phi.BwdsInst end # @testset "Expr(:throw_undef_if_not)" begin # @testset "defined" begin # slot_to_check = SlotRef(5.0) # oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - # @test oc isa Taped.Inst + # @test oc isa Phi.Inst # @test oc(0) == 2 # end # @testset "undefined (non-isbits)" begin # slot_to_check = SlotRef{Any}() # oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - # @test oc isa Taped.Inst + # @test oc isa Phi.Inst # @test_throws ErrorException oc(3) # end # @testset "undefined (isbits)" begin # slot_to_check = SlotRef{Float64}() # oc = build_inst(Val(:throw_undef_if_not), slot_to_check, 2) - # @test oc isa Taped.Inst + # @test oc isa Phi.Inst # # a placeholder for failing to throw an ErrorException when evaluated # @test_broken oc(5) == 1 # end # end - interp = Taped.TInterp() + interp = Phi.PInterp() # nothings inserted for consistency with generate_test_functions. @testset "$(_typeof((f, x...)))" for (interface_only, perf_flag, bnds, f, x...) in @@ -323,7 +323,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} sig = _typeof((f, x...)) @info "$sig" - in_f = Taped.InterpretedFunction(DefaultCtx(), sig, interp); + in_f = Phi.InterpretedFunction(DefaultCtx(), sig, interp); # Verify correctness. @assert f(deepcopy(x)...) == f(deepcopy(x)...) # primal runs @@ -331,7 +331,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} x_cpy_2 = deepcopy(x) @test has_equal_data(in_f(f, x_cpy_1...), f(x_cpy_2...)) @test has_equal_data(x_cpy_1, x_cpy_2) - rule = Taped.build_rrule!!(in_f); + rule = Phi.build_rrule!!(in_f); TestUtils.test_rrule!!( Xoshiro(123456), in_f, f, x...; perf_flag, interface_only, is_primitive=false, rule @@ -344,7 +344,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} # r = @benchmark $(Ref(in_f))[]($(Ref(f))[], $(Ref(deepcopy(x)))[]...); # # Estimate overal forwards-pass and pullback performance. - # __rrule!! = Taped.build_rrule!!(in_f); + # __rrule!! = Phi.build_rrule!!(in_f); # df = zero_codual(in_f); # codual_x = map(zero_codual, (f, x...)); # overall_timing = @benchmark TestUtils.to_benchmark($__rrule!!, $df, $codual_x...); @@ -353,7 +353,7 @@ array_ref_type(::Type{T}) where {T} = Base.RefArray{T, Vector{T}, Nothing} # println("original") # display(original) # println() - # println("taped") + # println("phi") # display(r) # println() # println("overall") diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index ebfa4f909..139680ff8 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -8,7 +8,7 @@ end @testset "s2s_reverse_mode_ad" begin @testset "SharedDataPairs" begin m = SharedDataPairs() - id = Taped.add_data!(m, 5.0) + id = Phi.add_data!(m, 5.0) @test length(m.pairs) == 1 @test m.pairs[1][1] == id @test m.pairs[1][2] == 5.0 @@ -21,22 +21,22 @@ end id_ssa_1 => CC.NewInstruction(nothing, Float64), id_ssa_2 => CC.NewInstruction(nothing, Any), ) - info = ADInfo(Taped.TInterp(), arg_types, ssa_insts, Any[]) + info = ADInfo(Phi.PInterp(), arg_types, ssa_insts, Any[]) # Verify that we can access the interpreter and terminator block ID. - @test info.interp isa Taped.TInterp + @test info.interp isa Phi.PInterp # Verify that we can get the type associated to Arguments, IDs, and others. global ___x = 5.0 global ___y::Float64 = 5.0 - @test Taped.get_primal_type(info, Argument(1)) == Float64 - @test Taped.get_primal_type(info, Argument(2)) == Int - @test Taped.get_primal_type(info, id_ssa_1) == Float64 - @test Taped.get_primal_type(info, GlobalRef(Base, :sin)) == typeof(sin) - @test Taped.get_primal_type(info, GlobalRef(Main, :___x)) == Any - @test Taped.get_primal_type(info, GlobalRef(Main, :___y)) == Float64 - @test Taped.get_primal_type(info, 5) == Int - @test Taped.get_primal_type(info, QuoteNode(:hello)) == Symbol + @test Phi.get_primal_type(info, Argument(1)) == Float64 + @test Phi.get_primal_type(info, Argument(2)) == Int + @test Phi.get_primal_type(info, id_ssa_1) == Float64 + @test Phi.get_primal_type(info, GlobalRef(Base, :sin)) == typeof(sin) + @test Phi.get_primal_type(info, GlobalRef(Main, :___x)) == Any + @test Phi.get_primal_type(info, GlobalRef(Main, :___y)) == Float64 + @test Phi.get_primal_type(info, 5) == Int + @test Phi.get_primal_type(info, QuoteNode(:hello)) == Symbol end @testset "make_ad_stmts!" begin @@ -45,13 +45,13 @@ end id_line_1 = ID() id_line_2 = ID() info = ADInfo( - Taped.TInterp(), + Phi.PInterp(), Dict{Argument, Any}(Argument(1) => typeof(sin), Argument(2) => Float64), Dict{ID, CC.NewInstruction}( id_line_1 => CC.NewInstruction(Expr(:invoke, nothing, cos, Argument(2)), Float64), id_line_2 => CC.NewInstruction(nothing, Any), ), - Any[Taped.NoTangentStack(), Stack{Float64}()], + Any[Phi.NoTangentStack(), Stack{Float64}()], ) @testset "Nothing" begin @@ -116,7 +116,7 @@ end @testset "PiNode" begin @testset "unhandled case" begin @test_throws( - Taped.UnhandledLanguageFeatureException, + Phi.UnhandledLanguageFeatureException, make_ad_stmts!(PiNode(5.0, Float64), ID(), info), ) end @@ -131,26 +131,26 @@ end @testset "non-const" begin global_ref = GlobalRef(S2SGlobals, :non_const_global) stmt_info = make_ad_stmts!(global_ref, ID(), info) - @test stmt_info isa Taped.ADStmtInfo + @test stmt_info isa Phi.ADStmtInfo @test Meta.isexpr(last(stmt_info.fwds)[2].stmt, :call) - @test last(stmt_info.fwds)[2].stmt.args[1] == Taped.__verify_const + @test last(stmt_info.fwds)[2].stmt.args[1] == Phi.__verify_const end @testset "differentiable const globals" begin stmt_info = make_ad_stmts!(GlobalRef(S2SGlobals, :const_float), ID(), info) - @test stmt_info isa Taped.ADStmtInfo + @test stmt_info isa Phi.ADStmtInfo @test Meta.isexpr(only(stmt_info.fwds)[2].stmt, :call) @test only(stmt_info.fwds)[2].stmt.args[1] == identity end end @testset "PhiCNode" begin @test_throws( - Taped.UnhandledLanguageFeatureException, + Phi.UnhandledLanguageFeatureException, make_ad_stmts!(Core.PhiCNode(Any[]), ID(), info), ) end @testset "UpsilonNode" begin @test_throws( - Taped.UnhandledLanguageFeatureException, + Phi.UnhandledLanguageFeatureException, make_ad_stmts!(Core.UpsilonNode(5), ID(), info), ) end @@ -160,14 +160,14 @@ end ad_stmts = make_ad_stmts!(stmt, id_line_1, info) fwds_stmt = ad_stmts.fwds[2][2].stmt @test Meta.isexpr(fwds_stmt, :call) - @test fwds_stmt.args[1] == Taped.__fwds_pass! + @test fwds_stmt.args[1] == Phi.__fwds_pass! @test Meta.isexpr(ad_stmts.rvs[2][2].stmt, :call) - @test ad_stmts.rvs[2][2].stmt.args[1] == Taped.__rvs_pass! + @test ad_stmts.rvs[2][2].stmt.args[1] == Phi.__rvs_pass! end @testset "copyast" begin stmt = Expr(:copyast, QuoteNode(:(hi))) ad_stmts = make_ad_stmts!(stmt, ID(), info) - @test ad_stmts isa Taped.ADStmtInfo + @test ad_stmts isa Phi.ADStmtInfo @test Meta.isexpr(ad_stmts.fwds[1][2].stmt, :call) @test ad_stmts.fwds[1][2].stmt.args[1] == identity end @@ -191,7 +191,7 @@ end end end - interp = Taped.TInterp() + interp = Phi.PInterp() @testset "$(_typeof((f, x...)))" for (n, (interface_only, perf_flag, bnds, f, x...)) in collect(enumerate(TestResources.generate_test_functions())) @@ -202,15 +202,15 @@ end ) # codual_args = map(zero_codual, (f, x...)) - # rule = Taped.build_rrule(interp, sig) + # rule = Phi.build_rrule(interp, sig) # out, pb!! = rule(codual_args...) # # @code_warntype optimize=true rule(codual_args...) # # @code_warntype optimize=true pb!!(tangent(out), map(tangent, codual_args)...) # primal_time = @benchmark $f($(Ref(x))[]...) # s2s_time = @benchmark $rule($codual_args...)[2]($(tangent(out)), $(map(tangent, codual_args))...) - # in_f = in_f = Taped.InterpretedFunction(DefaultCtx(), sig, interp); - # __rrule!! = Taped.build_rrule!!(in_f); + # in_f = in_f = Phi.InterpretedFunction(DefaultCtx(), sig, interp); + # __rrule!! = Phi.build_rrule!!(in_f); # df = zero_codual(in_f); # codual_x = map(zero_codual, (f, x...)); # interp_time = @benchmark TestUtils.to_benchmark($__rrule!!, $df, $codual_x...) diff --git a/test/rrules/builtins.jl b/test/rrules/builtins.jl index dc8b03d6f..fe3d019d3 100644 --- a/test/rrules/builtins.jl +++ b/test/rrules/builtins.jl @@ -1,12 +1,12 @@ @testset "builtins" begin @test_throws( ErrorException, - Taped.rrule!!(CoDual(IntrinsicsWrappers.add_ptr, NoTangent()), 5.0, 4.0), + Phi.rrule!!(CoDual(IntrinsicsWrappers.add_ptr, NoTangent()), 5.0, 4.0), ) @test_throws( ErrorException, - Taped.rrule!!(CoDual(IntrinsicsWrappers.sub_ptr, NoTangent()), 5.0, 4.0), + Phi.rrule!!(CoDual(IntrinsicsWrappers.sub_ptr, NoTangent()), 5.0, 4.0), ) TestUtils.run_rrule!!_test_cases(StableRNG, Val(:builtins)) diff --git a/test/rrules/foreigncall.jl b/test/rrules/foreigncall.jl index 757a6afc3..ad9b9168a 100644 --- a/test/rrules/foreigncall.jl +++ b/test/rrules/foreigncall.jl @@ -10,7 +10,7 @@ ] @test_throws( ErrorException, - Taped.rrule!!(zero_codual(Taped._foreigncall_), zero_codual(Val(name))), + Phi.rrule!!(zero_codual(Phi._foreigncall_), zero_codual(Val(name))), ) end end diff --git a/test/rrules/misc.jl b/test/rrules/misc.jl index 25a5f1e7d..64a4f8058 100644 --- a/test/rrules/misc.jl +++ b/test/rrules/misc.jl @@ -3,9 +3,9 @@ @testset "misc utility" begin x = randn(4, 5) p = Base.unsafe_convert(Ptr{Float64}, x) - @test Taped.wrap_ptr_as_view(p, 4, 4, 5) == x - @test Taped.wrap_ptr_as_view(p, 4, 2, 5) == x[1:2, :] - @test Taped.wrap_ptr_as_view(p, 4, 2, 3) == x[1:2, 1:3] + @test Phi.wrap_ptr_as_view(p, 4, 4, 5) == x + @test Phi.wrap_ptr_as_view(p, 4, 2, 5) == x[1:2, :] + @test Phi.wrap_ptr_as_view(p, 4, 2, 3) == x[1:2, 1:3] end @testset "lgetfield" begin @@ -19,11 +19,11 @@ end @testset "lsetfield!" begin x = TestResources.MutableFoo(5.0, randn(5)) - @test Taped.lsetfield!(x, Val(:a), 4.0) == 4.0 + @test Phi.lsetfield!(x, Val(:a), 4.0) == 4.0 @test x.a == 4.0 new_b = zeros(10) - @test Taped.lsetfield!(x, Val(:b), new_b) === new_b + @test Phi.lsetfield!(x, Val(:b), new_b) === new_b @test x.b === new_b end diff --git a/test/runtests.jl b/test/runtests.jl index 60c9ee804..81eb35aa5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ include("front_matter.jl") -@testset "Taped.jl" begin +@testset "Phi.jl" begin if test_group == "basic" include("utils.jl") include("tangents.jl") diff --git a/test/stack.jl b/test/stack.jl index 2ad9df2e1..37dd846c2 100644 --- a/test/stack.jl +++ b/test/stack.jl @@ -15,16 +15,16 @@ @test isempty(s) end @testset "tangent_stack_type" begin - @test Taped.tangent_stack_type(Float64) == Stack{Float64} - @test Taped.tangent_stack_type(Int) == Taped.NoTangentStack - @test Taped.tangent_stack_type(Any) == Stack{Any} - @test Taped.tangent_stack_type(DataType) == Stack{Any} - @test Taped.tangent_stack_type(Type{Float64}) == Taped.NoTangentStack + @test Phi.tangent_stack_type(Float64) == Stack{Float64} + @test Phi.tangent_stack_type(Int) == Phi.NoTangentStack + @test Phi.tangent_stack_type(Any) == Stack{Any} + @test Phi.tangent_stack_type(DataType) == Stack{Any} + @test Phi.tangent_stack_type(Type{Float64}) == Phi.NoTangentStack - @test Taped.tangent_ref_type_ub(Float64) == Taped.__array_ref_type(Float64) - @test Taped.tangent_ref_type_ub(Int) == Taped.NoTangentRef - @test Taped.tangent_ref_type_ub(Any) == Ref - @test Taped.tangent_ref_type_ub(DataType) == Ref - @test Taped.tangent_ref_type_ub(Type{Float64}) == Taped.NoTangentRef + @test Phi.tangent_ref_type_ub(Float64) == Phi.__array_ref_type(Float64) + @test Phi.tangent_ref_type_ub(Int) == Phi.NoTangentRef + @test Phi.tangent_ref_type_ub(Any) == Ref + @test Phi.tangent_ref_type_ub(DataType) == Ref + @test Phi.tangent_ref_type_ub(Type{Float64}) == Phi.NoTangentRef end end diff --git a/test/tangents.jl b/test/tangents.jl index 71e39e610..53180233f 100644 --- a/test/tangents.jl +++ b/test/tangents.jl @@ -106,7 +106,7 @@ __x = randn(10) p = pointer(__x, 3) - @testset "set_immutable_to_zero($(Taped._typeof(x)))" for x in Any[ + @testset "set_immutable_to_zero($(Phi._typeof(x)))" for x in Any[ NoTangent(), 5.0, 5f0, @@ -118,11 +118,11 @@ randn_tangent(Xoshiro(1), TestResources.MutableFoo(5.0, randn(3))), p, ] - @test Taped.set_immutable_to_zero(x) isa Taped._typeof(x) + @test Phi.set_immutable_to_zero(x) isa Phi._typeof(x) end # Bulk test auto-generated tangents. - @testset "autogenerated $n $x" for (n, x) in collect(enumerate(Taped.TestTypes.PRIMALS)) + @testset "autogenerated $n $x" for (n, x) in collect(enumerate(Phi.TestTypes.PRIMALS)) (interface_only, p) = x TestUtils.test_tangent_consistency(sr(1), p; interface_only) TestUtils.test_tangent_performance(sr(1), p) diff --git a/test/test_utils.jl b/test/test_utils.jl index 0265aac9d..1efe78815 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -87,7 +87,7 @@ input_addr_map = populate_address_map(x, x̄) f_f̄ = CoDual(f, zero_tangent(f)) x_x̄ = map(CoDual, x, x̄) - y_ȳ, _ = Taped.rrule!!(f_f̄, x_x̄...) + y_ȳ, _ = Phi.rrule!!(f_f̄, x_x̄...) z = (x..., primal(y_ȳ)) z̄ = (x̄..., tangent(y_ȳ)) @test_throws AssertionError populate_address_map(z, z̄) diff --git a/test/utils.jl b/test/utils.jl index ae13667a7..f2eb75012 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -9,10 +9,10 @@ @test _typeof((a=5.0, b=Float64)) == @NamedTuple{a::Float64, b::Type{Float64}} end @testset "tuple_map" begin - @test map(sin, (5.0, 4.0)) == Taped.tuple_map(sin, (5.0, 4.0)) + @test map(sin, (5.0, 4.0)) == Phi.tuple_map(sin, (5.0, 4.0)) @test ==( map(*, (5, 4.0, 3), (5.0, 4, 3.0)), - Taped.tuple_map(*, (5, 4.0, 3), (5.0, 4, 3.0)), + Phi.tuple_map(*, (5, 4.0, 3), (5.0, 4, 3.0)), ) end end