Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StochasticAD.jl backend #377

Closed
wants to merge 3 commits into from
Closed

Add StochasticAD.jl backend #377

wants to merge 3 commits into from

Conversation

arnauqb
Copy link

@arnauqb arnauqb commented Jul 24, 2024

This is my attempt at a minimal implementation for the StochasticAD.jl backend. I have only implemented the pushforward operator using the derivative_estimate method of StochasticAD.jl.

The backend recently started supporting a reverse mode via Enzyme but seems to be still bit experimental so I have not implemented it.

I have defined the AutoStochasticAD type with a field n_samples corresponding to the number of samples that are used to estimate the gradient. This backend is intended for programs with discrete randomness and in my personal experience you need quite a few samples to get a good estimate since the variance can be quite high. That is why I opted to have it as a simple input parameter rather than expecting the user to run the program multiple times manually. The results for non-discrete programs should be identical to ForwardDiff and so all the tests pass (at least for me locally).

Do you think it would be a good idea to include tests in the suite specifically designed with discrete distributions? (i.e., some stochastic program with categorical variables).

I took the freedom to do this PR since having this backend as part of DifferentiationInteraface.jl would be very useful to me but the author @gaurav-arya should have the last word in the implementation of this.

@gdalle
Copy link
Member

gdalle commented Jul 24, 2024

Thanks for the contribution! However I am not sure this belongs in DI: in most cases, your function is either stochastic or it is not. When it is stochastic, all the backends offered by DI are incorrect to various degrees, so you know you need something like StochasticAD.jl or DifferentiableExpectations.jl. When it is deterministic, then StochasticAD.jl and friends are useless. In a nutshell, I think the stochastic case doesn't overlap with the deterministic one, but I'd be curious to hear what Gaurav thinks, and @mschauer too.

@gdalle
Copy link
Member

gdalle commented Jul 24, 2024

Basically I don't want to encourage people to use standard AD backends in the stochastic case: even though some of them can appear to work, they're not designed to handle it and so shouldn't be trusted. If we put StochasticAD on the same shelf as the rest, it makes them look interchangeable.

@arnauqb
Copy link
Author

arnauqb commented Jul 24, 2024

hmm I see the logic.

I was hoping to integrate this into some of the Turinglang packages such as AdvancedVI.jl substituting these kinds of calls: https://github.com/TuringLang/AdvancedVI.jl/blob/b25d572fd28567c9c5ea0be575d28489b29bb6a7/src/objectives/elbo/repgradelbo.jl#L121 with DifferentiationInterface.

@gdalle
Copy link
Member

gdalle commented Jul 24, 2024

From what I can tell, in the file you linked, a deterministic AD backend (probably ForwardDiff) is used to compute the gradient of a stochastic function estimate_repgradelbo_ad_forward?
If so, that's kinda cursed, because it gives very counter-intuitive results (courtesy of Moritz):

julia> ForwardDiff.gradient(x->sum(x*rand()), ones(13))
13-element Vector{Float64}:
 0.480515430966141
 0.480515430966141
 0.480515430966141
 0.480515430966141
 0.480515430966141
 0.480515430966141
 0.480515430966141
 0.7989366586266822
 0.7989366586266822
 0.7989366586266822
 0.7989366586266822
 0.7989366586266822
 0.7989366586266822

@arnauqb
Copy link
Author

arnauqb commented Jul 24, 2024

The gradients are computed with the reparameterization trick (see e.g, here: https://github.com/TuringLang/AdvancedVI.jl/blob/b25d572fd28567c9c5ea0be575d28489b29bb6a7/src/families/location_scale.jl#L94).

Not sure how the chunk size is controlled in AdvancedVI.jl when using ForwardDiff to avoid the issue you point to...

@gdalle
Copy link
Member

gdalle commented Jul 24, 2024

DifferentiableExpectations.jl also aims at providing an interface for the reparametrization trick, which goes beyond location-scale distributions. Not sure how efficient it is, but I think it would make sense to see how we can make it work for Turing

@arnauqb
Copy link
Author

arnauqb commented Jul 24, 2024

Ok, then I agree it makes more sense to implement a StochasticAD.jl backend there for discrete distributions.

@arnauqb
Copy link
Author

arnauqb commented Jul 24, 2024

Maybe would also be worth to link it with https://github.com/TuringLang/DistributionsAD.jl

@gdalle
Copy link
Member

gdalle commented Jul 29, 2024

@gdalle gdalle closed this Jul 29, 2024
@gaurav-arya
Copy link

Apologies for the late reply, I'm trying to finish a thesis:) I've written some thoughts down below though in case it helps out.

I agree with @gdalle that it's dangerous to add StochasticAD as just another algorithm here to be used with exactly the same interface call, because it changes the semantics (discrete randomness suddenly stops getting ignored), which an algorithm choice shouldn't do. If the regular StochasticAD algorithm were to be added, it should probably only be used in accompaniment with a separate signal that the semantics of the differentiation function has changed, e.g. a StochasticADSemantics() where the derivative contribution of discrete random distributions is suddenly considered... but I probably wouldn't recommend the maintenance burden of this as it's likely still a bit of a minefield:)

Regarding plain old continuous randomness: afaik, ChainRules/Enzyme/etc. agree on the object which they are trying to compute, namely the "almost-sure derivative" / what you get from "differentiating inside the integral", e.g. see this fun Jax issue jax-ml/jax#676 (comment) :) If you think of a random program's semantics as a mapping from a space of random seeds to a value, then this is a well-defined object, and I think most AD systems treat failures to compute it as a bug, see e.g. EnzymeAD/Enzyme.jl#1388 where an incorrect rand rule is treated as a bug in Enzyme. The ForwardDiff case still fits the bill if you just look at the marginal distribution of a single element of a gradient; alternatively, one can manually sample and fix the seed of the program before providing it to ForwardDiff or any other backend, and enforce that all backends yield the same result. And funnily enough, one could also use StochasticAD with backend = StrategyWrapperFIsBackend(PrunedFIsBackend(), IgnoreDiscreteStrategy()) and also get a valid differentiate for this semantics:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants