Giter VIP home page Giter VIP logo

chainrules.jl's People

Contributors

alexrobson avatar andreasnoack avatar ararslan avatar carlolucibello avatar chrisrackauckas avatar cossio avatar dependabot[bot] avatar devmotion avatar dfdx avatar eloceanografo avatar frankschae avatar github-actions[bot] avatar gxyd avatar jrevels avatar keno avatar masonprotter avatar mattbrzezinski avatar mcabbott avatar moelf avatar mzgubic avatar nickrobinson251 avatar niklasschmitz avatar nmheim avatar oxinabox avatar rainerheintzmann avatar rofinn avatar sethaxen avatar simeonschaub avatar willtebbutt avatar yingboma avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

chainrules.jl's Issues

Customizing in-place accumulation

In porting Nabla's reverse mode sensitivity definitions to ChainRules, I've found that it's not particularly straightforward to port those which directly update a value on the tape. The equivalent operation in ChainRules appears to be accumulate!(value, rule, args...), which accumulates the result of rule(args...) to value. However, most rules are defined in terms of anonymous functions which are themselves defined within frule/rrule, which means we can't specialize accumulate! on particular rules in the outer scope.

I see a few options here:

  1. Move rule definitions which have updating counterparts to separate functions, so that one can write

    dfdx(x) = something
    rrule(f, x) = (f(x), Rule(dfdx))
    accumulate!(value, ::Rule{typeof(dfdx)}, x) = whatever

    One downside to this is that evaluation of the rule and accumulating the result to an existing value can't share any intermediate steps, which could mean using more memory. Another downside is that it would require reshuffling all of the current definitions outside of frule and rrule into a bunch of top-level named functions.

    On the other hand, an upside to this is that we could precompile the functions defined at the top level, which would likely help with #35.

  2. Have Rules store a second function alongside the sensitivity propagation function which can be used for in-place updating. Then we would have

    struct Rule{F<:Function,U<:Union{Function,Nothing}} <: AbstractRule
        f::F
        u::U
    end
    Rule(f) = Rule{typeof(f),Nothing}(f, nothing)

    accumulate! would then use the u function for updating if it isn't nothing, otherwise it would fall back to the current definition in terms of materialize!. The function u would have a signature like u(value, args...) which accumulates f(args...) to value.

  3. Something else entirely.

I kind of like option 2, though 1 seems fine as well, if a bit suboptimal in terms of the work required to implement that change up front. Thoughts?

FastMath operations

It would be good to have support for @fastmath, especially the functions that ccall libm.

Not sure if it's the right approach here, but in Zygote we did this by looping over the usual DiffRules.

This is a bit of a blocker for FluxML/Zygote.jl#366, because we want to remove DiffRules as part of that, but currently doing so would cause a regression on this.

Test concrete type constructors in addition to UnionAlls

In #23, we realized that while it's straightforward to test sensitivities for one-argument UnionAll constructors, e.g. Symmetric(X) and Diagonal(X), things can get more complicated when attempting to test constructors for concrete subtypes. The example in the linked PR was for Symmetric{T,M}, which requires a second argument when used as constructor. We should find a way to make it easier to test such things, perhaps by refactoring rrule_test?

Statically determine whether rule exists

Is there a more reasonable way to statically determine whether a function has a rule (with a certain number of arguments) than the following:

@generated function hasfrule(f::F, ::Val{N}) where {F, N}
    args = fill(zero(Real), N)
    ChainRules.frule(Core.Compiler.singleton_type(f), args...) === nothing ? :false : :true
end

Because currently, you return nothing as fallback, so I see no other possibility than just calling the function. I imagined there might be something generated by the rule definition macro, but found nothing.

Or is there an important reason not to (be able) to perform such a check?

FillArrays

There are currently instances when we want to use arrays of zeros in which size information is retained - this is not something supported by the Zero differential, but which is supported by FillArrays.jl's Zeros type. I would be nice to use this here, but are there perhaps issues with it not defining setindex?

Related to #15

Make sensitivities for structured matrix arguments structured

Quoting Will in #29:

