Comments (8)
In-place is slow because it's hitting the init === nothing
code path: https://github.com/JuliaGPU/Metal.jl/blob/main/src/mapreduce.jl#L230-L237
If GPUArrays.neutral_element()
returned nothing
by default, we may be able to something like:
-Base.mapreducedim!(f, op, R::AnyGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
+Base.mapreducedim!(f, op, R::AnyGPUArray{T}, A::AbstractArray) where {T} =
+ mapreducedim!(f, op, R, A; init=neutral_element(op, T))
With my limited Julia fundamentals knowledge, I don't know how to extend neutral_element
without breaking compatibility. Let me try other ways of initializing the partial reduction array...
from metal.jl.
I tried writing a reduction kernel which only supports 1d arrays, and it's about 4x as fast as the current implementation. I'll try to see if the generic implementation can be further improved.
from metal.jl.
Reductions are generally faster now, however in-place is still very slow:
julia> @btime sum($a)
760.000 μs (0 allocations: 0 bytes)
5.001241f6
julia> @btime sum($Ma)
708.083 μs (1197 allocations: 27.76 KiB)
5.001241f6
julia> @btime Metal.@sync sum!($r, $Ma)
376.325 ms (101199 allocations: 2.00 MiB)
1-element MtlVector{Float32}:
5.001241f6
from metal.jl.
Similar results on Ventura as well, so that's not the cause.
from metal.jl.
On my computer:
julia> a = fill(Float32(1.0), 10*1024*1024);
julia> da = MtlArray(a);
julia> @btime sum(a)
844.500 μs (1 allocation: 16 bytes)
1.048576f7
julia> @btime sum(da)
2.707 ms (857 allocations: 23.66 KiB)
1.048576f7
Now, if we do this:
diff --git a/src/mapreduce.jl b/src/mapreduce.jl
index 1d84d78..900f21d 100644
--- a/src/mapreduce.jl
+++ b/src/mapreduce.jl
@@ -123,7 +123,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, Rreduce, Rother, s
ireduce += localDim_reduce * groupDim_reduce
end
- val = reduce_group(op, val, neutral, shuffle, maxthreads)
+ val = 1 # reduce_group(op, val, neutral, shuffle, maxthreads)
# write back to memory
if localIdx_reduce == 1
It still takes 2ms to simply loop over the input/output arrays!
julia> @btime sum(da)
2.015 ms (857 allocations: 23.66 KiB)
1.0f0
My guess is that the slowdown is from all the indexing calculations (same as #41). But it's even harder to eliminate the cartesian indexing because the reduction process itself can add additional dimensions...
from metal.jl.
Good article: https://betterprogramming.pub/optimizing-parallel-reduction-in-metal-for-apple-m1-8e8677b49b01
from metal.jl.
It would be good to write a similar blog using Metal.jl
from metal.jl.
Another good source: https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/reduce.metal + https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/reduce.cpp
from metal.jl.
Related Issues (20)
- tag new version HOT 1
- Panic during profiling tests on 14.4 beta HOT 5
- M3 backend cannot handle atomics with complicated pointer conversions HOT 3
- Int128 does not compile HOT 4
- Two suspicious `mtl`-related behaviours HOT 6
- Add Support for BFloat16 HOT 3
- LU factorization: add allowsingular keyword argument HOT 1
- Autorelease changes lead to use after free with errors
- Shader validator error with linear broadcast kernel HOT 3
- Support for Paravirtualized Graphics for Github Actions CI HOT 4
- Reductions don't work on Shared Arrays HOT 1
- Port the opportunistic synchronization from CUDA.jl HOT 1
- Register v1.1.0 HOT 4
- Tests sporadically timing out on 1.11 HOT 9
- ReshapedArray indexing broken because of Int128 operation HOT 11
- KernelAbstractions copyto! typo
- Segmentation Faults HOT 11
- Port `accmulate!` and `findall` from CUDA.jl HOT 4
- `MTL.append_copy!` silently ignores Metal documentation restriction HOT 1
- Tests failing with `GPUCompiler` v0.26.5 and `LLVM` v7.1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from metal.jl.