Giter VIP home page Giter VIP logo

lux.jl's People

Contributors

andreuvall avatar arnostrouwen avatar ashwani-rathee avatar avik-pal avatar claforte avatar cossio avatar dependabot[bot] avatar gabrevaya avatar gdalle avatar github-actions[bot] avatar jumerckx avatar karthik-d-k avatar lungd avatar martinuzzifrancesco avatar maximilian-gelbrecht avatar pierre-haessig avatar pnavaro avatar roflmaostc avatar sathvikbhagavan avatar sebastianm-c avatar srikumarks avatar theabhirath avatar touchesir avatar visr avatar vpuri3 avatar wsmoses avatar yichengdwu avatar yng87 avatar zsz00 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lux.jl's Issues

`getindex` for `Chain`

I'm expecting the same feature as Flux

julia> using  Flux

julia>  m = Chain(Dense(3,4),Dense(4,4))
Chain(
  Dense(3 => 4),                        # 16 parameters
  Dense(4 => 4),                        # 20 parameters
)                   # Total: 4 arrays, 36 parameters, 400 bytes.

julia> m[1]
Dense(3 => 4)       # 16 parameters

Available architectures

Hi there, and congrats on what looks like a serious challenger to Flux!
I just saw the release of the model weights and I was wondering where to find predefined architectures, not necessarily with their weights. A quick scan of the docs didn't give me the answer but I might have missed it

Make it easier to pass empty state `st = (;)`

Hi there!
I am writing code with stateless neural networks, and I would like to be able to pass st = (;) (empty NamedTuple) everywhere instead of remembering the state from the Lux.setup funtion.
Judging by the code in layers/basic.jl, it should work with every layer... except Chain because of this line, where st is annotated as a NamedTuple{fields} instead of a NamedTuple. Is there a specific reason for this?

`No method matching` with argument `IRTools.Inner.Undefined` in gradient computation.

This code, when added in with the SimpleRNN example, fails.

s = SpiralClassifier(10, 20, 30)
ps, st = Lux.setup(Random.default_rng(), s)
x = rand(10, 20, 16)

gradient(ps) do ps
    out, st = s(x, ps, st)
    return sum(out)
end

I couldn't find similar issues online but I believe the above code should work?

The issue seems not to stem from this specific example but is more general as I had the same problem with a custom layer.
When the state variable is ignored, there's no error.

gradient(ps) do ps
    out, _ = s(x, ps, st)
    return sum(out)
end

Stacktrace:

