-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StochasticAD.jl backend #377
Conversation
Thanks for the contribution! However I am not sure this belongs in DI: in most cases, your function is either stochastic or it is not. When it is stochastic, all the backends offered by DI are incorrect to various degrees, so you know you need something like StochasticAD.jl or DifferentiableExpectations.jl. When it is deterministic, then StochasticAD.jl and friends are useless. In a nutshell, I think the stochastic case doesn't overlap with the deterministic one, but I'd be curious to hear what Gaurav thinks, and @mschauer too. |
Basically I don't want to encourage people to use standard AD backends in the stochastic case: even though some of them can appear to work, they're not designed to handle it and so shouldn't be trusted. If we put StochasticAD on the same shelf as the rest, it makes them look interchangeable. |
hmm I see the logic. I was hoping to integrate this into some of the Turinglang packages such as |
From what I can tell, in the file you linked, a deterministic AD backend (probably ForwardDiff) is used to compute the gradient of a stochastic function julia> ForwardDiff.gradient(x->sum(x*rand()), ones(13))
13-element Vector{Float64}:
0.480515430966141
0.480515430966141
0.480515430966141
0.480515430966141
0.480515430966141
0.480515430966141
0.480515430966141
0.7989366586266822
0.7989366586266822
0.7989366586266822
0.7989366586266822
0.7989366586266822
0.7989366586266822 |
The gradients are computed with the reparameterization trick (see e.g, here: https://github.com/TuringLang/AdvancedVI.jl/blob/b25d572fd28567c9c5ea0be575d28489b29bb6a7/src/families/location_scale.jl#L94). Not sure how the chunk size is controlled in AdvancedVI.jl when using ForwardDiff to avoid the issue you point to... |
DifferentiableExpectations.jl also aims at providing an interface for the reparametrization trick, which goes beyond location-scale distributions. Not sure how efficient it is, but I think it would make sense to see how we can make it work for Turing |
Ok, then I agree it makes more sense to implement a StochasticAD.jl backend there for discrete distributions. |
Maybe would also be worth to link it with https://github.com/TuringLang/DistributionsAD.jl |
Closing in favor of JuliaDecisionFocusedLearning/DifferentiableExpectations.jl#18 |
Apologies for the late reply, I'm trying to finish a thesis:) I've written some thoughts down below though in case it helps out. I agree with @gdalle that it's dangerous to add Regarding plain old continuous randomness: afaik, ChainRules/Enzyme/etc. agree on the object which they are trying to compute, namely the "almost-sure derivative" / what you get from "differentiating inside the integral", e.g. see this fun Jax issue jax-ml/jax#676 (comment) :) If you think of a random program's semantics as a mapping from a space of random seeds to a value, then this is a well-defined object, and I think most AD systems treat failures to compute it as a bug, see e.g. EnzymeAD/Enzyme.jl#1388 where an incorrect |
This is my attempt at a minimal implementation for the StochasticAD.jl backend. I have only implemented the
pushforward
operator using the derivative_estimate method of StochasticAD.jl.The backend recently started supporting a reverse mode via Enzyme but seems to be still bit experimental so I have not implemented it.
I have defined the
AutoStochasticAD
type with a fieldn_samples
corresponding to the number of samples that are used to estimate the gradient. This backend is intended for programs with discrete randomness and in my personal experience you need quite a few samples to get a good estimate since the variance can be quite high. That is why I opted to have it as a simple input parameter rather than expecting the user to run the program multiple times manually. The results for non-discrete programs should be identical to ForwardDiff and so all the tests pass (at least for me locally).Do you think it would be a good idea to include tests in the suite specifically designed with discrete distributions? (i.e., some stochastic program with categorical variables).
I took the freedom to do this PR since having this backend as part of DifferentiationInteraface.jl would be very useful to me but the author @gaurav-arya should have the last word in the implementation of this.