FWIW, the other thing to think about is what is actually happening computationally under the hood. Ultimately the Diagonal matrix type doesn't use any off-diagonal elements when used in e.g. a matrix-matrix multiply - the Diagonal type simply doesn't allow you to have non-zero off-diagonal elements, so it's a slightly odd question to ask what happens if you perturb the off-diagonals by an infinitesimal amount (i.e. compute the gradient w.r.t. them).

It's this slightly weird situation in which thinking about a Diagonal matrix as a regular dense matrix that happens to contain zeros on its off-diagonals isn't really faithful to the semantics of the type (not sure if I've really phrased that correctly, but hopefully the gist is clear)

Broadcasting with differentiable functions (Remove Cast?)

The current implementation of broadcast assumes that the function being broadcasted doesn't contain any differentiable bits, and that we can therefore safely assume that there is no gradient information to be associated with it. It also assumes that the forwards-inside-reverse-mode trick is the correct choice for implementing the adjoint, which isn't necessarily the case.

Presumably this implementation is a placeholder, however, it will definitely be necessary to relax the above assumptions before e.g. Zygote is able to adopt ChainRules, so I believe it should be addressed as a priority.

Rules for turning lazy broadcasted to eager one?

Zygote.jl has some definitions for broadcasted such as this:

@adjoint function broadcasted(::typeof(tanh), x::Numeric)
  y = tanh.(x)
  y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2))
end

Notice ed, not broadcast. It's converting a lazy broadcast to eager one to store the temporary values.

I'm wondering if this kind of rules should be ported to ChainRules.jl. I think the way ChainRulesCore.jl express the rules cannot be used to auto-generate this kind of specializations (both before and after JuliaDiff/ChainRulesCore.jl#30). However, implementing this rule in ChainRules.jl means that AD engines cannot choose to use it or not.

Does it make sense to have it in ChainRules.jl? Or should it be done for each AD implementation?