ERROR: MethodError: no method matching (::SpiralClassifier{LSTMCell{true, false, false, Tuple{typeof(Lux.zeros32), typeof(Lux.zeros32), typeof(Lux.ones32), typeof(Lux.zeros32)}, NTuple{4, typeof(Lux.glorot_uniform)}, typeof(Lux.zeros32), typeof(Lux.zeros32)}, Dense{true, typeof(sigmoid_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}})(::Array{Float64, 3}, ::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, ::IRTools.Inner.Undefined)
Closest candidates are:
  (::SpiralClassifier)(::AbstractArray{T, 3}, ::NamedTuple, ::NamedTuple) where T at ~/transformer/Lux.jl/examples/SimpleRNN/main.jl:60
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context, ::SpiralClassifier{LSTMCell{true, false, false, Tuple{typeof(Lux.zeros32), typeof(Lux.zeros32), typeof(Lux.ones32), typeof(Lux.zeros32)}, NTuple{4, typeof(Lux.glorot_uniform)}, typeof(Lux.zeros32), typeof(Lux.zeros32)}, Dense{true, typeof(sigmoid_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}, ::Array{Float64, 3}, ::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, ::IRTools.Inner.Undefined)
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:9
 [3] _pullback
   @ ~/transformer/Lux.jl/examples/SimpleRNN/main.jl:149 [inlined]
 [4] _pullback(ctx::Zygote.Context, f::var"#38#39", args::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [5] _pullback(f::Function, args::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:34
 [6] pullback(f::Function, args::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:40
 [7] gradient(f::Function, args::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:75
 [8] top-level scope
   @ ~/transformer/Lux.jl/examples/SimpleRNN/main.jl:148

RNN and LSTM break when using GPU

Bellow you can find a MWE with RNNCell. It is the same for LSTMCell.

using Lux, Random, CUDA

rnn = RNNCell(2 => 8)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, rnn) .|> gpu
x = rand(Float32, 2, 4, 10) |> gpu
rnn(view(x, :, 1, :), ps, st)
ERROR: ArgumentError: cannot take the CPU address of a CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
Stacktrace:
  [1] unsafe_convert(#unused#::Type{Ptr{Float32}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ CUDA ~/.julia/packages/CUDA/DfvRa/src/array.jl:319
  [2] gemm!(transA::Char, transB::Char, alpha::Float32, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::Matrix{Float32}, beta::Float32, C::Matrix{Float32})
    @ LinearAlgebra.BLAS /network/scratch/a/abrevayg/julia-1.8.0-rc3/share/julia/stdlib/v1.8/LinearAlgebra/src/blas.jl:1514
  [3] gemm_wrapper!(C::Matrix{Float32}, tA::Char, tB::Char, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::Matrix{Float32}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /network/scratch/a/abrevayg/julia-1.8.0-rc3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:674
  [4] mul!
    @ /network/scratch/a/abrevayg/julia-1.8.0-rc3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:161 [inlined]
  [5] mul!
    @ /network/scratch/a/abrevayg/julia-1.8.0-rc3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
  [6] *
    @ /network/scratch/a/abrevayg/julia-1.8.0-rc3/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:148 [inlined]
  [7] (::RNNCell{true, typeof(tanh), typeof(Lux.zeros32), typeof(Lux.glorot_uniform), typeof(Lux.ones32)})(::Tuple{SubArray{Float32, 2, CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}, Matrix{Float32}}, ps::NamedTuple{(:weight_ih, :weight_hh, :bias), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, st::NamedTuple{(:rng,), Tuple{Xoshiro}})
    @ Lux ~/.julia/packages/Lux/lEqCI/src/layers/recurrent.jl:81
  [8] (::RNNCell{true, typeof(tanh), typeof(Lux.zeros32), typeof(Lux.glorot_uniform), typeof(Lux.ones32)})(x::SubArray{Float32, 2, CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}, ps::NamedTuple{(:weight_ih, :weight_hh, :bias), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, st::NamedTuple{(:rng,), Tuple{Xoshiro}})
    @ Lux ~/.julia/packages/Lux/lEqCI/src/layers/recurrent.jl:76
  [9] top-level scope
    @ REPL[9]:1
 [10] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
(rnn_gpu_issue) pkg> st
Status `/network/scratch/a/abrevayg/rnn_gpu_issue/Project.toml`
  [052768ef] CUDA v3.12.0
  [b2108857] Lux v0.4.9
julia> VERSION
v"1.8.0-rc3"

Train examples/NeuralODE error

Train examples/NeuralODE error:

(base) xx@VM-1-8-ubuntu:~/codes/julia_learn$ julia --project=cv/lux/Project.toml cv/lux/mnist/test_2.jl
ERROR: LoadError: ArgumentError: tuple must be non-empty
Stacktrace:
  [1] first(#unused#::Tuple{})
    @ Base ./tuple.jl:140
  [2] _unapply(t::Nothing, xs::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:176
  [3] _unapply(t::Tuple{Nothing}, xs::Tuple{}) (repeats 2 times)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:180
  [4] _unapply(t::Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, xs::Tuple{Nothing, Nothing, Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:181
  [5] unapply(t::Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, xs::Tuple{Nothing, Nothing, Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:190
  [6] (::Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#37"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#177"{Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol), Tuple{Bool, Float32, Float32}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))}}}, Tuple{}, Colon, NamedTuple{(:save_everystep, :reltol, :abstol), Tuple{Bool, Float32, Float32}}}}})(Δ::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:208
  [7] (::Zygote.var"#1750#back#210"{Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#37"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#177"{Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:save_everystep, :reltol, :abstol), Tuple{Bool, Float32, Float32}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))}}}, Tuple{}, Colon, NamedTuple{(:save_everystep, :reltol, :abstol), Tuple{Bool, Float32, Float32}}}}}})(Δ::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [8] Pullback
    @ ~/.julia/packages/DiffEqBase/S7V8q/src/solve.jl:234 [inlined]
  [9] (::typeof(∂(#solve#40)))(Δ::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#208#209"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))})(Δ::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:207
 [11] (::Zygote.var"#1750#back#210"{Zygote.var"#208#209"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#40))}})(Δ::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [12] Pullback
    @ ~/.julia/packages/DiffEqBase/S7V8q/src/solve.jl:228 [inlined]
 [13] (::typeof(∂(solve##kw)))(Δ::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/codes/julia_learn/cv/lux/mnist/test_2.jl:56 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Tuple{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [16] macro expansion
    @ ~/.julia/packages/Lux/27p0k/src/layers/basic.jl:0 [inlined]
 [17] Pullback
    @ ~/.julia/packages/Lux/27p0k/src/layers/basic.jl:328 [inlined]
 [18] (::typeof(∂(applychain)))(Δ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [19] Pullback
    @ ~/.julia/packages/Lux/27p0k/src/layers/basic.jl:326 [inlined]
 [20] (::typeof(∂(λ)))(Δ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [21] Pullback
    @ ~/codes/julia_learn/cv/lux/mnist/test_2.jl:93 [inlined]
 [22] Pullback
    @ ~/codes/julia_learn/cv/lux/mnist/test_2.jl:123 [inlined]
 [23] (::typeof(∂(λ)))(Δ::Tuple{Float32, Nothing})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#52#53"{typeof(∂(λ))})(Δ::Tuple{Float32, Nothing})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
 [25] train()
    @ Main ~/codes/julia_learn/cv/lux/mnist/test_2.jl:124
 [26] top-level scope
    @ ~/codes/julia_learn/cv/lux/mnist/test_2.jl:149
in expression starting at /home/zhangyong/codes/julia_learn/cv/lux/mnist/test_2.jl:149

julia> versioninfo()
Julia Version 1.8.0-beta3
Commit 3e092a2521 (2022-03-29 15:42 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: 96 × Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, cascadelake)
  Threads: 96 on 96 virtual cores

(lux) pkg> st
Status `~/codes/julia_learn/cv/lux/Project.toml`
  [c7e460c6] ArgParse v1.1.4
  [02898b10] Augmentor v0.6.5
  [052768ef] CUDA v3.10.0
  [b0b7db55] ComponentArrays v0.11.15
  [2e981812] DataLoaders v0.1.3
⌃ [41bf760c] DiffEqSensitivity v6.49.1
  [587475ba] Flux v0.13.1
  [acf642fa] FluxMPI v0.4.2
  [59287772] Formatting v0.4.2
  [d9f16b24] Functors v0.2.8
  [6218d12a] ImageMagick v1.2.2
⌅ [916415d5] Images v0.24.1
  [b835a17e] JpegTurbo v0.1.1
  [b2108857] Lux v0.4.0 `https://github.com/avik-pal/Lux.jl.git#main`
  [cc2ba9b6] MLDataUtils v0.5.4
⌅ [eb30cadb] MLDatasets v0.5.16
  [f1d291b0] MLUtils v0.2.5
  [da04e1cc] MPI v0.19.2
  [dbeba491] Metalhead v0.7.1
  [872c559c] NNlib v0.8.5
  [3bd65402] Optimisers v0.2.4
  [1dea7af3] OrdinaryDiffEq v6.10.0
  [d7d3b36b] ParameterSchedulers v0.3.3
  [efcf1570] Setfield v0.8.2
  [e88e6eb3] Zygote v0.6.40
  [de0858da] Printf
  [9a3f8284] Random
  [10745b16] Statistics

Remove Requires.jl

  • Remove Requires.jl dependency with Flux.jl (in v0.5 this goes away)
  • Introduce a new package Flux2Lux.jl and store the transformation there.

CUDNNError during backpropagation in simple CNN

Hello, I get CUDNNError during backpropagation on simple CNN on CUDA. model apply (forward pass) works.

Lux v0.4.14
Julia v 1.7.3

code

using Lux, NNlib, Optimisers, Plots, Random, Statistics, Zygote,CUDA

#variables
dim_x,dim_y,dim_z = 34, 34, 34
rng = Random.default_rng()
#setting up convolutions
conv1 = (in, out) -> Lux.Conv((3,3,3),  in => out , NNlib.tanh, stride=1, dilation=0)
function getConvModel()
    return Lux.Chain(conv1(1,4),conv1(4,8),conv1(8,4),conv1(4,2),conv1(2,1))
end#getConvModel
#defining model, states, parameters,Optimisers
model = getConvModel()
ps, st = Lux.setup(rng, model)
opt = Optimisers.Adam()
#loss
function loss_function(model, ps, st, x)
    y_pred, st = Lux.apply(model, x, ps, st)
    return -1*(sum(y_pred)), st, ()
end
#Lux objects
tstate = Lux.Training.TrainState(rng, model, opt; transform_variables=Lux.gpu)
vjp_rule = Lux.Training.ZygoteVJP()
#main iteration loop
function main(tstate::Lux.Training.TrainState, vjp::Lux.Training.AbstractVJP, data,
    epochs::Int)
   # data = data .|> Lux.gpu
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp, loss_function,
                                                                data, tstate)
        @info epoch=epoch loss=loss
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate
end
# dummy data
x = randn(rng, Float32, dim_x,dim_y,dim_z)
x =reshape(x, (dim_x,dim_y,dim_z,1,1))
#execute
#works
y_pred, st =Lux.apply(model, x, ps, st) 
#breaks during backpropagation
tstate = main(tstate, vjp_rule, CuArray(x),1)

I get error

CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)

Full error

ERROR: CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
Stacktrace:
  [1] throw_api_error(res::CUDA.CUDNN.cudnnStatus_t)
    @ CUDA.CUDNN ~/.julia/packages/CUDA/DfvRa/lib/cudnn/error.jl:22
  [2] macro expansion
    @ ~/.julia/packages/CUDA/DfvRa/lib/cudnn/error.jl:35 [inlined]
  [3] cudnnSetConvolutionNdDescriptor(convDesc::Ptr{Nothing}, arrayLength::Int32, padA::Vector{Int32}, filterStrideA::Vector{Int32}, dilationA::Vector{Int32}, mode::CUDA.CUDNN.cudnnConvolutionMode_t, computeType::CUDA.CUDNN.cudnnDataType_t)
    @ CUDA.CUDNN ~/.julia/packages/CUDA/DfvRa/lib/utils/call.jl:26
  [4] cudnnSetConvolutionDescriptor(ptr::Ptr{Nothing}, padding::Vector{Int32}, stride::Vector{Int32}, dilation::Vector{Int32}, mode::CUDA.CUDNN.cudnnConvolutionMode_t, dataType::CUDA.CUDNN.cudnnDataType_t, mathType::CUDA.CUDNN.cudnnMathType_t, reorderType::CUDA.CUDNN.cudnnReorderType_t, groupCount::Int32)
    @ CUDA.CUDNN ~/.julia/packages/CUDA/DfvRa/lib/cudnn/convolution.jl:135
  [5] CUDA.CUDNN.cudnnConvolutionDescriptor(::Vector{Int32}, ::Vararg{Any})
    @ CUDA.CUDNN ~/.julia/packages/CUDA/DfvRa/lib/cudnn/descriptors.jl:39
  [6] CUDA.CUDNN.cudnnConvolutionDescriptor(cdims::DenseConvDims{3, 3, 3, 6, 3}, x::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, pad::Tuple{Int64, Int64, Int64})
    @ NNlibCUDA ~/.julia/packages/NNlibCUDA/kCpTE/src/cudnn/conv.jl:48
  [7] CUDA.CUDNN.cudnnConvolutionDescriptor(cdims::DenseConvDims{3, 3, 3, 6, 3}, x::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer})
    @ NNlibCUDA ~/.julia/packages/NNlibCUDA/kCpTE/src/cudnn/conv.jl:47
  [8] cudnnConvolutionDescriptorAndPaddedInput
    @ ~/.julia/packages/NNlibCUDA/kCpTE/src/cudnn/conv.jl:19 [inlined]
  [9] conv!(y::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, x::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, w::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; alpha::Int64, beta::Int64, algo::Int64)
    @ NNlibCUDA ~/.julia/packages/NNlibCUDA/kCpTE/src/cudnn/conv.jl:66
 [10] conv!
    @ ~/.julia/packages/NNlibCUDA/kCpTE/src/cudnn/conv.jl:60 [inlined]
 [11] #conv#196
    @ ~/.julia/packages/NNlib/0QnJJ/src/conv.jl:88 [inlined]
 [12] conv
    @ ~/.julia/packages/NNlib/0QnJJ/src/conv.jl:86 [inlined]
 [13] #rrule#312
    @ ~/.julia/packages/NNlib/0QnJJ/src/conv.jl:313 [inlined]
 [14] rrule
    @ ~/.julia/packages/NNlib/0QnJJ/src/conv.jl:304 [inlined]
 [15] rrule
    @ ~/.julia/packages/ChainRulesCore/ctmSK/src/rules.jl:134 [inlined]
 [16] chain_rrule
    @ ~/.julia/packages/Zygote/DRjAT/src/compiler/chainrules.jl:218 [inlined]
 [17] macro expansion
    @ ~/.julia/packages/Zygote/DRjAT/src/compiler/interface2.jl:0 [inlined]
 [18] _pullback
    @ ~/.julia/packages/Zygote/DRjAT/src/compiler/interface2.jl:9 [inlined]
 [19] _pullback
    @ ~/.julia/packages/Lux/wsZ6r/src/nnlib.jl:102 [inlined]
 [20] _pullback
    @ ~/.julia/packages/Lux/wsZ6r/src/layers/conv.jl:133 [inlined]
 [21] macro expansion
    @ ~/.julia/packages/Lux/wsZ6r/src/layers/basic.jl:0 [inlined]
 [22] _pullback
    @ ~/.julia/packages/Lux/wsZ6r/src/layers/basic.jl:507 [inlined]
 [23] _pullback(::Zygote.Context{false}, ::typeof(Lux.applychain), ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, Conv{3, true, 6, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}, ::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(), Tuple{}}}})
    @ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface2.jl:0
 [24] _pullback
    @ ~/.julia/packages/Lux/wsZ6r/src/layers/basic.jl:504 [inlined]
 [25] _pullback
    @ ~/.julia/packages/Lux/wsZ6r/src/core.jl:87 [inlined]
 [26] _pullback
    @ /media/jakub/NewVolume/projects/superVoxelJuliaCode/fullWithLoss/testClusterError.jl:18 [inlined]
 [27] _pullback(::Zygote.Context{false}, ::typeof(loss_function), ::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, Conv{3, true, 6, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(), Tuple{}}}}, ::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface2.jl:0
 [28] _pullback
    @ ~/.julia/packages/Lux/wsZ6r/src/contrib/training.jl:129 [inlined]
 [29] _pullback(ctx::Zygote.Context{false}, f::Lux.Training.var"#2#3"{typeof(loss_function), CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}}, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}}}}})
    @ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface2.jl:0
 [30] pullback(f::Function, cx::Zygote.Context{false}, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}}}}})
    @ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface.jl:44
 [31] pullback(f::Function, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}}}}})
    @ Zygote ~/.julia/packages/Zygote/DRjAT/src/compiler/interface.jl:42
 [32] compute_gradients(#unused#::Lux.Training.ZygoteVJP, objective_function::typeof(loss_function), data::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, ts::Lux.Training.TrainState{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(), Tuple{}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{Optimisers.Leaf{Adam{Float32}, Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, Optimisers.Leaf{Adam{Float32}, Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}}}}, Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, Conv{3, true, 6, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}})
    @ Lux.Training ~/.julia/packages/Lux/wsZ6r/src/contrib/training.jl:129
 [33] main(tstate::Lux.Training.TrainState{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(), Tuple{}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, NamedTuple{(:weight, :bias), Tuple{Optimisers.Leaf{Adam{Float32}, Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, Optimisers.Leaf{Adam{Float32}, Tuple{CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}}}}, Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), NTuple{5, Conv{3, true, 6, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}}, vjp::Lux.Training.ZygoteVJP, data::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, epochs::Int64)
    @ Main /media/jakub/NewVolume/projects/superVoxelJuliaCode/fullWithLoss/testClusterError.jl:29

Support for multidimensional data?

Hello I have multidimensional (3 and 4 dims) data and flattening them is not an option but layers like dense seem to accept only 1 or 2D input. Am I correct or I am missing something?

Thanks For response !

Remove `ActivationFunction`?

The original idea was to perform some kinds of primitive layer fusion, but that is not implemented currently. Instead, it has led to a quite a few unexpected issues like SciML/DiffEqFlux.jl#737. I will be deprecating ActivationFunction in v0.4.* and write a fallback to WrappedFunction. In v0.5 I will remove it, unless someone has strong opposition to this.

Deprecated in #80

Performance regressions with ComponentArrays

using Lux, ComponentArrays, ReverseDiff, Random, Zygote

c = Chain(Dense(3, 128), Dense(128, 1024))

x = randn(Float32, 3, 1)
ps, st = Lux.setup(Random.default_rng(), c)

ps_c = ps |> Lux.ComponentArray

@benchmark ReverseDiff.gradient(ps -> sum(first(Lux.apply(c, x, ps, st))), ps_c)

#=
BenchmarkTools.Trial: 5530 samples with 1 evaluation.
 Range (min … max):  611.340 μs …   8.200 ms  ┊ GC (min … max): 0.00% … 84.66%
 Time  (median):     762.516 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   887.038 μs ± 585.500 μs  ┊ GC (mean ± σ):  7.27% ±  9.95%

  ▇██▆▅▄▃▃▃▂▂▁▁                                                 ▂
  ████████████████▇▇▆▆▄▃▅▁▃▃▁▁▁▁▁▁▃▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▅▅▄▅▅▅ █
  611 μs        Histogram: log(frequency) by time       4.68 ms <

 Memory estimate: 2.04 MiB, allocs estimate: 43.
=#

@benchmark Zygote.gradient(ps -> sum(first(Lux.apply(c, x, ps, st))), ps_c)

#=
BenchmarkTools.Trial: 3598 samples with 1 evaluation.
 Range (min … max):  907.267 μs …   9.921 ms  ┊ GC (min … max): 0.00% … 85.00%
 Time  (median):       1.275 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.371 ms ± 873.184 μs  ┊ GC (mean ± σ):  9.51% ± 12.47%

  ▇█▅▇▇▅▄▂▁▁                                                    ▁
  █████████████▇▅▆▄▁▁▃▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▅█▇▄▃▅▄▃▄▁▁▃▃▁▁▁▁▃▁▄▆▄▄ █
  907 μs        Histogram: log(frequency) by time       7.11 ms <

 Memory estimate: 3.61 MiB, allocs estimate: 192.
=#

@benchmark Zygote.gradient(ps -> sum(first(Lux.apply(c, x, ps, st))), ps)

#= 
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  344.190 μs …   7.707 ms  ┊ GC (min … max): 0.00% … 89.44%
 Time  (median):     380.648 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   436.060 μs ± 353.445 μs  ┊ GC (mean ± σ):  4.65% ±  5.52%

   ██▃▂▁                                                         
  ▃█████▇▆▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  344 μs           Histogram: frequency by time          774 μs <

 Memory estimate: 571.28 KiB, allocs estimate: 84.
=#

I am assuming the major regression comes from re-constructing the ComponentArray in the Backward Pass (which is expected to be slow)

Improve `Julia & Lux for the uninitiated`

Hi, congrats on a very interesting package, I look forward to trying it out! I'm going through the docs and noticed some typos. I also recommend small potential improvements. I couldn't easily identify the original files to do a PR, so here they are:

In http://lux.csail.mit.edu/dev/examples/generated/beginner/Basics/main/:

  • we don't enfore it -> we don't enforce it
  • We relu on the Julia StdLib -> We rely on the Julia StdLib
  • we create an PRNG and seed it -> we create a PRNG (pseudorandom number generator) and seed (initialize) it
  • we should use Lux.replicate on PRNG before using them -> we should use Lux.replicate on PRNGs before using them
  • provides an uniform API -> provides a uniform API
  • Note that AD.gradient will only work for scalar valued outputs -> Note that AD.gradient will only work for scalar valued outputs. (period at the end.)
  • to demonstrate Lux let us use the Dense layer. -> to demonstrate Lux, let's use the Dense layer. (Equivalent to Pytorch's nn.Linear)

In the same page, I recommend adding a line to make the following a bit more "user-friendly", e.g. for Pytorch users curious about Julia+Lux:

  • ∇f(x) = x:
    • add underneath: "∇" can be typed by \del<tab> in the Julia REPL or in a Julia-compatible editor. You can press ? in the REPL to enter Julia *help* mode, and, then paste the ∇, to find out how to type any unicode character in Julia.
  • For updating our parameters let's use [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) -> To update our parameters, let's use from [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) an SGD (Stochastic Gradient Descent) with learning rate set to 0.01:
  • Initialize the initial state of the optimiser -> Setup the initial state of the optimiser:
  • Define the loss function -> Define the loss function:
  • println("Loss Value with ground true W & b: ", mse(W, b, x_samples, y_samples)) -> println("Loss value evaluated with true parameters (weights and biases): ", mse(W, b, x_samples, y_samples))
  • # Perform parameter update -> # Update model's parameters:

IMHO the Jacobian-Vector Product and the Vector-Jacobian Product sections are technical details that's unlikely to be of interest to most people first looking at the docs... I recommend moving those section at the bottom of that page, or at least prefacing it with a "side-note: " so people can skip it.

Can one compose lux layers with graph neural network

Can one compose lux layers with graph neural network.

For flux there is geometric flux, but composing lux and flux is sth to be avoided

Or alternatively can one build a model in lux and incorporate it as a part of flux chain , to compose it with flux geometric

Scalar indexing problem for the NeuralODE example

Hi, firstly, thank you very much for this great package with super complete and didactical documentation! :)

While going through the documentation I realized that the NeuralODE example is not working properly on GPU. It throws the scalar indexing error and I think it is because of having the parameters as a ComponentArray, but I don't know how to fix it.

Error log
ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/rSIl2/src/GPUArraysCore.jl:78
  [3] getindex(xs::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/gok9K/src/host/indexing.jl:9
  [4] setindex!
    @ ./array.jl:979 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/ComponentArrays/NEqmD/src/array_interface.jl:0 [inlined]
  [6] _setindex!(x::ComponentVector{Float32}, v::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, idx::Val{:bias})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/NEqmD/src/array_interface.jl:129
  [7] setproperty!
    @ ~/.julia/packages/ComponentArrays/NEqmD/src/namedtuple_interface.jl:17 [inlined]
  [8] (::ComponentArrays.var"#getproperty_adjoint#88"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))}}}, Symbol})(Δ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ ComponentArrays ~/.julia/packages/ComponentArrays/NEqmD/src/compat/chainrulescore.jl:4
  [9] ZBack
    @ ~/.julia/packages/Zygote/IoW2g/src/compiler/chainrules.jl:205 [inlined]
 [10] Pullback
    @ ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:639 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:0 [inlined]
 [12] Pullback
    @ ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:507 [inlined]
 [13] (::typeof((applychain)))(Δ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:504 [inlined]
 [15] (::typeof((λ)))(Δ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [16] Pullback
    @ /network/scratch/a/abrevayg/.julia/packages/Lux/SApdg/examples/NeuralODE/main.jl:103 [inlined]
 [17] Pullback
    @ /network/scratch/a/abrevayg/.julia/packages/Lux/SApdg/examples/NeuralODE/main.jl:134 [inlined]
 [18] (::typeof((λ)))(Δ::Tuple{Float32, Nothing})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#60#61"{typeof((λ))})(Δ::Tuple{Float32, Nothing})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:41
 [20] train()
    @ Main /network/scratch/a/abrevayg/.julia/packages/Lux/SApdg/examples/NeuralODE/main.jl:135
 [21] top-level scope
    @ /network/scratch/a/abrevayg/.julia/packages/Lux/SApdg/examples/NeuralODE/main.jl:155
 [22] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [23] top-level scope
    @ REPL[6]:1
 [24] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
in expression starting at /network/scratch/a/abrevayg/.julia/packages/Lux/SApdg/examples/NeuralODE/main.jl:155
(examples) pkg> st
Status `~/.julia/packages/Lux/lEqCI/examples/Project.toml`
  [c29ec348] AbstractDifferentiation v0.4.3
  [c7e460c6] ArgParse v1.1.4
  [02898b10] Augmentor v0.6.6
  [052768ef] CUDA v3.12.0
⌅ [b0b7db55] ComponentArrays v0.11.17
  [2e981812] DataLoaders v0.1.3
  [41bf760c] DiffEqSensitivity v6.79.0
  [587475ba] Flux v0.13.4
⌅ [acf642fa] FluxMPI v0.5.3
  [59287772] Formatting v0.4.2
  [f6369f11] ForwardDiff v0.10.30
⌅ [d9f16b24] Functors v0.2.8
  [6218d12a] ImageMagick v1.2.2
⌃ [916415d5] Images v0.24.1
  [b835a17e] JpegTurbo v0.1.1
  [b2108857] Lux v0.4.9
  [cc2ba9b6] MLDataUtils v0.5.4
  [eb30cadb] MLDatasets v0.7.4
  [f1d291b0] MLUtils v0.2.9
  [dbeba491] Metalhead v0.7.3
  [872c559c] NNlib v0.8.8
  [3bd65402] Optimisers v0.2.7
  [1dea7af3] OrdinaryDiffEq v6.18.2
  [d7d3b36b] ParameterSchedulers v0.3.3
  [91a5bcdd] Plots v1.31.3
  [37e2e3b7] ReverseDiff v1.14.1
⌅ [efcf1570] Setfield v0.8.2
  [fce5fe82] Turing v0.21.9
  [e88e6eb3] Zygote v0.6.41
  [de0858da] Printf
  [9a3f8284] Random
  [10745b16] Statistics
Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ cannot be upgraded. To see why use `status --outdated`
julia> VERSION
v"1.8.0-rc3"

Immutable Arrays

Testing out the Immutable Arrays from JuliaLang/julia#44381 with #7

TLDR: Performance is a slight pain (seems broadcasting) right now, but it is very straightforward to support these once the functionality is available in Base

EDIT: Code updated to work for Lux 0.4.*

Trial 1: From the Usage Example

using Lux, Random, Functors

make_immutable(x::AbstractArray) = ImmutableArray(copy(x))
make_immutable(x) = x

# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
                        Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps)
st_immutable = fmap(make_immutable, st)

# Dummy Input
x = randn(Float32, 128, 1024)
x_immutable = make_immutable(x)

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 1296 samples with 1 evaluation.
 Range (min … max):  2.125 ms … 26.658 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     3.096 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.836 ms ±  2.313 ms  ┊ GC (mean ± σ):  2.58% ± 7.71%

    ▂█                                                        
  ▆▄██▇▆▄▄▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▂▂▂▁▁▂▂▂▂▂▁▂▂▂▂▁▂▂▂ ▃
  2.13 ms        Histogram: frequency by time        14.1 ms <

 Memory estimate: 3.60 MiB, allocs estimate: 144.

Immutable Arrays

BenchmarkTools.Trial: 41 samples with 1 evaluation.
 Range (min … max):  107.855 ms … 159.665 ms  ┊ GC (min … max): 3.98%2.64%
 Time  (median):     119.911 ms               ┊ GC (median):    3.54%
 Time  (mean ± σ):   123.706 ms ±  10.746 ms  ┊ GC (mean ± σ):  3.54% ± 0.67%

              ▂█▄                                                
  ▄▁▁▁▁▁▁▁▄▆▄█████▄▁▄▆▄▆▁▁▄▁▁▄▁▁▁▁▁▁▄▁▁▁▁▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
  108 ms           Histogram: frequency by time          160 ms <

 Memory estimate: 58.32 MiB, allocs estimate: 3418558.

Trial 2: Only a Dense Layer

# Construct the layer
model = Dense(128, 256)

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Dummy Input
x = randn(Float32, 128, 1024);
x_immutable = make_immutable(x);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 4469 samples with 1 evaluation.
 Range (min … max):  483.810 μs … 30.894 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     716.669 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.100 ms ±  1.501 ms  ┊ GC (mean ± σ):  5.01% ± 12.19%

  █▆▆▅▄▃▂▂▂▂▃▃▃▂▁                                              ▁
  █████████████████▇▇▇▆▇▆▅▅▃▃▄▅▅▄▃▅▁▁▆▄▅▁▃▃▃▃▅▁▃▃▃▃▁▃▁▁▃▁▁▁▁▃▅ █
  484 μs        Histogram: log(frequency) by time      7.69 ms <

 Memory estimate: 2.00 MiB, allocs estimate: 4.

Immutable Arrays

BenchmarkTools.Trial: 259 samples with 1 evaluation.
 Range (min … max):  15.392 ms … 52.229 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     17.997 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   19.327 ms ±  4.194 ms  ┊ GC (mean ± σ):  1.72% ± 4.44%

    ▃▆█ ▂                                                      
  ▃▆███▆█▇▅▇▇▄▆▃▆▄▄▅▄▄▄▄▄▃▄▄▃▂▁▃▃▂▁▃▂▁▁▂▂▁▂▂▂▁▃▁▃▂▂▁▁▁▂▂▁▂▂▁▂ ▃
  15.4 ms         Histogram: frequency by time        32.6 ms <

 Memory estimate: 7.00 MiB, allocs estimate: 262153.

Seems like there is a lot of time being spent on broadcasting the bias (seems like a problem with broadcasting in general)

julia> @benchmark $ps_immutable.weight * $x_immutable
BenchmarkTools.Trial: 4032 samples with 1 evaluation.
 Range (min … max):  346.287 μs … 51.079 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     540.489 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.224 ms ±  1.854 ms  ┊ GC (mean ± σ):  2.36% ± 8.18%

  █▆▄▄▃▁▁▁ ▂▂▁▁▁▂▂▁▁  ▁▁                                       ▁
  █████████████████████████▇▇▇▆▇▆▇▆▆▃▆▆▆▅▅▅▅▄▅▅▅▆▅▅▅▅▅▅▄▃▁▁▁▃▃ █
  346 μs        Histogram: log(frequency) by time      8.78 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 5.

julia> @benchmark $ps_immutable.weight * $x_immutable .+ $ps_immutable.bias
BenchmarkTools.Trial: 338 samples with 1 evaluation.
 Range (min … max):  11.177 ms … 33.105 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     13.699 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.792 ms ±  3.901 ms  ┊ GC (mean ± σ):  2.43% ± 5.87%

   █▃                                                          
  ▅██▇▇▅▅▇▅▇▅▅▄▅▅▄▃▃▄▄▃▂▃▃▁▂▃▁▃▂▃▃▃▁▃▂▂▂▁▃▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▃▂▁▂ ▃
  11.2 ms         Histogram: frequency by time        30.9 ms <

 Memory estimate: 7.00 MiB, allocs estimate: 262153.

Trial 3: No broadcasting

model = Dense(128, 256; bias=false)

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 5501 samples with 1 evaluation.
 Range (min … max):  295.161 μs … 23.801 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     451.402 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   899.925 μs ±  1.386 ms  ┊ GC (mean ± σ):  3.10% ± 8.68%

  █▆▆▄▃▂▁▂▁▁▁▂▂▂▂▁ ▁                                           ▁
  ██████████████████▇█▇█▇▇▆▆▇▇▆▆▆▆▆▆▅▅▅▆▅▅▁▆▄▆▅▃▅▄▅▄▆▄▅▁▄▆▅▅▃▅ █
  295 μs        Histogram: log(frequency) by time      6.98 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 2.

Immutable Arrays

BenchmarkTools.Trial: 5303 samples with 1 evaluation.
 Range (min … max):  311.574 μs … 26.953 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     436.316 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   930.509 μs ±  1.488 ms  ┊ GC (mean ± σ):  3.23% ± 8.75%

  █▆▅▃▂▁   ▁▁▂▁▁                                               ▁
  █████████████████▆█▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▅▅▅▅▂▅▂▄▅▄▅▄▄▃▂▃▄▄▂▃▂▃ █
  312 μs        Histogram: log(frequency) by time      7.61 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 5.

Trial 4

model = Chain(Dense(128, 256; bias=false), Chain(Dense(256, 512; bias=false),
                                                                                   Dense(512, 10; bias=false)))

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 1372 samples with 1 evaluation.
 Range (min … max):  1.380 ms … 49.871 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     2.918 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.615 ms ±  3.116 ms  ┊ GC (mean ± σ):  2.42% ± 7.94%

  ▅█    ▃                                                     
  ███▇▆▇██▇▆▅▄▄▄▃▃▃▃▂▃▃▃▂▃▂▃▂▂▂▂▁▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂ ▃
  1.38 ms        Histogram: frequency by time        15.8 ms <

 Memory estimate: 3.04 MiB, allocs estimate: 6.

Immutable Arrays

BenchmarkTools.Trial: 894 samples with 1 evaluation.
 Range (min … max):  1.505 ms … 66.104 ms  ┊ GC (min … max): 0.00%0.00%
 Time  (median):     4.153 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.561 ms ±  5.432 ms  ┊ GC (mean ± σ):  1.87% ± 7.54%

  █▆▅▅▅▄▅▆▆▅▄▄▂▂▂▂▁     ▁  ▁     ▁                            
  █████████████████▇█▆███▆▇█▅▆▇███▆▇█▄▇▇▇▅▄▆▅▅▁▄▁▆▄▁▅▇▅▄▄▆▁▅ █
  1.5 ms       Histogram: log(frequency) by time     23.1 ms <

 Memory estimate: 3.04 MiB, allocs estimate: 17.

cc @ChrisRackauckas @ianatol @aviatesk

Basic example from Migrating from Flux to Lux is broken || normalization issue

The Lux version of the first exemplre from in Migrating from Flux to Lux from the docs is broken. The reason is that the input x is a Matrix{Float64}, but when setting up the model ps, st = Lux.setup(rng, model) the parameters are Matrix{Float32} by default, and normalization function requires the eltypes of all their arguments to be the same. A quick fix for the example in the docs to work is just to initialize x as with Float32s: x = randn(rng, Float32, 2, 4). However I think it would be good the fix the normalization issue eventually. Also, it would be nice for keeping that line of code between Flux and Lux the same. Or we could make x = randn(rng, Float32, 2, 4) for both Lux and Flux.

For fixing the normalization issue, an some options could be to add a method with a different parametrized type for the first argument (which is x) or try to promote/convert some of the types if they are not the same. What do you think it would be the best way to handle this?

Code and error message
using Lux, Random, NNlib, Zygote

model = Chain(Dense(2 => 4), BatchNorm(4, relu), Dense(4 => 2))
rng = Random.default_rng()
x = randn(rng, 2, 4)
ps, st = Lux.setup(rng, model)
model(x, ps, st)
ERROR: MethodError: no method matching normalization(::Matrix{Float64}, ::Vector{Float32}, ::Vector{Float32}, ::Vector{Float32}, ::Vector{Float32}, ::typeof(relu), ::Vector{Int64}, ::Val{true}, ::Float32, ::Float32)
Closest candidates are:
  normalization(::AbstractArray{T, N}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Any, ::Any, ::Val) where {T, N} at ~/.julia/packages/Lux/lEqCI/src/nnlib.jl:31
  normalization(::AbstractArray{T, N}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Any, ::Any, ::Val, ::T) where {T, N} at ~/.julia/packages/Lux/lEqCI/src/nnlib.jl:31
  normalization(::AbstractArray{T, N}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Union{Nothing, AbstractVector{T}}, ::Any, ::Any, ::Val, ::T, ::T) where {T, N} at ~/.julia/packages/Lux/lEqCI/src/nnlib.jl:31
Stacktrace:
 [1] (::BatchNorm{true, true, typeof(relu), typeof(Lux.zeros32), typeof(Lux.ones32), Float32})(x::Matrix{Float64}, ps::NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, st::NamedTuple{(:running_mean, :running_var, :training), Tuple{Vector{Float32}, Vector{Float32}, Val{true}}})
   @ Lux ~/.julia/packages/Lux/lEqCI/src/layers/normalize.jl:120
 [2] macro expansion
   @ ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:0 [inlined]
 [3] applychain(layers::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(relu), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}, x::Matrix{Float64}, ps::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(:running_mean, :running_var, :training), Tuple{Vector{Float32}, Vector{Float32}, Val{true}}}, NamedTuple{(), Tuple{}}}})
   @ Lux ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:507
 [4] (::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(relu), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}})(x::Matrix{Float64}, ps::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(:running_mean, :running_var, :training), Tuple{Vector{Float32}, Vector{Float32}, Val{true}}}, NamedTuple{(), Tuple{}}}})
   @ Lux ~/.julia/packages/Lux/lEqCI/src/layers/basic.jl:504
 [5] top-level scope
   @ REPL[15]:1
 [6] top-level scope
   @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
