juliadiff / chainrules.jl Goto Github PK
View Code? Open in Web Editor NEWforward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
License: Other
forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
License: Other
In porting Nabla's reverse mode sensitivity definitions to ChainRules, I've found that it's not particularly straightforward to port those which directly update a value on the tape. The equivalent operation in ChainRules appears to be accumulate!(value, rule, args...)
, which accumulates the result of rule(args...)
to value
. However, most rules are defined in terms of anonymous functions which are themselves defined within frule
/rrule
, which means we can't specialize accumulate!
on particular rules in the outer scope.
I see a few options here:
Move rule definitions which have updating counterparts to separate functions, so that one can write
dfdx(x) = something
rrule(f, x) = (f(x), Rule(dfdx))
accumulate!(value, ::Rule{typeof(dfdx)}, x) = whatever
One downside to this is that evaluation of the rule and accumulating the result to an existing value can't share any intermediate steps, which could mean using more memory. Another downside is that it would require reshuffling all of the current definitions outside of frule
and rrule
into a bunch of top-level named functions.
On the other hand, an upside to this is that we could precompile the functions defined at the top level, which would likely help with #35.
Have Rule
s store a second function alongside the sensitivity propagation function which can be used for in-place updating. Then we would have
struct Rule{F<:Function,U<:Union{Function,Nothing}} <: AbstractRule
f::F
u::U
end
Rule(f) = Rule{typeof(f),Nothing}(f, nothing)
accumulate!
would then use the u
function for updating if it isn't nothing
, otherwise it would fall back to the current definition in terms of materialize!
. The function u
would have a signature like u(value, args...)
which accumulates f(args...)
to value
.
Something else entirely.
I kind of like option 2, though 1 seems fine as well, if a bit suboptimal in terms of the work required to implement that change up front. Thoughts?
We have rules for mean
, except for mean(f, x; dims)
which is new as of Julia v1.3
New rules are introduced in #80
It would be good to have support for @fastmath
, especially the functions that ccall
libm.
Not sure if it's the right approach here, but in Zygote we did this by looping over the usual DiffRules
.
This is a bit of a blocker for FluxML/Zygote.jl#366, because we want to remove DiffRules as part of that, but currently doing so would cause a regression on this.
Currently it sgn
which is apparently not correct.
The complex absolute value function is continuous everywhere but complex differentiable nowhere because it violates the Cauchy–Riemann equations.[13]
In #23, we realized that while it's straightforward to test sensitivities for one-argument UnionAll
constructors, e.g. Symmetric(X)
and Diagonal(X)
, things can get more complicated when attempting to test constructors for concrete subtypes. The example in the linked PR was for Symmetric{T,M}
, which requires a second argument when used as constructor. We should find a way to make it easier to test such things, perhaps by refactoring rrule_test
?
I see in https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/LinearAlgebra/factorization.jl
we use a fair bit of A'
which is adjoint
.
In other places @simeonschaub has said we need to not use that, and to use transpose
instread,
so we don't recursively conjugate the complex numbers
Is there a more reasonable way to statically determine whether a function has a rule (with a certain number of arguments) than the following:
@generated function hasfrule(f::F, ::Val{N}) where {F, N}
args = fill(zero(Real), N)
ChainRules.frule(Core.Compiler.singleton_type(f), args...) === nothing ? :false : :true
end
Because currently, you return nothing
as fallback, so I see no other possibility than just calling the function. I imagined there might be something generated by the rule definition macro, but found nothing.
Or is there an important reason not to (be able) to perform such a check?
There are currently instances when we want to use arrays of zeros in which size information is retained - this is not something supported by the Zero
differential, but which is supported by FillArrays.jl
's Zeros
type. I would be nice to use this here, but are there perhaps issues with it not defining setindex
?
Related to #15
Quoting Will in #29:
FWIW, the other thing to think about is what is actually happening computationally under the hood. Ultimately the
Diagonal
matrix type doesn't use any off-diagonal elements when used in e.g. a matrix-matrix multiply - theDiagonal
type simply doesn't allow you to have non-zero off-diagonal elements, so it's a slightly odd question to ask what happens if you perturb the off-diagonals by an infinitesimal amount (i.e. compute the gradient w.r.t. them).It's this slightly weird situation in which thinking about a
Diagonal
matrix as a regular dense matrix that happens to contain zeros on its off-diagonals isn't really faithful to the semantics of the type (not sure if I've really phrased that correctly, but hopefully the gist is clear)
So as to match the structure used for StdLibs,
following the pattern of LinearAlgebra
If the function being broadcasted is a closure,
then the derivative w.r.t that function is not DNE.
and for now we should maybe just bail out.
something like adding to the start of each rrule
and frule
with them
fieldcount(typeof(f)) > 0 && return nothing
The current implementation of broadcast assumes that the function being broadcast
ed doesn't contain any differentiable bits, and that we can therefore safely assume that there is no gradient information to be associated with it. It also assumes that the forwards-inside-reverse-mode trick is the correct choice for implementing the adjoint, which isn't necessarily the case.
Presumably this implementation is a placeholder, however, it will definitely be necessary to relax the above assumptions before e.g. Zygote is able to adopt ChainRules, so I believe it should be addressed as a priority.
Zygote.jl has some definitions for broadcasted
such as this:
@adjoint function broadcasted(::typeof(tanh), x::Numeric)
y = tanh.(x)
y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2))
end
Notice ed
, not broadcast
. It's converting a lazy broadcast to eager one to store the temporary values.
I'm wondering if this kind of rules should be ported to ChainRules.jl. I think the way ChainRulesCore.jl express the rules cannot be used to auto-generate this kind of specializations (both before and after JuliaDiff/ChainRulesCore.jl#30). However, implementing this rule in ChainRules.jl means that AD engines cannot choose to use it or not.
Does it make sense to have it in ChainRules.jl? Or should it be done for each AD implementation?
(Maybe related to @willtebbutt's question here #12 (comment) ?)
Julia 1.1.0 - is this expected?
(v1.1) pkg> add ChainRules.jl
Updating registry at `C:\Users\awf\.julia\registries\General`
Updating git-repo `https://github.com/JuliaRegistries/General.git`
Resolving package versions...
Installed ChainRulesCore ─ v0.2.0
Installed FFTW ─────────── v1.0.1
Installed FillArrays ───── v0.7.2
Installed ChainRules ───── v0.1.1
Installed DataStructures ─ v0.17.1
Updating `C:\Users\awf\.julia\environments\v1.1\Project.toml`
[082447d4] + ChainRules v0.1.1
Updating `C:\Users\awf\.julia\environments\v1.1\Manifest.toml`
[082447d4] + ChainRules v0.1.1
[d360d2e6] + ChainRulesCore v0.2.0
[864edb3b] ↑ DataStructures v0.17.0 ⇒ v0.17.1
[7a1cc6ca] ↑ FFTW v0.3.0 ⇒ v1.0.1
[1a297f60] ↑ FillArrays v0.7.0 ⇒ v0.7.2
Building FFTW → `C:\Users\awf\.julia\packages\FFTW\MJ7kl\deps\build.log`
julia> using ChainRules
[ Info: Precompiling ChainRules [082447d4-558c-5d27-93f4-14fc19e9eca2]
┌ Warning: Error requiring NaNMath from ChainRules:
│ LoadError: ArgumentError: Package ChainRules does not have NaNMath in its dependencies:
│ - If you have ChainRules checked out for development and have
│ added NaNMath as a dependency but haven't updated your primary
│ environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ChainRules
│ Stacktrace:
│ [1] require(::Module, ::Symbol) at .\loading.jl:836
│ [2] include at .\boot.jl:326 [inlined]
Not really just a new rule, just a better one?
https://papers.nips.cc/paper/8579-backpropagation-friendly-eigendecomposition.pdf
probably means just renaming lgamma
to just be loggamma
, but not 100% sure if that will get things write for both Complex and Real values
See JuliaMath/SpecialFunctions.jl#156
Here's the noisy warning i see currently (when running tests)
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = rrule(::typeof(lgamma), ::Int64) at rule_definition_tools.jl:104
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:104
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = #test_scalar#3(::Float64, ::Float64, ::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(test_scalar), ::typeof(lgamma), ::Int64) at test_util.jl:27
└ @ Main ~/projects/autodiff/dev/ChainRules/test/test_util.jl:27
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = fdm(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::typeof(lgamma), ::Int64, ::Val{true}) at methods.jl:222
└ @ FiniteDifferences ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:222
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Int64,Array{Int64,1},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Int64,UnitRange{Int64},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = frule(::typeof(lgamma), ::Int64) at rule_definition_tools.jl:99
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:99
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = rrule(::typeof(lgamma), ::Float64) at rule_definition_tools.jl:104
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:104
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = #test_scalar#3(::Float64, ::Float64, ::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(test_scalar), ::typeof(lgamma), ::Float64) at test_util.jl:27
└ @ Main ~/projects/autodiff/dev/ChainRules/test/test_util.jl:27
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = fdm(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::typeof(lgamma), ::Float64, ::Val{true}) at methods.jl:222
└ @ FiniteDifferences ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:222
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Float64,Array{Int64,1},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Float64,UnitRange{Int64},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Float64,UnitRange{Int64},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Float64,UnitRange{Int64},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│ caller = frule(::typeof(lgamma), ::Float64) at rule_definition_tools.jl:99
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:99
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = rrule(::typeof(lgamma), ::Complex{Float64}) at rule_definition_tools.jl:104
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:104
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #test_scalar#3(::Float64, ::Float64, ::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(test_scalar), ::typeof(lgamma), ::Complex{Float64}) at test_util.jl:27
└ @ Main ~/projects/autodiff/dev/ChainRules/test/test_util.jl:27
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│ caller = frule(::typeof(lgamma), ::Complex{Float64}) at rule_definition_tools.jl:99
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:99
For functions like map(f, xs)
if we don't have a an rrule
/frule
defined for f
we should bail out early by returning nothing
or just not define the rrule
/frule
at all.
Since checkpointing can be implemented as a sensitivity,
it may very well belong here.
See:
Many repos have a roadmap to v1.
This plan goes a little further.
The idea is:
AbstractDifferential
, which conflates the idea of scaling and adding.This roadmap needs some refining,
but this was the big picture @willtebbutt and I were talking about
It would brighten up the docs to have a nice logo.
This is what i have so far.
using Luxor
using Random
const bridge_len = 50
function chain(jiggle=0)
shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5))
### 1
shaky_rotate(0)
sethue(Luxor.julia_red)
translate(0,-150);
link()
m1 = getmatrix()
### 2
sethue(Luxor.julia_green)
translate(-50, 130);
shaky_rotate(π/3);
link()
m2 = getmatrix()
setmatrix(m1)
sethue(Luxor.julia_red)
sector(Point(0, bridge_len), 50, 90, 0, -1.3π, :fill)
setmatrix(m2)
### 3
shaky_rotate(-π/3);
translate(-120,80);
sethue(Luxor.julia_purple)
link()
setmatrix(m2)
setcolor(Luxor.julia_green)
sector(Point(0, bridge_len), 50, 90, -0., -1.5π, :fill)
end
function link()
sector(50, 90, π, 0, :fill)
sector(Point(0, bridge_len), 50, 90, 0, -π, :fill)
sector(54, 90, π, 0, :fill)
sector(Point(0, bridge_len), 50, 90, 0, -π, :fill)
rect(50,-3,40, bridge_len+6, :fill)
rect(-50-40,-3,40, bridge_len+6, :fill)
end
Random.seed!(1)
@png begin
background("black")
chain(0.5)
end
It appears there are adjoint rules for FFT
in Zygote here:
FluxML/Zygote.jl#215
As found for hypot
#74
we are not testing some of our simpler scalar rules with FiniteDifferences.
We have easy helpers for how to do that in test/test_utils.jl
but we just are not using them
Lots has changed since the docs were first written. #152 addresses a number of things, but there are a few more things that we might want to consider:
Additionally, the section of the docs on Differentials needs to be expanded / modified to make the following points:
Found when doing #76
asec
x = 1.6 - 0.8im
(central_fdm(5, 1))(asec, x) = 0.16078185380606091 + 0.29837305540037473im
(last(frule(asec, x)))(1) = 0.27724414876919296 + 0.19496914288010564im
WolframAlpha says 0.160782 + 0.298373 i
Argreeing with the central_fdm
acsc
x = 1.6 - 0.8im
(central_fdm(5, 1))(acsc, x) = -0.16078185380578297 - 0.29837305540024617im
(last(frule(acsc, x)))(1) = -0.27724414876919296 - 0.19496914288010564im
WolframAlpha says -0.160782 - 0.298373 i
Argreeing with the central_fdm
@simeonschaub do you know what is going on here?
It has been years since I did complex calculus
Most of the functions in
https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/LinearAlgebra/dense.jl
right now
don't belong there.
As they are loaded even if LinearAlgebra isn't.
Now right now we don't have a Requires.jl block around that,
so it doesn't matter,
but conceptually we could.
Also it is suprising to have to go look for them there.
I know it has caught people out before.
When this issue is closed: JuliaDiff/ChainRulesTestUtils.jl#2 then ChainRulesTestUtils.jl
will be a registered package.
We can remove this file and introduce ChainRulesTestUtils.jl
as a testing package dependency.
If I am correct: the rrule for getfield
is always:
function rrule(::typeof(getfield), val::P, field::Symbol) where P
getfield_pullback(dy) = NO_FIELDS, Composite{P}(; (field => dy,)...)
return getfield(val, field), getfield_pullback
end
So we don't have that dependency to keep this as a light weight recipe function.
One of the benefits of ChainRules' design is that rules for multiple arguments can share intermediate computations by virtue of defining variables outside of the individual Rule
s then capturing them in the wrapped closures. However, this approach likely incurs the infamous JuliaLang/julia#15276. To get around this, we could potentially define a macro that does a Rule
definition but scans the closure expression for uses of variables not defined therein, and wrapping those in let
. That would be heinously hacky but might buy us some performance improvements.
Seems the last set of changes to the docs are not showing up on
http://www.juliadiff.org/ChainRules.jl/dev/
Shouldn't the derivative of conj
with a complex argument be Wirtinger(Zero(), One())
? Instead I get this:
julia> w, dw = frule(conj, 1+im)
(1 - 1im, ChainRules.WirtingerRule{ChainRules.Rule{getfield(ChainRules, Symbol("##304#308"))},ChainRules.Rule{getfield(ChainRules, Symbol("##305#309"))}}(ChainRules.Rule{getfield(ChainRules, Symbol("##304#308"))}(getfield(ChainRules, Symbol("##304#308"))(Core.Box(One()))), ChainRules.Rule{getfield(ChainRules, Symbol("##305#309"))}(getfield(ChainRules, Symbol("##305#309"))(Core.Box(One())))))
julia> dw(One())
One()
julia> dw(1 + 0im)
ChainRules.Wirtinger{Complex{Int64},ChainRules.Zero}(1 + 0im, ChainRules.Zero())
In #44, I directly ported over code from Nabla, which contains the following comment:
See [1] for implementation details: pages 5-9 in particular. The derivations presented in
[1] assume column-major layout, whereas Julia primarily uses row-major. We therefore
implement both the derivations in [1] and their transpose, which is more appropriate to
Julia.[1] - "Differentiation of the Cholesky decomposition", Murray 2016
Julia actually uses column-major layout for its arrays. I haven't dug into the code to determine whether the implementation matches the comment, i.e. assumes row-major, but it would be worthwhile to do so to ensure we're not being unnecessarily inefficient.
We currently have some deployed docs at http://www.juliadiff.org/ChainRules.jl/latest/
crying out for "getting started" documentation.
Currently 4 sections are suggested
Adding even one section would be a great start.
Lots of things have docstrings now... but adding more would also be helpful!
Related: #198
Some of the code used
zero(X)
, or @thunk(zero(X))
I feel like if Zero
is working right we should be able to replace that with Zero()
E.g.
https://github.com/JuliaDiff/ChainRules.jl/blob/master/test/rules/linalg/factorization.jl
ChainRules.TestUtils
so people can use this functionality in their own testsLike for #153
there is a generic frule for default constructors.
It relies on the arguments being the same as the fields so would have to detect that.
Its basically:
function frule(::Type{P}, args..., _, dargs...) where P
y = P(args...)
dy = Composite{P}(; zip(fieldnames(P, dargs)...)
end
the more serious implementation might be:
function frule(::Type{P}, all_args...) where P
nargs = fieldcount(P)
length(all_args) == 2nargs + 1 || return nothing
args = @inbounds all_args[1:nargs]
all(typeof.(args) .== fieldtypes(P)) || return nothing
dargs = @inbounds all_args[end-nargs : end] # skip dself as constructors never functors
y = P(args...)
dy = Composite{P}(; zip(fieldnames(P, dargs)...)
end
But this may well need to be written as a generated function if it doesn't constant fold good.
Talked to @jrevels about this a few months back. And he wasfine with it.
Right now the code is roughly in the intersection of what YASGuide and BlueStyle allow.
Which is pretty large as they overlap significantly.
But given most people maintaining it now are used to BlueStyle, we should use that.
#80 doesn't follow BlueStyle indenting.
And that's fine because the readme says YASGuide.
But we should change that
A PkgEval run for a PR which changes the generated numbers for randn!
indicates that the tests of this package might fail in Julia 1.5 (and on Julia current master). Apologies if this is a false positive.
cf.
https://github.com/JuliaCI/NanosoldierReports/blob/7de24e455342298cbef56826b5827f0d7640d2c1/pkgeval/by_hash/b89e35c_vs_098ef24/logs/ChainRules/1.5.0-DEV-71a4a114c2.log
We have a lot of scalar rules
A few I know are still missing
ldexp
(WolframAlpha partials)clamp
see https://discuss.pytorch.org/t/torch-round-gradient/28628/3polygamma
Zero
everywhere it is differentiableIdk if we want to close over the point gsps. (We do that for abs
)
round
see https://discuss.pytorch.org/t/torch-round-gradient/28628/3floor
ceil
Composite
was made for this case,
but we haven't changed it over, and it still uses a hack.
we have U
, S
, V
, but not Vt
See
I cannot reproduce it locally. @oxinabox do you mind to take a look?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.