Giter VIP home page Giter VIP logo

Comments (6)

devmotion avatar devmotion commented on June 12, 2024

The first issue (the method error) I can explain quite easily: DistributionsAD actually mixes different AD backends in some implementations (I assume because some computations are faster with Zygote and the other way around), and if you want to perform these computations you have to load both backends. That's what's happening here: in https://github.com/TuringLang/DistributionsAD.jl/blob/d7ceffbe3cd5ef54f98cdac6f39142ee9d3f8895/src/reversediffx.jl#L168-L170 you can see that the method for ReverseDiff actually uses a pullback from Zygote - but that one does not exist if you don't load Zygote.

Since the AD backends are all heavy dependencies, all AD functionality is hidden behind @requires blocks. The "correct" approach would be to put the implementation in a nested @requires block, such that the Zygote functionality is guaranteed to be available. However, that wouldn't actually fix your problem since then just the whole function (and not only the pullback) is not available anymore. So I'm not really sure what's the best approach to avoid these issues apart from not mixing the AD backends at all (which I assume would however lead to decreased performance in some cases).

BTW probably the mix of ReverseDiff and Zygote also explains why they both yield the same result.

from bijectors.jl.

paschermayr avatar paschermayr commented on June 12, 2024

Thank you very much for the information, that helps a lot!

Regarding problem 2: When I set get_logpost2( MvNormal([X=>1, Y=>1]), ...), I do get identical results for both ForwardDiff and ReverseDiff. This problem only arises as soon as get_logpost2( MvNormal([X<1, Y<1]), ...).

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024

I noticed the following

julia> using DistributionsAD, Distributions, FiniteDifferences, ForwardDiff, Tracker, ReverseDiff, Zygote

julia> function h(theta, x)
           dist = MvNormal(theta)
           return logpdf(dist, x)
       end
h (generic function with 1 method)

julia> x = randn(2)
2-element Array{Float64,1}:
 -0.2728447604886671
  0.2954077831645062

julia> FiniteDifferences.grad(central_fdm(5, 1), x -> h([0.1, 0.1], x), x)[1]
2-element Array{Float64,1}:
  27.28447604886685
 -29.540778316450915

julia> ForwardDiff.gradient(x -> h([0.1, 0.1], x), x)
2-element Array{Float64,1}:
  27.284476048866708
 -29.540778316450616

julia> Tracker.gradient(x -> h([0.1, 0.1], x), theta)[1]
Tracked 2-element Array{Float64,1}:
  2728.4476048866704
 -2954.077831645061

julia> ReverseDiff.gradient(x -> h([0.1, 0.1], x), x)
2-element Array{Float64,1}:
  2728.4476048866704
 -2954.077831645061

julia> Zygote.gradient(x -> h([0.1, 0.1], x), x)[1]
2-element Array{Float64,1}:
  27.284476048866708
 -29.54077831645062

with

  [31c24e10] Distributions v0.23.4
  [26cc04aa] FiniteDifferences v0.10.2
  [f6369f11] ForwardDiff v0.10.10
  [37e2e3b7] ReverseDiff v1.2.0
  [9f7883ad] Tracker v0.2.7
  [e88e6eb3] Zygote v0.5.1

Moreover, I get

julia> FiniteDifferences.grad(central_fdm(5, 1), x -> h([10.0, 10.0], x), x)[1]
2-element Array{Float64,1}:
  0.0027284476047971617
 -0.0029540778316831193

julia> ForwardDiff.gradient(x -> h([10.0, 10.0], x), x)
2-element Array{Float64,1}:
  0.0027284476048866713
 -0.002954077831645062

julia> Tracker.gradient(x -> h([10.0, 10.0], x), theta)[1]
Tracked 2-element Array{Float64,1}:
  2.7284476048866714e-5
 -2.954077831645062e-5

julia> ReverseDiff.gradient(x -> h([10.0, 10.0], x), x)
2-element Array{Float64,1}:
  2.7284476048866714e-5
 -2.954077831645062e-5

julia> Zygote.gradient(x -> h([10.0, 10.0], x), x)[1]
2-element Array{Float64,1}:
  0.0027284476048866713
 -0.002954077831645062

so for me it fails for standard deviations < 1 and > 1 consistently for Tracker and ReverseDiff. Something seems to be broken there, I'll open an issue over at DistributionsAD.

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024

With TuringLang/DistributionsAD.jl#92, FiniteDifferences, Tracker, ForwardDiff, and ReverseDiff all yield the same gradient. Seems Zygote is still broken though, not sure why yet.

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024

The latest commit in TuringLang/DistributionsAD.jl#92 fixes Zygote as well, so the second part of this issue here is resolved.

from bijectors.jl.

devmotion avatar devmotion commented on June 12, 2024

The AD issues are fixed in DistributionsAD 0.6.2.

from bijectors.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.