(examples) pkg> st
Status `~/.julia/packages/Lux/lEqCI/examples/Project.toml`
  [c29ec348] AbstractDifferentiation v0.4.3
  [c7e460c6] ArgParse v1.1.4
  [02898b10] Augmentor v0.6.6
  [052768ef] CUDA v3.12.0
⌅ [b0b7db55] ComponentArrays v0.11.17
  [2e981812] DataLoaders v0.1.3
  [41bf760c] DiffEqSensitivity v6.79.0
  [587475ba] Flux v0.13.4
⌅ [acf642fa] FluxMPI v0.5.3
  [59287772] Formatting v0.4.2
  [f6369f11] ForwardDiff v0.10.30
⌅ [d9f16b24] Functors v0.2.8
  [6218d12a] ImageMagick v1.2.2
⌃ [916415d5] Images v0.24.1
  [b835a17e] JpegTurbo v0.1.1
  [b2108857] Lux v0.4.9
  [cc2ba9b6] MLDataUtils v0.5.4
  [eb30cadb] MLDatasets v0.7.4
  [f1d291b0] MLUtils v0.2.9
  [dbeba491] Metalhead v0.7.3
  [872c559c] NNlib v0.8.8
  [3bd65402] Optimisers v0.2.7
  [1dea7af3] OrdinaryDiffEq v6.18.2
  [d7d3b36b] ParameterSchedulers v0.3.3
  [91a5bcdd] Plots v1.31.3
  [37e2e3b7] ReverseDiff v1.14.1
