From f9618e5cb6172215171345a6fd877fe6d065156d Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 6 Jan 2025 22:55:32 +0000 Subject: [PATCH] Make staged primitive rules possible (#437) * Add build_primitive_rrule function * Permit build primitive * Fix typo * Docstring * Bump patch version * Grammar * Add note to docs * Rename files --- Project.toml | 2 +- docs/make.jl | 2 +- .../{tools_for_rules.md => defining_rules.md} | 16 +++++++++++-- src/Mooncake.jl | 23 +++++++++++++++++++ src/interpreter/s2s_reverse_mode_ad.jl | 16 ++++++++----- 5 files changed, 49 insertions(+), 10 deletions(-) rename docs/src/utilities/{tools_for_rules.md => defining_rules.md} (70%) diff --git a/Project.toml b/Project.toml index fc43c7bb3..c5619e9b4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.71" +version = "0.4.72" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/make.jl b/docs/make.jl index 57808956b..5e0f484ae 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -28,7 +28,7 @@ makedocs(; joinpath("understanding_mooncake", "rule_system.md"), ], "Utilities" => [ - joinpath("utilities", "tools_for_rules.md"), + joinpath("utilities", "defining_rules.md"), joinpath("utilities", "debug_mode.md"), joinpath("utilities", "debugging_and_mwes.md"), ], diff --git a/docs/src/utilities/tools_for_rules.md b/docs/src/utilities/defining_rules.md similarity index 70% rename from docs/src/utilities/tools_for_rules.md rename to docs/src/utilities/defining_rules.md index 73fc48f4f..dfbc42493 100644 --- a/docs/src/utilities/tools_for_rules.md +++ b/docs/src/utilities/defining_rules.md @@ -1,8 +1,8 @@ -# Tools for Rules +# Defining Rules Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own `rrule!!` from scratch. -In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations. +In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations, which we discuss before discussing the more involved process of actually writing rules. ## Simplfiying Code via Overlays @@ -31,3 +31,15 @@ There is enough similarity between these two systems that most of the boilerplat ```@docs Mooncake.@from_rrule ``` + +## Adding Methods To `rrule!!` And `build_primitive_rrule` + +If the above strategies do not work for you, you should first implement a method of [`Mooncake.is_primitive`](@ref) for the signature of interest: +```@docs +Mooncake.is_primitive +``` +Then implement a method of one of the following: +```@docs +Mooncake.rrule!! +Mooncake.build_primitive_rrule +``` \ No newline at end of file diff --git a/src/Mooncake.jl b/src/Mooncake.jl index cd268f166..8670f1e6f 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -75,6 +75,29 @@ pb!!(1.0) """ function rrule!! end +""" + build_primitive_rrule(sig::Type{<:Tuple}) + +Construct an rrule for signature `sig`. For this function to be called in `build_rrule`, you +must also ensure that `is_primitive(context_type, sig)` is `true`. The callable returned by +this must obey the rrule interface, but there are no restrictions on the type of callable +itself. For example, you might return a callable `struct`. By default, this function returns +`rrule!!` so, most of the time, you should just implement a method of `rrule!!`. + +# Extended Help + +The purpose of this function is to permit computation at rule construction time, which can +be re-used at runtime. For example, you might wish to derive some information from `sig` +which you use at runtime (e.g. the fdata type of one of the arguments). While constant +propagation will often optimise this kind of computation away, it will sometimes fail to do +so in hard-to-predict circumstances. Consequently, if you need certain computations not to +happen at runtime in order to guarantee good performance, you might wish to e.g. emit a +callable `struct` with type parameters which are the result of this computation. In this +context, the motivation for using this function is the same as that of using staged +programming (e.g. via `@generated` functions) more generally. +""" +build_primitive_rrule(::Type{<:Tuple}) = rrule!! + include("utils.jl") include("tangents.jl") include("fwds_rvs_data.jl") diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 1ab4e71b9..7398a1dbd 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -948,8 +948,8 @@ end # Rule derivation. # -_is_primitive(C::Type, mi::Core.MethodInstance) = is_primitive(C, mi.specTypes) -_is_primitive(C::Type, sig::Type) = is_primitive(C, sig) +_get_sig(sig::Type) = sig +_get_sig(mi::Core.MethodInstance) = mi.specTypes function forwards_ret_type(primal_ir::IRCode) return fcodual_type(Base.Experimental.compute_ir_rettype(primal_ir)) @@ -967,8 +967,9 @@ important for performance in dynamic dispatch, and to ensure that recursion work properly. """ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C} - if _is_primitive(C, sig_or_mi) - return debug_mode ? DebugRRule{typeof(rrule!!)} : typeof(rrule!!) + if is_primitive(C, _get_sig(sig_or_mi)) + rule = build_primitive_rrule(_get_sig(sig_or_mi)) + return debug_mode ? DebugRRule{typeof(rule)} : typeof(rule) end ir, _ = lookup_ir(interp, sig_or_mi) @@ -1076,7 +1077,11 @@ function build_rrule( end # If we have a hand-coded rule, just use that. - _is_primitive(C, sig_or_mi) && return (debug_mode ? DebugRRule(rrule!!) : rrule!!) + sig = _get_sig(sig_or_mi) + if is_primitive(C, sig) + rule = build_primitive_rrule(sig) + return (debug_mode ? DebugRRule(rule) : rule) + end # We don't have a hand-coded rule, so derived one. lock(MOONCAKE_INFERENCE_LOCK) @@ -1093,7 +1098,6 @@ function build_rrule( rvs_oc = misty_closure(dri.rvs_ret_type, dri.rvs_ir, dri.shared_data...) # Compute the signature. Needs careful handling with varargs. - sig = sig_or_mi isa Core.MethodInstance ? sig_or_mi.specTypes : sig_or_mi nargs = num_args(dri.info) if dri.isva sig = Tuple{