Comments (5)
Yes, I can look into this.
@baggepinnen I'm not familiar with ForwardDiffChainRules. Can you provide an MWE that exhibits the observed failure using just ChainRules?
Using our own testing machinery, I am unable to observe any failures on 1000x the number of random matrices:
julia> using ChainRules, ChainRulesTestUtils, LinearAlgebra, Random, Test
julia> Random.seed!(42);
julia> @testset "exp!" begin
Xs = (randn(4, 4) for _ in 1:20_000)
@testset for X in Xs
test_frule(LinearAlgebra.exp!, X)
end
end;
Test Summary: | Pass Total Time
exp! | 80000 80000 13.2s
Btw, in a fresh environment, your example errors on my machine in the for loop with:
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen})
@ Base range.jl:880
iterate(::Union{LinRange, StepRangeLen}, ::Integer)
@ Base range.jl:880
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
@ Base dict.jl:698
...
Stacktrace:
[1] indexed_iterate(I::Nothing, i::Int64)
@ Base ./tuple.jl:91
[2] exp!(x1::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}})
@ Main ~/.julia/packages/ForwardDiffChainRules/2Xt9G/src/ForwardDiffChainRules.jl:81
[3] test_exp(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}})
@ Main ./REPL[4]:3
[4] chunk_mode_gradient(f::typeof(test_exp), x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}})
@ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:123
[5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}}, ::Val{true})
@ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:21
[6] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}})
@ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
[7] gradient(f::Function, x::Vector{Float64})
@ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
[8] top-level scope
@ ./REPL[5]:4
from chainrules.jl.
This only appears to be a problem for the non-symmetric version of exp!
. When I create a symmetric matrix X = X'X
and switch to FiniteDifferences.jl for more accurate testing, it works just fine. The non-symmetric input matrix is still problematic though
using LinearAlgebra, ForwardDiff, ForwardDiffChainRules, FiniteDifferences
@ForwardDiff_frule LinearAlgebra.exp!(x1::AbstractMatrix{<:ForwardDiff.Dual})
function test_exp(x)
X = copy(reshape(x, 4, 4))
X2 = LinearAlgebra.exp!(X)
sum(X2)
end
for i = 1:20
X = randn(4,4)
X = X'X
x = vec(X)
g1 = ForwardDiff.gradient(test_exp, x)
g2 = FiniteDifferences.grad(central_fdm(5, 1), test_exp, x)[1]
@show norm(g1-g2)
end
I've also tested the reverse rule using Zygote and there is no problem in reverse
from chainrules.jl.
@sethaxen do you think you might have time to look into this?
from chainrules.jl.
I think the problem is related to how ForwardDiffChainRules deals with (doesn't deal with) the fact that exp!
mutates its input argument, by adding a call to copy
on the input argument before each invokation of the frule
I get the correct results. This is probably an issue with ForwardDiffChainRules then.
from chainrules.jl.
Sounds right. This line calls frule
repeatedly on the same primals, so it assumes the function is nonmutating: https://github.com/ThummeTo/ForwardDiffChainRules.jl/blob/d70301a28f61250c3168446c4b147b195ceee117/src/ForwardDiffChainRules.jl#L88
from chainrules.jl.
Related Issues (20)
- Missing frules for `copy` HOT 1
- Sparse vector to real power throws a pullback error
- `rrule` for `mean(f, x)` is not vectorized? HOT 2
- No rules for `typed_hvcat` HOT 1
- Scalar indexing error using prod(...) HOT 1
- Basic functions don't work for CUDA HOT 1
- muladd test code doesn't test complex numbers. HOT 2
- unzip CI broken on 1.0
- SparseInverseSubset.jl dependency causes issues on non-GPL Julia builds HOT 1
- derivatives of more matrix functions HOT 3
- ambiguous rrule for sum of AbstractArray{Bool} HOT 2
- Make `OneElement` more GPU friendly
- `frule` for `sum` doesn't work for `Generator` HOT 1
- `frule`s for `one` are ambiguous
- bump Adapt to 4 in Project.toml HOT 1
- Method ambiguities causing test failures in Julia 1.10+ HOT 1
- Make the rrule for 3-arg dot lazy HOT 6
- `getindex` frule behavior HOT 1
- no method matching ChainRules.OneElement HOT 1
- `rrule` for `map` for tuples is outdated with respect to JuliaLang/julia#42216 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from chainrules.jl.