-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create a flag to use Enzyme as the AD in training/etc. #2443
Comments
I think the basic interface needed is a nice using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu # CPU training
# device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
for epoch in 1:epochs
g = gradient_ez(model -> loss(model, X, y), model)[1] # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end We should add tests for the loss functions. This one is failing: gradient_ez(ŷ -> Flux.logitcrossentropy(ŷ, y), randn(Float32, num_classes, batch_size)) |
A modification to your code above which will be more performant/stable/etc (closures are bad). In any case still has the same issue and will investigate # using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero!(x::AbstractArray) = x .= 0
_make_zero!(x) = x
make_zero!(model) = fmap(_make_zero!, model)
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu # CPU training
# device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
g = deepcopy(model)
for epoch in 1:epochs
make_zero!(g)
Enzyme.autodiff(Reverse, loss, Duplicated(model, g), Const(X), Const(y))
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end |
Yeah this works now with the NNlib type stability fix FluxML/NNlib.jl#584 |
The previous "interface" was to import the corresponding AD package and just call e.g. The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it. To me the best option would be a But I suggest a dedicated doc page on using Enzyme + Flux will be easier to get through quickly. |
Sure, I think docs would be a great first start. I don't really know how to use Flux or where that would go best, so I'll leave that to you. At the same time, if we're already doing API design, for training it would be nice to not have to constantly reallocate the gradient buffer (with make_zero). I don't know if there's an in-place zeroing function you have for models, but that would be highly beneficial here. |
I edited the code in your post to zero the gradient in-place. A slight problem in |
On gpu I get the following error error┌ Warning: active variables passed by value to jl_new_task are not yet supported └ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59 ERROR: Enzyme compilation failed due to illegal type analysis. Current scope: ; Function Attrs: mustprogress willreturn define internal fastcc void @preprocess_julia_fill__33038({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,0]:Pointer, [-1,0,0,0,0]:Pointer, [-1,0,0,0,8]:Integer, [-1,0,0,0,16]:Pointer, [-1,0,0,16]:Integer, [-1,0,0,17]:Integer, [-1,0,0,18]:Integer, [-1,0,0,19]:Integer, [-1,0,0,20]:Integer, [-1,0,0,21]:Integer, [-1,0,0,22]:Integer, [-1,0,0,23]:Integer, [-1,0,0,24]:Integer, [-1,0,0,32]:Pointer, [-1,0,0,40]:Pointer, [-1,0,0,40,-1]:Integer, [-1,0,8]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="139959628162192" "enzymejl_parmtype_ref"="2" %0, float "enzyme_type"="{[-1]:Float@float}" "enzymejl_parmtype"="139978039813152" "enzymejl_parmtype_ref"="0" %1) unnamed_addr #657 !dbg !47671 { top: %2 = call {}*** @julia.get_pgcstack() %3 = call {}*** @julia.get_pgcstack() %4 = bitcast {}*** %2 to {}** %5 = getelementptr inbounds {}*, {}** %4, i64 -14 %6 = getelementptr inbounds {}*, {}** %5, i64 16 %7 = bitcast {}** %6 to i8** %8 = load i8*, i8** %7, align 8 %9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %5, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.457({} addrspace(10)* %9, i8 0, i64 8), !enzyme_zerostack !590 %phic1 = bitcast {} addrspace(10)* %9 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %10 = bitcast {}*** %3 to {}** %11 = getelementptr inbounds {}*, {}** %10, i64 -14 %12 = getelementptr inbounds {}*, {}** %11, i64 16 %13 = bitcast {}** %12 to i8** %14 = load i8*, i8** %13, align 8 %15 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %11, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.456({} addrspace(10)* %15, i8 0, i64 8), !enzyme_zerostack !590 %phic = bitcast {} addrspace(10)* %15 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %phic19 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !4822 %16 = call {}*** @julia.get_pgcstack() #658 store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic1, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* null) store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %15, {} addrspace(10)* null) %current_task329 = getelementptr inbounds {}**, {}*** %16, i64 -14 %current_task3 = bitcast {}*** %current_task329 to {}** %ptls_field30 = getelementptr inbounds {}**, {}*** %16, i64 2 %17 = bitcast {}*** %ptls_field30 to i64*** %ptls_load3132 = load i64**, i64*** %17, align 8, !tbaa !591 %18 = getelementptr inbounds i64*, i64** %ptls_load3132, i64 2 %safepoint = load i64*, i64** %18, align 8, !tbaa !595 fence syncscope("singlethread") seq_cst call void @julia.safepoint(i64* %safepoint) #658, !dbg !47675 fence syncscope("singlethread") seq_cst %bitcast_coercion = bitcast float %1 to i32, !dbg !47676 %19 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !47678 %getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %19 unordered, align 8, !dbg !47678, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !614, !align !615 %20 = addrspacecast {} addrspace(10)* %getfield to i8 addrspace(11)*, !dbg !47681 %21 = getelementptr inbounds i8, i8 addrspace(11)* %20, i64 8, !dbg !47681 %22 = load i8, i8 addrspace(11)* %21, align 8, !dbg !47681, !tbaa !602, !alias.scope !606, !noalias !609 %23 = and i8 %22, 1, !dbg !47681 %.not = icmp eq i8 %23, 0, !dbg !47681 br i1 %.not, label %L8, label %L5, !dbg !47682L5: ; preds = %top L8: ; preds = %top L17: ; preds = %L8 L20: ; preds = %L8 L27: ; preds = %L20 L46: ; preds = %L27 L51: ; preds = %try, %L46 L62: ; preds = %L51 L65: ; preds = %L62 L67: ; preds = %L65, %L62, %L51 L69: ; preds = %L67 L71: ; preds = %L67 L73: ; preds = %L20 try: ; preds = %L27 err: ; preds = %L71 ok: ; preds = %L71 Type analysis state: Illegal updateAnalysis prev:{[-1]:Integer} new: {[-1]:Float@float} Caused by: Stacktrace: |
You'll need JuliaGPU/CUDA.jl#2371 and then JuliaPackaging/Yggdrasil#8666. It then hits a cublasscal issue, which I stopped investigating to go get dinner. |
Enzyme's own julia> sh = [1f0, 2f0]; nt = (a=sh, b=sh, c=copy(sh));
julia> Enzyme.gradient(Reverse, x -> sum(map(sum, x)), nt)
(a = Float32[2.0, 2.0], b = Float32[2.0, 2.0], c = Float32[1.0, 1.0])
(jl_o1ZBlk) pkg> st Enzyme
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_o1ZBlk/Project.toml`
[7da242da] Enzyme v0.12.4 The above example doesn't work for me, but I believe for epoch in 1:epochs
g = Enzyme.gradient(Reverse, m -> loss(m, X, y), model) # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
Right. For those coming from Zygote, it's slightly odd that the gradient contains numbers for non-diff things. But I believe Optimisers.jl's idea of what parameters can be updated is narrow enough that it will only use true gradient numbers from Enzyme.jl. |
This should be resolved by #2446 Like I say in that PR """ I will note that perf atm is unclear and is worth investigating. However, before we do that, having a good way to run/test things is critical, hence this PR. |
edit: accidentally reran cpu, please ignore below. CUDA works on the simple example now. It does require either CUDA#master on already merged branches or hopefully a backport release from CUDA.jl via JuliaGPU/CUDA.jl#2375 as well as a Enzyme_jll bump wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ cat orig.jl
using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end
batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
# device = Flux.cpu # CPU training
device = Flux.gpu # GPU training
X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device
model = Chain(Dense(feature_size => 32, relu),
Dense(32, num_classes)) |> device
opt_state = Flux.setup(Adam(1e-3), model)
loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))
function report(epoch)
@info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end
report(0)
for epoch in 1:epochs
g = gradient_ez(model -> loss(model, X, y), model)[1] # Enzyme gradient
# g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
Flux.update!(opt_state, model, g)
report(epoch)
end
wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ ~/git/Enzyme.jl/julia-1.10.2/bin/julia --project orig.jl
┌ Warning: Package cuDNN not found in current path.
│ - Run `import Pkg; Pkg.add("cuDNN")` to install the cuDNN package, then restart julia.
│ - If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU.
└ @ FluxCUDAExt ~/git/Flux.jl/ext/FluxCUDAExt/FluxCUDAExt.jl:57
┌ Info: Epoch: 0
│ loss = 2.7904227f0
└ accuracy = 0.125
┌ Info: Epoch: 1
│ loss = 2.5142982f0
└ accuracy = 0.15625
┌ Info: Epoch: 2
│ loss = 2.2610319f0
└ accuracy = 0.203125
┌ Info: Epoch: 3
│ loss = 2.029134f0
└ accuracy = 0.28125
┌ Info: Epoch: 4
│ loss = 1.8172197f0
└ accuracy = 0.3515625
┌ Info: Epoch: 5
│ loss = 1.6268556f0
└ accuracy = 0.4375
┌ Info: Epoch: 6
│ loss = 1.4554112f0
└ accuracy = 0.546875
┌ Info: Epoch: 7
│ loss = 1.3014916f0
└ accuracy = 0.6640625
┌ Info: Epoch: 8
│ loss = 1.163165f0
└ accuracy = 0.7890625
┌ Info: Epoch: 9
│ loss = 1.0413302f0
└ accuracy = 0.8515625
┌ Info: Epoch: 10
│ loss = 0.93555194f0
└ accuracy = 0.8515625
┌ Info: Epoch: 11
│ loss = 0.84206563f0
└ accuracy = 0.8828125
┌ Info: Epoch: 12
│ loss = 0.7600569f0
└ accuracy = 0.90625
┌ Info: Epoch: 13
│ loss = 0.6874082f0
└ accuracy = 0.921875
┌ Info: Epoch: 14
│ loss = 0.6230737f0
└ accuracy = 0.9296875
┌ Info: Epoch: 15
│ loss = 0.5663827f0
└ accuracy = 0.9609375
┌ Info: Epoch: 16
│ loss = 0.5165455f0
└ accuracy = 0.96875
┌ Info: Epoch: 17
│ loss = 0.4719535f0
└ accuracy = 0.96875
┌ Info: Epoch: 18
│ loss = 0.4319139f0
└ accuracy = 0.9765625
┌ Info: Epoch: 19
│ loss = 0.39577293f0
└ accuracy = 0.984375
┌ Info: Epoch: 20
│ loss = 0.36347917f0
└ accuracy = 0.984375
┌ Info: Epoch: 21
│ loss = 0.33449084f0
└ accuracy = 0.9921875
┌ Info: Epoch: 22
│ loss = 0.30846184f0
└ accuracy = 0.9921875
┌ Info: Epoch: 23
│ loss = 0.28476223f0
└ accuracy = 0.9921875
┌ Info: Epoch: 24
│ loss = 0.26318714f0
└ accuracy = 1.0
┌ Info: Epoch: 25
│ loss = 0.24353352f0
└ accuracy = 1.0
┌ Info: Epoch: 26
│ loss = 0.22557218f0
└ accuracy = 1.0
┌ Info: Epoch: 27
│ loss = 0.20921068f0
└ accuracy = 1.0
┌ Info: Epoch: 28
│ loss = 0.19429381f0
└ accuracy = 1.0
┌ Info: Epoch: 29
│ loss = 0.18054952f0
└ accuracy = 1.0
┌ Info: Epoch: 30
│ loss = 0.16796987f0
└ accuracy = 1.0
┌ Info: Epoch: 31
│ loss = 0.1563463f0
└ accuracy = 1.0
┌ Info: Epoch: 32
│ loss = 0.14567412f0
└ accuracy = 1.0
┌ Info: Epoch: 33
│ loss = 0.13588753f0
└ accuracy = 1.0
┌ Info: Epoch: 34
│ loss = 0.12687433f0
└ accuracy = 1.0
┌ Info: Epoch: 35
│ loss = 0.11857266f0
└ accuracy = 1.0
┌ Info: Epoch: 36
│ loss = 0.11093213f0
└ accuracy = 1.0
┌ Info: Epoch: 37
│ loss = 0.103871785f0
└ accuracy = 1.0
┌ Info: Epoch: 38
│ loss = 0.09736837f0
└ accuracy = 1.0
┌ Info: Epoch: 39
│ loss = 0.09138645f0
└ accuracy = 1.0
┌ Info: Epoch: 40
│ loss = 0.08586908f0
└ accuracy = 1.0
┌ Info: Epoch: 41
│ loss = 0.080786735f0
└ accuracy = 1.0
┌ Info: Epoch: 42
│ loss = 0.07610354f0
└ accuracy = 1.0
┌ Info: Epoch: 43
│ loss = 0.07179588f0
└ accuracy = 1.0
┌ Info: Epoch: 44
│ loss = 0.06783663f0
└ accuracy = 1.0
┌ Info: Epoch: 45
│ loss = 0.06419177f0
└ accuracy = 1.0
┌ Info: Epoch: 46
│ loss = 0.060845155f0
└ accuracy = 1.0
┌ Info: Epoch: 47
│ loss = 0.057761367f0
└ accuracy = 1.0
┌ Info: Epoch: 48
│ loss = 0.0549154f0
└ accuracy = 1.0
┌ Info: Epoch: 49
│ loss = 0.05228231f0
└ accuracy = 1.0
┌ Info: Epoch: 50
│ loss = 0.049845647f0
└ accuracy = 1.0
┌ Info: Epoch: 51
│ loss = 0.047589153f0
└ accuracy = 1.0
┌ Info: Epoch: 52
│ loss = 0.045498513f0
└ accuracy = 1.0
┌ Info: Epoch: 53
│ loss = 0.04355742f0
└ accuracy = 1.0
┌ Info: Epoch: 54
│ loss = 0.04175187f0
└ accuracy = 1.0
┌ Info: Epoch: 55
│ loss = 0.04007356f0
└ accuracy = 1.0
┌ Info: Epoch: 56
│ loss = 0.038507923f0
└ accuracy = 1.0
┌ Info: Epoch: 57
│ loss = 0.037045095f0
└ accuracy = 1.0
┌ Info: Epoch: 58
│ loss = 0.035674226f0
└ accuracy = 1.0
┌ Info: Epoch: 59
│ loss = 0.034392048f0
└ accuracy = 1.0
┌ Info: Epoch: 60
│ loss = 0.033194654f0
└ accuracy = 1.0
┌ Info: Epoch: 61
│ loss = 0.032058075f0
└ accuracy = 1.0
┌ Info: Epoch: 62
│ loss = 0.030996136f0
└ accuracy = 1.0
┌ Info: Epoch: 63
│ loss = 0.02999451f0
└ accuracy = 1.0
┌ Info: Epoch: 64
│ loss = 0.029050402f0
└ accuracy = 1.0
┌ Info: Epoch: 65
│ loss = 0.02815985f0
└ accuracy = 1.0
┌ Info: Epoch: 66
│ loss = 0.027319008f0
└ accuracy = 1.0
┌ Info: Epoch: 67
│ loss = 0.02652272f0
└ accuracy = 1.0
┌ Info: Epoch: 68
│ loss = 0.025767544f0
└ accuracy = 1.0
┌ Info: Epoch: 69
│ loss = 0.025051065f0
└ accuracy = 1.0
┌ Info: Epoch: 70
│ loss = 0.024369944f0
└ accuracy = 1.0
┌ Info: Epoch: 71
│ loss = 0.023721226f0
└ accuracy = 1.0
┌ Info: Epoch: 72
│ loss = 0.023103705f0
└ accuracy = 1.0
┌ Info: Epoch: 73
│ loss = 0.022514593f0
└ accuracy = 1.0
┌ Info: Epoch: 74
│ loss = 0.021952922f0
└ accuracy = 1.0
┌ Info: Epoch: 75
│ loss = 0.021417053f0
└ accuracy = 1.0
┌ Info: Epoch: 76
│ loss = 0.020906389f0
└ accuracy = 1.0
┌ Info: Epoch: 77
│ loss = 0.0204159f0
└ accuracy = 1.0
┌ Info: Epoch: 78
│ loss = 0.01994732f0
└ accuracy = 1.0
┌ Info: Epoch: 79
│ loss = 0.01949887f0
└ accuracy = 1.0
┌ Info: Epoch: 80
│ loss = 0.01906871f0
└ accuracy = 1.0
┌ Info: Epoch: 81
│ loss = 0.018656129f0
└ accuracy = 1.0
┌ Info: Epoch: 82
│ loss = 0.018260362f0
└ accuracy = 1.0
┌ Info: Epoch: 83
│ loss = 0.017879806f0
└ accuracy = 1.0
┌ Info: Epoch: 84
│ loss = 0.017513612f0
└ accuracy = 1.0
┌ Info: Epoch: 85
│ loss = 0.017161498f0
└ accuracy = 1.0
┌ Info: Epoch: 86
│ loss = 0.01682241f0
└ accuracy = 1.0
┌ Info: Epoch: 87
│ loss = 0.016495718f0
└ accuracy = 1.0
┌ Info: Epoch: 88
│ loss = 0.016181245f0
└ accuracy = 1.0
┌ Info: Epoch: 89
│ loss = 0.015877243f0
└ accuracy = 1.0
┌ Info: Epoch: 90
│ loss = 0.0155781405f0
└ accuracy = 1.0
┌ Info: Epoch: 91
│ loss = 0.01528422f0
└ accuracy = 1.0
┌ Info: Epoch: 92
│ loss = 0.014997441f0
└ accuracy = 1.0
┌ Info: Epoch: 93
│ loss = 0.014718127f0
└ accuracy = 1.0
┌ Info: Epoch: 94
│ loss = 0.014446221f0
└ accuracy = 1.0
┌ Info: Epoch: 95
│ loss = 0.014181806f0
└ accuracy = 1.0
┌ Info: Epoch: 96
│ loss = 0.013925277f0
└ accuracy = 1.0
┌ Info: Epoch: 97
│ loss = 0.013677116f0
└ accuracy = 1.0
┌ Info: Epoch: 98
│ loss = 0.013437184f0
└ accuracy = 1.0
┌ Info: Epoch: 99
│ loss = 0.013204632f0
└ accuracy = 1.0
┌ Info: Epoch: 100
│ loss = 0.012979296f0
└ accuracy = 1.0 |
@CarloLucibello this I am thinking of some way, we could smoothly transition without switching to one completely? |
@darsnack I'd actually love to revisit the dream of DI + Flux one of these days.
Why not create a package named DifferentiationInterfaceForFlux or something, which relies on DI but tests compatibility with Flux layers and makes it part of its API? In other words, if I change something in DI that removes compatibility with Flux layers, the glue package could still be frozen to its current version until it gets resolved. |
Motivation and description
Now that all the internal Flux tests pass, we should start setting up for integration. Having such a flag would make it easier for myself and others to test things out, debug, etc.
Possible Implementation
No response
The text was updated successfully, but these errors were encountered: