fluxml / functors.jl Goto Github PK
View Code? Open in Web Editor NEWParameterise all the things
Home Page: https://fluxml.ai/Functors.jl/stable/
License: MIT License
Parameterise all the things
Home Page: https://fluxml.ai/Functors.jl/stable/
License: MIT License
I was very pleased to discover that this is a thing that's been carved out from Flux, but was slightly surprised by the following performance:
using Functors, BenchmarkTools
using Functors: functor
struct Bar{T}
x::T
end
@functor Bar
bar = Bar(5.0)
julia> @benchmark fmap(Float32, bar)
BenchmarkTools.Trial:
memory estimate: 608 bytes
allocs estimate: 16
--------------
minimum time: 952.304 ns (0.00% GC)
median time: 984.652 ns (0.00% GC)
mean time: 1.040 μs (1.92% GC)
maximum time: 76.549 μs (96.27% GC)
--------------
samples: 10000
evals/sample: 23
Digging down a little, functor
seems to be performant:
julia> @benchmark functor($bar)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 1.455 ns (0.00% GC)
median time: 1.470 ns (0.00% GC)
mean time: 1.589 ns (0.00% GC)
maximum time: 39.906 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
👍
Similarly, isleaf
seems to be fine:
using Functors: isleaf
julia> @benchmark isleaf($bar)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 0.020 ns (0.00% GC)
median time: 0.032 ns (0.00% GC)
mean time: 0.030 ns (0.00% GC)
maximum time: 0.050 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
👍
So there's something else going on in fmap
and fmap1
that I assume has something to do with the IdDict
that's being used. So, I would be interested to know a) what the need for the cache is (it's kind of un-obvious to me) and b) whether there's a way to get rid of all of this overhead as it seems kind of unnecessary in this simple case?
edit: I realised while out on a walk that it's probably something to do with diamonds in the dependency graph for any particular data structure. Is this the case?
I think this is a bug -- functor(typeof(x), y)
should always use x
's field names on y
:
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct AnotherFoo; x; y; end
julia> x = Foo([1, 2], 3);
julia> y = AnotherFoo([4, 5], 6);
julia> z = (x = [7, 8], y = 9);
julia> functor(x)
((x = [1, 2], y = 3), var"#31#32"())
julia> functor(typeof(x), y)
((x = [4, 5], y = 6), var"#31#32"())
julia> functor(typeof(z), y) # this is wrong?
(AnotherFoo([4, 5], 6), Functors.var"#5#6"())
As far as I can tell, we don't use it anywhere. Removal has the bonus of wider compat and fewer potential errors on nightly.
julia> Functors.isleaf((;))
false
I guess we should consider this a bug.
Now that we have docstrings for most of the public interface, it would help to have those rendered somewhere. AFAICT we also aren't running any of the doctests in those docstrings :P
It was noticed in FluxML/Flux.jl#2107 that Functors.jl + ProtoStruct.jl doesn't work, as Functors uses fieldnames
+ getproperty
, which is overloaded by ProtoStruct.jl.
https://github.com/FluxML/Functors.jl/blob/v0.3.0/src/functor.jl#L11-L16
This is a bug, Functors should use getfield
to be consistent. (Possibly the code was written before getproperty
existed?)
It doesn't seem right to hardcode that e.g. AbstractArray{<:Number}
is a leaf via making the functor for this type never recurse into its children. This makes it harder when the user wants to, e.g. apply a function to every scalar leaf without worrying about arrays.
The exclude
function already has the ability to stop higher up in the tree. So perhaps all leaf information should be encoded in the exclude
function: if we want to stop at a node, simply don't call functor
on it in the first place. Then, maybe:
functor
should search for children as aggressively as possible, not stopping e.g. at AbstractArray{<:Number}
@leaf AbstractArray{<:Number}
would directly affect the default isleaf
functionality, which is used in exclude
, so the previous behaviour would apply by default.This issue keeps track of the breaking change proposed in #74 (review) in order to avoid duplicating walks, e.g. defining SomeWalk
and also SomeWalkWithPath
.
See beacon-biosignals/LegolasFlux.jl#4 (comment) and the alternate implementation
function fcollect2(x; output = [], cache = Base.IdSet(), exclude = v -> false)
x in cache && return output
if !exclude(x)
push!(cache, x)
push!(output, x)
foreach(y -> fcollect2(y; cache = cache, output=output, exclude = exclude), Functors.children(x))
end
return output
end
I realized that currently, Functors doesn't support ChainRulesCore.Tangent:
using Functors, ChainRulesCore
x = (a = 2, b = 3)
dx = Tangent{typeof(x)](a = 4, b = 9)
Functors.functor(dx)
results in
((), Functors.var"#1#2"{Tangent{NamedTuple{(:a, :b), Tuple{Int64, Int64}}, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}}(Tangent{NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(a = 4, b = 9)))
Typical use cases of Functors
will usually use re
of x
instead of dx
of course - still, it might be nice to support functor
for tangents.
I've been using Functors
in a case where I want to make sure it's really fast in the special case of small tuples. For a single functor, this is fine, but for functions of multiple functors...
julia> using Functors
julia> @btime Functors.fmap(+, (3,))
0.859 ns (0 allocations: 0 bytes)
(3,)
julia> @btime Functors.fmap(+, (3,), (3,))
404.070 ns (9 allocations: 464 bytes)
(6,)
JET.jl states the following for the second call
═════ 7 possible errors found ═════
┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:3 Functors.:(var"#fmap#134")(tuple(Functors.isleaf, Functors.DefaultWalk(), Functors.IdDict(), Functors.NoKeyword(), #self#, f, x), ys...)
│┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:11 fmap(tuple(_walk, f, x), ys...)
││┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:1 walk(tuple(#132, x), ys...)
│││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:132 ret = walk.walk(tuple(recurse, x), ys...)
││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:92 walk.walk(tuple(recurse, x), ys...)
│││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:62 Functors._map(tuple(recurse, func), yfuncs...)
││││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:1 Functors.map(tuple(f), x...)
│││││││┌ @ tuple.jl:298 f(t[1], s[1])
││││││││┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:1 fmap(tuple(getfield(#self#, :walk), getfield(#self#, :f)), xs...)
│││││││││┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:1 Functors.fmap(::Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, ::typeof(+), ::Int64, ::Int64)
││││││││││ failed to optimize: Functors.fmap(::Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, ::typeof(+), ::Int64, ::Int64)
│││││││││└──────────────────────────────────────────────────
││││││││┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:1 (::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)})(::Int64, ::Int64)
│││││││││ failed to optimize: (::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)})(::Int64, ::Int64)
││││││││└──────────────────────────────────────────────────
│││││││┌ @ tuple.jl:298 map(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││││││ failed to optimize: map(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││││││└────────────────
││││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:1 Functors._map(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││││││ failed to optimize: Functors._map(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││││└───────────────────────────────────────────────────
│││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:59 (::DefaultWalk)(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││││ failed to optimize: (::DefaultWalk)(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││││└────────────────────────────────────────────────────
││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:92 (::ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)})(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││││ failed to optimize: (::ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)})(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││└────────────────────────────────────────────────────
│││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:127 (::Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword})(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││ failed to optimize: (::Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword})(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││└─────────────────────────────────────────────────────
The "failed to optimize" comes from an OptimizationFailureReport
which "will happen when there are (mutually) recursive calls and Julia compiler decided not to do inference in order to make sure the inference's termination". Seemingly, this only crops up when we provide multiple functors.
I'd like to tweak the Functors.jl
code so that this case is optimized, although I'm not really sure where to start in fixing this so would recommend any tips:)
This would make (re)viewing docs changes much easier.
This has been done for a couple of the larger FluxML repos already. Copying and slightly modifying the configuration from those should be sufficient.
The following (unsurprisingly) yields an error:
struct Foo
a::Float64
b::String
end
@functor Foo (a,)
x, re = functor(Foo(5.0, "hi"))
julia> re(x)
ERROR: MethodError: no method matching Foo(::Float64)
Closest candidates are:
Foo(::Float64, ::String) at REPL[34]:2
Foo(::Any, ::Any) at REPL[34]:2
Stacktrace:
[1] (::var"#22#23")(::NamedTuple{(:a,),Tuple{Float64}}) at /Users/willtebbutt/.julia/dev/Functors/src/functor.jl:12
[2] top-level scope at REPL[37]:1
You can imagine this kind of thing cropping up, where only a subset of the fields of an object are considered "parameters", but you do need to have access to the others to make an instance of the object.
Is the intended way to handle this to define my own method of functor
as follows:
function Functors.functor(x::Foo)
function reconstruct_Foo(xs)
return Foo(xs.a, x.b)
end
return (a = x.a,), reconstruct_Foo
end
julia> x, re = functor(Foo(5.0, "hi"))
julia> re(x)
Foo(5.0, "hi")
julia> fmap(Float32, Foo(5.0, "hi"))
Foo{Float32}(5.0f0, "hi")
?
Due to the default fallback
functor(T, x) = (), _ -> x
in Functors.jl, every custom type is considered a leaf (i.e. it has no children) and we have to sprinkle @functor MyType
everywhere in Flux and in user code.
We could remove all this boilerplate by having by default what @functor MyType
currently does. Then 99% of people could live their life completely unaware of @functor/functor
(historically poorly documented and poorly understood) and only use the much clearer trainable(x::MyType)
in case they need to customize the parameter collection.
Besides the transition, which I think could be made rather smooth, does anyone see any counterindication in changing the default?
Suppose I have a type
using Functors
struct Foo{T1,T2}
x::T2
Foo{T1}(x::T2) where {T1,T2} = new{T1,T2}(x)
end
Then using
@functor Foo
will not manage to capture the T1
type:
x = Foo{Real}(2.0)
y, re = Functors.functor(x)
re(y)
ERROR: MethodError: no method matching Foo(::Float64)
Stacktrace:
[1] (::var"#5#6")(y::NamedTuple{(:x,), Tuple{Float64}})
@ Main ~/.julia/packages/Functors/qBIlC/src/functor.jl:23
[2] top-level scope
@ REPL[14]:1
using
@functor Foo{T} where {T}
is not helping more
This issue is used to trigger TagBot; feel free to unsubscribe.
If you haven't already, you should update your TagBot.yml
to include issue comment triggers.
Please see this post on Discourse for instructions and more details.
If you'd like for me to do this for you, comment TagBot fix
on this issue.
I'll open a PR within a few hours, please be patient!
Currently, xs, re = Functors.functor(Dict("a"=>1.0))
returns empty for xs
Support complex models using Dictionary. Related issue: FluxML/Optimisers.jl#114
No response
Given
using Functors
struct Foo <: Function
a::Float64
end
(f::Foo)(x) = f.a * x
@functor Foo
it would IMHO be very useful to support ComposedFunction
, to enable
Functors.functor(Foo(4.2) ∘ Foo(5.3)) == ((outer = (a = 4.2,), inner = (a = 5.3,)), re)
If acceptable, I'd volunteer to implement it.
Update: Fixed example above
In Flux, we have trainable
to designate a subset of leaves as nodes to walk when updating parameters for training. In FluxPrune.jl, I defined pruneable
to designate a subset of leaves for pruning (note that these cannot be the same as the trainable
nodes).
Right now this creates an unfortunate circumstance as discussed in FluxML/Flux.jl#1946. Users need to @functor
their types, remember to define trainable
if necessary. Potentially, to use FluxPrune.jl, they might want to remember to define pruneable
. On the developer side of things, we can use the walk
keyword of fmap
to walk the differently labeled leaf nodes. But this usually requires defining a separate walk function based on the subset that you are hoping to target.
An alternative would be to build this information directly into what @functor
defines. Right now, each child of a functor has a name and a value. I propose adding "tags" which would be a tuple of symbols. Then we could do something like
@functor Conv trainable=(weight, bias) pruneable=(weight,)
Ideally, this mechanism should be dynamic, meaning that if Flux.jl already defines the trainable leaves of a type, then another package like FluxPrune.jl should be able to add a pruneable tag on top of that.
My hope is that we make it easier on users by only having one line for making your type Flux-compatible. And we make it easier on developers by making it easy to filter nodes when walking by tag. I haven't spent a lot of time on the implementation aspect, but I just wanted to float the notion of tags first and get some feedback.
The signature should be extended to
fleaves(f, exclude = Functors.isleaf, walk = Functors.DefaultWalk())
This package probably wants a way to write mapreduce
, to replace e.g. sum(norm(p) for p in params(m))
in Flux. This seems like the minimal attempt, but it's not Zygote-friendly. Can this be fixed, and is there a better way?
julia> using Functors, Zygote
julia> const INIT = Base._InitialValue();
julia> function fmapreduce(f, op, x; init = INIT, walk = (f, x) -> foreach(f, Functors.children(x)), kw...)
fmap(x; walk, kw...) do y
init = init===INIT ? f(y) : op(init, f(y))
end
init===INIT ? Base.mapreduce_empty(f, op) : init
end
fmapreduce (generic function with 1 method)
julia> m = ([1,2], (x=[3,4], y=5), 6);
julia> fmapreduce(sum, +, m)
21
julia> gradient(fmapreduce, sum, +, m)
(nothing, nothing, nothing)
In the signature fmap(walk::AbstractWalk, f, x, ys...)
in the public API, the function f
is ultimately not used. Arguably, since this API does not use f
, it should be deprecated and replaced by something that is not called fmap
(which should be reserved for the special case of an ExcludeWalk
where f
is applied).
However, these changes should be made with caution due to potentially breaking effects. See the body of PR of #61 for some additional info.
Edit: it's possible #32 helps with a lot of this
julia> using Functors
julia> struct C
a
b
c
end
julia> @functor C (c, a)
Tuple{Expr, Expr, Expr}(:(y[1]), :(x.b), :(y[2]))
Vector{Expr}Expr[:(c = x.c), :(a = x.a)]
julia> c = C(1,2,3)
C(1, 2, 3)
julia> Functors.functor(c)
((c = 3, a = 1), var"#5#6"{C}(C(1, 2, 3)))
julia> cc, re = Functors.functor(c)
((c = 3, a = 1), var"#5#6"{C}(C(1, 2, 3)))
julia> re(map(float, cc))
C(3.0, 2, 1.0)
I think it's because makefunctor
only use "y[$(y, += 1)]", ignoring the sequence of fields and the information of NameTuple's keys:
function makefunctor(m::Module, T, fs = fieldnames(T))
yᵢ = 0
escargs = map(fieldnames(T)) do f
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]
@eval m begin
$Functors.functor(::Type{<:$T}, x) = (;$(escfs...)), y -> $T($(escargs...))
end
end
Or this is actually nothing, because it is required to pass in order, but docs is needed.
I get this warning every time:
┌ Warning: Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtyle `AbstractWalk`.
│ caller = fmap(f::Function, x::Chain{Tuple{var"#1#2", Conv{2, 4, typeof(tanh), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 4, typeof(tanh), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, ys::NamedTuple{(:layers,), Tuple{Tuple{Tuple{}, NamedTuple{(:σ, :weight, :bias, :stride, :pad, :dilation, :groups), Tuple{Tuple{}, Int64, Int64, Tuple{Tuple{}, Tuple{}}, NTuple{4, Tuple{}}, Tuple{Tuple{}, Tuple{}}, Tuple{}}}, Tuple{}, NamedTuple{(:σ, :weight, :bias, :stride, :pad, :dilation, :groups), Tuple{Tuple{}, Int64, Int64, Tuple{Tuple{}, Tuple{}}, NTuple{4, Tuple{}}, Tuple{Tuple{}, Tuple{}}, Tuple{}}}, Tuple{}, Tuple{}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}; exclude::Function, walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword) at maps.jl:7
└ @ Functors C:\Users\Hossein Pourbozorg\.julia\packages\Functors\orBYx\src\maps.jl:7
I absolutely love the functionality of destructure
, however I am surprised it is existing in Flux.jl
instead of Functors.jl
since it is not "Neural Network specific".
Maybe it is worth moving the function to Functors
?
This should just work, right? https://discourse.julialang.org/t/data-science-lessons-making-10-neural-networks-run-on-gpu/74592
I am not sure whether this is intended or not, but when trying to specify the trainable parameters in the following code a "no method matching error" for the constructor is returned.
struct MyLayer{R, S, T}
a::R
b::S
c::T
end
Flux.@functor MyLayer (a,b)
m = MyLayer(Dense(1,10,tanh), Dense(1,10,tanh), zeros(5))
Flux.destructure(m) # returns no method matching error for the constructor
In the last sentence of the corresponding section in the documentation it is mentioned that a corresponding constructor has to exist. However, it is not clear to me how this constructor should look like.
I was studying the Flux source code how to create me own layer types.
I came across the macro @functor
whose purpose remained unclear to me. After asking in Slack it was signaled to me that those question comes up a lot and that it should be answered here.
I could imagine the documentation taking these forms:
@functor
in docs and explain when you need it and what for@functor
as opposed to functor being out of context and vague at the end of the page. The tutorial should enable the user to develop a first-class flux layer.A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.