diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 164b9f628..a6fc2a1ed 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -41,7 +41,6 @@ include("utils/traits.jl") include("utils/basis.jl") include("utils/batchsize.jl") include("utils/check.jl") -include("utils/exceptions.jl") include("utils/printing.jl") include("utils/context.jl") include("utils/linalg.jl") @@ -123,4 +122,6 @@ export AutoSparse @public inner, outer +include("init.jl") + end # module diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index ebf74a26b..802e4534e 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -145,18 +145,6 @@ function _prepare_pullback_aux( return PushforwardPullbackPrep(pushforward_prep) end -function _prepare_pullback_aux( - ::PullbackFast, f, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context} -) - throw(MissingBackendError(backend)) -end - -function _prepare_pullback_aux( - ::PullbackFast, f!, y, backend::AbstractADType, x, ty::NTuple, contexts::Vararg{Context} -) - throw(MissingBackendError(backend)) -end - ## One argument function _pullback_via_pushforward( diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 83475757b..b5da406aa 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -146,24 +146,6 @@ function _prepare_pushforward_aux( return PullbackPushforwardPrep(pullback_prep) end -function _prepare_pushforward_aux( - ::PushforwardFast, f, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context} -) - throw(MissingBackendError(backend)) -end - -function _prepare_pushforward_aux( - ::PushforwardFast, - f!, - y, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context}, -) - throw(MissingBackendError(backend)) -end - ## One argument function _pushforward_via_pullback( diff --git a/DifferentiationInterface/src/init.jl b/DifferentiationInterface/src/init.jl new file mode 100644 index 000000000..619581f81 --- /dev/null +++ b/DifferentiationInterface/src/init.jl @@ -0,0 +1,13 @@ +function __init__() + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs + if exc.f in (_prepare_pushforward_aux, _prepare_pullback_aux) + B = first(T for T in argtypes if T <: AbstractADType) + printstyled( + io, + "\n\nThe autodiff backend package you want to use may not be loaded. Please run the following command and try again:"; + bold=true, + ) + printstyled(io, "\n\n\timport $(package_name(B))"; color=:cyan, bold=true) + end + end +end diff --git a/DifferentiationInterface/src/utils/exceptions.jl b/DifferentiationInterface/src/utils/exceptions.jl deleted file mode 100644 index 308731cfc..000000000 --- a/DifferentiationInterface/src/utils/exceptions.jl +++ /dev/null @@ -1,21 +0,0 @@ -struct MissingBackendError <: Exception - backend::AbstractADType -end - -function Base.showerror(io::IO, e::MissingBackendError) - println(io, "MissingBackendError: Failed to use $(e.backend).") - if !check_available(e.backend) - print( - io, - """Backend package is probably not loaded. To fix this, try to run - - import $(package_name(e.backend)) - """, - ) - else - print( - io, - "Please open an issue: https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/new", - ) - end -end diff --git a/DifferentiationInterface/src/utils/printing.jl b/DifferentiationInterface/src/utils/printing.jl index fc63469ad..77d083a7e 100644 --- a/DifferentiationInterface/src/utils/printing.jl +++ b/DifferentiationInterface/src/utils/printing.jl @@ -1,19 +1,24 @@ -function package_name(b::AbstractADType) - s = string(b) +package_name(b::AbstractADType) = package_name(typeof(b)) + +function package_name(::Type{B}) where {B<:AbstractADType} + s = string(B) s = chopprefix(s, "ADTypes.") s = chopprefix(s, "Auto") - k = findfirst('(', s) - isnothing(k) && throw(ArgumentError("Cannot parse backend into package")) - return s[begin:(k - 1)] + k = findfirst('{', s) + if isnothing(k) + return s + else + return s[begin:(k - 1)] + end end -function package_name(b::SecondOrder) - p1 = package_name(outer(b)) - p2 = package_name(inner(b)) +function package_name(::Type{SecondOrder{O,I}}) where {O,I} + p1 = package_name(O) + p2 = package_name(I) return p1 == p2 ? p1 : "$p1, $p2" end -package_name(b::AutoSparse) = package_name(dense_ad(b)) +package_name(::Type{<:AutoSparse{D}}) where {D} = package_name(D) function document_preparation(operator_name::AbstractString; same_point=false) if same_point diff --git a/DifferentiationInterface/test/Core/Internals/exceptions.jl b/DifferentiationInterface/test/Core/Internals/exceptions.jl deleted file mode 100644 index 909f8ca2f..000000000 --- a/DifferentiationInterface/test/Core/Internals/exceptions.jl +++ /dev/null @@ -1,41 +0,0 @@ -using ADTypes: ADTypes, AbstractADType -using DifferentiationInterface -import DifferentiationInterface as DI -using Test - -""" - AutoBrokenForward <: ADTypes.AbstractADType - -Available forward-mode backend with no pushforward implementation. -Used to test error messages. -""" -struct AutoBrokenForward <: AbstractADType end -ADTypes.mode(::AutoBrokenForward) = ADTypes.ForwardMode() -DI.check_available(::AutoBrokenForward) = true - -""" - AutoBrokenReverse <: ADTypes.AbstractADType - -Available reverse-mode backend with no pullback implementation. -Used to test error messages. -""" -struct AutoBrokenReverse <: AbstractADType end -ADTypes.mode(::AutoBrokenReverse) = ADTypes.ReverseMode() - -## Test exceptions -@testset "MissingBackendError" begin - x = [1.0] - y = similar(x) - - @test_throws DI.MissingBackendError jacobian(copy, AutoBrokenForward(), x) - @test_throws DI.MissingBackendError jacobian(copy, AutoBrokenReverse(), x) - - @test_throws DI.MissingBackendError jacobian(copyto!, y, AutoBrokenForward(), x) - @test_throws DI.MissingBackendError jacobian(copyto!, y, AutoBrokenReverse(), x) - - @test_throws DI.MissingBackendError hessian(sum, AutoBrokenForward(), x) - @test_throws DI.MissingBackendError hessian(sum, AutoBrokenReverse(), x) - - sprint(showerror, DI.MissingBackendError(AutoBrokenForward())) - sprint(showerror, DI.MissingBackendError(AutoBrokenReverse())) -end