Giter VIP home page Giter VIP logo

Comments (13)

mashu avatar mashu commented on September 22, 2024

The same code run on CUDA

@btime CUDA.@sync Flux.gradient(mha) do m
           sum(first(m(x, x, x)))
       end
 11.983 ms (2583 allocations: 137.55 KiB)

whereas

@btime CUDA.@sync gradient_ez(mha) do m
           sum(first(m($x, $x, $x)))
       end
       
 ....
   [2] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/srACB/src/api.jl:190
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:3141
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5074
  [5] codegen
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:4481 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771
  [7] _thunk
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5809 [inlined]
  [9] (::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{4, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5875
 [10] JuliaContext(f::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{}, Int64, Bool, Bool, UInt64, DataType}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
 [12] #s2027#559
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5827 [inlined]

So it can be reproduced with following packages and Julia 1.10.3

  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [7da242da] Enzyme v0.12.6
  [587475ba] Flux v0.14.15
  [e88e6eb3] Zygote v0.6.70
  [02a925ec] cuDNN v1.3.1

Thanks!

from flux.jl.

wsmoses avatar wsmoses commented on September 22, 2024

@mashu can you post the whole log?

from flux.jl.

mashu avatar mashu commented on September 22, 2024

I was convinced I attached it earlier, but apparently I didn't so here it is
MWA.log.gz
The following code was run as

julia --project=@. src/MWA.jl 2> MWA.log

using Enzyme
using Flux
using CUDA

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Active(x))
        else
            push!(args, Duplicated(x, make_zero(x)))
        end
    end
    ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return g
end

x = CUDA.rand(Float32, 64, 100, 512)
mha = MultiHeadAttention(64 => 64 => 64) |> gpu

Flux.gradient(mha) do m
    sum(first(m(x, x, x)))
end

Δ = gradient_ez(mha) do m
    sum(first(m(x, x, x)))
end

from flux.jl.

wsmoses avatar wsmoses commented on September 22, 2024

Also does this wokr on CPU?

from flux.jl.

mashu avatar mashu commented on September 22, 2024

@wsmoses Initially I got compilation error with CPU version, but after moving to separate project (MWE) it only fails for GPU. Having said that, I still can't figure out why it fails in my main project, as packages are up to date and basically the same version. But this GPU failure is at least reproducible.

from flux.jl.

wsmoses avatar wsmoses commented on September 22, 2024

GPU is in progress so the report is super helpful but also presently expected.

Maybe check the current versions of packages in your project and see if it's forcing an older Enzyme?

from flux.jl.

mashu avatar mashu commented on September 22, 2024

It's the same version of ⌅ [7cc45869] Enzyme_jll v0.0.109+0 in both working and non-working version. Must be some indirect dependency that I can't figure out.
As for the GPU part, my impression is that CPU paths are sometimes slow in Flux and not well optimized, probably because most people use GPU paths for any work.

from flux.jl.

wsmoses avatar wsmoses commented on September 22, 2024

Ah but what's your Enzyme version (rather than Enzyme_jll which is a dependncy)

from flux.jl.

mashu avatar mashu commented on September 22, 2024

Looks the same v0.12.6

Working MWE ]st

  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [7da242da] Enzyme v0.12.6
  [587475ba] Flux v0.14.15
  [e88e6eb3] Zygote v0.6.70
  [02a925ec] cuDNN v1.3.1

Broken one ]st

  [6e4b80f9] BenchmarkTools v1.5.0
  [336ed68f] CSV v0.10.14
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [a93c6f00] DataFrames v1.6.1
  [864edb3b] DataStructures v0.18.20
  [31c24e10] Distributions v0.25.108
  [7da242da] Enzyme v0.12.6
  [c2308a5c] FASTX v2.1.5
  [587475ba] Flux v0.14.15
  [41a02a25] Folds v0.2.10
  [033835bb] JLD2 v0.4.47
  [682c06a0] JSON v0.21.4
  [e6f89c97] LoggingExtras v1.0.3
  [12afc1b8] NeuralAttentionlib v0.2.13
  [0b1bfda6] OneHotArrays v0.2.5
  [3bd65402] Optimisers v0.3.3
  [d7d3b36b] ParameterSchedulers v0.4.1
  [92933f4c] ProgressMeter v1.10.0
  [2913bbd2] StatsBase v0.34.3
  [b8865327] UnicodePlots v3.6.4
  [02a925ec] cuDNN v1.3.1
  [56ddb016] Logging

from flux.jl.

mashu avatar mashu commented on September 22, 2024

Also including log with error that happens CPU side on the broken project, not sure if that helps though.
CPU.log

from flux.jl.

wsmoses avatar wsmoses commented on September 22, 2024

From the log I think the simplest answer here is we should just add the attention custom derivative in nnlib. I assume there's one already for CR?

If so you can try our import CR rule into enzyme macro as a test to see if anything else fails, while in the interim we can look at making a fast rule for (CR rules will be slower and come with caveats)

from flux.jl.

mashu avatar mashu commented on September 22, 2024

@wsmoses Long story short, I wanted to use Enzyme, because I often lack skills to write rrule and there is none for MultiHeadAttention in NNlib. Longer answer is that I am using currently NeuralAttentionlib.jl which is part of Transformers.jl which has customization to layer I need and rrule that makes that variant of MHA couple of times faster on GPU. My hope was that maybe Enzyme does better job than Zygote when it comes to performance of the code it produces (when no rrule is provided).

from flux.jl.

wsmoses avatar wsmoses commented on September 22, 2024

If you can wait a short bit (it's currently unregistered and there's a bunch of small things we should add), Reactant.jl is an execution engine (eg does tons of fancy optimizations/kernel fusion), is both Enzyme and GPU compatible out of the box, and might be what you're looking for.

In the interim I'll push on the GPU support for native Enztme here too, but just throwing that out there if helpful.

https://github.com/EnzymeAD/Reactant.jl

from flux.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.