Giter VIP home page Giter VIP logo

Comments (5)

sethaxen avatar sethaxen commented on September 25, 2024 1

Yes, I can look into this.

@baggepinnen I'm not familiar with ForwardDiffChainRules. Can you provide an MWE that exhibits the observed failure using just ChainRules?

Using our own testing machinery, I am unable to observe any failures on 1000x the number of random matrices:

julia> using ChainRules, ChainRulesTestUtils, LinearAlgebra, Random, Test

julia> Random.seed!(42);

julia> @testset "exp!" begin
           Xs = (randn(4, 4) for _ in 1:20_000)
           @testset for X in Xs
               test_frule(LinearAlgebra.exp!, X)
           end
       end;
Test Summary: |  Pass  Total   Time
exp!          | 80000  80000  13.2s

Btw, in a fresh environment, your example errors on my machine in the for loop with:

ERROR: MethodError: no method matching iterate(::Nothing)

Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen})
   @ Base range.jl:880
  iterate(::Union{LinRange, StepRangeLen}, ::Integer)
   @ Base range.jl:880
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
   @ Base dict.jl:698
  ...

Stacktrace:
 [1] indexed_iterate(I::Nothing, i::Int64)
   @ Base ./tuple.jl:91
 [2] exp!(x1::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}})
   @ Main ~/.julia/packages/ForwardDiffChainRules/2Xt9G/src/ForwardDiffChainRules.jl:81
 [3] test_exp(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}})
   @ Main ./REPL[4]:3
 [4] chunk_mode_gradient(f::typeof(test_exp), x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:123
 [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}}, ::Val{true})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:21
 [6] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
 [7] gradient(f::Function, x::Vector{Float64})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
 [8] top-level scope
   @ ./REPL[5]:4

from chainrules.jl.

baggepinnen avatar baggepinnen commented on September 25, 2024

This only appears to be a problem for the non-symmetric version of exp!. When I create a symmetric matrix X = X'X and switch to FiniteDifferences.jl for more accurate testing, it works just fine. The non-symmetric input matrix is still problematic though

using LinearAlgebra, ForwardDiff, ForwardDiffChainRules, FiniteDifferences
@ForwardDiff_frule LinearAlgebra.exp!(x1::AbstractMatrix{<:ForwardDiff.Dual})
function test_exp(x)
    X = copy(reshape(x, 4, 4))
    X2 = LinearAlgebra.exp!(X)
    sum(X2)
end

for i = 1:20
    X = randn(4,4)
    X = X'X
    x = vec(X)
    g1 = ForwardDiff.gradient(test_exp, x)
    g2 = FiniteDifferences.grad(central_fdm(5, 1), test_exp, x)[1]
    @show norm(g1-g2)
end

I've also tested the reverse rule using Zygote and there is no problem in reverse

from chainrules.jl.

oxinabox avatar oxinabox commented on September 25, 2024

@sethaxen do you think you might have time to look into this?

from chainrules.jl.

baggepinnen avatar baggepinnen commented on September 25, 2024

I think the problem is related to how ForwardDiffChainRules deals with (doesn't deal with) the fact that exp! mutates its input argument, by adding a call to copy on the input argument before each invokation of the frule I get the correct results. This is probably an issue with ForwardDiffChainRules then.

from chainrules.jl.

sethaxen avatar sethaxen commented on September 25, 2024

Sounds right. This line calls frule repeatedly on the same primals, so it assumes the function is nonmutating: https://github.com/ThummeTo/ForwardDiffChainRules.jl/blob/d70301a28f61250c3168446c4b147b195ceee117/src/ForwardDiffChainRules.jl#L88

from chainrules.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.