Giter VIP home page Giter VIP logo

Comments (8)

maxwindiff avatar maxwindiff commented on May 25, 2024 2

In-place is slow because it's hitting the init === nothing code path: https://github.com/JuliaGPU/Metal.jl/blob/main/src/mapreduce.jl#L230-L237

If GPUArrays.neutral_element() returned nothing by default, we may be able to something like:

-Base.mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
+Base.mapreducedim!(f, op, R::AnyGPUArray{T}, A::AbstractArray) where {T} =
+  mapreducedim!(f, op, R, A; init=neutral_element(op, T))

With my limited Julia fundamentals knowledge, I don't know how to extend neutral_element without breaking compatibility. Let me try other ways of initializing the partial reduction array...

from metal.jl.

maxwindiff avatar maxwindiff commented on May 25, 2024 1

I tried writing a reduction kernel which only supports 1d arrays, and it's about 4x as fast as the current implementation. I'll try to see if the generic implementation can be further improved.

from metal.jl.

maxwindiff avatar maxwindiff commented on May 25, 2024 1

Reductions are generally faster now, however in-place is still very slow:

julia> @btime sum($a)
  760.000 μs (0 allocations: 0 bytes)
5.001241f6

julia> @btime sum($Ma)
  708.083 μs (1197 allocations: 27.76 KiB)
5.001241f6

julia> @btime Metal.@sync sum!($r, $Ma)
  376.325 ms (101199 allocations: 2.00 MiB)
1-element MtlVector{Float32}:
 5.001241f6

from metal.jl.

mchitre avatar mchitre commented on May 25, 2024

Similar results on Ventura as well, so that's not the cause.

from metal.jl.

maxwindiff avatar maxwindiff commented on May 25, 2024

On my computer:

julia> a = fill(Float32(1.0), 10*1024*1024);
julia> da = MtlArray(a);
julia> @btime sum(a)
  844.500 μs (1 allocation: 16 bytes)
1.048576f7
julia> @btime sum(da)
  2.707 ms (857 allocations: 23.66 KiB)
1.048576f7

Now, if we do this:

diff --git a/src/mapreduce.jl b/src/mapreduce.jl
index 1d84d78..900f21d 100644
--- a/src/mapreduce.jl
+++ b/src/mapreduce.jl
@@ -123,7 +123,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, Rreduce, Rother, s
             ireduce += localDim_reduce * groupDim_reduce
         end
 
-        val = reduce_group(op, val, neutral, shuffle, maxthreads)
+        val = 1 # reduce_group(op, val, neutral, shuffle, maxthreads)
 
         # write back to memory
         if localIdx_reduce == 1

It still takes 2ms to simply loop over the input/output arrays!

julia> @btime sum(da)
  2.015 ms (857 allocations: 23.66 KiB)
1.0f0

My guess is that the slowdown is from all the indexing calculations (same as #41). But it's even harder to eliminate the cartesian indexing because the reduction process itself can add additional dimensions...

from metal.jl.

maleadt avatar maleadt commented on May 25, 2024

Good article: https://betterprogramming.pub/optimizing-parallel-reduction-in-metal-for-apple-m1-8e8677b49b01

from metal.jl.

rveltz avatar rveltz commented on May 25, 2024

It would be good to write a similar blog using Metal.jl

from metal.jl.

maleadt avatar maleadt commented on May 25, 2024

Another good source: https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/reduce.metal + https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/reduce.cpp

from metal.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.