⌅ [efcf1570] Setfield v0.8.2
  [fce5fe82] Turing v0.21.9
  [e88e6eb3] Zygote v0.6.41
  [de0858da] Printf
  [9a3f8284] Random
  [10745b16] Statistics
Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ cannot be upgraded. To see why use `status --outdated`
julia> VERSION
v"1.8.0-rc3"

Proposal of Lux + Enzyme + CUDA differential programming example

Hello, I have a working toy example of a custom function with backpropagation backed by Enzyme.jl and Lux as a tooling base. Maybe it would be useful for somebody.

using ChainRulesCore,Zygote,CUDA,Enzyme
using CUDAKernels
using KernelAbstractions
using KernelGradients
using Zygote, Lux
using Lux, Random
import NNlib, Optimisers, Plots, Random, Statistics, Zygote
using FillArrays

#### test data
Nx, Ny, Nz = 8, 8, 8
oneSidePad = 1
totalPad=oneSidePad*2
A = CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dA= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

Aoutout = CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dAoutout= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

p = CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
dp= CUDA.ones(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 

threads = (4, 4, 4)
blocks = (2, 2, 2)
rng = Random.default_rng()

#### main kernel
function testKern(A, p, Aout,Nx)
    #adding one bewcouse of padding
    x = (threadIdx().x + ((blockIdx().x - 1) * CUDA.blockDim_x())) + 1
    y = (threadIdx().y + ((blockIdx().y - 1) * CUDA.blockDim_y())) + 1
    z = (threadIdx().z + ((blockIdx().z - 1) * CUDA.blockDim_z())) + 1
    Aout[x, y, z] = A[x, y, z] *p[x, y, z] *p[x, y, z] *p[x, y, z] 
    
    return nothing
end

function testKernDeff( A, dA, p
    , dp, Aout
    , dAout,Nx)
    Enzyme.autodiff_deferred(testKern, Const, Duplicated(A, dA), Duplicated(p, dp), Duplicated(Aout, dAout),Const(Nx)
    )
    return nothing
end

function calltestKern(A, p,Nx)
    Aout = CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad ) 
    @cuda threads = threads blocks = blocks testKern( A, p,  Aout,Nx)
    return Aout
