Comments (13)
The same code run on CUDA
@btime CUDA.@sync Flux.gradient(mha) do m
sum(first(m(x, x, x)))
end
11.983 ms (2583 allocations: 137.55 KiB)
whereas
@btime CUDA.@sync gradient_ez(mha) do m
sum(first(m($x, $x, $x)))
end
....
[2] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/srACB/src/api.jl:190
[3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:3141
[4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5074
[5] codegen
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:4481 [inlined]
[6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771
[7] _thunk
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771 [inlined]
[8] cached_compilation
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5809 [inlined]
[9] (::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{4, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5875
[10] JuliaContext(f::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{…}, Int64, Bool, Bool, UInt64, DataType}; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
[11] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
[12] #s2027#559
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5827 [inlined]
So it can be reproduced with following packages and Julia 1.10.3
[6e4b80f9] BenchmarkTools v1.5.0
[052768ef] CUDA v5.3.4
[082447d4] ChainRules v1.66.0
[d360d2e6] ChainRulesCore v1.23.0
[7da242da] Enzyme v0.12.6
[587475ba] Flux v0.14.15
[e88e6eb3] Zygote v0.6.70
[02a925ec] cuDNN v1.3.1
Thanks!
from flux.jl.
@mashu can you post the whole log?
from flux.jl.
I was convinced I attached it earlier, but apparently I didn't so here it is
MWA.log.gz
The following code was run as
julia --project=@. src/MWA.jl 2> MWA.log
using Enzyme
using Flux
using CUDA
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end
x = CUDA.rand(Float32, 64, 100, 512)
mha = MultiHeadAttention(64 => 64 => 64) |> gpu
Flux.gradient(mha) do m
sum(first(m(x, x, x)))
end
Δ = gradient_ez(mha) do m
sum(first(m(x, x, x)))
end
from flux.jl.
Also does this wokr on CPU?
from flux.jl.
@wsmoses Initially I got compilation error with CPU version, but after moving to separate project (MWE) it only fails for GPU. Having said that, I still can't figure out why it fails in my main project, as packages are up to date and basically the same version. But this GPU failure is at least reproducible.
from flux.jl.
GPU is in progress so the report is super helpful but also presently expected.
Maybe check the current versions of packages in your project and see if it's forcing an older Enzyme?
from flux.jl.
It's the same version of ⌅ [7cc45869] Enzyme_jll v0.0.109+0
in both working and non-working version. Must be some indirect dependency that I can't figure out.
As for the GPU part, my impression is that CPU paths are sometimes slow in Flux and not well optimized, probably because most people use GPU paths for any work.
from flux.jl.
Ah but what's your Enzyme version (rather than Enzyme_jll which is a dependncy)
from flux.jl.
Looks the same v0.12.6
Working MWE ]st
[6e4b80f9] BenchmarkTools v1.5.0
[052768ef] CUDA v5.3.4
[082447d4] ChainRules v1.66.0
[d360d2e6] ChainRulesCore v1.23.0
[7da242da] Enzyme v0.12.6
[587475ba] Flux v0.14.15
[e88e6eb3] Zygote v0.6.70
[02a925ec] cuDNN v1.3.1
Broken one ]st
[6e4b80f9] BenchmarkTools v1.5.0
[336ed68f] CSV v0.10.14
[052768ef] CUDA v5.3.4
[082447d4] ChainRules v1.66.0
[d360d2e6] ChainRulesCore v1.23.0
[a93c6f00] DataFrames v1.6.1
[864edb3b] DataStructures v0.18.20
[31c24e10] Distributions v0.25.108
[7da242da] Enzyme v0.12.6
[c2308a5c] FASTX v2.1.5
[587475ba] Flux v0.14.15
[41a02a25] Folds v0.2.10
[033835bb] JLD2 v0.4.47
[682c06a0] JSON v0.21.4
[e6f89c97] LoggingExtras v1.0.3
[12afc1b8] NeuralAttentionlib v0.2.13
[0b1bfda6] OneHotArrays v0.2.5
[3bd65402] Optimisers v0.3.3
[d7d3b36b] ParameterSchedulers v0.4.1
[92933f4c] ProgressMeter v1.10.0
[2913bbd2] StatsBase v0.34.3
[b8865327] UnicodePlots v3.6.4
[02a925ec] cuDNN v1.3.1
[56ddb016] Logging
from flux.jl.
Also including log with error that happens CPU side on the broken project, not sure if that helps though.
CPU.log
from flux.jl.
From the log I think the simplest answer here is we should just add the attention custom derivative in nnlib. I assume there's one already for CR?
If so you can try our import CR rule into enzyme macro as a test to see if anything else fails, while in the interim we can look at making a fast rule for (CR rules will be slower and come with caveats)
from flux.jl.
@wsmoses Long story short, I wanted to use Enzyme, because I often lack skills to write rrule and there is none for MultiHeadAttention in NNlib. Longer answer is that I am using currently NeuralAttentionlib.jl which is part of Transformers.jl which has customization to layer I need and rrule that makes that variant of MHA couple of times faster on GPU. My hope was that maybe Enzyme does better job than Zygote when it comes to performance of the code it produces (when no rrule is provided).
from flux.jl.
If you can wait a short bit (it's currently unregistered and there's a bunch of small things we should add), Reactant.jl is an execution engine (eg does tons of fancy optimizations/kernel fusion), is both Enzyme and GPU compatible out of the box, and might be what you're looking for.
In the interim I'll push on the GPU support for native Enztme here too, but just throwing that out there if helpful.
https://github.com/EnzymeAD/Reactant.jl
from flux.jl.
Related Issues (20)
- SamePad() for even sized filters.
- Dense layers with shared parameters HOT 5
- Implementation of `AdamW` differs from PyTorch HOT 10
- `gpu` should warn if cuDNN is not installed HOT 2
- Cannot take `gradient` of L2 regularization loss HOT 1
- Create a flag to use Enzyme as the AD in training/etc. HOT 14
- test Enzyme gradient for loss functions
- test Enzyme gpu support
- Enable github Discussions
- Stacked RNN in Flux.jl?
- Add option to throw error on passing wrong precision floats to layers HOT 3
- Potential bug of RNN training flow
- why is my `withgradient` type unstable ? HOT 1
- is `Flux.huber_loss` type-unstable ?
- Can't load a Fluxml trained & saved model. Getting ERROR: CUDA error: invalid device context (code 201, ERROR_INVALID_CONTEXT) HOT 1
- ConvTranspose with padding on cpu throws exception HOT 1
- DifferentiationInterface testing HOT 6
- Requires deprecated cuNN.jl package HOT 1
- Model saved under Flux v0.14.16 does not load on v0.14.17 HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flux.jl.