Comments (11)
Commenting so I keep track of this. If you think something deserves to be in DI, let me know!
cc @adrhill
from lux.jl.
@gdalle what do you think about the last part in https://discourse.julialang.org/t/ann-lux-jl-explicitly-parameterized-neural-networks-in-julia/81689/65?u=avikpal?
If that exists in DI, I can just unwrap StatefulLuxLayer into that DI struct and forward the call
from lux.jl.
Seems like overloading the calls for custom functions won't really work because of ambiguity issues:
from lux.jl.
Can't we just specialize on Lux.StatefulLuxLayer
and pass the arguments to Lux.vector_jacobian_product
for DifferentiationInterface.pullback
?
from lux.jl.
I was trying that for the gradient calls but DI specializes on the extras type which means we will also have to specialize on each extras for all backends
from lux.jl.
To support second order for Enzyme, I introduced DifferentiationInterface.nested(::AbstractADType)
in gdalle/DifferentiationInterface.jl#285. The idea is that it returns a possibly different version of the backend object, which is aware that it is being differentiated. At the moment it doesn't do anything, except for AutoEnzyme
which is turned into a homemade AutoDeferredEnzyme
.
Would this be useful functionality for Lux.jl and friends? Should I make it public / work on it some more?
One could imagine an extension where nested
tells the inner backend what outer backend is trying to differentiate through it.
from lux.jl.
If I understand correctly, Lux
handles nested AD implicitly by replacing the calls (#598) and explicitly with vector_jacobian_product
and jacobian_vector_product
.
@gdalle Can DifferentiationInterface.nested
resolve the need for them? (assuming that everyone only use DI
, not APIs of each package)
from lux.jl.
I'm not sure, cause there are several things one might want to do with nested backends, and depending on the situation this lux replacement trick may not always be appropriate?
from lux.jl.
Just putting it out there in case Avik is inspired. Essentially, modifying the backend is the cleanest approach I could think of for this type of problem
from lux.jl.
To clarify how nested AD works in Lux: It doesn't simply switch the backends, i.e. we don't take a Zygote.gradient(Zygote.gradient(...)...)
call and make it ForwardDiff.gradient(Zygote.gradient(...)...)
, you could in principle do that but you shouldn't (doing that would be computationally terrible). Instead, it changes the operations to a JVP
over a gradient
. Now, just extend that to Jacobians, JVPs, VJPs, etc.
The only case where replacement is not ideal is ForwardDiff.gradient(ForwardDiff.gradient(...))
where the problem size is extremely small, but we don't replace that anyway.
All the other forms of Zygote over ForwardDiff or Zygote over Zygote (or any reverse mode over X-mode) have no computational benefit and will error in most cases, so it does make sense to switch.
Even doing an Enzyme.Reverse
over Enzyme.Reverse
will be a bad idea just because of the overhead of reverse mode1. Basically, for 2nd order (not general nested higher-order AD), it is almost certainly beneficial to switch the operations.
Footnotes
-
Okay, it might be faster if the reverse mode is calling into one of the vendor-specific codes and the forward mode isn't, but that is mostly because we got lazy. ↩
from lux.jl.
Oh right, my nested
trick works because I needed to change the behavior of the inner backend, but here you change the behavior of the outer backend when a gradient is already happening inside. I honestly don't know if there is a nice way to integrate this in DI, especially because we don't handle multiple parameters atm.
from lux.jl.
Related Issues (20)
- Error for JVP by Enzyme HOT 16
- [Nested AD] Incorrect gradient when taking a gradient over a gradient using StatefulLuxLayer HOT 7
- batched_jacobian + CUDA => InvalidIRError HOT 2
- Add a compiled tape version for ReverseDiff
- SimpleChains integration doesn't work with Enzyme HOT 3
- Simple MLP requires Enzyme runtimeActivity HOT 2
- Using `swish` as `Conv` activation function errors on the GPU HOT 1
- Fast activation error HOT 1
- Definition and implementation of 'Loss' in Linear Regression Tutorial "Julia & Lux for the Uninitiated" HOT 2
- Add improper qualified accesses checks
- `rrule` for `Base.merge` defined in `ChainRulesCore`
- Different activation functions in one layer HOT 1
- Remove Auto-Flattening of Chains
- Add type-stability checks via `DispatchDoctor.jl`
- Rethinking `eltype` conversions in Adaptors
- Support for inactive arguments in DifferentiationInterface HOT 4
- Add simple tests for other accelerators
- Feature request: Bidirectional for RNN layer. HOT 1
- Predefined loss functions HOT 1
- Static Type Parameters not accessible inside `@compact`
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lux.jl.