Skip to content

Commit

Permalink
Complete Phi rename (#102)
Browse files Browse the repository at this point in the history
* Update README.md

* Rename from Taped to Phi

* Fix typo
  • Loading branch information
willtebbutt authored Mar 25, 2024
1 parent fe95e24 commit b803b49
Show file tree
Hide file tree
Showing 41 changed files with 294 additions and 294 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name = "Taped"
name = "Phi"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt and contributors"]
version = "0.1.0"
Expand All @@ -21,7 +21,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
TapedSpecialFunctionsExt = "SpecialFunctions"
PhiSpecialFunctionsExt = "SpecialFunctions"

[compat]
BenchmarkTools = "1"
Expand Down
2 changes: 1 addition & 1 deletion bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Taped = "07d77754-e150-4737-8c94-cd238a1fb45b"
Phi = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Expand Down
16 changes: 8 additions & 8 deletions bench/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Benchmarking

There are two flavours of benchmarks implemented in `run_benchmarks.jl`.
One is a set of pass / fail tests designed to check that large performance regressions are avoided in Taped.
The other is a set of comparisons between a variety of frameworks -- this set is designed to give a rough sense of where Taped stands in comparison to other AD frameworks, and the results should not be thought of as pass / fail tests.
One is a set of pass / fail tests designed to check that large performance regressions are avoided in Phi.
The other is a set of comparisons between a variety of frameworks -- this set is designed to give a rough sense of where Phi stands in comparison to other AD frameworks, and the results should not be thought of as pass / fail tests.

## Taped-Only Benchmarking
## Phi-Only Benchmarking

The benchmarking runs as part of CI, and evaluates a sequence of pass / fail tests.

Expand All @@ -21,9 +21,9 @@ to run the benchmarks which test the performance of AD. This will produce a `Dat
containing a run-down of the results. It has the following columns:
1. `tag`: a `String` with an automatically generated name for the test
1. `primal_time`: the time taken to run the original code
1. `taped_time`: the time is takes Taped to AD the code
1. `Taped`: `taped_time / primal_time`
1. `range`: a named tuple with fields `lb` and `ub` specifying the acceptable range of values for `Taped`.
1. `phi_time`: the time is takes Phi to AD the code
1. `Phi`: `phi_time / primal_time`
1. `range`: a named tuple with fields `lb` and `ub` specifying the acceptable range of values for `Phi`.

From here you can look at whatever properties of the results you are interested in.

Expand All @@ -33,7 +33,7 @@ Note that the types of all of the columns are very simple, so it is fine to writ
CSV.write("file_name.csv", df)
```

Additionally, the convenience function `plot_ratio_histogram!` can be used to produce a histogram of `Taped` with formatting which is suited to this field. Call it as follows:
Additionally, the convenience function `plot_ratio_histogram!` can be used to produce a histogram of `Phi` with formatting which is suited to this field. Call it as follows:
```julia
derived_results = benchmark_derived_rrules!!(Xoshiro)
df = DataFrame(df)
Expand All @@ -42,7 +42,7 @@ plot_ratio_histogram!(df)

## Inter-framework Benchmarking

This comprises a small suite of functions that we AD using Taped, Zygote, ReverseDiff, and Enzyme. This suite of benchmarks is also run as part of CI, and the output is recorded in two ways:
This comprises a small suite of functions that we AD using Phi, Zygote, ReverseDiff, and Enzyme. This suite of benchmarks is also run as part of CI, and the output is recorded in two ways:
1. a table of results is posted as comment in a PR
1. the table and a corresponding graph are stored as github actions artifacts, and can be retrieved by going to the "Checks" tab of your PR, and clicking on the artifact button.

Expand Down
34 changes: 17 additions & 17 deletions bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ using
PrettyTables,
Random,
ReverseDiff,
Taped,
Phi,
Test,
Turing,
Zygote

using Taped:
using Phi:
CoDual,
generate_hand_written_rrule!!_test_cases,
generate_derived_rrule!!_test_cases,
InterpretedFunction,
TestUtils,
TInterp,
PInterp,
_typeof

using Taped.TestUtils: _deepcopy, to_benchmark
using Phi.TestUtils: _deepcopy, to_benchmark

function zygote_to_benchmark(ctx, x::Vararg{Any, N}) where {N}
out, pb = Zygote._pullback(ctx, x...)
Expand Down Expand Up @@ -84,7 +84,7 @@ _broadcast_sin_cos_exp(x::AbstractArray{<:Real}) = sum(sin.(cos.(exp.(x))))
# about all of the operations.
_simple_mlp(W2, W1, Y, X) = sum(abs2, Y - W2 * map(x -> x * (0 <= x), W1 * X))

# Only Zygote and Taped can actually handle this. Note that Taped only has rules for BLAS
# Only Zygote and Phi can actually handle this. Note that Phi only has rules for BLAS
# and LAPACK stuff, not explicit rules for things like the squared euclidean distance.
# Consequently, Zygote is at a major advantage.
_gp_lml(x, y, s) = logpdf(GP(SEKernel())(x, s), y)
Expand Down Expand Up @@ -174,12 +174,12 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
evals=1,
)

# Benchmark AD via Taped.
@info "taped"
rule = Taped.build_rrule(args...)
# Benchmark AD via Phi.
@info "phi"
rule = Phi.build_rrule(args...)
coduals = map(x -> x isa CoDual ? x : zero_codual(x), args)
to_benchmark(rule, coduals...)
suite["taped"] = @benchmark(to_benchmark($rule, $coduals...))
suite["phi"] = @benchmark(to_benchmark($rule, $coduals...))

if include_other_frameworks

Expand Down Expand Up @@ -219,16 +219,16 @@ end
function combine_results(result, tag, _range, default_range)
d = result[2]
primal_time = time(minimum(d["primal"]))
taped_time = time(minimum(d["taped"]))
phi_time = time(minimum(d["phi"]))
zygote_time = in("zygote", keys(d)) ? time(minimum(d["zygote"])) : missing
rd_time = in("rd", keys(d)) ? time(minimum(d["rd"])) : missing
ez_time = in("enzyme", keys(d)) ? time(minimum(d["enzyme"])) : missing
fallback_tag = string((result[1][1], map(Taped._typeof, result[1][2:end])...))
fallback_tag = string((result[1][1], map(Phi._typeof, result[1][2:end])...))
return (
tag=tag === nothing ? fallback_tag : tag,
primal_time=primal_time,
taped_time=taped_time,
Taped=taped_time / primal_time,
phi_time=phi_time,
Phi=phi_time / primal_time,
zygote_time=zygote_time,
Zygote=zygote_time / primal_time,
rd_time=rd_time,
Expand Down Expand Up @@ -283,26 +283,26 @@ end
function flag_concerning_performance(ratios)
@testset "detect concerning performance" begin
@testset for ratio in ratios
@test ratio.range.lb < ratio.Taped < ratio.range.ub
@test ratio.range.lb < ratio.Phi < ratio.range.ub
end
end
end

"""
plot_ratio_histogram!(df::DataFrame)
Constructs a histogram of the `taped_ratio` field of `df`, with formatting that is
Constructs a histogram of the `phi_ratio` field of `df`, with formatting that is
well-suited to the numbers typically found in this field.
"""
function plot_ratio_histogram!(df::DataFrame)
bin = 10.0 .^ (0.0:0.05:6.0)
xlim = extrema(bin)
histogram(df.Taped; xscale=:log10, xlim, bin, title="log", label="")
histogram(df.Phi; xscale=:log10, xlim, bin, title="log", label="")
end

function create_inter_ad_benchmarks()
results = benchmark_inter_framework_rules()
tools = [:Taped, :Zygote, :ReverseDiff, :Enzyme]
tools = [:Phi, :Zygote, :ReverseDiff, :Enzyme]
df = DataFrame(results)[:, [:tag, tools...]]

# Plot graph of results.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module TapedSpecialFunctionsExt
module PhiSpecialFunctionsExt

using SpecialFunctions, Taped
using SpecialFunctions, Phi

import Taped: @from_rrule, DefaultCtx
import Phi: @from_rrule, DefaultCtx

@from_rrule DefaultCtx Tuple{typeof(airyai), Float64}
@from_rrule DefaultCtx Tuple{typeof(airyaix), Float64}
Expand Down
2 changes: 1 addition & 1 deletion src/Taped.jl → src/Phi.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module Taped
module Phi

const CC = Core.Compiler

Expand Down
18 changes: 9 additions & 9 deletions src/chain_rules_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ __increment_shim!!(x, y) = increment!!(x, y)
"""
@from_rrule ctx sig
Creates a `Taped.rrule!!` from a `ChainRulesCore.rrule`. `ctx` is the type of the context in
Creates a `Phi.rrule!!` from a `ChainRulesCore.rrule`. `ctx` is the type of the context in
which this rule should apply, and `sig` is the type-tuple which specifies which primal the
rule should apply to.
For example,
```julia
@from_rrule DefaultCtx Tuple{typeof(sin), Float64}
```
would define a `Taped.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`.
would define a `Phi.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`.
Health warning:
Use this function with care. It has only been tested for `Float64` arguments and arguments
Expand All @@ -29,13 +29,13 @@ macro from_rrule(ctx, sig)
arg_type_symbols = sig.args[2:end]

arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols))
arg_types = map(t -> :(Taped.CoDual{<:$t}), arg_type_symbols)
arg_types = map(t -> :(Phi.CoDual{<:$t}), arg_type_symbols)
arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types)

call_rrule = Expr(
:call,
:(Taped.ChainRulesCore.rrule),
map(n -> :(Taped.primal($n)), arg_names)...,
:(Phi.ChainRulesCore.rrule),
map(n -> :(Phi.primal($n)), arg_names)...,
)

pb_arg_names = map(n -> Symbol("dx_$(n)"), eachindex(arg_names))
Expand All @@ -45,7 +45,7 @@ macro from_rrule(ctx, sig)
incrementers = Expr(
:tuple,
map(pb_arg_names, pb_output_names) do a, b
:(Taped.__increment_shim!!($a, $b))
:(Phi.__increment_shim!!($a, $b))
end...,
)

Expand All @@ -62,18 +62,18 @@ macro from_rrule(ctx, sig)
rule_expr = ExprTools.combinedef(
Dict(
:head => :function,
:name => :(Taped.rrule!!),
:name => :(Phi.rrule!!),
:args => arg_exprs,
:body => quote
y, pb = $call_rrule
$pb
return Taped.zero_codual(y), pb!!
return Phi.zero_codual(y), pb!!
end,
)
)

ex = quote
Taped.is_primitive(::Type{$ctx}, ::Type{$sig}) = true
Phi.is_primitive(::Type{$ctx}, ::Type{$sig}) = true
$rule_expr
end
return esc(ex)
Expand Down
22 changes: 11 additions & 11 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ end

TICache() = TICache(IdDict{Core.MethodInstance, Core.CodeInstance}())

struct TapedInterpreter{C} <: CC.AbstractInterpreter
struct PhiInterpreter{C} <: CC.AbstractInterpreter
meta # additional information
world::UInt
inf_params::CC.InferenceParams
opt_params::CC.OptimizationParams
inf_cache::Vector{CC.InferenceResult}
code_cache::TICache
oc_cache::Dict{Any, Any}
function TapedInterpreter(
function PhiInterpreter(
::Type{C};
meta=nothing,
world::UInt=Base.get_world_counter(),
Expand All @@ -32,15 +32,15 @@ struct TapedInterpreter{C} <: CC.AbstractInterpreter
end
end

TapedInterpreter() = TapedInterpreter(DefaultCtx)
PhiInterpreter() = PhiInterpreter(DefaultCtx)

const TInterp = TapedInterpreter
const PInterp = PhiInterpreter

CC.InferenceParams(interp::TInterp) = interp.inf_params
CC.OptimizationParams(interp::TInterp) = interp.opt_params
CC.get_world_counter(interp::TInterp) = interp.world
CC.get_inference_cache(interp::TInterp) = interp.inf_cache
function CC.code_cache(interp::TInterp)
CC.InferenceParams(interp::PInterp) = interp.inf_params
CC.OptimizationParams(interp::PInterp) = interp.opt_params
CC.get_world_counter(interp::PInterp) = interp.world
CC.get_inference_cache(interp::PInterp) = interp.inf_cache
function CC.code_cache(interp::PInterp)
return CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
end
function CC.get(wvc::CC.WorldView{TICache}, mi::Core.MethodInstance, default)
Expand All @@ -62,7 +62,7 @@ _type(x::CC.PartialStruct) = x.typ
_type(x::CC.Conditional) = Union{x.thentype, x.elsetype}

function CC.inlining_policy(
interp::TapedInterpreter{C},
interp::PhiInterpreter{C},
@nospecialize(src),
@nospecialize(info::CC.CallInfo),
stmt_flag::UInt8,
Expand All @@ -85,4 +85,4 @@ function CC.inlining_policy(
)
end

context_type(::TInterp{C}) where {C} = C
context_type(::PInterp{C}) where {C} = C
2 changes: 1 addition & 1 deletion src/interpreter/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true
You should implemented more complicated method of `is_primitive` in the usual way.
"""
macro is_primitive(Tctx, sig)
return esc(:(Taped.is_primitive(::Type{$Tctx}, ::Type{<:$sig}) = true))
return esc(:(Phi.is_primitive(::Type{$Tctx}, ::Type{<:$sig}) = true))
end
8 changes: 4 additions & 4 deletions src/interpreter/interpreted_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ struct InterpretedFunction{sig<:Tuple, C, Treturn, Targ_info<:ArgInfo}
bb_starts::Vector{Int}
bb_ends::Vector{Int}
ir::IRCode
interp::TapedInterpreter
interp::PhiInterpreter
spnames::Any
end

Expand Down Expand Up @@ -417,7 +417,7 @@ end
Construct a data structure which can be used to execute the instruction specified by `sig`.
For example,
```julia
in_f = InterpretedFunction(DefaultCtx(), Tuple{typeof(sin), Float64}, Taped.TInterp())
in_f = InterpretedFunction(DefaultCtx(), Tuple{typeof(sin), Float64}, Phi.PInterp())
in_f(sin, 5.0)
```
will yield exactly the same result as running `sin(5.0)`. The advantage of this data
Expand Down Expand Up @@ -458,7 +458,7 @@ and each `Argument` / `SSAValue` in the IR with a (heap-allocated) `AbstractSlot
most part, these slots are `Ref`s).
While the details of what each kind of `OpaqueClosure` can be found in the corresponding
`Taped.build_inst` method, they generally have the following structure:
`Phi.build_inst` method, they generally have the following structure:
- load data from argument / ssa slots,
- do computation,
- write result to the instruction's ssa slot,
Expand Down Expand Up @@ -566,7 +566,7 @@ end
# `InterpretedFunction`s operate recursively -- if the types associated to the `args` field
# of a `:call` expression have not been inferred successfully, then we must wait until
# runtime to determine what code to run. The `DelayedInterpretedFunction` does exactly this.
struct DelayedInterpretedFunction{C, Tlocal_cache, T<:TapedInterpreter}
struct DelayedInterpretedFunction{C, Tlocal_cache, T<:PhiInterpreter}
ctx::C
local_cache::Tlocal_cache
interp::T
Expand Down
Loading

2 comments on commit b803b49

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/103582

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" b803b4944e773312cc54e40cb3ce1fef12abce69
git push origin v0.1.0

Please sign in to comment.