Comments (13)
That's not the intended use for Flux.train!
. This function is meant to iterate over an entire epoch, not a single batch. Try writing your loop as
function train_loop(model, optimizer, train_loader, test_loader; epochs=5)
for epoch ∈ 1:epochs
iter = tqdm(train_loader)
total = 0
corrects = 0
for (X, Y) ∈ iter
grads = Flux.gradient(model) do m
predicted = m(X)
ignore() do
b_size = size(X)[end]
corrects += sum(onecold(predicted, 0:9) .== onecold(Y, 0:9)) # edit, labels is Y
total += b_size
end
logitcrossentropy(predicted, Y)
end
optimizer, model = Flux.Optimise.update!(optimizer, model, grads[1]) # edit, fixed [0]
set_postfix(iter, accuracy=corrects / total)
end
val_accuracy = accuracy(model, test_loader)
@info "Epoch $epoch/5 | Accuracy : $val_accuracy"
end
end
from flux.jl.
That's not the intended use for
Flux.train!
. This function is meant to iterate over an entire epoch, not a single batch. Try writing your loop asfunction train_loop(model, optimizer, train_loader, test_loader; epochs=5) for epoch ∈ 1:epochs iter = tqdm(train_loader) total = 0 corrects = 0 for (X, Y) ∈ iter grads = Flux.gradient(model) do m predicted = m(X) ignore() do b_size = size(features)[end] corrects += sum(onecold(predicted, 0:9) .== onecold(labels, 0:9)) total += b_size end logitcrossentropy(predicted, labels) end optimizer, model = Flux.Optimise.update!(optimizer, model, grads[0]) set_postfix(iter, accuracy=corrects / total) end val_accuracy = accuracy(model, test_loader) @info "Epoch $epoch/5 | Accuracy : $val_accuracy" end end
i did that already, same speed, even a little slower
from flux.jl.
My guess is that this NNlib's CPU implementations of Conv etc. being sub-optimal. That's the target of e.g. FluxML/NNlib.jl#540, and seeing whether that PR speeds up this example might be helpful. (And if it does, finding a way to push that PR forwards).
Otherwise, isolating exactly which operations are slower would be more helpful than overall times. Xref earlier issue about the same thing #2300
from flux.jl.
will there be any updates?
from flux.jl.
Have you seen the linked PR at FluxML/NNlib.jl#540? Other than contributing performance improvements to NNlib itself, best thing would be to do some benchmarking of what the bottlenecks in the Julia code are with a profiler. Ideally you could narrow it down to 1-2 types of layers which could be compared directly against their equivalents in PyTorch.
from flux.jl.
whatever it is, it's related to backward path, feed forward path is in flux is already faster than pytorch, or same speed at least
from flux.jl.
That's why I asked to narrow it down. If you can find which specific layers are slower on the backwards path and provide a MWE demonstrating that, then we have something to work with.
from flux.jl.
MWE
here are CPU tests
i've not tested with GPU
FeedForward Flux:
using Flux
using BenchmarkTools
m = Conv((3, 3), 1 => 16; stride=(2, 2), pad=1)
A = Float32.(randn(28, 28, 1, 100))
# compile for the first time
m(A)
@btime m(A)
753.000 μs (76 allocations: 2.44 MiB)
FeedForward Pytorch:
import torch
import torch.nn as nn
m = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1)
A = torch.randn((100, 1, 28, 28))
%timeit m(A)
172 µs ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
from flux.jl.
MWE
here are CPU tests i've not tested with GPU
FeedForward Flux:
using Flux using BenchmarkTools m = Conv((3, 3), 1 => 16; stride=(2, 2), pad=1) A = Float32.(randn(28, 28, 1, 100)) # compile for the first time m(A) @btime m(A)753.000 μs (76 allocations: 2.44 MiB)
FeedForward Pytorch:
import torch import torch.nn as nn m = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1) A = torch.randn((100, 1, 28, 28)) %timeit m(A)172 µs ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Flux is Significantly slower(almost 6 times) than Pytorch on CPU!!!
from flux.jl.
MWE
here are CPU tests i've not tested with GPU
FeedForward Flux:using Flux using BenchmarkTools m = Conv((3, 3), 1 => 16; stride=(2, 2), pad=1) A = Float32.(randn(28, 28, 1, 100)) # compile for the first time m(A) @btime m(A)753.000 μs (76 allocations: 2.44 MiB)
FeedForward Pytorch:import torch import torch.nn as nn m = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1) A = torch.randn((100, 1, 28, 28)) %timeit m(A)172 µs ± 7.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Flux is Significantly slower(almost 6 times) than Pytorch on CPU!!! Same Approch for Dense Layer Pytorch is 1.7 times Faster than Flux, Also RNNs in Flux are Significantly Slower just like the CNN than Pytorch(6 times slower) as we need to Loop Over Sequences
from flux.jl.
@aminaqi that's a different issue, namely FluxML/NNlib.jl#234. As mentioned in that issue and the linked Discourse discussion, make sure you're starting Julia with multiple threads and using MKL for a proper apples-to-apples comparison with PyTorch.
For this issue, it's not clear where the exact slowdown(s) come from. What I'm sure of is that it can't be solely the conv forward pass, which is what you're benchmarking.
PS. it looks like the formatting on your comments got messed up? Every one quotes the entirety of the one before it and it probably shouldn't.
from flux.jl.
i've started julia with 6 threads, anyway even if i start julia with multi threads, it's still significantly slower than pytorch because that's only feedforward, we have a slowdown on backward too, which makes flux to be 10 times slower than pytorch or tensorflow
also not only Conv, but RNNS also
from flux.jl.
Are you seeing Julia be 10x slower on the forward and backwards pass, for CNNs and RNNs, against PyTorch and TensorFlow? I'm pretty sure we are slower on all of those, but 10x for all of them would not be expected. If that's really what you're seeing, I'd recommend starting a Discourse thread with some MWEs for the various benchmarks and linking back to that here. It's possible that Flux itself is only a small part of the issue there, and Discourse will allow more folks to weigh in on what other parts of your code may be contributing (only Flux maintainers really follow this issue tracker).
Either way, the performance gap being discussed in this issue already has a reasonable benchmark. It just needs to be narrowed down to a couple of layers and/or profiled so we can see what the bottlenecks are to take action on them. If nobody has bandwidth to do that, then I'm not sure there's much else to discuss here.
from flux.jl.
Related Issues (20)
- deprecate Flux.params HOT 7
- Significant time spent moving medium-size arrays to GPU, type instability HOT 10
- ConvTranspose errors with symmetric non-constant pad
- SamePad() for even sized filters.
- Dense layers with shared parameters HOT 5
- Implementation of `AdamW` differs from PyTorch HOT 10
- `gpu` should warn if cuDNN is not installed HOT 2
- Cannot take `gradient` of L2 regularization loss HOT 1
- Create a flag to use Enzyme as the AD in training/etc. HOT 14
- test Enzyme gradient for loss functions
- test Enzyme gpu support
- Enzyme fails with MultiHeadAttention layer HOT 13
- Enable github Discussions
- Stacked RNN in Flux.jl?
- Add option to throw error on passing wrong precision floats to layers HOT 3
- Potential bug of RNN training flow
- why is my `withgradient` type unstable ? HOT 1
- is `Flux.huber_loss` type-unstable ?
- Can't load a Fluxml trained & saved model. Getting ERROR: CUDA error: invalid device context (code 201, ERROR_INVALID_CONTEXT) HOT 1
- ConvTranspose with padding on cpu throws exception HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flux.jl.