end



# rrule for ChainRules.
function ChainRulesCore.rrule(::typeof(calltestKern), A, p,Nx)
    
    Aout = calltestKern(A, p,Nx)#CUDA.zeros(Nx+totalPad, Ny+totalPad, Nz+totalPad )
    function call_test_kernel1_pullback(dAout)
        threads = (4, 4, 4)
        blocks = (2, 2, 2)
        dp = CUDA.ones(size(p))
        dA = CUDA.ones(size(A))
        #@device_code_warntype @cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, CuArray(collect(dAout)),Nx)
        @cuda threads = threads blocks = blocks testKernDeff( A, dA, p, dp, Aout, CuArray(collect(dAout)),Nx)

        f̄ = NoTangent()
        x̄ = dA
        ȳ = dp
        
        return f̄, x̄, ȳ,NoTangent()
    end   
    return Aout, call_test_kernel1_pullback

end


#first testing does custom backpropagation compiles
ress=Zygote.jacobian(calltestKern,A,p,Nx )


#lux layers from http://lux.csail.mit.edu/dev/manual/interface/
struct KernelAstr<: Lux.AbstractExplicitLayer
    confA::Int
end

function KernelA(confA)
    return KernelAstr(confA)
end

function Lux.initialparameters(rng::AbstractRNG, l::KernelAstr)
    return (paramsA=CuArray(rand(rng,Float32, l.confA, l.confA, l.confA))
    ,Nx =l.confA )
end
"""
https://stackoverflow.com/questions/52035775/in-julia-1-0-how-to-set-a-named-tuple-with-only-one-key-value-pair
in order to get named tuple with single element put comma after
"""
function Lux.initialstates(::AbstractRNG, l::KernelAstr)::NamedTuple
    return (NxSt=l.confA , )
end

function (l::KernelAstr)(x, ps, st::NamedTuple)
    return calltestKern(x, ps.paramsA,ps.Nx),st
end



l = KernelA(Nx)
ps, st = Lux.setup(rng, l)
println("Parameter Length: ", Lux.parameterlength(l), "; State Length: ",
        Lux.statelength(l))

x = randn(rng, Float32, Nx, Ny,Nz)
x= CuArray(x)
# testing weather forward pass runs
y_pred, st =Lux.apply(l, x, ps, st)



model = Lux.Chain(KernelA(Nx),KernelA(Nx)) 
opt = Optimisers.Adam(0.0003)

"""
extremely simple loss function we just want to get the result to be as close to 100 as possible
"""
function loss_function(model, ps, st, x)
    y_pred, st = Lux.apply(model, x, ps, st)
    return (100-sum(y_pred))^2, st, ()
end

tstate = Lux.Training.TrainState(rng, model, opt; transform_variables=Lux.gpu)
vjp_rule = Lux.Training.ZygoteVJP()


function main(tstate::Lux.Training.TrainState, vjp::Lux.Training.AbstractVJP, data,
    epochs::Int)
   # data = data .|> Lux.gpu
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp, loss_function,
                                                                data, tstate)
        @info epoch=epoch loss=loss
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate
end
# one epoch just to check if it runs
tstate = main(tstate, vjp_rule, x,1)
#training 
tstate = main(tstate, vjp_rule, x,1000)

Taking PRNGs seriously

Currently, we have very rudimentary handling of stochastic layers. Initialization of RNGs for stochastic layers is done as:

    randn(rng, 1)
    return (rng=replicate(rng), training=true)

This makes stochastic layers start from different RNGs. Need to look at how jax frameworks do it

Distributed Data Parallel Training on examples/ImageNet error

Distributed Data Parallel Training on example/ImageNet error

(base) xx@VM-1-8-ubuntu:~/codes/julia_learn/cv/lux/ImageNet$ mpiexecjl -n 2 julia --project=.. -t 8 main.jl --arch ResNet18 --batch-size=256 /train_tmp    
┌ Warning: MPI Implementation is not CUDA Aware
└ @ FluxMPI.MPIExtensions ~/.julia/packages/FluxMPI/nB5FP/src/mpi_extensions.jl:38 
┌ Warning: MPI Implementation is not CUDA Aware
└ @ FluxMPI.MPIExtensions ~/.julia/packages/FluxMPI/nB5FP/src/mpi_extensions.jl:38 
2022-05-15T02:21:28.548 [0 / 2] Using GPU 1
2022-05-15T02:21:28.594 [1 / 2] Using GPU 2 
2022-05-15T02:21:41.372 => creating model `ResNet18`  
2022-05-15T02:21:54.928 ==> staring `ResNet18` warmup...
2022-05-15T02:22:22.044 ==> forward pass warmup completed  
2022-05-15T02:24:32.537 ==> backward pass warmup completed   
signal (11): Segmentation fault   
in expression starting at /home/zhangyong/codes/julia_learn/cv/lux/ImageNet/main.jl:537
unknown function (ip: 0x7f0f43b0da5f)
unknown function (ip: 0x7f0b84ba9f3d)
unknown function (ip: 0x7f0b84b945ff)
unknown function (ip: 0x7f0b84b9e526)
unknown function (ip: 0x7f0b84b3076a)
unknown function (ip: 0x7f0b84a6d0b6)
unknown function (ip: 0x7f0b84a6efff)
unknown function (ip: 0x7f0b84a6f895) 
MPI_Bcast at /usr/lib/libmpi.so (unknown line) 
Bcast! at /home/zhangyong/.julia/packages/MPI/08SPr/src/collective.jl:53 [inlined]
Bcast! at /home/zhangyong/.julia/packages/MPI/08SPr/src/collective.jl:59 [inlined] 
#synchronize!#19 at /home/zhangyong/.julia/packages/FluxMPI/nB5FP/src/synchronize.jl:18 [inlined]
synchronize!##kw at /home/zhangyong/.julia/packages/FluxMPI/nB5FP/src/synchronize.jl:18 [inlined]
#17 at /home/zhangyong/.julia/packages/FluxMPI/nB5FP/src/synchronize.jl:14 [inlined] 
#fmap#17 at /home/zhangyong/.julia/packages/Functors/qBIlC/src/functor.jl:50
fmap##kw at /home/zhangyong/.julia/packages/Functors/qBIlC/src/functor.jl:49 [inlined] 
#18 at /home/zhangyong/.julia/packages/Functors/qBIlC/src/functor.jl:50
unknown function (ip: 0x7f0acc2b2732)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined] 
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
map at ./tuple.jl:221 [inlined] 
map at ./namedtuple.jl:218
_default_walk at /home/zhangyong/.julia/packages/Functors/qBIlC/src/functor.jl:43 
[inlined]                                                                       
#fmap#17 at /home/zhangyong/.julia/packages/Functors/qBIlC/src/functor.jl:50 
unknown function (ip: 0x7f0acc2b2430)
unknown function (ip: 0x7f0acc2b2009)
unknown function (ip: 0x7f0acc2b1fd0)
fmap##kw at /home/zhangyong/.julia/packages/Functors/qBIlC/src/functor.jl:49 [inlined]
#18 at /home/zhangyong/.julia/packages/Functors/qBIlC/src/functor.jl:50
unknown function (ip: 0x7f0acc2b1e8d)

