Comments (14)
I think the basic interface needed is a nice gradient function.
Enzyme's own gradient
should now do this, as make_zero
understands nested structures:
julia> sh = [1f0, 2f0]; nt = (a=sh, b=sh, c=copy(sh));
julia> Enzyme.gradient(Reverse, x -> sum(map(sum, x)), nt)
(a = Float32[2.0, 2.0], b = Float32[2.0, 2.0], c = Float32[1.0, 1.0])
(jl_o1ZBlk) pkg> st Enzyme
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_o1ZBlk/Project.toml`
[7da242da] Enzyme v0.12.4
The above example doesn't work for me, but I believe function gradient_ez(f, x...)
can be deleted to have just this:
for epoch in 1:epochs
g = Enzyme.gradient(Reverse, m -> loss(m, X, y), model) # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
A slight problem in make_zero! is that it sets to zero the arrays but not the scalar field, so those are going to be accumulated. That can be fixed later and in principle it is not even a problem since scalars are not updated bu the optimizer.
Right. For those coming from Zygote, it's slightly odd that the gradient contains numbers for non-diff things. But I believe Optimisers.jl's idea of what parameters can be updated is narrow enough that it will only use true gradient numbers from Enzyme.jl.
from flux.jl.
You'll need JuliaGPU/CUDA.jl#2371 and then JuliaPackaging/Yggdrasil#8666. It then hits a cublasscal issue, which I stopped investigating to go get dinner.
from flux.jl.
The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it.
@darsnack I'd actually love to revisit the dream of DI + Flux one of these days.
- For multiple inputs, I think I see a way to support additional constant inputs without too much pain (gdalle/DifferentiationInterface.jl#311). Apparently it's what you need for e.g.
X
andy
in training. - For array-only, the trouble is not supporting general structs, it's testing them. We've had this discussion together, and I don't want to commit to something that a) doesn't work for every backend and b) will probably be undertested because arbtrary structs can be, well, arbitrary. In my view, non-arrays cannot be in the DI API because there will be plenty of cases that fail, and it's very hard to say which ones ahead of time.
To me the best option would be a Flux.gradient (and Flux.withgradient) that uses ADTypes.jl (only to avoid further fragmentation). Alternatively, a small package that wraps Enzyme.autodiff + make_zero in a Zygote-like interface (similar to what's above).
Why not create a package named DifferentiationInterfaceForFlux or something, which relies on DI but tests compatibility with Flux layers and makes it part of its API? In other words, if I change something in DI that removes compatibility with Flux layers, the glue package could still be frozen to its current version until it gets resolved.
from flux.jl.
from flux.jl.
I think the basic interface needed is a nice gradient
function.
This code is still not working though, on both cpu and cuda gpu:
using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu # CPU training
# device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
for epoch in 1:epochs
g = gradient_ez(model -> loss(model, X, y), model)[1] # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
We should add tests for the loss functions. This one is failing:
gradient_ez(ŷ -> Flux.logitcrossentropy(ŷ, y), randn(Float32, num_classes, batch_size))
from flux.jl.
A modification to your code above which will be more performant/stable/etc (closures are bad).
In any case still has the same issue and will investigate
# using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero!(x::AbstractArray) = x .= 0
_make_zero!(x) = x
make_zero!(model) = fmap(_make_zero!, model)
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu # CPU training
# device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
g = deepcopy(model)
for epoch in 1:epochs
make_zero!(g)
Enzyme.autodiff(Reverse, loss, Duplicated(model, g), Const(X), Const(y))
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
from flux.jl.
Yeah this works now with the NNlib type stability fix FluxML/NNlib.jl#584
from flux.jl.
The previous "interface" was to import the corresponding AD package and just call e.g. Tracker.withgradient
.
The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it.
To me the best option would be a Flux.gradient
(and Flux.withgradient
) that uses ADTypes.jl (only to avoid further fragmentation). Alternatively, a small package that wraps Enzyme.autodiff
+ make_zero
in a Zygote-like interface (similar to what's above).
But I suggest a dedicated doc page on using Enzyme + Flux will be easier to get through quickly.
from flux.jl.
Sure, I think docs would be a great first start. I don't really know how to use Flux or where that would go best, so I'll leave that to you.
At the same time, if we're already doing API design, for training it would be nice to not have to constantly reallocate the gradient buffer (with make_zero). I don't know if there's an in-place zeroing function you have for models, but that would be highly beneficial here.
from flux.jl.
it would be nice to not have to constantly reallocate the gradient buffer
I edited the code in your post to zero the gradient in-place. A slight problem in make_zero!
is that it sets to zero the arrays but not the scalar field, so those are going to be accumulated. That can be fixed later and in principle it is not even a problem since scalars are not updated bu the optimizer.
from flux.jl.
On gpu I get the following error
error
┌ Warning: active variables passed by value to jl_new_task are not yet supported └ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59 ERROR: Enzyme compilation failed due to illegal type analysis. Current scope: ; Function Attrs: mustprogress willreturn define internal fastcc void @preprocess_julia_fill__33038({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,0]:Pointer, [-1,0,0,0,0]:Pointer, [-1,0,0,0,8]:Integer, [-1,0,0,0,16]:Pointer, [-1,0,0,16]:Integer, [-1,0,0,17]:Integer, [-1,0,0,18]:Integer, [-1,0,0,19]:Integer, [-1,0,0,20]:Integer, [-1,0,0,21]:Integer, [-1,0,0,22]:Integer, [-1,0,0,23]:Integer, [-1,0,0,24]:Integer, [-1,0,0,32]:Pointer, [-1,0,0,40]:Pointer, [-1,0,0,40,-1]:Integer, [-1,0,8]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="139959628162192" "enzymejl_parmtype_ref"="2" %0, float "enzyme_type"="{[-1]:Float@float}" "enzymejl_parmtype"="139978039813152" "enzymejl_parmtype_ref"="0" %1) unnamed_addr #657 !dbg !47671 { top: %2 = call {}*** @julia.get_pgcstack() %3 = call {}*** @julia.get_pgcstack() %4 = bitcast {}*** %2 to {}** %5 = getelementptr inbounds {}*, {}** %4, i64 -14 %6 = getelementptr inbounds {}*, {}** %5, i64 16 %7 = bitcast {}** %6 to i8** %8 = load i8*, i8** %7, align 8 %9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %5, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.457({} addrspace(10)* %9, i8 0, i64 8), !enzyme_zerostack !590 %phic1 = bitcast {} addrspace(10)* %9 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %10 = bitcast {}*** %3 to {}** %11 = getelementptr inbounds {}*, {}** %10, i64 -14 %12 = getelementptr inbounds {}*, {}** %11, i64 16 %13 = bitcast {}** %12 to i8** %14 = load i8*, i8** %13, align 8 %15 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %11, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.456({} addrspace(10)* %15, i8 0, i64 8), !enzyme_zerostack !590 %phic = bitcast {} addrspace(10)* %15 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %phic19 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !4822 %16 = call {}*** @julia.get_pgcstack() #658 store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic1, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* null) store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %15, {} addrspace(10)* null) %current_task329 = getelementptr inbounds {}**, {}*** %16, i64 -14 %current_task3 = bitcast {}*** %current_task329 to {}** %ptls_field30 = getelementptr inbounds {}**, {}*** %16, i64 2 %17 = bitcast {}*** %ptls_field30 to i64*** %ptls_load3132 = load i64**, i64*** %17, align 8, !tbaa !591 %18 = getelementptr inbounds i64*, i64** %ptls_load3132, i64 2 %safepoint = load i64*, i64** %18, align 8, !tbaa !595 fence syncscope("singlethread") seq_cst call void @julia.safepoint(i64* %safepoint) #658, !dbg !47675 fence syncscope("singlethread") seq_cst %bitcast_coercion = bitcast float %1 to i32, !dbg !47676 %19 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !47678 %getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %19 unordered, align 8, !dbg !47678, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !614, !align !615 %20 = addrspacecast {} addrspace(10)* %getfield to i8 addrspace(11)*, !dbg !47681 %21 = getelementptr inbounds i8, i8 addrspace(11)* %20, i64 8, !dbg !47681 %22 = load i8, i8 addrspace(11)* %21, align 8, !dbg !47681, !tbaa !602, !alias.scope !606, !noalias !609 %23 = and i8 %22, 1, !dbg !47681 %.not = icmp eq i8 %23, 0, !dbg !47681 br i1 %.not, label %L8, label %L5, !dbg !47682L5: ; preds = %top
%24 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10) nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165787312 to {}) to {} addrspace(10))) #659, !dbg !47683
%box = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !47683
%25 = bitcast {} addrspace(10)* %box to [1 x {} addrspace(10)] addrspace(10), !dbg !47683
%26 = extractvalue [1 x {} addrspace(10)] %24, 0, !dbg !47683
%27 = getelementptr [1 x {} addrspace(10)], [1 x {} addrspace(10)] addrspace(10) %25, i64 0, i64 0, !dbg !47683
store {} addrspace(10)* %26, {} addrspace(10)* addrspace(10)* %27, align 8, !dbg !47683, !tbaa !621, !alias.scope !606, !noalias !47684
%28 = addrspacecast {} addrspace(10)* %box to {} addrspace(12), !dbg !47683
call void @ijl_throw({} addrspace(12) %28) #661, !dbg !47683
unreachable, !dbg !47683
L8: ; preds = %top
%29 = addrspacecast {} addrspace(10)* %getfield to {} addrspace(10)* addrspace(11), !dbg !47685
%getfield6 = load atomic {} addrspace(10), {} addrspace(10)* addrspace(11)* %29 unordered, align 8, !dbg !47685, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !628, !align !615
%30 = addrspacecast {} addrspace(10)* %getfield6 to i8 addrspace(11), !dbg !47687
%getfield_addr7 = getelementptr inbounds i8, i8 addrspace(11) %30, i64 40, !dbg !47687
%31 = bitcast i8 addrspace(11)* %getfield_addr7 to {} addrspace(10)* addrspace(11), !dbg !47687
%getfield8 = load atomic {} addrspace(10), {} addrspace(10)* addrspace(11)* %31 unordered, align 8, !dbg !47687, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !615, !align !615
%32 = call token (...) @llvm.julia.gc_preserve_begin({} addrspace(10)* nonnull %getfield8) #658, !dbg !47689
%33 = addrspacecast {} addrspace(10)* %getfield8 to {} addrspace(11), !dbg !47690
%34 = call nonnull {} @julia.pointer_from_objref({} addrspace(11)* noundef %33) #662, !dbg !47690
%ptr.i = bitcast {}* %34 to i64*, !dbg !47689
%rv.i = load atomic i64, i64* %ptr.i acquire, align 16, !dbg !47689
call void @llvm.julia.gc_preserve_end(token %32) #658, !dbg !47689
%.not33 = icmp eq i64 %rv.i, 0, !dbg !47692
br i1 %.not33, label %L17, label %L20, !dbg !47688
L17: ; preds = %L8
%35 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10) nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165788400 to {}) to {} addrspace(10))) #658, !dbg !47693
%box11 = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !47693
%36 = bitcast {} addrspace(10)* %box11 to [1 x {} addrspace(10)] addrspace(10), !dbg !47693
%37 = extractvalue [1 x {} addrspace(10)] %35, 0, !dbg !47693
%38 = getelementptr [1 x {} addrspace(10)], [1 x {} addrspace(10)] addrspace(10) %36, i64 0, i64 0, !dbg !47693
store {} addrspace(10)* %37, {} addrspace(10)* addrspace(10)* %38, align 8, !dbg !47693, !tbaa !621, !alias.scope !606, !noalias !47684
%39 = addrspacecast {} addrspace(10)* %box11 to {} addrspace(12), !dbg !47693
call void @ijl_throw({} addrspace(12) %39) #661, !dbg !47693
unreachable, !dbg !47693
L20: ; preds = %L8
%40 = addrspacecast {} addrspace(10)* %getfield6 to { {} addrspace(10), i64, i64, i8 } addrspace(11), !dbg !47694
%41 = getelementptr inbounds { {} addrspace(10), i64, i64, i8 }, { {} addrspace(10), i64, i64, i8 } addrspace(11)* %40, i64 0, i32 0, !dbg !47694
%42 = load {} addrspace(10), {} addrspace(10) addrspace(11)* %41, align 8, !dbg !47694, !tbaa !602, !alias.scope !606, !noalias !609
%43 = addrspacecast {} addrspace(10)* %42 to i8 addrspace(11), !dbg !47696
%44 = getelementptr inbounds i8, i8 addrspace(11) %43, i64 8, !dbg !47696
%45 = load i8, i8 addrspace(11)* %44, align 8, !dbg !47696, !tbaa !602, !alias.scope !606, !noalias !609
%46 = and i8 %45, 1, !dbg !47696
%.not34 = icmp eq i8 %46, 0, !dbg !47696
br i1 %.not34, label %L73, label %L27, !dbg !47698
L27: ; preds = %L20
%47 = call fastcc nonnull align 8 {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %42) #658, !dbg !47700
store volatile {} addrspace(10)* %42, {} addrspace(10)* addrspace(10)* %phic, align 8, !dbg !47701, !noalias !47672
call void ({} addrspace(10), ...) @julia.write_barrier({} addrspace(10) %15, {} addrspace(10)* %42), !dbg !47701
store volatile {} addrspace(10)* %47, {} addrspace(10)* addrspace(10)* %phic1, align 8, !dbg !47701, !noalias !47672
call void ({} addrspace(10), ...) @julia.write_barrier({} addrspace(10) %9, {} addrspace(10)* %47), !dbg !47701
store volatile i8 0, i8* %phic19, align 1, !dbg !47701, !tbaa !774, !alias.scope !776, !noalias !47702
%48 = call i64 @ijl_excstack_state() #658, !dbg !47701
%49 = call i32 @julia.except_enter() #663, !dbg !47701
%50 = icmp eq i32 %49, 0, !dbg !47701
br i1 %50, label %try, label %L46, !dbg !47701
L46: ; preds = %L27
%phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0. = load volatile {} addrspace(10), {} addrspace(10) addrspace(10)* %phic, align 8, !dbg !47703, !nonnull !590
%phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0. = load volatile {} addrspace(10), {} addrspace(10) addrspace(10)* %phic1, align 8, !dbg !47703, !nonnull !590
%phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0. = load volatile i8, i8* %phic19, align 1, !dbg !47703
call void @ijl_pop_handler(i32 noundef 1) #658, !dbg !47703
%51 = and i8 %phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0., 1, !dbg !47703
%phi.cast = icmp ne i8 %51, 0, !dbg !47703
br label %L51, !dbg !47703
L51: ; preds = %try, %L46
%value_phi = phi {} addrspace(10)* [ %42, %try ], [ %phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0., %L46 ]
%value_phi15 = phi {} addrspace(10)* [ %47, %try ], [ %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0., %L46 ]
%value_phi17 = phi i1 [ true, %try ], [ %phi.cast, %L46 ]
%52 = addrspacecast {} addrspace(10)* %value_phi15 to {} addrspace(11), !dbg !47704
%53 = icmp eq {} addrspace(11) %52, addrspacecast ({}* inttoptr (i64 139978194116616 to {}) to {} addrspace(11)), !dbg !47704
%54 = addrspacecast {} addrspace(10)* %value_phi to {} addrspace(11)*
%55 = icmp eq {} addrspace(11)* %52, %54
%or.cond = select i1 %53, i1 true, i1 %55, !dbg !47704
br i1 %or.cond, label %L67, label %L62, !dbg !47704
L62: ; preds = %L51
%56 = addrspacecast {} addrspace(10)* %value_phi15 to i8 addrspace(11), !dbg !47705
%57 = getelementptr inbounds i8, i8 addrspace(11) %56, i64 8, !dbg !47705
%58 = load i8, i8 addrspace(11)* %57, align 8, !dbg !47705, !tbaa !846, !alias.scope !606, !noalias !609
%59 = and i8 %58, 1, !dbg !47705
%.not35 = icmp eq i8 %59, 0, !dbg !47705
br i1 %.not35, label %L67, label %L65, !dbg !47704
L65: ; preds = %L62
%60 = call fastcc nonnull {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %value_phi15) #658, !dbg !47707
br label %L67, !dbg !47707
L67: ; preds = %L65, %L62, %L51
br i1 %50, label %L71, label %L69, !dbg !47707
L69: ; preds = %L67
call fastcc void @julia_rethrow_31152() #661, !dbg !47707
unreachable, !dbg !47707
L71: ; preds = %L67
br i1 %value_phi17, label %ok, label %err, !dbg !47707
L73: ; preds = %L20
call fastcc void @julia_error_31187({} addrspace(10)* nofree noundef nonnull align 32 addrspacecast ({}* inttoptr (i64 139962719163168 to {}) to {} addrspace(10))) #661, !dbg !47708
unreachable, !dbg !47708
try: ; preds = %L27
%61 = call fastcc i64 @julia_unsafe_convert_32014({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) %0) #658, !dbg !47709
%62 = addrspacecast {} addrspace(10)* %0 to i8 addrspace(11), !dbg !47713
%63 = getelementptr inbounds i8, i8 addrspace(11) %62, i64 24, !dbg !47713
%aggregate_load_box.sroa.0.0..sroa_idx = bitcast i8 addrspace(11)* %63 to i64 addrspace(11), !dbg !47713
%aggregate_load_box.sroa.0.0.copyload = load i64, i64 addrspace(11) %aggregate_load_box.sroa.0.0..sroa_idx, align 8, !dbg !47713, !tbaa !710, !alias.scope !711, !noalias !47716
%aggregate_load_box.sroa.2.0..sroa_idx25 = getelementptr inbounds i8, i8 addrspace(11)* %62, i64 32, !dbg !47713
%64 = bitcast i8 addrspace(11)* %aggregate_load_box.sroa.2.0..sroa_idx25 to i64 addrspace(11), !dbg !47713
%aggregate_load_box.sroa.2.0.copyload = load i64, i64 addrspace(11) %64, align 8, !dbg !47713, !tbaa !710, !alias.scope !711, !noalias !47716
%65 = mul i64 %aggregate_load_box.sroa.2.0.copyload, %aggregate_load_box.sroa.0.0.copyload, !dbg !47717
call fastcc void @julia_set__33047(i64 zeroext %61, i32 zeroext %bitcast_coercion, i64 signext %65) #658, !dbg !47712
store volatile i8 1, i8* %phic19, align 1, !dbg !47703, !tbaa !774, !alias.scope !776, !noalias !47702
call void @ijl_pop_handler(i32 noundef 1) #658, !dbg !47703
br label %L51, !dbg !47703
err: ; preds = %L71
call void @ijl_undefined_var_error({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 139978194630336 to {}) to {} addrspace(12))) #661, !dbg !47707
unreachable, !dbg !47707
ok: ; preds = %L71
ret void, !dbg !47699
}
Type analysis state:
%current_task3 = bitcast {}*** %current_task329 to {}: {}, intvals: {}
%bitcast_coercion = bitcast float %1 to i32, !dbg !603: {[-1]:Integer}, intvals: {}
{} addrspace(10)* addrspacecast ({}* inttoptr (i64 139965165787312 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
{}* inttoptr (i64 139965165787312 to {}): {[-1]:Anything}, intvals: {}
%value_phi15 = phi {} addrspace(10) [ %47, %try ], [ %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0., %L46 ]: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,16]:Pointer}, intvals: {}
%24 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10) nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165787312 to {}) to {} addrspace(10))) #659, !dbg !630: {[-1]:Pointer}, intvals: {}
%box11 = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({} nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !650: {[-1,-1]:Pointer}, intvals: {}
%phic1 = bitcast {} addrspace(10)* %9 to {} addrspace(10)* addrspace(10), !enzyme_caststack !590: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
%13 = bitcast {}** %12 to i8**: {[-1]:Pointer}, intvals: {}
%17 = bitcast {}** %ptls_field30 to i64***: {[-1]:Pointer}, intvals: {}
%ptls_load3132 = load i64**, i64*** %17, align 8, !tbaa !596: {}, intvals: {}
%15 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %11, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}) to {} addrspace(10))), !enzyme_fromstack !591: {}, intvals: {}
%60 = call fastcc nonnull {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %value_phi15) #658, !dbg !673: {}, intvals: {}
%11 = getelementptr inbounds {}, {}** %10, i64 -14: {}, intvals: {}
{} inttoptr (i64 139962719163168 to {}): {[-1]:Anything}, intvals: {}
{} addrspace(10) addrspacecast ({}* inttoptr (i64 139962719163168 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
i64 8: {[-1]:Integer}, intvals: {8,}
%9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %5, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}) to {} addrspace(10))), !enzyme_fromstack !591: {}, intvals: {}
{}* inttoptr (i64 139961738084176 to {}): {[-1]:Anything}, intvals: {}
%4 = bitcast {}** %2 to {}: {}, intvals: {}
%61 = call fastcc i64 @julia_unsafe_convert_32014({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) %0) #658, !dbg !675: {}, intvals: {}
%2 = call {}* @julia.get_pgcstack(): {}, intvals: {}
%5 = getelementptr inbounds {}, {}** %4, i64 -14: {}, intvals: {}
%6 = getelementptr inbounds {}, {}** %5, i64 16: {}, intvals: {}
%12 = getelementptr inbounds {}, {}** %11, i64 16: {}, intvals: {}
%14 = load i8, i8** %13, align 8: {}, intvals: {}
%box = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !630: {[-1,-1]:Pointer}, intvals: {}
%safepoint = load i64*, i64** %18, align 8, !tbaa !600: {}, intvals: {}
{} addrspace(10)* addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
{}* inttoptr (i64 139978038671616 to {}): {[-1]:Anything}, intvals: {}
%35 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10)* nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165788400 to {}) to {} addrspace(10))) #658, !dbg !650: {[-1]:Pointer}, intvals: {}
%phic19 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !592: {[-1]:Pointer}, intvals: {}
%7 = bitcast {}** %6 to i8**: {[-1]:Pointer}, intvals: {}
%8 = load i8*, i8** %7, align 8: {}, intvals: {}
%16 = call {}*** @julia.get_pgcstack() #658: {}, intvals: {}
%phic = bitcast {} addrspace(10)* %15 to {} addrspace(10)* addrspace(10), !enzyme_caststack !590: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
%3 = call {}** @julia.get_pgcstack(): {}, intvals: {}
%42 = load {} addrspace(10), {} addrspace(10) addrspace(11)* %41, align 8, !dbg !651, !tbaa !613, !alias.scope !617, !noalias !620: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,16]:Pointer}, intvals: {}
{} addrspace(10)* addrspacecast ({}* inttoptr (i64 139965165788400 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
{}* inttoptr (i64 139965165788400 to {}): {[-1]:Anything}, intvals: {}
%18 = getelementptr inbounds i64, i64** %ptls_load3132, i64 2: {[-1]:Pointer}, intvals: {}
%ptls_field30 = getelementptr inbounds {}, {}* %16, i64 2: {}, intvals: {}
{} addrspace(10)* null: {[-1]:Pointer, [-1,-1]:Anything}, intvals: {0,}
%65 = mul i64 %aggregate_load_box.sroa.2.0.copyload, %aggregate_load_box.sroa.0.0.copyload, !dbg !691: {[-1]:Integer}, intvals: {}
%10 = bitcast {}*** %3 to {}: {}, intvals: {}
{} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
{} addrspace(10)* %0: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,0]:Pointer, [-1,0,0,0,0]:Pointer, [-1,0,0,0,8]:Integer, [-1,0,0,0,16]:Pointer, [-1,0,0,16]:Integer, [-1,0,0,17]:Integer, [-1,0,0,18]:Integer, [-1,0,0,19]:Integer, [-1,0,0,20]:Integer, [-1,0,0,21]:Integer, [-1,0,0,22]:Integer, [-1,0,0,23]:Integer, [-1,0,0,24]:Integer, [-1,0,0,32]:Pointer, [-1,0,0,40]:Pointer, [-1,0,0,40,-1]:Integer, [-1,0,8]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}, intvals: {}
float %1: {[-1]:Float@float}, intvals: {}
%47 = call fastcc nonnull align 8 {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %42) #658, !dbg !662: {}, intvals: {}
%current_task329 = getelementptr inbounds {}, {}*** %16, i64 -14: {}, intvals: {}
Illegal updateAnalysis prev:{[-1]:Integer} new: {[-1]:Float@float}
val: %bitcast_coercion = bitcast float %1 to i32, !dbg !603 origin= %bitcast_coercion = bitcast float %1 to i32, !dbg !603
MethodInstance for fill!(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)
Caused by:
Stacktrace:
[1] reinterpret
@ ./essentials.jl:581
[2] fill!
@ ~/.julia/packages/CUDA/jdJ7Z/src/array.jl:829
Stacktrace:
[1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:1690
[2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/2FwRI/src/api.jl:154
[3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:3177
[4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5070
[5] codegen
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:4477 [inlined]
[6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755
[7] _thunk
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755 [inlined]
[8] cached_compilation
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5793 [inlined]
[9] (::Enzyme.Compiler.var"#554#555"{…})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5859
[10] JuliaContext(f::Enzyme.Compiler.var"#554#555"{…}; kwargs::@kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
[11] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
[12] #s2027#553
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5811 [inlined]
[13]
@ Enzyme.Compiler ./none:0
[14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[15] autodiff
@ ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:286 [inlined]</
[16] autodiff
@ ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:315 [inlined]
[17] autodiff(::ReverseMode{…}, ::typeof(loss), ::Duplicated{…}, ::Const{…}, ::Const{…})
@ Enzyme ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:300
[18] top-level scope
@ ~/juliadev/Flux/mlp.jl:37
Some type information was truncated. Use show(err)
to see complete types.
from flux.jl.
This should be resolved by #2446
Like I say in that PR
"""
I have no opinions on the design/API and I will give this PR to you all to make it however you feel (and I will go back to staring at CUDA).
I will note that perf atm is unclear and is worth investigating. However, before we do that, having a good way to run/test things is critical, hence this PR.
"""
from flux.jl.
edit: accidentally reran cpu, please ignore below.
CUDA works on the simple example now. It does require either CUDA#master on already merged branches or hopefully a backport release from CUDA.jl via JuliaGPU/CUDA.jl#2375 as well as a Enzyme_jll bump
wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ cat orig.jl
using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
# device = Flux.cpu # CPU training
device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
for epoch in 1:epochs
g = gradient_ez(model -> loss(model, X, y), model)[1] # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ ~/git/Enzyme.jl/julia-1.10.2/bin/julia --project orig.jl
┌ Warning: Package cuDNN not found in current path.
│ - Run `import Pkg; Pkg.add("cuDNN")` to install the cuDNN package, then restart julia.
│ - If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU.
└ @ FluxCUDAExt ~/git/Flux.jl/ext/FluxCUDAExt/FluxCUDAExt.jl:57
┌ Info: Epoch: 0
│ loss = 2.7904227f0
└ accuracy = 0.125
┌ Info: Epoch: 1
│ loss = 2.5142982f0
└ accuracy = 0.15625
┌ Info: Epoch: 2
│ loss = 2.2610319f0
└ accuracy = 0.203125
┌ Info: Epoch: 3
│ loss = 2.029134f0
└ accuracy = 0.28125
┌ Info: Epoch: 4
│ loss = 1.8172197f0
└ accuracy = 0.3515625
┌ Info: Epoch: 5
│ loss = 1.6268556f0
└ accuracy = 0.4375
┌ Info: Epoch: 6
│ loss = 1.4554112f0
└ accuracy = 0.546875
┌ Info: Epoch: 7
│ loss = 1.3014916f0
└ accuracy = 0.6640625
┌ Info: Epoch: 8
│ loss = 1.163165f0
└ accuracy = 0.7890625
┌ Info: Epoch: 9
│ loss = 1.0413302f0
└ accuracy = 0.8515625
┌ Info: Epoch: 10
│ loss = 0.93555194f0
└ accuracy = 0.8515625
┌ Info: Epoch: 11
│ loss = 0.84206563f0
└ accuracy = 0.8828125
┌ Info: Epoch: 12
│ loss = 0.7600569f0
└ accuracy = 0.90625
┌ Info: Epoch: 13
│ loss = 0.6874082f0
└ accuracy = 0.921875
┌ Info: Epoch: 14
│ loss = 0.6230737f0
└ accuracy = 0.9296875
┌ Info: Epoch: 15
│ loss = 0.5663827f0
└ accuracy = 0.9609375
┌ Info: Epoch: 16
│ loss = 0.5165455f0
└ accuracy = 0.96875
┌ Info: Epoch: 17
│ loss = 0.4719535f0
└ accuracy = 0.96875
┌ Info: Epoch: 18
│ loss = 0.4319139f0
└ accuracy = 0.9765625
┌ Info: Epoch: 19
│ loss = 0.39577293f0
└ accuracy = 0.984375
┌ Info: Epoch: 20
│ loss = 0.36347917f0
└ accuracy = 0.984375
┌ Info: Epoch: 21
│ loss = 0.33449084f0
└ accuracy = 0.9921875
┌ Info: Epoch: 22
│ loss = 0.30846184f0
└ accuracy = 0.9921875
┌ Info: Epoch: 23
│ loss = 0.28476223f0
└ accuracy = 0.9921875
┌ Info: Epoch: 24
│ loss = 0.26318714f0
└ accuracy = 1.0
┌ Info: Epoch: 25
│ loss = 0.24353352f0
└ accuracy = 1.0
┌ Info: Epoch: 26
│ loss = 0.22557218f0
└ accuracy = 1.0
┌ Info: Epoch: 27
│ loss = 0.20921068f0
└ accuracy = 1.0
┌ Info: Epoch: 28
│ loss = 0.19429381f0
└ accuracy = 1.0
┌ Info: Epoch: 29
│ loss = 0.18054952f0
└ accuracy = 1.0
┌ Info: Epoch: 30
│ loss = 0.16796987f0
└ accuracy = 1.0
┌ Info: Epoch: 31
│ loss = 0.1563463f0
└ accuracy = 1.0
┌ Info: Epoch: 32
│ loss = 0.14567412f0
└ accuracy = 1.0
┌ Info: Epoch: 33
│ loss = 0.13588753f0
└ accuracy = 1.0
┌ Info: Epoch: 34
│ loss = 0.12687433f0
└ accuracy = 1.0
┌ Info: Epoch: 35
│ loss = 0.11857266f0
└ accuracy = 1.0
┌ Info: Epoch: 36
│ loss = 0.11093213f0
└ accuracy = 1.0
┌ Info: Epoch: 37
│ loss = 0.103871785f0
└ accuracy = 1.0
┌ Info: Epoch: 38
│ loss = 0.09736837f0
└ accuracy = 1.0
┌ Info: Epoch: 39
│ loss = 0.09138645f0
└ accuracy = 1.0
┌ Info: Epoch: 40
│ loss = 0.08586908f0
└ accuracy = 1.0
┌ Info: Epoch: 41
│ loss = 0.080786735f0
└ accuracy = 1.0
┌ Info: Epoch: 42
│ loss = 0.07610354f0
└ accuracy = 1.0
┌ Info: Epoch: 43
│ loss = 0.07179588f0
└ accuracy = 1.0
┌ Info: Epoch: 44
│ loss = 0.06783663f0
└ accuracy = 1.0
┌ Info: Epoch: 45
│ loss = 0.06419177f0
└ accuracy = 1.0
┌ Info: Epoch: 46
│ loss = 0.060845155f0
└ accuracy = 1.0
┌ Info: Epoch: 47
│ loss = 0.057761367f0
└ accuracy = 1.0
┌ Info: Epoch: 48
│ loss = 0.0549154f0
└ accuracy = 1.0
┌ Info: Epoch: 49
│ loss = 0.05228231f0
└ accuracy = 1.0
┌ Info: Epoch: 50
│ loss = 0.049845647f0
└ accuracy = 1.0
┌ Info: Epoch: 51
│ loss = 0.047589153f0
└ accuracy = 1.0
┌ Info: Epoch: 52
│ loss = 0.045498513f0
└ accuracy = 1.0
┌ Info: Epoch: 53
│ loss = 0.04355742f0
└ accuracy = 1.0
┌ Info: Epoch: 54
│ loss = 0.04175187f0
└ accuracy = 1.0
┌ Info: Epoch: 55
│ loss = 0.04007356f0
└ accuracy = 1.0
┌ Info: Epoch: 56
│ loss = 0.038507923f0
└ accuracy = 1.0
┌ Info: Epoch: 57
│ loss = 0.037045095f0
└ accuracy = 1.0
┌ Info: Epoch: 58
│ loss = 0.035674226f0
└ accuracy = 1.0
┌ Info: Epoch: 59
│ loss = 0.034392048f0
└ accuracy = 1.0
┌ Info: Epoch: 60
│ loss = 0.033194654f0
└ accuracy = 1.0
┌ Info: Epoch: 61
│ loss = 0.032058075f0
└ accuracy = 1.0
┌ Info: Epoch: 62
│ loss = 0.030996136f0
└ accuracy = 1.0
┌ Info: Epoch: 63
│ loss = 0.02999451f0
└ accuracy = 1.0
┌ Info: Epoch: 64
│ loss = 0.029050402f0
└ accuracy = 1.0
┌ Info: Epoch: 65
│ loss = 0.02815985f0
└ accuracy = 1.0
┌ Info: Epoch: 66
│ loss = 0.027319008f0
└ accuracy = 1.0
┌ Info: Epoch: 67
│ loss = 0.02652272f0
└ accuracy = 1.0
┌ Info: Epoch: 68
│ loss = 0.025767544f0
└ accuracy = 1.0
┌ Info: Epoch: 69
│ loss = 0.025051065f0
└ accuracy = 1.0
┌ Info: Epoch: 70
│ loss = 0.024369944f0
└ accuracy = 1.0
┌ Info: Epoch: 71
│ loss = 0.023721226f0
└ accuracy = 1.0
┌ Info: Epoch: 72
│ loss = 0.023103705f0
└ accuracy = 1.0
┌ Info: Epoch: 73
│ loss = 0.022514593f0
└ accuracy = 1.0
┌ Info: Epoch: 74
│ loss = 0.021952922f0
└ accuracy = 1.0
┌ Info: Epoch: 75
│ loss = 0.021417053f0
└ accuracy = 1.0
┌ Info: Epoch: 76
│ loss = 0.020906389f0
└ accuracy = 1.0
┌ Info: Epoch: 77
│ loss = 0.0204159f0
└ accuracy = 1.0
┌ Info: Epoch: 78
│ loss = 0.01994732f0
└ accuracy = 1.0
┌ Info: Epoch: 79
│ loss = 0.01949887f0
└ accuracy = 1.0
┌ Info: Epoch: 80
│ loss = 0.01906871f0
└ accuracy = 1.0
┌ Info: Epoch: 81
│ loss = 0.018656129f0
└ accuracy = 1.0
┌ Info: Epoch: 82
│ loss = 0.018260362f0
└ accuracy = 1.0
┌ Info: Epoch: 83
│ loss = 0.017879806f0
└ accuracy = 1.0
┌ Info: Epoch: 84
│ loss = 0.017513612f0
└ accuracy = 1.0
┌ Info: Epoch: 85
│ loss = 0.017161498f0
└ accuracy = 1.0
┌ Info: Epoch: 86
│ loss = 0.01682241f0
└ accuracy = 1.0
┌ Info: Epoch: 87
│ loss = 0.016495718f0
└ accuracy = 1.0
┌ Info: Epoch: 88
│ loss = 0.016181245f0
└ accuracy = 1.0
┌ Info: Epoch: 89
│ loss = 0.015877243f0
└ accuracy = 1.0
┌ Info: Epoch: 90
│ loss = 0.0155781405f0
└ accuracy = 1.0
┌ Info: Epoch: 91
│ loss = 0.01528422f0
└ accuracy = 1.0
┌ Info: Epoch: 92
│ loss = 0.014997441f0
└ accuracy = 1.0
┌ Info: Epoch: 93
│ loss = 0.014718127f0
└ accuracy = 1.0
┌ Info: Epoch: 94
│ loss = 0.014446221f0
└ accuracy = 1.0
┌ Info: Epoch: 95
│ loss = 0.014181806f0
└ accuracy = 1.0
┌ Info: Epoch: 96
│ loss = 0.013925277f0
└ accuracy = 1.0
┌ Info: Epoch: 97
│ loss = 0.013677116f0
└ accuracy = 1.0
┌ Info: Epoch: 98
│ loss = 0.013437184f0
└ accuracy = 1.0
┌ Info: Epoch: 99
│ loss = 0.013204632f0
└ accuracy = 1.0
┌ Info: Epoch: 100
│ loss = 0.012979296f0
└ accuracy = 1.0
from flux.jl.
@CarloLucibello this gradient_ez
is very useful. Thanks! Would it be possible to have also option to run Enzyme from Zygote? Or an example similar to that one with gradient_ez
how to add Zygote.@adjoint
such that for one custom Flux layer instead of Zygote, Enzyme is used, but the rest is still Zygote?
I am thinking of some way, we could smoothly transition without switching to one completely?
from flux.jl.
Related Issues (20)
- SamePad() for even sized filters.
- Dense layers with shared parameters HOT 5
- Implementation of `AdamW` differs from PyTorch HOT 10
- `gpu` should warn if cuDNN is not installed HOT 2
- Cannot take `gradient` of L2 regularization loss HOT 1
- test Enzyme gradient for loss functions
- test Enzyme gpu support
- Enzyme fails with MultiHeadAttention layer HOT 13
- Enable github Discussions
- Stacked RNN in Flux.jl?
- Add option to throw error on passing wrong precision floats to layers HOT 3
- Potential bug of RNN training flow
- why is my `withgradient` type unstable ? HOT 1
- is `Flux.huber_loss` type-unstable ?
- Can't load a Fluxml trained & saved model. Getting ERROR: CUDA error: invalid device context (code 201, ERROR_INVALID_CONTEXT) HOT 1
- ConvTranspose with padding on cpu throws exception HOT 1
- DifferentiationInterface testing HOT 6
- Requires deprecated cuNN.jl package HOT 1
- Model saved under Flux v0.14.16 does not load on v0.14.17 HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flux.jl.