Comments (6)
The first issue (the method error) I can explain quite easily: DistributionsAD actually mixes different AD backends in some implementations (I assume because some computations are faster with Zygote and the other way around), and if you want to perform these computations you have to load both backends. That's what's happening here: in https://github.com/TuringLang/DistributionsAD.jl/blob/d7ceffbe3cd5ef54f98cdac6f39142ee9d3f8895/src/reversediffx.jl#L168-L170 you can see that the method for ReverseDiff actually uses a pullback from Zygote - but that one does not exist if you don't load Zygote.
Since the AD backends are all heavy dependencies, all AD functionality is hidden behind @requires
blocks. The "correct" approach would be to put the implementation in a nested @requires
block, such that the Zygote functionality is guaranteed to be available. However, that wouldn't actually fix your problem since then just the whole function (and not only the pullback) is not available anymore. So I'm not really sure what's the best approach to avoid these issues apart from not mixing the AD backends at all (which I assume would however lead to decreased performance in some cases).
BTW probably the mix of ReverseDiff and Zygote also explains why they both yield the same result.
from bijectors.jl.
Thank you very much for the information, that helps a lot!
Regarding problem 2: When I set get_logpost2( MvNormal([X=>1, Y=>1]), ...)
, I do get identical results for both ForwardDiff and ReverseDiff. This problem only arises as soon as get_logpost2( MvNormal([X<1, Y<1]), ...)
.
from bijectors.jl.
I noticed the following
julia> using DistributionsAD, Distributions, FiniteDifferences, ForwardDiff, Tracker, ReverseDiff, Zygote
julia> function h(theta, x)
dist = MvNormal(theta)
return logpdf(dist, x)
end
h (generic function with 1 method)
julia> x = randn(2)
2-element Array{Float64,1}:
-0.2728447604886671
0.2954077831645062
julia> FiniteDifferences.grad(central_fdm(5, 1), x -> h([0.1, 0.1], x), x)[1]
2-element Array{Float64,1}:
27.28447604886685
-29.540778316450915
julia> ForwardDiff.gradient(x -> h([0.1, 0.1], x), x)
2-element Array{Float64,1}:
27.284476048866708
-29.540778316450616
julia> Tracker.gradient(x -> h([0.1, 0.1], x), theta)[1]
Tracked 2-element Array{Float64,1}:
2728.4476048866704
-2954.077831645061
julia> ReverseDiff.gradient(x -> h([0.1, 0.1], x), x)
2-element Array{Float64,1}:
2728.4476048866704
-2954.077831645061
julia> Zygote.gradient(x -> h([0.1, 0.1], x), x)[1]
2-element Array{Float64,1}:
27.284476048866708
-29.54077831645062
with
[31c24e10] Distributions v0.23.4
[26cc04aa] FiniteDifferences v0.10.2
[f6369f11] ForwardDiff v0.10.10
[37e2e3b7] ReverseDiff v1.2.0
[9f7883ad] Tracker v0.2.7
[e88e6eb3] Zygote v0.5.1
Moreover, I get
julia> FiniteDifferences.grad(central_fdm(5, 1), x -> h([10.0, 10.0], x), x)[1]
2-element Array{Float64,1}:
0.0027284476047971617
-0.0029540778316831193
julia> ForwardDiff.gradient(x -> h([10.0, 10.0], x), x)
2-element Array{Float64,1}:
0.0027284476048866713
-0.002954077831645062
julia> Tracker.gradient(x -> h([10.0, 10.0], x), theta)[1]
Tracked 2-element Array{Float64,1}:
2.7284476048866714e-5
-2.954077831645062e-5
julia> ReverseDiff.gradient(x -> h([10.0, 10.0], x), x)
2-element Array{Float64,1}:
2.7284476048866714e-5
-2.954077831645062e-5
julia> Zygote.gradient(x -> h([10.0, 10.0], x), x)[1]
2-element Array{Float64,1}:
0.0027284476048866713
-0.002954077831645062
so for me it fails for standard deviations < 1 and > 1 consistently for Tracker and ReverseDiff. Something seems to be broken there, I'll open an issue over at DistributionsAD.
from bijectors.jl.
With TuringLang/DistributionsAD.jl#92, FiniteDifferences, Tracker, ForwardDiff, and ReverseDiff all yield the same gradient. Seems Zygote is still broken though, not sure why yet.
from bijectors.jl.
The latest commit in TuringLang/DistributionsAD.jl#92 fixes Zygote as well, so the second part of this issue here is resolved.
from bijectors.jl.
The AD issues are fixed in DistributionsAD 0.6.2.
from bijectors.jl.
Related Issues (20)
- MethodError: no method matching `(::Inverse)...`
- Zygote error differentiating `Coupling` HOT 4
- `PlanarLayer` broken in 0.9.9 HOT 17
- OrderedBijector gives surprising results HOT 13
- Road to DensityInterface? HOT 11
- Rewrite: what do we want
- Fitting normalizing flow HOT 2
- Matrix Variate Bijectors - Lower/Uppertriangular transforms
- Zygote AD & `logpdf` for transformed multivariate
- `Bijectors.ordered(d)` is incorrect if `bijector(d)` is not `Identity` HOT 1
- Adding support for FunctionChains.jl? HOT 9
- Problem compiling an app with Turing
- MethodError: no method matching bijector(::MixtureModel{Multivariate, Continuous, MvNormal, Float64}) HOT 1
- CorrBijector makes posterior improper HOT 7
- SimplexBijector tests fails on boundary
- Remove heavy usage of `@generated`
- Questions on custom bijectors HOT 1
- Bijector for MatrixNormal HOT 2
- Remove heavy usage of `@generated` HOT 2
- Add `rng` to have the reproducibility in `PlanarLayer` and `RadialLayer`
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from bijectors.jl.