Comments (7)
That already exists, roughly:
julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1));
julia> st = Flux.state(model)
(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())),)
julia> Flux.loadmodel!(model, st); # this is a nested copyto!
julia> using ComponentArrays
julia> ca = ComponentArray(; Flux.state(model)...)
ComponentVector{Tuple{@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}}}(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())))
julia> ca.layers[1].weight .= NaN
1×2 Matrix{Float32}:
NaN NaN
julia> Flux.loadmodel!(model, ca)
Chain(
Dense(2 => 1, tanh), # 3 parameters (some NaN)
Dense(1 => 1), # 2 parameters
) # Total: 4 arrays, 5 parameters, 276 bytes.
The caveats are (1) what Flux.state returns includes non-trainable parameters, (2) I've no idea what'll happen to shared parameters, ComponentArrays ignores them, and (3) this is designed for loading from disk not for use within gradients, so Zygote may hate it, but that's fixable. (Edit, (4) my use of ComponentArray does not seem to produce something backed by one big vector, e.g. getfield(ca, :data)
, maybe I need to read their docs.)
Flux.loadmodel!
is for nested structures, we also have Flux.destructure
which is about flat vectors of parameters (and should respect points 1,2,3).
Possibly OT here. But perhaps worth opening an issue... perhaps with an example of what you wish would work?
from flux.jl.
Hi @kishore-nori, could you open a new issue and provide a specific example that we can reason on? Your case seems to be well served by destructure
, if it's slow we should try to understand why.
from flux.jl.
I think we were waiting for a couple more features to land so we could have parity with some of the remaining use cases people might use implicit params for. FluxML/Optimisers.jl#57 is the main one I can think of.
from flux.jl.
I think that's the only one left
from flux.jl.
Would there be an alternative way to perform copy!
between a flat vector and a Params
like object, or even probably directly into nn
(a Flux.Chain
), something like copy!(x, nn)
and copy!(nn, x)
?
Along these lines, I also wanted to ask if Flux.jl
would have ComponentArrays
used similar to Lux.jl
? And would it be optional like Lux.jl
with NamedTuple
being default for parameters?
from flux.jl.
Hi Michael, thanks a lot for the detailed reply (and sorry for the delay in my reply), I wasn't aware of Flux.State
. My use case has been to use Flux.jl
with Optim.jl
which requires a flat vector, so with Flux.Params
I could use the existing copy!
provided by Zygote.jl
(earlier from FluxOptTools.jl
) between Flux.Params
and flat vector, and this was useful also to convert the gradient into a flat vector for Optim.jl
, of course all the usage of copy!
was outside Zygote's over-watch.
Now, if I understand correctly, I have to write my own copy!
for conversion between Flux.State
and flat vector object, and this would be useful also with the object (seems similar to st
) returned by Zygote gradient with the new Flux usage Zygote.gradient(loss, model)
, which is not very hard, but the problem like you mentioned - "(1) what Flux.state returns includes non-trainable parameters" needs to be tackled (does trainables(model)
is intend to solve this issue?).
And with regards to destructure
, it makes the whole process more expensive due to a new model created every single epoch, and I have observed this hurts performance, so I have kept it aside.
And with regards to ComponentArrays
, I think it works for situations where we have nested NamedTuple
s, in case of a neural network a layer wise NamedTuple
of NamedTuple
but Flux.State
doesn't return that but a Tuple
of NamedTuple
s, hence the discrepancy observed above, but doesn't seem to be conceptually far away from intended usage.
So for now I can write a copy!
between Flux.State
and flat vector ignoring the non-trainable parameters, but would be happy to know if trainables(model)
and ComponentArrays
solutions work! Thanks a lot!
from flux.jl.
Sure will come up with a MWE and open an issue, thank you. By the way, I have realized that that idea of destructure!
(FluxML/Optimisers.jl#165) would be really beneficial and fit well for my purpose.
from flux.jl.
Related Issues (20)
- Dimensions check for `Conv` is incomplete, leading to confusing error HOT 1
- 2x performance regression due to 5e80211c3302b5e7b79b4f670498f5a68af6659b HOT 2
- Why is Flux.destructure type unstable? HOT 3
- bad formatting for PairwiseFusion docstring HOT 1
- Zero-sized arrays cannot be applied to Dense layers. HOT 4
- Adding Simple Recurrent Unit as a recurrent layer
- Collecting PyTorch -> Flux migration notes
- tests are failing due to ComponentArrays HOT 2
- Significant time spent moving medium-size arrays to GPU, type instability HOT 10
- ConvTranspose errors with symmetric non-constant pad
- SamePad() for even sized filters.
- Dense layers with shared parameters HOT 5
- Implementation of `AdamW` differs from PyTorch HOT 10
- `gpu` should warn if cuDNN is not installed HOT 2
- Cannot take `gradient` of L2 regularization loss HOT 1
- Create a flag to use Enzyme as the AD in training/etc. HOT 13
- test Enzyme gradient for loss functions
- test Enzyme gpu support
- Enzyme fails with MultiHeadAttention layer HOT 13
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flux.jl.