Skip to content

Commit

Permalink
Better type annotations in fallbacks (#458)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Sep 8, 2024
1 parent ccf5247 commit 1155218
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions DifferentiationInterface/src/fallbacks/no_extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,56 +87,68 @@ for op in (:pushforward, :pullback, :hvp)
HVPExtras
end
# 1-arg
@eval function $prep_op_same_point(f::F, backend::AbstractADType, x, seed) where {F}
@eval function $prep_op_same_point(
f::F, backend::AbstractADType, x, seed::Tangents
) where {F}
ex = $prep_op(f, backend, x, seed)
return $prep_op_same_point(f, ex, backend, x, seed)
end
@eval function $prep_op_same_point(
f::F, ex::$E, backend::AbstractADType, x, seed
f::F, ex::$E, backend::AbstractADType, x, seed::Tangents
) where {F}
return ex
end
@eval function $op(f::F, backend::AbstractADType, x, seed) where {F}
@eval function $op(f::F, backend::AbstractADType, x, seed::Tangents) where {F}
ex = $prep_op(f, backend, x, seed)
return $op(f, ex, backend, x, seed)
end
@eval function $op!(f::F, result, backend::AbstractADType, x, seed) where {F}
@eval function $op!(
f::F, result::Tangents, backend::AbstractADType, x, seed::Tangents
) where {F}
ex = $prep_op(f, backend, x, seed)
return $op!(f, result, ex, backend, x, seed)
end
op == :hvp && continue
@eval function $val_and_op(f::F, backend::AbstractADType, x, seed) where {F}
@eval function $val_and_op(f::F, backend::AbstractADType, x, seed::Tangents) where {F}
ex = $prep_op(f, backend, x, seed)
return $val_and_op(f, ex, backend, x, seed)
end
@eval function $val_and_op!(f::F, result, backend::AbstractADType, x, seed) where {F}
@eval function $val_and_op!(
f::F, result::Tangents, backend::AbstractADType, x, seed::Tangents
) where {F}
ex = $prep_op(f, backend, x, seed)
return $val_and_op!(f, result, ex, backend, x, seed)
end
# 2-arg
@eval function $prep_op_same_point(f!::F, y, backend::AbstractADType, x, seed) where {F}
@eval function $prep_op_same_point(
f!::F, y, backend::AbstractADType, x, seed::Tangents
) where {F}
ex = $prep_op(f!, y, backend, x, seed)
return $prep_op_same_point(f!, y, ex, backend, x, seed)
end
@eval function $prep_op_same_point(
f!::F, y, ex::$E, backend::AbstractADType, x, seed
f!::F, y, ex::$E, backend::AbstractADType, x, seed::Tangents
) where {F}
return ex
end
@eval function $op(f!::F, y, backend::AbstractADType, x, seed) where {F}
@eval function $op(f!::F, y, backend::AbstractADType, x, seed::Tangents) where {F}
ex = $prep_op(f!, y, backend, x, seed)
return $op(f!, y, ex, backend, x, seed)
end
@eval function $op!(f!::F, y, result, backend::AbstractADType, x, seed) where {F}
@eval function $op!(
f!::F, y, result::Tangents, backend::AbstractADType, x, seed::Tangents
) where {F}
ex = $prep_op(f!, y, backend, x, seed)
return $op!(f!, y, result, ex, backend, x, seed)
end
@eval function $val_and_op(f!::F, y, backend::AbstractADType, x, seed) where {F}
@eval function $val_and_op(
f!::F, y, backend::AbstractADType, x, seed::Tangents
) where {F}
ex = $prep_op(f!, y, backend, x, seed)
return $val_and_op(f!, y, ex, backend, x, seed)
end
@eval function $val_and_op!(
f!::F, y, result, backend::AbstractADType, x, seed
f!::F, y, result::Tangents, backend::AbstractADType, x, seed::Tangents
) where {F}
ex = $prep_op(f!, y, backend, x, seed)
return $val_and_op!(f!, y, result, ex, backend, x, seed)
Expand Down

0 comments on commit 1155218

Please sign in to comment.