(Maybe related to @willtebbutt's question here #12 (comment) ?)

Warning: Error requiring NaNMath/SpecialFunctions from ChainRules

Julia 1.1.0 - is this expected?

(v1.1) pkg> add ChainRules.jl
  Updating registry at `C:\Users\awf\.julia\registries\General`
  Updating git-repo `https://github.com/JuliaRegistries/General.git`
 Resolving package versions...
 Installed ChainRulesCore ─ v0.2.0
 Installed FFTW ─────────── v1.0.1
 Installed FillArrays ───── v0.7.2
 Installed ChainRules ───── v0.1.1
 Installed DataStructures ─ v0.17.1
  Updating `C:\Users\awf\.julia\environments\v1.1\Project.toml`
  [082447d4] + ChainRules v0.1.1
  Updating `C:\Users\awf\.julia\environments\v1.1\Manifest.toml`
  [082447d4] + ChainRules v0.1.1
  [d360d2e6] + ChainRulesCore v0.2.0
  [864edb3b] ↑ DataStructures v0.17.0 ⇒ v0.17.1
  [7a1cc6ca] ↑ FFTW v0.3.0 ⇒ v1.0.1
  [1a297f60] ↑ FillArrays v0.7.0 ⇒ v0.7.2
  Building FFTW → `C:\Users\awf\.julia\packages\FFTW\MJ7kl\deps\build.log`

julia> using ChainRules
[ Info: Precompiling ChainRules [082447d4-558c-5d27-93f4-14fc19e9eca2]
┌ Warning: Error requiring NaNMath from ChainRules:
│ LoadError: ArgumentError: Package ChainRules does not have NaNMath in its dependencies:
│ - If you have ChainRules checked out for development and have
│   added NaNMath as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ChainRules
│ Stacktrace:
│  [1] require(::Module, ::Symbol) at .\loading.jl:836
│  [2] include at .\boot.jl:326 [inlined]

Fixup SpecialFunctions `lgamma` deprecation

probably means just renaming lgamma to just be loggamma, but not 100% sure if that will get things write for both Complex and Real values

@scalar_rule(SpecialFunctions.lgamma(x), SpecialFunctions.digamma(x))

See JuliaMath/SpecialFunctions.jl#156


Here's the noisy warning i see currently (when running tests)

 ┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = rrule(::typeof(lgamma), ::Int64) at rule_definition_tools.jl:104
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:104
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = #test_scalar#3(::Float64, ::Float64, ::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(test_scalar), ::typeof(lgamma), ::Int64) at test_util.jl:27
└ @ Main ~/projects/autodiff/dev/ChainRules/test/test_util.jl:27
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = fdm(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::typeof(lgamma), ::Int64, ::Val{true}) at methods.jl:222
└ @ FiniteDifferences ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:222
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Int64,Array{Int64,1},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Int64,UnitRange{Int64},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = frule(::typeof(lgamma), ::Int64) at rule_definition_tools.jl:99
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:99
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = rrule(::typeof(lgamma), ::Float64) at rule_definition_tools.jl:104
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:104
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = #test_scalar#3(::Float64, ::Float64, ::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(test_scalar), ::typeof(lgamma), ::Float64) at test_util.jl:27
└ @ Main ~/projects/autodiff/dev/ChainRules/test/test_util.jl:27
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = fdm(::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::typeof(lgamma), ::Float64, ::Val{true}) at methods.jl:222
└ @ FiniteDifferences ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:222
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Float64,Array{Int64,1},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = #20 at methods.jl:263 [inlined]
└ @ Core ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Float64,UnitRange{Int64},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Float64,UnitRange{Int64},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = _mapreduce(::FiniteDifferences.var"#20#22"{typeof(lgamma),Float64,UnitRange{Int64},Array{Float64,1},Float64}, ::typeof(Base.add_sum), ::IndexLinear, ::Base.OneTo{Int64}) at methods.jl:263
└ @ Base ~/.julia/packages/FiniteDifferences/kgtFk/src/methods.jl:263
┌ Warning: `lgamma(x::Real)` is deprecated, use `(logabsgamma(x))[1]` instead.
│   caller = frule(::typeof(lgamma), ::Float64) at rule_definition_tools.jl:99
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:99
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = rrule(::typeof(lgamma), ::Complex{Float64}) at rule_definition_tools.jl:104
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:104
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #test_scalar#3(::Float64, ::Float64, ::FiniteDifferences.Central{UnitRange{Int64},Array{Float64,1}}, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(test_scalar), ::typeof(lgamma), ::Complex{Float64}) at test_util.jl:27
└ @ Main ~/projects/autodiff/dev/ChainRules/test/test_util.jl:27
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #4 at test_util.jl:37 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:37
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = #5 at test_util.jl:38 [inlined]
└ @ Core ~/projects/autodiff/dev/ChainRules/test/test_util.jl:38
┌ Warning: `lgamma(x::Number)` is deprecated, use `loggamma(x)` instead.
│   caller = frule(::typeof(lgamma), ::Complex{Float64}) at rule_definition_tools.jl:99
└ @ ChainRules.SpecialFunctionsGlue ~/.julia/packages/ChainRulesCore/ZpveH/src/rule_definition_tools.jl:99

Roadmap to v2

Many repos have a roadmap to v1.
This plan goes a little further.

The idea is:

ChainRules/ChainRulesCore v1:

  • It pragmatically works.
  • It is completely usable, and is in use by multiple AutoDiff systems by the time we tag v1.
  • But it might play fast and lose with the math, like it will probably keep something like the current defintion of AbstractDifferential, which conflates the idea of scaling and adding.
  • It has some edge cases that we just bail out on and don't handle representing. This may include complex deriviatives, and it probably will include mutation/nonpure functions.
  • It will support Julia 1.0

ChainRules/ChainRulesCore v2:

  • The math will be tighter
  • After a while of use in public we will have identified the actual edge cases that matter of things we can't represent. and it will be enhanced to support those
  • It may not support julia 1.0, e.g. might depend on features added in 1.3

This roadmap needs some refining,
but this was the big picture @willtebbutt and I were talking about

Logo?

It would brighten up the docs to have a nice logo.

This is what i have so far.

using Luxor
using Random

const bridge_len = 50

function chain(jiggle=0)
    shaky_rotate(θ) = rotate(θ + jiggle*(rand()-0.5))
    
    ### 1
    shaky_rotate(0)
    sethue(Luxor.julia_red)
    translate(0,-150);
    link()
    m1 = getmatrix()
    
    
    ### 2
    sethue(Luxor.julia_green)
    translate(-50, 130);
    shaky_rotate(π/3); 
    link()
    m2 = getmatrix()
    
    setmatrix(m1)
    sethue(Luxor.julia_red)
    sector(Point(0, bridge_len), 50, 90, 0, -1.3π, :fill)
    setmatrix(m2)
    
    ### 3
    shaky_rotate(-π/3);
    translate(-120,80);
    sethue(Luxor.julia_purple)
    link()
    
    setmatrix(m2)
    setcolor(Luxor.julia_green)
    sector(Point(0, bridge_len), 50, 90, -0., -1.5π, :fill)
end


function link()
    sector(50, 90, π, 0, :fill)
    sector(Point(0, bridge_len), 50, 90, 0, -π, :fill)
    
    sector(54, 90, π, 0, :fill)
    sector(Point(0, bridge_len), 50, 90, 0, -π, :fill)
    
    rect(50,-3,40, bridge_len+6, :fill)
    rect(-50-40,-3,40, bridge_len+6, :fill)
end

Random.seed!(1)
@png begin
    background("black")
    chain(0.5)
end

logo showing 3 links in a chain

Changes to Docs

Lots has changed since the docs were first written. #152 addresses a number of things, but there are a few more things that we might want to consider:

  • changing all references to autodiff / automatic differentiation to AD / algorithmic differentiation, with a terminology box in the docs somewhere, explaining what we're on about.
  • In the "On writing good rrule and frule " bit, we should consider modifying the recommendations for Zero() or One(). In particular, removing any mention of One because it's not generally appropriate or helpful. Additionally, the "Write Tests" section should probably be entitled "Write Tests using FiniteDifferences". Similarly, the "CAS Systems are your friends" should have some reference to them not being helpful when writing tests for derivatives.

Additionally, the section of the docs on Differentials needs to be expanded / modified to make the following points:

  • is the only operation guaranteed to be defined on differentials and, moreover, it's only guaranteed to be defined between differentials that are valid for the same primal type.
  • is only guaranteed to be defined between scalars and differentials. You can't generally multiple a differential by another differential
  • between primals and differentials isn't always defined, but it is for a lot of interesting cases e.g. Real, Matrix{<:Real} etc
    stuff that I've missed that we've figured out since this section was last written.

Rules for dervative of asec and acsc with complex input incorrect?

Found when doing #76

asec

x = 1.6 - 0.8im
(central_fdm(5, 1))(asec, x) = 0.16078185380606091 + 0.29837305540037473im
(last(frule(asec, x)))(1) = 0.27724414876919296 + 0.19496914288010564im

WolframAlpha says 0.160782 + 0.298373 i
Argreeing with the central_fdm

acsc

x = 1.6 - 0.8im
(central_fdm(5, 1))(acsc, x) = -0.16078185380578297 - 0.29837305540024617im
(last(frule(acsc, x)))(1) = -0.27724414876919296 - 0.19496914288010564im

WolframAlpha says -0.160782 - 0.298373 i
Argreeing with the central_fdm

@simeonschaub do you know what is going on here?
It has been years since I did complex calculus

generic `rrule` for all `getfield`

If I am correct: the rrule for getfield is always:

function rrule(::typeof(getfield), val::P, field::Symbol) where P
    getfield_pullback(dy) = NO_FIELDS, Composite{P}(; (field => dy,)...)
    return getfield(val, field), getfield_pullback
end

Audit performance of captured variables in closures

One of the benefits of ChainRules' design is that rules for multiple arguments can share intermediate computations by virtue of defining variables outside of the individual Rules then capturing them in the wrapped closures. However, this approach likely incurs the infamous JuliaLang/julia#15276. To get around this, we could potentially define a macro that does a Rule definition but scans the closure expression for uses of variables not defined therein, and wrapping those in let. That would be heinously hacky but might buy us some performance improvements.

Wrong derivative for `conj`

Shouldn't the derivative of conj with a complex argument be Wirtinger(Zero(), One())? Instead I get this:

julia> w, dw = frule(conj, 1+im)
(1 - 1im, ChainRules.WirtingerRule{ChainRules.Rule{getfield(ChainRules, Symbol("##304#308"))},ChainRules.Rule{getfield(ChainRules, Symbol("##305#309"))}}(ChainRules.Rule{getfield(ChainRules, Symbol("##304#308"))}(getfield(ChainRules, Symbol("##304#308"))(Core.Box(One()))), ChainRules.Rule{getfield(ChainRules, Symbol("##305#309"))}(getfield(ChainRules, Symbol("##305#309"))(Core.Box(One())))))

julia> dw(One())
One()

julia> dw(1 + 0im)
ChainRules.Wirtinger{Complex{Int64},ChainRules.Zero}(1 + 0im, ChainRules.Zero())

Audit cache friendliness of Cholesky rrule

In #44, I directly ported over code from Nabla, which contains the following comment:

See [1] for implementation details: pages 5-9 in particular. The derivations presented in
[1] assume column-major layout, whereas Julia primarily uses row-major. We therefore
implement both the derivations in [1] and their transpose, which is more appropriate to
Julia.

[1] - "Differentiation of the Cholesky decomposition", Murray 2016

Julia actually uses column-major layout for its arrays. I haven't dug into the code to determine whether the implementation matches the comment, i.e. assumes row-major, but it would be worthwhile to do so to ensure we're not being unnecessarily inefficient.

Writie initial Getting Started documentation

We currently have some deployed docs at http://www.juliadiff.org/ChainRules.jl/latest/
crying out for "getting started" documentation.

Currently 4 sections are suggested

  • Forward-Mode vs. Reverse-Mode Chain Rule Evaluation
  • Real Scalar Differentiation Rules
  • Complex Scalar Differentiation Rules
  • Non-Scalar Differentiation Rules
    ...but some other categorisation might work fine.

Adding even one section would be a great start.

Lots of things have docstrings now... but adding more would also be helpful!

Make `test_utils` available as a stand alone package

  • The test utils that use finite differences to check correctness are nice, and probably could be useful for people adding sensitivites in their own packages.
  • We could make this a submodule, like ChainRules.TestUtils so people can use this functionality in their own tests
  • These utils depend on FDM/FiniteDifferences.jl, which is probably not a dependency we want to add to ChainRulesCore.jl, but is less bad here where we already have heavier dependencies (and it'd be a test-only dependency for users)
  • If this functionality does turn out to be useful and widely used, then we can easily make this its own package in future

generic frule for constructors?

Like for #153
there is a generic frule for default constructors.
It relies on the arguments being the same as the fields so would have to detect that.

Its basically:

function frule(::Type{P}, args..., _, dargs...) where P
     y  = P(args...)
     dy = Composite{P}(; zip(fieldnames(P, dargs)...)
end

the more serious implementation might be:

function frule(::Type{P}, all_args...) where P
     nargs = fieldcount(P)
     length(all_args) == 2nargs + 1 || return nothing
     args = @inbounds all_args[1:nargs]
     all(typeof.(args) .== fieldtypes(P)) || return nothing
     dargs = @inbounds all_args[end-nargs : end]  # skip dself as constructors never functors
     y  = P(args...)
     dy = Composite{P}(; zip(fieldnames(P, dargs)...)
end

But this may well need to be written as a generated function if it doesn't constant fold good.

Change to BlueStyle

Talked to @jrevels about this a few months back. And he wasfine with it.
Right now the code is roughly in the intersection of what YASGuide and BlueStyle allow.
Which is pretty large as they overlap significantly.

But given most people maintaining it now are used to BlueStyle, we should use that.

#80 doesn't follow BlueStyle indenting.
And that's fine because the readme says YASGuide.
But we should change that

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.