Giter VIP home page Giter VIP logo

Comments (7)

mcabbott avatar mcabbott commented on May 23, 2024 2

That already exists, roughly:

julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1));

julia> st = Flux.state(model)
(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())),)

julia> Flux.loadmodel!(model, st);  # this is a nested copyto!

julia> using ComponentArrays

julia> ca = ComponentArray(; Flux.state(model)...)
ComponentVector{Tuple{@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}}}(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())))

julia> ca.layers[1].weight .= NaN
1×2 Matrix{Float32}:
 NaN  NaN

julia> Flux.loadmodel!(model, ca)
Chain(
  Dense(2 => 1, tanh),                  # 3 parameters  (some NaN)
  Dense(1 => 1),                        # 2 parameters
)                   # Total: 4 arrays, 5 parameters, 276 bytes.

The caveats are (1) what Flux.state returns includes non-trainable parameters, (2) I've no idea what'll happen to shared parameters, ComponentArrays ignores them, and (3) this is designed for loading from disk not for use within gradients, so Zygote may hate it, but that's fixable. (Edit, (4) my use of ComponentArray does not seem to produce something backed by one big vector, e.g. getfield(ca, :data), maybe I need to read their docs.)

Flux.loadmodel! is for nested structures, we also have Flux.destructure which is about flat vectors of parameters (and should respect points 1,2,3).

Possibly OT here. But perhaps worth opening an issue... perhaps with an example of what you wish would work?

from flux.jl.

CarloLucibello avatar CarloLucibello commented on May 23, 2024 1

Hi @kishore-nori, could you open a new issue and provide a specific example that we can reason on? Your case seems to be well served by destructure, if it's slow we should try to understand why.

from flux.jl.

ToucheSir avatar ToucheSir commented on May 23, 2024

I think we were waiting for a couple more features to land so we could have parity with some of the remaining use cases people might use implicit params for. FluxML/Optimisers.jl#57 is the main one I can think of.

from flux.jl.

darsnack avatar darsnack commented on May 23, 2024

I think that's the only one left

from flux.jl.

kishore-nori avatar kishore-nori commented on May 23, 2024

Would there be an alternative way to perform copy! between a flat vector and a Params like object, or even probably directly into nn (a Flux.Chain), something like copy!(x, nn) and copy!(nn, x)?

Along these lines, I also wanted to ask if Flux.jl would have ComponentArrays used similar to Lux.jl? And would it be optional like Lux.jl with NamedTuple being default for parameters?

from flux.jl.

kishore-nori avatar kishore-nori commented on May 23, 2024

Hi Michael, thanks a lot for the detailed reply (and sorry for the delay in my reply), I wasn't aware of Flux.State. My use case has been to use Flux.jl with Optim.jl which requires a flat vector, so with Flux.Params I could use the existing copy! provided by Zygote.jl (earlier from FluxOptTools.jl) between Flux.Params and flat vector, and this was useful also to convert the gradient into a flat vector for Optim.jl, of course all the usage of copy! was outside Zygote's over-watch.

Now, if I understand correctly, I have to write my own copy! for conversion between Flux.State and flat vector object, and this would be useful also with the object (seems similar to st ) returned by Zygote gradient with the new Flux usage Zygote.gradient(loss, model), which is not very hard, but the problem like you mentioned - "(1) what Flux.state returns includes non-trainable parameters" needs to be tackled (does trainables(model) is intend to solve this issue?).

And with regards to destructure, it makes the whole process more expensive due to a new model created every single epoch, and I have observed this hurts performance, so I have kept it aside.

And with regards to ComponentArrays, I think it works for situations where we have nested NamedTuples, in case of a neural network a layer wise NamedTuple of NamedTuple but Flux.State doesn't return that but a Tuple of NamedTuples, hence the discrepancy observed above, but doesn't seem to be conceptually far away from intended usage.

So for now I can write a copy! between Flux.State and flat vector ignoring the non-trainable parameters, but would be happy to know if trainables(model) and ComponentArrays solutions work! Thanks a lot!

from flux.jl.

kishore-nori avatar kishore-nori commented on May 23, 2024

Sure will come up with a MWE and open an issue, thank you. By the way, I have realized that that idea of destructure! (FluxML/Optimisers.jl#165) would be really beneficial and fit well for my purpose.

from flux.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.