......

__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line) 
_start at /home/zhangyong/.julia/juliaup/julia-1.8.0-beta3+0~x64/bin/julia (unknown line)
Allocations: 281319919 (Pool: 281207097; Big: 112822); GC: 86   
ERROR: failed process: Process(`mpiexec -n 2 julia -t 8 main.jl --arch ResNet18 --batch-size=256 /train_tmp`, ProcessExited(1)) [1]

Stacktrace:
 [1] pipeline_error
   @ ./process.jl:561 [inlined]
 [2] run(::Cmd; wait::Bool)
   @ Base ./process.jl:476
 [3] run(::Cmd)
   @ Base process.jl:474
 [4] (::var"#1#2")(exe::Cmd)
   @ Main none:4
 [5] (::MPI.var"#28#29"{var"#1#2"})(cmd::Cmd)
   @ MPI ~/.julia/packages/MPI/08SPr/src/environment.jl:25
 [6] _mpiexec(fn::MPI.var"#28#29"{var"#1#2"})
   @ MPI ~/.julia/packages/MPI/08SPr/deps/deps.jl:6
 [7] mpiexec(fn::var"#1#2")
   @ MPI ~/.julia/packages/MPI/08SPr/src/environment.jl:25
 [8] top-level scope
   @ none:4

julia> versioninfo()
Julia Version 1.8.0-beta3
Commit 3e092a2521 (2022-03-29 15:42 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: 96 × Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, cascadelake)
  Threads: 1 on 96 virtual cores
Environment:
  JULIA_CUDA_MEMORY_POOL = none

Lighter syntax for stateless networks?

When the neural network I'm working with has no evolving state, I find it quite cumbersome to always call model(x, ps, st). This becomes especially painful when the st needs to be passed down as an argument, even though it is essentially useless.
How hard would it be to have a default model(x, ps) for these cases? Or a very cheap function empty_state(model) to avoid passing down st in lower level calls?

How to avoid the activation function conversion

If I initialize a model with tanh activations they get autoconverted to tanh_fast which causes problems for my application.

Chain(
  Dense(2, layer_size, tanh; init_weight=Lux.glorot_normal),
  Dense(layer_size, 1, identity; init_weight=Lux.glorot_normal)
)
Chain(
    layer_1 = Dense(2 => 4, tanh_fast),  # 12 parameters
    layer_2 = Dense(4 => 1),            # 5 parameters
)         # Total: 17 parameters,
          #        plus 0 states, summarysize 32 bytes.

Can I avoid the automatic conversion? For now I use an anonymous functions which seem to work.

Named Layers for Container Types

Currently containers can only take tuples, and they set the named to layer_1, layer_2.... #44 was an initial prototype for this. But we need to do it for all container layers implemented in Lux.

Support for non-CUDNN data types

Hi, firstly, thanks for the great package!

I've been experimenting with using this for a project which a) requires complex numbers and b) requires both jvps (using ForwardDiff) and vjps (using Zygote for now). While this works great on the CPU, currently with GPU arrays this ends up trying to dispatch complex / dual numbers to the CUDNN kernels.

This seems like it should be fixable with some extra type constraints in a few places, e.g. changing

function elementwise_add(x::CuArray, y::CuArray)

to

function elementwise_add(x::CuArray{T}, y::CuArray{T}) where T <: CUDNNDataType

where CUDNNDataType is some suitable union.

Adding such constraints to elementwise_add, elementwise_mul and applyactivation (and applyactivations chain rule?) should get this working at least for purely dense NNs. I'm not familiar enough with the rest of the library to know what else would need changing for other layers, but I'm happy to take a stab at getting dense NNs working with non-CUDNN types if you're interested.

Edit: a related issue that will also require a fix to get this working FluxML/NNlibCUDA.jl#47.

Tracking support for Enzyme.jl

Opening this issue mostly to track how much of Lux (v0.4.7) is supported by Enzyme (v0.10.4)

  • Lux.Dense is supported
using Lux, Random, Enzyme
rng = Random.default_rng()

function loss_function(model, x, ps, st)
    return sum(Lux.apply(model, x, ps, st)[1])
end

model = Chain(Dense(2 => 4), Dense(4 => 2))
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 2, 1)

dps = Lux.fmap(zero, ps)
Enzyme.autodiff(Reverse, loss_function, Const(model), Const(x), Duplicated(ps, dps), Const(st))
println(dps)
  • Lux.BatchNorm works
using Lux, Random, Enzyme
rng = Random.default_rng()

function loss_function(model, x, ps, st)
    return sum(Lux.apply(model, x, ps, st)[1])
end

model = Chain(Dense(2 => 4), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 2, 4)

dps = Lux.fmap(zero, ps)
Enzyme.autodiff(Reverse, loss_function, Const(model), Const(x), Duplicated(ps, dps), Const(st))
println(dps)
Click to expand!
warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-pc-linux-gnu'

warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-pc-linux-gnu'

warning: Linking two modules of different target triples: 'bcloader' is 'x86_64-unknown-linux-gnu' whereas 'text' is 'x86_64-pc-linux-gnu'

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %109 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %2) #33, !dbg !321"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %255 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %3) #33, !dbg !504"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %478 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %4) #33, !dbg !700"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %553 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %5) #33, !dbg !814"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %109 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %2) #34, !dbg !323"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %255 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %3) #34, !dbg !506"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %478 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %4) #34, !dbg !702"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %553 = call nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %5) #34, !dbg !816"
└ @ Enzyme.Compiler /mnt/julia/packages/GPUCompiler/iaKrd/src/utils.jl:35

signal (11): Segmentation fault
in expression starting at REPL[33]:1
_ZNK4llvm10AllocaInst14isStaticAllocaEv at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
runOnFunction at /buildworker/worker/package_linux64/build/src/llvm-late-gc-lowering.cpp:2689
_ZN4llvm13FPPassManager13runOnFunctionERNS_8FunctionE at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
_ZN4llvm13FPPassManager11runOnModuleERNS_6ModuleE at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
_ZN4llvm6legacy15PassManagerImpl3runERNS_6ModuleE at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
LLVMRunPassManager at /mnt/softwares/julia-nightly/bin/../lib/julia/libLLVM-13jl.so (unknown line)
LLVMRunPassManager at /mnt/julia/packages/LLVM/WjSQG/lib/13/libLLVM_h.jl:4898 [inlined]
run! at /mnt/julia/packages/LLVM/WjSQG/src/passmanager.jl:39 [inlined]
#55 at /mnt/julia/packages/Enzyme/di3zM/src/compiler/optimize.jl:230
#ModulePassManager#64 at /mnt/julia/packages/LLVM/WjSQG/src/passmanager.jl:33
unknown function (ip: 0x7f4adebb715e)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
ModulePassManager at /mnt/julia/packages/LLVM/WjSQG/src/passmanager.jl:31
post_optimze! at /mnt/julia/packages/Enzyme/di3zM/src/compiler/optimize.jl:227 [inlined]
post_optimze! at /mnt/julia/packages/Enzyme/di3zM/src/compiler/optimize.jl:221 [inlined]
_thunk at /mnt/julia/packages/Enzyme/di3zM/src/compiler.jl:4617
unknown function (ip: 0x7f4ade76325d)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
cached_compilation at /mnt/julia/packages/Enzyme/di3zM/src/compiler.jl:4637
unknown function (ip: 0x7f4af61f3885)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
#s565#115 at /mnt/julia/packages/Enzyme/di3zM/src/compiler.jl:4697 [inlined]
#s565#115 at ./none:0
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
GeneratedFunctionStub at ./boot.jl:582
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
jl_call_staged at /buildworker/worker/package_linux64/build/src/method.c:520
ijl_code_for_staged at /buildworker/worker/package_linux64/build/src/method.c:571
get_staged at ./compiler/utilities.jl:114
retrieve_code_info at ./compiler/utilities.jl:126 [inlined]
InferenceState at ./compiler/inferencestate.jl:280
typeinf_edge at ./compiler/typeinfer.jl:867
abstract_call_method at ./compiler/abstractinterpretation.jl:632
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:156
abstract_call_known at ./compiler/abstractinterpretation.jl:1666
abstract_call at ./compiler/abstractinterpretation.jl:1724
abstract_call at ./compiler/abstractinterpretation.jl:1703
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1845
typeinf_local at ./compiler/abstractinterpretation.jl:2310
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2406
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:876
abstract_call_method at ./compiler/abstractinterpretation.jl:632
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:156
abstract_call_known at ./compiler/abstractinterpretation.jl:1666
abstract_call at ./compiler/abstractinterpretation.jl:1724
abstract_call at ./compiler/abstractinterpretation.jl:1703
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1845
typeinf_local at ./compiler/abstractinterpretation.jl:2310
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2406
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_edge at ./compiler/typeinfer.jl:876
abstract_call_method at ./compiler/abstractinterpretation.jl:632
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:156
abstract_call_known at ./compiler/abstractinterpretation.jl:1666
abstract_call at ./compiler/abstractinterpretation.jl:1724
abstract_call at ./compiler/abstractinterpretation.jl:1703
abstract_eval_statement at ./compiler/abstractinterpretation.jl:1845
typeinf_local at ./compiler/abstractinterpretation.jl:2310
typeinf_nocycle at ./compiler/abstractinterpretation.jl:2406
_typeinf at ./compiler/typeinfer.jl:230
typeinf at ./compiler/typeinfer.jl:213
typeinf_ext at ./compiler/typeinfer.jl:957
typeinf_ext_toplevel at ./compiler/typeinfer.jl:990
typeinf_ext_toplevel at ./compiler/typeinfer.jl:986
jfptr_typeinf_ext_toplevel_16088.clone_1 at /mnt/softwares/julia-nightly/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
jl_type_infer at /buildworker/worker/package_linux64/build/src/gf.c:319
jl_generate_fptr_impl at /buildworker/worker/package_linux64/build/src/jitlayers.cpp:314
jl_compile_method_internal at /buildworker/worker/package_linux64/build/src/gf.c:2072
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2350 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
autodiff at /mnt/julia/packages/Enzyme/di3zM/src/Enzyme.jl:285
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
do_apply at /buildworker/worker/package_linux64/build/src/builtins.c:725
autodiff at /mnt/julia/packages/Enzyme/di3zM/src/Enzyme.jl:319
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
do_apply at /buildworker/worker/package_linux64/build/src/builtins.c:725
autodiff at /mnt/julia/packages/Enzyme/di3zM/src/Enzyme.jl:412
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
do_call at /buildworker/worker/package_linux64/build/src/interpreter.c:126
eval_value at /buildworker/worker/package_linux64/build/src/interpreter.c:215
eval_stmt_value at /buildworker/worker/package_linux64/build/src/interpreter.c:166 [inlined]
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:612
jl_interpret_toplevel_thunk at /buildworker/worker/package_linux64/build/src/interpreter.c:750
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:906
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:850
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:850
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:556
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:522
jl_interpret_toplevel_thunk at /buildworker/worker/package_linux64/build/src/interpreter.c:750
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:906
ijl_toplevel_eval_in at /buildworker/worker/package_linux64/build/src/toplevel.c:965
eval at ./boot.jl:368 [inlined]
eval_user_input at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
repl_backend_loop at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
start_repl_backend at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
#run_repl#47 at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
run_repl at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:356
jfptr_run_repl_63590.clone_1 at /mnt/softwares/julia-nightly/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
#964 at ./client.jl:419
jfptr_YY.964_53574.clone_1 at /mnt/softwares/julia-nightly/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
jl_f__call_latest at /buildworker/worker/package_linux64/build/src/builtins.c:769
#invokelatest#2 at ./essentials.jl:729 [inlined]
invokelatest at ./essentials.jl:727 [inlined]
run_main_repl at ./client.jl:404
exec_options at ./client.jl:318
_start at ./client.jl:522
jfptr__start_58493.clone_1 at /mnt/softwares/julia-nightly/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2358 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2540
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1831 [inlined]
true_main at /buildworker/worker/package_linux64/build/src/jlapi.c:567
jl_repl_entrypoint at /buildworker/worker/package_linux64/build/src/jlapi.c:711
main at /buildworker/worker/package_linux64/build/cli/loader_exe.c:59
unknown function (ip: 0x7f4af79d9d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
_start at /mnt/softwares/julia-nightly/bin/julia (unknown line)
Allocations: 133145644 (Pool: 133050263; Big: 95381); GC: 61

TagBot trigger issue

This issue is used to trigger TagBot; feel free to unsubscribe.

If you haven't already, you should update your TagBot.yml to include issue comment triggers.
Please see this post on Discourse for instructions and more details.

If you'd like for me to do this for you, comment TagBot fix on this issue.
I'll open a PR within a few hours, please be patient!

Thoughts on docs & tutorials

Hi! Yesterday I read through the docs and here are some comments. For the early stage of the package you already have quite a lot of documentation, which is really nice :)

  • In "training a simple lstm" there's the paragraph: "We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates." - To me this is not very clear. How do I know that I have to parametrize with :lstm_cell and :classifier? What is the mechanism saving me from defining initialparameters and initalstates?
  • The beginner tutorial should comprise an example of very basic MNIST training using CNNs (chain of convs/pool/dense).
  • The "NeuralODE" example could either be moved to advanced or link to https://book.sciml.ai/notes/11/ for more background.
  • It should be explained why pullback is used, what it takes as args and what it returns, although this would probably repeat the Zygote docs. What are all these nothings for and how many do I need? :D Why do I need only need the first element of the gradient? Especially new users are maybe not perfectly aware of the split between DL library and AD system.
  • Advanced example: ImageNet training
    • Lux.transform hasn't made it to the docs yet
    • Seems like there is some mixup between MLUtils and LearnBase, due to the transition from one to the other.
    • The meter code is a lot of boilerplate for an example; it's nice to see, but looks like it belongs into some package
    • The learning rate should be scaled linearly with the number of nodes, but should it increase or decrease? I think increase, right? I don't remember why.
    • An explanation why full gc sweeps and the CUDA.reclaim() is needed in each epoch (twice?).

How to freeze layers?

Thanks for the awesome project! I really enjoyed your talk too.

Flux.jl has option for trainable as well as deleting things from params
Certain tasks may have some layers frozen for a few epochs and trainable later or vice versa

What is the recommended way in these scenarios?

  1. Freeze layers of a particular types say Conv2 (The way we do this in PyTorch is going through the params and filtering by Type)
  2. Freeze a particular layer by index/name
  3. Freeze only parts of a layer say first w of W in Dense
  4. Freeze Series of layers say all except last 2 layers

Thanks!

Handle parameters as plain arrays

It would be useful to be able to handle the plain version of the parameters.
Otherwise it is necessary to keep track of the structure of a chain manually to reshape accordingly the plain parameters before evaluating a chain.

using Lux, Random, Optimisers, Zygote, ComponentArrays

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
              Chain(Dense(256, 1, tanh),Dense(1, 10)))

# Parameter and State Variables
ps, st = Lux.setup(rng, model)

# Dummy Input
x = rand(rng, Float32, 128, 2)

# Run the model
y, st = Lux.apply(model, x, ps, st)  # works

flat = ComponentArray(ps) |> collect
y, st = Lux.apply(model, x, flat, st)  # fails

ax = getaxes(ComponentArray(ps))
ps_new = ComponentArray(flat, ax)
y, st = Lux.apply(model, x, ps_new, st)  # works but feels like a workaround

WeightNorm causes NaN for Conv layer gradients

When normalizing the bias of a conv layer, Zygote returns NaNs for the gradient of bias_v. This also happens with 2d conv layers. The gradient works as expected without normalizing the bias.

using Lux, Random, Zygote

function test_weightnorm()
    Random.seed!(12345)
    rng = Random.default_rng()
    x = randn(Float32, 300, 72, 32)

    model = WeightNorm(Conv((9,), 72=>72, stride=1, pad=1, dilation=1), (:weight, :bias))
    ps, st = Lux.setup(rng, model)

    ∇params, _ = gradient(ps, x) do p, x
        pred, _ = Lux.apply(model, x, p, st)
        sum(pred)
    end
    println(∇params[:normalized][:bias_v])
end

test_weightnorm()

prints

[NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN;;;]

with

Julia v1.7.3
Lux v0.4.9
NNlib v0.8.8
Zygote v0.6.41

is there transposed convolution

Hello thanks for creating this fantastic library !

I am trying to create a Unet architecture and image autoenconder, there are necessery convolutions and poolings but I do not see Transposed convolution

like ConvTranspose in https://github.com/FluxML/Flux.jl/blob/d26ebdb45c6e7af98c562cbc8c32c3492acfee8d/src/layers/conv.jl

below Flux code for reference

struct ConvTranspose{N,M,F,A,V}
  σ::F
  weight::A
  bias::V
  stride::NTuple{N,Int}
  pad::NTuple{M,Int}
  dilation::NTuple{N,Int}
  groups::Int
end

_channels_in(l::ConvTranspose)  = size(l.weight)[end]
_channels_out(l::ConvTranspose) = size(l.weight)[end-1]*l.groups

"""
    ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation, groups])
Constructs a ConvTranspose layer with the given weight and bias.
Accepts the same keywords and has the same defaults as
[`ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ; ...)`](@ref ConvTranspose).
# Examples
```jldoctest
julia> weight = rand(3, 4, 5);
julia> bias = zeros(4);
julia> layer = ConvTranspose(weight, bias, sigmoid)
ConvTranspose((3,), 5 => 4, σ)  # 64 parameters
julia> layer(randn(100, 5, 64)) |> size  # transposed convolution will increase the dimension size (upsampling)
(102, 4, 64)
julia> Flux.params(layer) |> length
2
"""
function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
                      stride = 1, pad = 0, dilation = 1, groups=1) where {T,N}
  stride = expand(Val(N-2), stride)
  dilation = expand(Val(N-2), dilation)
  pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
  b = create_bias(w, bias, size(w, N-1) * groups)
  return ConvTranspose(σ, w, b, stride, pad, dilation, groups)
end

function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
                      init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
                      groups = 1,
                      bias = true,
                      ) where N

  weight = convfilter(k, reverse(ch); init, groups)                    
  ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
end

@functor ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
  # Calculate size of "input", from ∇conv_data()'s perspective...
  combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
  I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
  C_in = size(c.weight)[end-1] * c.groups
  batch_size = size(x)[end]
  # Create DenseConvDims() that looks like the corresponding conv()
  w_size = size(c.weight)
  return DenseConvDims((I..., C_in, batch_size), w_size;
                      stride=c.stride,
                      padding=c.pad,
                      dilation=c.dilation,
                      groups=c.groups,
  )
end

ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)

function (c::ConvTranspose)(x::AbstractArray)
  σ = NNlib.fast_act(c.σ, x)
  cdims = conv_transpose_dims(c, x)
  σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

Front page example broken

Hi! Thanks for this interesting work! I just tried the front page example and it turned out not to work for me. Taking the gradient fails with:

julia> gs = gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)[1]
ERROR: Compiling Tuple{NNlibCUDA.var"##cudnnBNForward!#87", Nothing, Float32, Float32, Float32, Bool, Bool, Bool, typeof(NNlibCUDA.cudnnBNForward!), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32}: try/catch is not supported.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/reverse.jl:121
  [3] #Primal#19
    @ ~/.julia/packages/Zygote/Y6SC4/src/compiler/reverse.jl:202 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/reverse.jl:315
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/emit.jl:101
  [6] #s3043#1206
    @ ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s3043#1206"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
  [9] _pullback
    @ ~/.julia/packages/NNlibCUDA/i1IW9/src/cudnn/batchnorm.jl:48 [inlined]
 [10] _pullback(::Zygote.Context, ::NNlibCUDA.var"#cudnnBNForward!##kw", ::NamedTuple{(:eps, :training), Tuple{Float32, Bool}}, ::typeof(NNlibCUDA.cudnnBNForward!), ::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [11] _pullback (repeats 2 times)
    @ ~/.julia/packages/NNlibCUDA/i1IW9/src/cudnn/batchnorm.jl:37 [inlined]
 [12] _pullback
    @ ~/.julia/packages/NNlibCUDA/i1IW9/src/cudnn/batchnorm.jl:31 [inlined]
 [13] _pullback(::Zygote.Context, ::NNlibCUDA.var"##batchnorm#85", ::Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol}, NamedTuple{(:eps, :training), Tuple{Float32, Bool}}}, ::typeof(NNlibCUDA.batchnorm), ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32)
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [14] _pullback
    @ ~/.julia/packages/NNlibCUDA/i1IW9/src/cudnn/batchnorm.jl:30 [inlined]
 [15] _pullback
    @ ~/.julia/packages/Lux/HkXlk/src/layers/normalize.jl:114 [inlined]
 [16] _pullback(::Zygote.Context, ::BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:γ, :β), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::NamedTuple{(:μ, :σ², :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [17] macro expansion
    @ ~/.julia/packages/Lux/HkXlk/src/layers/basic.jl:0 [inlined]
 [18] _pullback
    @ ~/.julia/packages/Lux/HkXlk/src/layers/basic.jl:330 [inlined]
 [19] _pullback(::Zygote.Context, ::typeof(Lux.applychain), ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:γ, :β), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:γ, :β), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:μ, :σ², :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(:μ, :σ², :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [20] _pullback
    @ ~/.julia/packages/Lux/HkXlk/src/layers/basic.jl:328 [inlined]
 [21] _pullback(::Zygote.Context, ::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:γ, :β), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:γ, :β), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:μ, :σ², :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(:μ, :σ², :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [22] _pullback
    @ ~/.julia/packages/Lux/HkXlk/src/core.jl:61 [inlined]
 [23] _pullback(::Zygote.Context, ::typeof(Lux.apply), ::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, BatchNorm{true, true, typeof(identity), typeof(Lux.zeros32), typeof(Lux.ones32), Float32}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:γ, :β), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:γ, :β), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:μ, :σ², :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(:μ, :σ², :training), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Bool}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [24] _pullback
    @ ./REPL[10]:1 [inlined]
 [25] _pullback(ctx::Zygote.Context, f::var"#1#2", args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(, ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(, ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface2.jl:0
 [26] _pullback(f::Function, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(, ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(, ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:34
 [27] pullback(f::Function, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(, ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(, ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:40
 [28] gradient(f::Function, args::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(, ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(, ), Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}})
    @ Zygote ~/.julia/packages/Zygote/Y6SC4/src/compiler/interface.jl:75
 [29] top-level scope
    @ REPL[10]:1
 [30] top-level scope
    @ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52

Package status output & version:

  [6e4b80f9] BenchmarkTools v1.3.1
  [587475ba] Flux v0.13.0
  [bdcacae8] LoopVectorization v0.12.108
  [b2108857] Lux v0.3.0 `[email protected]:avik-pal/Lux.jl.git#main`
  [356022a1] NamedDims v0.2.47
  [3bd65402] Optimisers v0.2.3
  [c46f51b8] ProfileView v1.5.1
  [94979ff8] RSPointMatching v0.1.0 `~/projects/RSPointMatching`
  [90137ffa] StaticArrays v1.4.4
  [e88e6eb3] Zygote v0.6.39
  [9a3f8284] Random

Julia Version 1.7.2
Commit bf53498635 (2022-02-06 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: AMD Ryzen 7 1700X Eight-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, znver1)
Environment:
  JULIA_PKG_DEVDIR = projects/

optimising parameters with Optimization.jl

I did try to use a ComponentArray as @ChrisRackauckas on slack but no luck while as asking for a way to destructure and re-structe in Lux.jl, similar to Flux.jl 😅 .

using Lux, Optim, Optimization, OptimizationOptimJL
using Random, ComponentArrays, StatsBase
model = Lux.Chain(
  Dense(2, 8, tanh),
  Dense(8, 1,x->x^2), x -> reshape(x, :)
  )

rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)
ŷ, st = Lux.apply(model, rand(Float32,2,10), ps, st)
y = rand(Float32, 10)
function loss(y, model, ps, st)
  ŷ, st = Lux.apply(model, rand(Float32,2,10), ps, st) # ps = ps_new[2]
  #ŷ += ps_new[1]
  mean(abs2.(y .- ŷ))
end
loss_one(ps) = loss(y, model, ps, st)
optim_cost = (p, tmp=nothing) -> loss_one(p)
optim_cost(ps) # up to here it works.
optim_prob = OptimizationProblem(optim_cost, ps) # what about having ps_new = [1.0f0, ps] ?
optim_para = solve(optim_prob, Optim.BFGS(initial_stepnorm=0.1)) # fails

concat input and output of a layer

Hello, I want to have a layer that is built from a chain of convolutions and returns concatenated output of those convolutions and input.
The code that I have below based on SkipConnection works during forward pass but breaks on backpropagation on myCatt function. How to do it properly?

import Lux
import NNlib, Optimisers, Plots, Random, Statistics, Zygote, HDF5

Nx, Ny, Nz = 32, 32, 32
oneSidePad = 1
totalPad = oneSidePad*2
dim_x,dim_y,dim_z= Nx+totalPad, Ny+totalPad, Nz+totalPad
featureNumb=3
conv1 = (in, out) -> Lux.Conv((3,3,3),  in => out , NNlib.tanh, stride=1, pad=Lux.SamePad())
rng = Random.default_rng()

function myCatt(a,b)
    cat(a,b;dims=4)
end    
modelConv=Lux.Chain(conv1(featureNumb,4),conv1(4,16),conv1(16,4),conv1(4,3))
modelConv=Lux.SkipConnection(modelConv,myCatt)
# modelConv=Lux.BranchLayer(modelConv,Lux.NoOpLayer)
ps, st = Lux.setup(rng, modelConv)
x = ones(rng, Float32, dim_x,dim_y,dim_z,featureNumb)
x =reshape(x, (dim_x,dim_y,dim_z,featureNumb,1))
y_pred, st =Lux.apply(modelConv, x, ps, st) 
size(y_pred)

Quickstart Example: `using Optimisers, Zygote` do not work unless we explicitly add those to current environment.

Hi @avik-pal, Firstly, thanks for the great work that you are putting into this Library :)

Issue:

I'm new to Lux.jl, so i tried the Quick start example in my Colab NB and found out that even though Lux installs Optimisers and Zygote packages, we cannot directly use (as per Quickstart example) -->

using Optimisers, Zygote

we get Package not found error.

Workaround:

First we have to add both packages explicitly -->

Pkg.add(["Optimisers", "Zygote"])

though this won't install the packages again, instead it will reuse the installed packages (done during Lux installation) with the advantage of name spacing directly like below -->

using Optimisers, Zygote

Or else we would have to namespace like this (i suppose) -->

using Lux.Optimisers, Lux.Zygote

As this example is the first thing most people tries to run who are new to Lux, i thought i would bring up this issue.

P.S: I only tried this on Colab, not locally.

`PairwiseFusion` takes more inputs than documented

Currently, the documentation for PairwiseFusion reports that the layer takes N-input tuples for N layers:

https://github.com/avik-pal/Lux.jl/blob/0ca5a265b7ef8c6d3da2b2dfacfefc56a39f5163/src/layers/basic.jl#L358

This, however, is not true. PairwiseFusion seems to be taking (N+1)-input tuples for N layers:

julia> layer = PairwiseFusion(+,  Dense(1, 30),  Dense(30, 10))
PairwiseFusion(
    +
    layer_1 = Dense(1 => 30),           # 60 parameters
    layer_2 = Dense(30 => 10),          # 310 parameters
)         # Total: 370 parameters,
          #        plus 0 states, summarysize 32 bytes.

julia> ps, st = Lux.setup(rng, layer);

julia> layer(x, ps, st)
ERROR: BoundsError: attempt to access Tuple{Matrix{Float64}, Matrix{Float64}} at index [3]
Stacktrace:
 [1] getindex(t::Tuple, i::Int64)
   @ Base ./tuple.jl:29
 [2] macro expansion
   @ ~/.julia/packages/Lux/0jReB/src/layers/basic.jl:416 [inlined]
 [3] applypairwisefusion(layers::NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}, connection::typeof(+), x::Tuple{Matrix{Float64}, Matrix{Float64}}, ps::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, st::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
   @ Lux ~/.julia/packages/Lux/0jReB/src/layers/basic.jl:404
 [4] (::PairwiseFusion{typeof(+), NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}})(x::Tuple{Matrix{Float64}, Matrix{Float64}}, ps::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, st::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
   @ Lux ~/.julia/packages/Lux/0jReB/src/layers/basic.jl:401
 [5] top-level scope
   @ REPL[9]:1
 [6] top-level scope
   @ ~/.julia/packages/CUDA/fAEDi/src/initialization.jl:52

julia> x = (rand(1, 10), rand(30, 10), rand(10, 10));

julia> layer(x, ps, st)[1] |> size
(10, 10)

This seems to be because each iteration seems to end only after combination of the current layers output with the next input i.e. for two layers, we still need $x_3$ because the layer stops after $y_2$ combines with $x_3$. I would've filed a docs PR but I wasn't sure if this was intended behaviour...

Recurrent Neural Networks

Get a prototype example of RNNs in. Given the way we handle parameters and states it should be trivial to handle RNNs. It is only a matter of providing a simple example

Remaining Deprecations

Making this issue to remember to deprecate / fix some features before v0.5.

  • bias should become use_bias
  • Add initialstates and initialparameters for nothing layer. (I am slightly on the fence for this one) -- Adding this is probably going to create more confusion.
  • Remove elementwise_* and apply_activation functions

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.