Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a flag to use Enzyme as the AD in training/etc. #2443

Closed
wsmoses opened this issue May 11, 2024 · 14 comments
Closed

Create a flag to use Enzyme as the AD in training/etc. #2443

wsmoses opened this issue May 11, 2024 · 14 comments
Labels

Comments

@wsmoses
Copy link
Contributor

wsmoses commented May 11, 2024

Motivation and description

Now that all the internal Flux tests pass, we should start setting up for integration. Having such a flag would make it easier for myself and others to test things out, debug, etc.

Possible Implementation

No response

@wsmoses
Copy link
Contributor Author

wsmoses commented May 11, 2024

cc @CarloLucibello @ToucheSir

@CarloLucibello
Copy link
Member

CarloLucibello commented May 12, 2024

I think the basic interface needed is a nice gradient function.
This code is still not working though, on both cpu and cuda gpu:

using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Active(x))
        else
            push!(args, Duplicated(x, make_zero(x)))
        end
    end
    ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return g
end

batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu        # CPU training
# device = Flux.gpu      # GPU training

X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device

model = Chain(Dense(feature_size => 32, relu),
              Dense(32, num_classes)) |> device

opt_state = Flux.setup(Adam(1e-3), model)

loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))

function report(epoch)
    @info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end

report(0)
for epoch in 1:epochs
    g = gradient_ez(model -> loss(model, X, y), model)[1]     # Enzyme gradient
    # g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
    Flux.update!(opt_state, model, g)
    report(epoch)
end

We should add tests for the loss functions. This one is failing:

gradient_ez(ŷ -> Flux.logitcrossentropy(ŷ, y), randn(Float32, num_classes, batch_size))

@wsmoses
Copy link
Contributor Author

wsmoses commented May 12, 2024

A modification to your code above which will be more performant/stable/etc (closures are bad).

In any case still has the same issue and will investigate

# using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics

_make_zero!(x::AbstractArray) = x .= 0
_make_zero!(x) = x
make_zero!(model) = fmap(_make_zero!, model)

batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
device = Flux.cpu        # CPU training
# device = Flux.gpu      # GPU training

X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device

model = Chain(Dense(feature_size => 32, relu),
              Dense(32, num_classes)) |> device

opt_state = Flux.setup(Adam(1e-3), model)

loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))

function report(epoch)
    @info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end

report(0)
g = deepcopy(model)
for epoch in 1:epochs
    make_zero!(g)
    Enzyme.autodiff(Reverse, loss, Duplicated(model, g), Const(X), Const(y))
    # g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
    Flux.update!(opt_state, model, g)
    report(epoch)
end

@wsmoses
Copy link
Contributor Author

wsmoses commented May 12, 2024

Yeah this works now with the NNlib type stability fix FluxML/NNlib.jl#584

@darsnack
Copy link
Member

darsnack commented May 12, 2024

The previous "interface" was to import the corresponding AD package and just call e.g. Tracker.withgradient.

The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it.

To me the best option would be a Flux.gradient (and Flux.withgradient) that uses ADTypes.jl (only to avoid further fragmentation). Alternatively, a small package that wraps Enzyme.autodiff + make_zero in a Zygote-like interface (similar to what's above).

But I suggest a dedicated doc page on using Enzyme + Flux will be easier to get through quickly.

@wsmoses
Copy link
Contributor Author

wsmoses commented May 12, 2024

Sure, I think docs would be a great first start. I don't really know how to use Flux or where that would go best, so I'll leave that to you.

At the same time, if we're already doing API design, for training it would be nice to not have to constantly reallocate the gradient buffer (with make_zero). I don't know if there's an in-place zeroing function you have for models, but that would be highly beneficial here.

@CarloLucibello
Copy link
Member

it would be nice to not have to constantly reallocate the gradient buffer

I edited the code in your post to zero the gradient in-place. A slight problem in make_zero! is that it sets to zero the arrays but not the scalar field, so those are going to be accumulated. That can be fixed later and in principle it is not even a problem since scalars are not updated bu the optimizer.

@CarloLucibello
Copy link
Member

CarloLucibello commented May 13, 2024

On gpu I get the following error

error ┌ Warning: active variables passed by value to jl_new_task are not yet supported └ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59 ERROR: Enzyme compilation failed due to illegal type analysis. Current scope: ; Function Attrs: mustprogress willreturn define internal fastcc void @preprocess_julia_fill__33038({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,0]:Pointer, [-1,0,0,0,0]:Pointer, [-1,0,0,0,8]:Integer, [-1,0,0,0,16]:Pointer, [-1,0,0,16]:Integer, [-1,0,0,17]:Integer, [-1,0,0,18]:Integer, [-1,0,0,19]:Integer, [-1,0,0,20]:Integer, [-1,0,0,21]:Integer, [-1,0,0,22]:Integer, [-1,0,0,23]:Integer, [-1,0,0,24]:Integer, [-1,0,0,32]:Pointer, [-1,0,0,40]:Pointer, [-1,0,0,40,-1]:Integer, [-1,0,8]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="139959628162192" "enzymejl_parmtype_ref"="2" %0, float "enzyme_type"="{[-1]:Float@float}" "enzymejl_parmtype"="139978039813152" "enzymejl_parmtype_ref"="0" %1) unnamed_addr #657 !dbg !47671 { top: %2 = call {}*** @julia.get_pgcstack() %3 = call {}*** @julia.get_pgcstack() %4 = bitcast {}*** %2 to {}** %5 = getelementptr inbounds {}*, {}** %4, i64 -14 %6 = getelementptr inbounds {}*, {}** %5, i64 16 %7 = bitcast {}** %6 to i8** %8 = load i8*, i8** %7, align 8 %9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %5, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.457({} addrspace(10)* %9, i8 0, i64 8), !enzyme_zerostack !590 %phic1 = bitcast {} addrspace(10)* %9 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %10 = bitcast {}*** %3 to {}** %11 = getelementptr inbounds {}*, {}** %10, i64 -14 %12 = getelementptr inbounds {}*, {}** %11, i64 16 %13 = bitcast {}** %12 to i8** %14 = load i8*, i8** %13, align 8 %15 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %11, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}*) to {} addrspace(10)*)), !enzyme_fromstack !615 call void @zeroType.456({} addrspace(10)* %15, i8 0, i64 8), !enzyme_zerostack !590 %phic = bitcast {} addrspace(10)* %15 to {} addrspace(10)* addrspace(10)*, !enzyme_caststack !590 %phic19 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !4822 %16 = call {}*** @julia.get_pgcstack() #658 store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic1, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %9, {} addrspace(10)* null) store {} addrspace(10)* null, {} addrspace(10)* addrspace(10)* %phic, align 8, !noalias !47672 call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %15, {} addrspace(10)* null) %current_task329 = getelementptr inbounds {}**, {}*** %16, i64 -14 %current_task3 = bitcast {}*** %current_task329 to {}** %ptls_field30 = getelementptr inbounds {}**, {}*** %16, i64 2 %17 = bitcast {}*** %ptls_field30 to i64*** %ptls_load3132 = load i64**, i64*** %17, align 8, !tbaa !591 %18 = getelementptr inbounds i64*, i64** %ptls_load3132, i64 2 %safepoint = load i64*, i64** %18, align 8, !tbaa !595 fence syncscope("singlethread") seq_cst call void @julia.safepoint(i64* %safepoint) #658, !dbg !47675 fence syncscope("singlethread") seq_cst %bitcast_coercion = bitcast float %1 to i32, !dbg !47676 %19 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !47678 %getfield = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %19 unordered, align 8, !dbg !47678, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !614, !align !615 %20 = addrspacecast {} addrspace(10)* %getfield to i8 addrspace(11)*, !dbg !47681 %21 = getelementptr inbounds i8, i8 addrspace(11)* %20, i64 8, !dbg !47681 %22 = load i8, i8 addrspace(11)* %21, align 8, !dbg !47681, !tbaa !602, !alias.scope !606, !noalias !609 %23 = and i8 %22, 1, !dbg !47681 %.not = icmp eq i8 %23, 0, !dbg !47681 br i1 %.not, label %L8, label %L5, !dbg !47682

L5: ; preds = %top
%24 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10) nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165787312 to {}) to {} addrspace(10))) #659, !dbg !47683
%box = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !47683
%25 = bitcast {} addrspace(10)* %box to [1 x {} addrspace(10)] addrspace(10), !dbg !47683
%26 = extractvalue [1 x {} addrspace(10)] %24, 0, !dbg !47683
%27 = getelementptr [1 x {} addrspace(10)
], [1 x {} addrspace(10)] addrspace(10) %25, i64 0, i64 0, !dbg !47683
store {} addrspace(10)* %26, {} addrspace(10)* addrspace(10)* %27, align 8, !dbg !47683, !tbaa !621, !alias.scope !606, !noalias !47684
%28 = addrspacecast {} addrspace(10)* %box to {} addrspace(12), !dbg !47683
call void @ijl_throw({} addrspace(12)
%28) #661, !dbg !47683
unreachable, !dbg !47683

L8: ; preds = %top
%29 = addrspacecast {} addrspace(10)* %getfield to {} addrspace(10)* addrspace(11), !dbg !47685
%getfield6 = load atomic {} addrspace(10)
, {} addrspace(10)* addrspace(11)* %29 unordered, align 8, !dbg !47685, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !628, !align !615
%30 = addrspacecast {} addrspace(10)* %getfield6 to i8 addrspace(11), !dbg !47687
%getfield_addr7 = getelementptr inbounds i8, i8 addrspace(11)
%30, i64 40, !dbg !47687
%31 = bitcast i8 addrspace(11)* %getfield_addr7 to {} addrspace(10)* addrspace(11), !dbg !47687
%getfield8 = load atomic {} addrspace(10)
, {} addrspace(10)* addrspace(11)* %31 unordered, align 8, !dbg !47687, !tbaa !602, !alias.scope !606, !noalias !609, !nonnull !590, !dereferenceable !615, !align !615
%32 = call token (...) @llvm.julia.gc_preserve_begin({} addrspace(10)* nonnull %getfield8) #658, !dbg !47689
%33 = addrspacecast {} addrspace(10)* %getfield8 to {} addrspace(11), !dbg !47690
%34 = call nonnull {}
@julia.pointer_from_objref({} addrspace(11)* noundef %33) #662, !dbg !47690
%ptr.i = bitcast {}* %34 to i64*, !dbg !47689
%rv.i = load atomic i64, i64* %ptr.i acquire, align 16, !dbg !47689
call void @llvm.julia.gc_preserve_end(token %32) #658, !dbg !47689
%.not33 = icmp eq i64 %rv.i, 0, !dbg !47692
br i1 %.not33, label %L17, label %L20, !dbg !47688

L17: ; preds = %L8
%35 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10) nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165788400 to {}) to {} addrspace(10))) #658, !dbg !47693
%box11 = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !47693
%36 = bitcast {} addrspace(10)* %box11 to [1 x {} addrspace(10)] addrspace(10), !dbg !47693
%37 = extractvalue [1 x {} addrspace(10)] %35, 0, !dbg !47693
%38 = getelementptr [1 x {} addrspace(10)
], [1 x {} addrspace(10)] addrspace(10) %36, i64 0, i64 0, !dbg !47693
store {} addrspace(10)* %37, {} addrspace(10)* addrspace(10)* %38, align 8, !dbg !47693, !tbaa !621, !alias.scope !606, !noalias !47684
%39 = addrspacecast {} addrspace(10)* %box11 to {} addrspace(12), !dbg !47693
call void @ijl_throw({} addrspace(12)
%39) #661, !dbg !47693
unreachable, !dbg !47693

L20: ; preds = %L8
%40 = addrspacecast {} addrspace(10)* %getfield6 to { {} addrspace(10), i64, i64, i8 } addrspace(11), !dbg !47694
%41 = getelementptr inbounds { {} addrspace(10), i64, i64, i8 }, { {} addrspace(10), i64, i64, i8 } addrspace(11)* %40, i64 0, i32 0, !dbg !47694
%42 = load {} addrspace(10), {} addrspace(10) addrspace(11)* %41, align 8, !dbg !47694, !tbaa !602, !alias.scope !606, !noalias !609
%43 = addrspacecast {} addrspace(10)* %42 to i8 addrspace(11), !dbg !47696
%44 = getelementptr inbounds i8, i8 addrspace(11)
%43, i64 8, !dbg !47696
%45 = load i8, i8 addrspace(11)* %44, align 8, !dbg !47696, !tbaa !602, !alias.scope !606, !noalias !609
%46 = and i8 %45, 1, !dbg !47696
%.not34 = icmp eq i8 %46, 0, !dbg !47696
br i1 %.not34, label %L73, label %L27, !dbg !47698

L27: ; preds = %L20
%47 = call fastcc nonnull align 8 {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %42) #658, !dbg !47700
store volatile {} addrspace(10)* %42, {} addrspace(10)* addrspace(10)* %phic, align 8, !dbg !47701, !noalias !47672
call void ({} addrspace(10), ...) @julia.write_barrier({} addrspace(10) %15, {} addrspace(10)* %42), !dbg !47701
store volatile {} addrspace(10)* %47, {} addrspace(10)* addrspace(10)* %phic1, align 8, !dbg !47701, !noalias !47672
call void ({} addrspace(10), ...) @julia.write_barrier({} addrspace(10) %9, {} addrspace(10)* %47), !dbg !47701
store volatile i8 0, i8* %phic19, align 1, !dbg !47701, !tbaa !774, !alias.scope !776, !noalias !47702
%48 = call i64 @ijl_excstack_state() #658, !dbg !47701
%49 = call i32 @julia.except_enter() #663, !dbg !47701
%50 = icmp eq i32 %49, 0, !dbg !47701
br i1 %50, label %try, label %L46, !dbg !47701

L46: ; preds = %L27
%phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0. = load volatile {} addrspace(10), {} addrspace(10) addrspace(10)* %phic, align 8, !dbg !47703, !nonnull !590
%phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0. = load volatile {} addrspace(10), {} addrspace(10) addrspace(10)* %phic1, align 8, !dbg !47703, !nonnull !590
%phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0. = load volatile i8, i8* %phic19, align 1, !dbg !47703
call void @ijl_pop_handler(i32 noundef 1) #658, !dbg !47703
%51 = and i8 %phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0.phic19.0., 1, !dbg !47703
%phi.cast = icmp ne i8 %51, 0, !dbg !47703
br label %L51, !dbg !47703

L51: ; preds = %try, %L46
%value_phi = phi {} addrspace(10)* [ %42, %try ], [ %phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0.phic.0., %L46 ]
%value_phi15 = phi {} addrspace(10)* [ %47, %try ], [ %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0., %L46 ]
%value_phi17 = phi i1 [ true, %try ], [ %phi.cast, %L46 ]
%52 = addrspacecast {} addrspace(10)* %value_phi15 to {} addrspace(11), !dbg !47704
%53 = icmp eq {} addrspace(11)
%52, addrspacecast ({}* inttoptr (i64 139978194116616 to {}) to {} addrspace(11)), !dbg !47704
%54 = addrspacecast {} addrspace(10)* %value_phi to {} addrspace(11)*
%55 = icmp eq {} addrspace(11)* %52, %54
%or.cond = select i1 %53, i1 true, i1 %55, !dbg !47704
br i1 %or.cond, label %L67, label %L62, !dbg !47704

L62: ; preds = %L51
%56 = addrspacecast {} addrspace(10)* %value_phi15 to i8 addrspace(11), !dbg !47705
%57 = getelementptr inbounds i8, i8 addrspace(11)
%56, i64 8, !dbg !47705
%58 = load i8, i8 addrspace(11)* %57, align 8, !dbg !47705, !tbaa !846, !alias.scope !606, !noalias !609
%59 = and i8 %58, 1, !dbg !47705
%.not35 = icmp eq i8 %59, 0, !dbg !47705
br i1 %.not35, label %L67, label %L65, !dbg !47704

L65: ; preds = %L62
%60 = call fastcc nonnull {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %value_phi15) #658, !dbg !47707
br label %L67, !dbg !47707

L67: ; preds = %L65, %L62, %L51
br i1 %50, label %L71, label %L69, !dbg !47707

L69: ; preds = %L67
call fastcc void @julia_rethrow_31152() #661, !dbg !47707
unreachable, !dbg !47707

L71: ; preds = %L67
br i1 %value_phi17, label %ok, label %err, !dbg !47707

L73: ; preds = %L20
call fastcc void @julia_error_31187({} addrspace(10)* nofree noundef nonnull align 32 addrspacecast ({}* inttoptr (i64 139962719163168 to {}) to {} addrspace(10))) #661, !dbg !47708
unreachable, !dbg !47708

try: ; preds = %L27
%61 = call fastcc i64 @julia_unsafe_convert_32014({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) %0) #658, !dbg !47709
%62 = addrspacecast {} addrspace(10)* %0 to i8 addrspace(11), !dbg !47713
%63 = getelementptr inbounds i8, i8 addrspace(11)
%62, i64 24, !dbg !47713
%aggregate_load_box.sroa.0.0..sroa_idx = bitcast i8 addrspace(11)* %63 to i64 addrspace(11), !dbg !47713
%aggregate_load_box.sroa.0.0.copyload = load i64, i64 addrspace(11)
%aggregate_load_box.sroa.0.0..sroa_idx, align 8, !dbg !47713, !tbaa !710, !alias.scope !711, !noalias !47716
%aggregate_load_box.sroa.2.0..sroa_idx25 = getelementptr inbounds i8, i8 addrspace(11)* %62, i64 32, !dbg !47713
%64 = bitcast i8 addrspace(11)* %aggregate_load_box.sroa.2.0..sroa_idx25 to i64 addrspace(11), !dbg !47713
%aggregate_load_box.sroa.2.0.copyload = load i64, i64 addrspace(11)
%64, align 8, !dbg !47713, !tbaa !710, !alias.scope !711, !noalias !47716
%65 = mul i64 %aggregate_load_box.sroa.2.0.copyload, %aggregate_load_box.sroa.0.0.copyload, !dbg !47717
call fastcc void @julia_set__33047(i64 zeroext %61, i32 zeroext %bitcast_coercion, i64 signext %65) #658, !dbg !47712
store volatile i8 1, i8* %phic19, align 1, !dbg !47703, !tbaa !774, !alias.scope !776, !noalias !47702
call void @ijl_pop_handler(i32 noundef 1) #658, !dbg !47703
br label %L51, !dbg !47703

err: ; preds = %L71
call void @ijl_undefined_var_error({} addrspace(12)* noundef addrspacecast ({}* inttoptr (i64 139978194630336 to {}) to {} addrspace(12))) #661, !dbg !47707
unreachable, !dbg !47707

ok: ; preds = %L71
ret void, !dbg !47699
}

Type analysis state:

%current_task3 = bitcast {}*** %current_task329 to {}: {}, intvals: {}
%bitcast_coercion = bitcast float %1 to i32, !dbg !603: {[-1]:Integer}, intvals: {}
{} addrspace(10)* addrspacecast ({}* inttoptr (i64 139965165787312 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
{}* inttoptr (i64 139965165787312 to {}): {[-1]:Anything}, intvals: {}
%value_phi15 = phi {} addrspace(10)
[ %47, %try ], [ %phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0.phic1.0., %L46 ]: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,16]:Pointer}, intvals: {}
%24 = call fastcc [1 x {} addrspace(10)] @julia_ArgumentError_31098({} addrspace(10) nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165787312 to {}) to {} addrspace(10))) #659, !dbg !630: {[-1]:Pointer}, intvals: {}
%box11 = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}
nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !650: {[-1,-1]:Pointer}, intvals: {}
%phic1 = bitcast {} addrspace(10)* %9 to {} addrspace(10)* addrspace(10), !enzyme_caststack !590: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
%13 = bitcast {}** %12 to i8**: {[-1]:Pointer}, intvals: {}
%17 = bitcast {}
** %ptls_field30 to i64***: {[-1]:Pointer}, intvals: {}
%ptls_load3132 = load i64**, i64*** %17, align 8, !tbaa !596: {}, intvals: {}
%15 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %11, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}) to {} addrspace(10))), !enzyme_fromstack !591: {}, intvals: {}
%60 = call fastcc nonnull {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %value_phi15) #658, !dbg !673: {}, intvals: {}
%11 = getelementptr inbounds {}, {}** %10, i64 -14: {}, intvals: {}
{}
inttoptr (i64 139962719163168 to {}): {[-1]:Anything}, intvals: {}
{} addrspace(10)
addrspacecast ({}* inttoptr (i64 139962719163168 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
i64 8: {[-1]:Integer}, intvals: {8,}
%9 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @julia.gc_alloc_obj({}** %5, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}) to {} addrspace(10))), !enzyme_fromstack !591: {}, intvals: {}
{}* inttoptr (i64 139961738084176 to {}): {[-1]:Anything}, intvals: {}
%4 = bitcast {}
** %2 to {}: {}, intvals: {}
%61 = call fastcc i64 @julia_unsafe_convert_32014({} addrspace(10)* nocapture noundef nonnull readonly align 8 dereferenceable(40) %0) #658, !dbg !675: {}, intvals: {}
%2 = call {}
* @julia.get_pgcstack(): {}, intvals: {}
%5 = getelementptr inbounds {}, {}** %4, i64 -14: {}, intvals: {}
%6 = getelementptr inbounds {}
, {}** %5, i64 16: {}, intvals: {}
%12 = getelementptr inbounds {}, {}** %11, i64 16: {}, intvals: {}
%14 = load i8
, i8** %13, align 8: {}, intvals: {}
%box = call noalias nonnull dereferenceable(8) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task3, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10))) #660, !dbg !630: {[-1,-1]:Pointer}, intvals: {}
%safepoint = load i64*, i64** %18, align 8, !tbaa !600: {}, intvals: {}
{} addrspace(10)* addrspacecast ({}* inttoptr (i64 139978038671616 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
{}* inttoptr (i64 139978038671616 to {}): {[-1]:Anything}, intvals: {}
%35 = call fastcc [1 x {} addrspace(10)
] @julia_ArgumentError_31098({} addrspace(10)* nofree noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 139965165788400 to {}) to {} addrspace(10))) #658, !dbg !650: {[-1]:Pointer}, intvals: {}
%phic19 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !592: {[-1]:Pointer}, intvals: {}
%7 = bitcast {}** %6 to i8**: {[-1]:Pointer}, intvals: {}
%8 = load i8*, i8** %7, align 8: {}, intvals: {}
%16 = call {}*** @julia.get_pgcstack() #658: {}, intvals: {}
%phic = bitcast {} addrspace(10)* %15 to {} addrspace(10)* addrspace(10), !enzyme_caststack !590: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
%3 = call {}
** @julia.get_pgcstack(): {}, intvals: {}
%42 = load {} addrspace(10), {} addrspace(10) addrspace(11)* %41, align 8, !dbg !651, !tbaa !613, !alias.scope !617, !noalias !620: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Integer, [-1,16]:Pointer}, intvals: {}
{} addrspace(10)* addrspacecast ({}* inttoptr (i64 139965165788400 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
{}* inttoptr (i64 139965165788400 to {}): {[-1]:Anything}, intvals: {}
%18 = getelementptr inbounds i64
, i64** %ptls_load3132, i64 2: {[-1]:Pointer}, intvals: {}
%ptls_field30 = getelementptr inbounds {}, {}* %16, i64 2: {}, intvals: {}
{} addrspace(10)* null: {[-1]:Pointer, [-1,-1]:Anything}, intvals: {0,}
%65 = mul i64 %aggregate_load_box.sroa.2.0.copyload, %aggregate_load_box.sroa.0.0.copyload, !dbg !691: {[-1]:Integer}, intvals: {}
%10 = bitcast {}*** %3 to {}: {}, intvals: {}
{} addrspace(10)* addrspacecast ({}* inttoptr (i64 139961738084176 to {}) to {} addrspace(10)): {[-1]:Anything}, intvals: {}
{} addrspace(10)* %0: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,0]:Pointer, [-1,0,0,0,0]:Pointer, [-1,0,0,0,8]:Integer, [-1,0,0,0,16]:Pointer, [-1,0,0,16]:Integer, [-1,0,0,17]:Integer, [-1,0,0,18]:Integer, [-1,0,0,19]:Integer, [-1,0,0,20]:Integer, [-1,0,0,21]:Integer, [-1,0,0,22]:Integer, [-1,0,0,23]:Integer, [-1,0,0,24]:Integer, [-1,0,0,32]:Pointer, [-1,0,0,40]:Pointer, [-1,0,0,40,-1]:Integer, [-1,0,8]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}, intvals: {}
float %1: {[-1]:Float@float}, intvals: {}
%47 = call fastcc nonnull align 8 {} addrspace(10)* @julia_context__32398({} addrspace(10)* noundef nonnull align 8 dereferenceable(24) %42) #658, !dbg !662: {}, intvals: {}
%current_task329 = getelementptr inbounds {}
, {}*** %16, i64 -14: {}, intvals: {}

Illegal updateAnalysis prev:{[-1]:Integer} new: {[-1]:Float@float}
val: %bitcast_coercion = bitcast float %1 to i32, !dbg !603 origin= %bitcast_coercion = bitcast float %1 to i32, !dbg !603
MethodInstance for fill!(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)

Caused by:
Stacktrace:
[1] reinterpret
@ ./essentials.jl:581
[2] fill!
@ ~/.julia/packages/CUDA/jdJ7Z/src/array.jl:829

Stacktrace:
[1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:1690
[2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/2FwRI/src/api.jl:154
[3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:3177
[4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5070
[5] codegen
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:4477 [inlined]
[6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755
[7] _thunk
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5755 [inlined]
[8] cached_compilation
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5793 [inlined]
[9] (::Enzyme.Compiler.var"#554#555"{…})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5859
[10] JuliaContext(f::Enzyme.Compiler.var"#554#555"{…}; kwargs::@kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
[11] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
[12] #s2027#553
@ ~/.julia/packages/Enzyme/2FwRI/src/compiler.jl:5811 [inlined]
[13]
@ Enzyme.Compiler ./none:0
[14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[15] autodiff
@ ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:286 [inlined]</
[16] autodiff
@ ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:315 [inlined]
[17] autodiff(::ReverseMode{…}, ::typeof(loss), ::Duplicated{…}, ::Const{…}, ::Const{…})
@ Enzyme ~/.julia/packages/Enzyme/2FwRI/src/Enzyme.jl:300
[18] top-level scope
@ ~/juliadev/Flux/mlp.jl:37
Some type information was truncated. Use show(err) to see complete types.

@wsmoses
Copy link
Contributor Author

wsmoses commented May 13, 2024

You'll need JuliaGPU/CUDA.jl#2371 and then JuliaPackaging/Yggdrasil#8666. It then hits a cublasscal issue, which I stopped investigating to go get dinner.

@mcabbott
Copy link
Member

I think the basic interface needed is a nice gradient function.

Enzyme's own gradient should now do this, as make_zero understands nested structures:

julia> sh = [1f0, 2f0]; nt = (a=sh, b=sh, c=copy(sh));

julia> Enzyme.gradient(Reverse, x -> sum(map(sum, x)), nt)
(a = Float32[2.0, 2.0], b = Float32[2.0, 2.0], c = Float32[1.0, 1.0])

(jl_o1ZBlk) pkg> st Enzyme
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_o1ZBlk/Project.toml`
  [7da242da] Enzyme v0.12.4

The above example doesn't work for me, but I believe function gradient_ez(f, x...) can be deleted to have just this:

for epoch in 1:epochs
    g = Enzyme.gradient(Reverse, m -> loss(m, X, y), model) # Enzyme gradient
    # g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
    Flux.update!(opt_state, model, g)
    report(epoch)
end

A slight problem in make_zero! is that it sets to zero the arrays but not the scalar field, so those are going to be accumulated. That can be fixed later and in principle it is not even a problem since scalars are not updated bu the optimizer.

Right. For those coming from Zygote, it's slightly odd that the gradient contains numbers for non-diff things. But I believe Optimisers.jl's idea of what parameters can be updated is narrow enough that it will only use true gradient numbers from Enzyme.jl.

@wsmoses
Copy link
Contributor Author

wsmoses commented May 14, 2024

This should be resolved by #2446

Like I say in that PR

"""
I have no opinions on the design/API and I will give this PR to you all to make it however you feel (and I will go back to staring at CUDA).

I will note that perf atm is unclear and is worth investigating. However, before we do that, having a good way to run/test things is critical, hence this PR.
"""

@wsmoses
Copy link
Contributor Author

wsmoses commented May 14, 2024

edit: accidentally reran cpu, please ignore below.

CUDA works on the simple example now. It does require either CUDA#master on already merged branches or hopefully a backport release from CUDA.jl via JuliaGPU/CUDA.jl#2375 as well as a Enzyme_jll bump

wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ cat orig.jl 
using CUDA # for GPU training
using Flux, Enzyme
using Random, Statistics

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Active(x))
        else
            push!(args, Duplicated(x, make_zero(x)))
        end
    end
    ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return g
end

batch_size = 128
feature_size = 784
num_classes = 10
epochs = 100
# device = Flux.cpu        # CPU training
device = Flux.gpu      # GPU training

X = randn(Float32, feature_size, batch_size) |> device
y = Flux.onehotbatch(rand(1:num_classes, batch_size), 1:num_classes) |> device

model = Chain(Dense(feature_size => 32, relu),
              Dense(32, num_classes)) |> device

opt_state = Flux.setup(Adam(1e-3), model)

loss(model, x, y) = Flux.logitcrossentropy(model(x), y)
accuracy(model, x, y) = mean(Flux.onecold(model(x)) .== Flux.onecold(y))

function report(epoch)
    @info "Epoch: $epoch" loss=loss(model, X, y) accuracy=accuracy(model, X, y)
end

report(0)
for epoch in 1:epochs
    g = gradient_ez(model -> loss(model, X, y), model)[1]     # Enzyme gradient
    # g = Flux.gradient(model -> loss(model, X, y), model)[1] # Zygote gradient
    Flux.update!(opt_state, model, g)
    report(epoch)
end
wmoses@beast:~/git/Flux.jl ((HEAD detached at origin/master)) $ ~/git/Enzyme.jl/julia-1.10.2/bin/julia --project orig.jl 
┌ Warning: Package cuDNN not found in current path.
│ - Run `import Pkg; Pkg.add("cuDNN")` to install the cuDNN package, then restart julia.
│ - If cuDNN is not installed, some Flux functionalities will not be available when running on the GPU.
└ @ FluxCUDAExt ~/git/Flux.jl/ext/FluxCUDAExt/FluxCUDAExt.jl:57
┌ Info: Epoch: 0
│   loss = 2.7904227f0
└   accuracy = 0.125
┌ Info: Epoch: 1
│   loss = 2.5142982f0
└   accuracy = 0.15625
┌ Info: Epoch: 2
│   loss = 2.2610319f0
└   accuracy = 0.203125
┌ Info: Epoch: 3
│   loss = 2.029134f0
└   accuracy = 0.28125
┌ Info: Epoch: 4
│   loss = 1.8172197f0
└   accuracy = 0.3515625
┌ Info: Epoch: 5
│   loss = 1.6268556f0
└   accuracy = 0.4375
┌ Info: Epoch: 6
│   loss = 1.4554112f0
└   accuracy = 0.546875
┌ Info: Epoch: 7
│   loss = 1.3014916f0
└   accuracy = 0.6640625
┌ Info: Epoch: 8
│   loss = 1.163165f0
└   accuracy = 0.7890625
┌ Info: Epoch: 9
│   loss = 1.0413302f0
└   accuracy = 0.8515625
┌ Info: Epoch: 10
│   loss = 0.93555194f0
└   accuracy = 0.8515625
┌ Info: Epoch: 11
│   loss = 0.84206563f0
└   accuracy = 0.8828125
┌ Info: Epoch: 12
│   loss = 0.7600569f0
└   accuracy = 0.90625
┌ Info: Epoch: 13
│   loss = 0.6874082f0
└   accuracy = 0.921875
┌ Info: Epoch: 14
│   loss = 0.6230737f0
└   accuracy = 0.9296875
┌ Info: Epoch: 15
│   loss = 0.5663827f0
└   accuracy = 0.9609375
┌ Info: Epoch: 16
│   loss = 0.5165455f0
└   accuracy = 0.96875
┌ Info: Epoch: 17
│   loss = 0.4719535f0
└   accuracy = 0.96875
┌ Info: Epoch: 18
│   loss = 0.4319139f0
└   accuracy = 0.9765625
┌ Info: Epoch: 19
│   loss = 0.39577293f0
└   accuracy = 0.984375
┌ Info: Epoch: 20
│   loss = 0.36347917f0
└   accuracy = 0.984375
┌ Info: Epoch: 21
│   loss = 0.33449084f0
└   accuracy = 0.9921875
┌ Info: Epoch: 22
│   loss = 0.30846184f0
└   accuracy = 0.9921875
┌ Info: Epoch: 23
│   loss = 0.28476223f0
└   accuracy = 0.9921875
┌ Info: Epoch: 24
│   loss = 0.26318714f0
└   accuracy = 1.0
┌ Info: Epoch: 25
│   loss = 0.24353352f0
└   accuracy = 1.0
┌ Info: Epoch: 26
│   loss = 0.22557218f0
└   accuracy = 1.0
┌ Info: Epoch: 27
│   loss = 0.20921068f0
└   accuracy = 1.0
┌ Info: Epoch: 28
│   loss = 0.19429381f0
└   accuracy = 1.0
┌ Info: Epoch: 29
│   loss = 0.18054952f0
└   accuracy = 1.0
┌ Info: Epoch: 30
│   loss = 0.16796987f0
└   accuracy = 1.0
┌ Info: Epoch: 31
│   loss = 0.1563463f0
└   accuracy = 1.0
┌ Info: Epoch: 32
│   loss = 0.14567412f0
└   accuracy = 1.0
┌ Info: Epoch: 33
│   loss = 0.13588753f0
└   accuracy = 1.0
┌ Info: Epoch: 34
│   loss = 0.12687433f0
└   accuracy = 1.0
┌ Info: Epoch: 35
│   loss = 0.11857266f0
└   accuracy = 1.0
┌ Info: Epoch: 36
│   loss = 0.11093213f0
└   accuracy = 1.0
┌ Info: Epoch: 37
│   loss = 0.103871785f0
└   accuracy = 1.0
┌ Info: Epoch: 38
│   loss = 0.09736837f0
└   accuracy = 1.0
┌ Info: Epoch: 39
│   loss = 0.09138645f0
└   accuracy = 1.0
┌ Info: Epoch: 40
│   loss = 0.08586908f0
└   accuracy = 1.0
┌ Info: Epoch: 41
│   loss = 0.080786735f0
└   accuracy = 1.0
┌ Info: Epoch: 42
│   loss = 0.07610354f0
└   accuracy = 1.0
┌ Info: Epoch: 43
│   loss = 0.07179588f0
└   accuracy = 1.0
┌ Info: Epoch: 44
│   loss = 0.06783663f0
└   accuracy = 1.0
┌ Info: Epoch: 45
│   loss = 0.06419177f0
└   accuracy = 1.0
┌ Info: Epoch: 46
│   loss = 0.060845155f0
└   accuracy = 1.0
┌ Info: Epoch: 47
│   loss = 0.057761367f0
└   accuracy = 1.0
┌ Info: Epoch: 48
│   loss = 0.0549154f0
└   accuracy = 1.0
┌ Info: Epoch: 49
│   loss = 0.05228231f0
└   accuracy = 1.0
┌ Info: Epoch: 50
│   loss = 0.049845647f0
└   accuracy = 1.0
┌ Info: Epoch: 51
│   loss = 0.047589153f0
└   accuracy = 1.0
┌ Info: Epoch: 52
│   loss = 0.045498513f0
└   accuracy = 1.0
┌ Info: Epoch: 53
│   loss = 0.04355742f0
└   accuracy = 1.0
┌ Info: Epoch: 54
│   loss = 0.04175187f0
└   accuracy = 1.0
┌ Info: Epoch: 55
│   loss = 0.04007356f0
└   accuracy = 1.0
┌ Info: Epoch: 56
│   loss = 0.038507923f0
└   accuracy = 1.0
┌ Info: Epoch: 57
│   loss = 0.037045095f0
└   accuracy = 1.0
┌ Info: Epoch: 58
│   loss = 0.035674226f0
└   accuracy = 1.0
┌ Info: Epoch: 59
│   loss = 0.034392048f0
└   accuracy = 1.0
┌ Info: Epoch: 60
│   loss = 0.033194654f0
└   accuracy = 1.0
┌ Info: Epoch: 61
│   loss = 0.032058075f0
└   accuracy = 1.0
┌ Info: Epoch: 62
│   loss = 0.030996136f0
└   accuracy = 1.0
┌ Info: Epoch: 63
│   loss = 0.02999451f0
└   accuracy = 1.0
┌ Info: Epoch: 64
│   loss = 0.029050402f0
└   accuracy = 1.0
┌ Info: Epoch: 65
│   loss = 0.02815985f0
└   accuracy = 1.0
┌ Info: Epoch: 66
│   loss = 0.027319008f0
└   accuracy = 1.0
┌ Info: Epoch: 67
│   loss = 0.02652272f0
└   accuracy = 1.0
┌ Info: Epoch: 68
│   loss = 0.025767544f0
└   accuracy = 1.0
┌ Info: Epoch: 69
│   loss = 0.025051065f0
└   accuracy = 1.0
┌ Info: Epoch: 70
│   loss = 0.024369944f0
└   accuracy = 1.0
┌ Info: Epoch: 71
│   loss = 0.023721226f0
└   accuracy = 1.0
┌ Info: Epoch: 72
│   loss = 0.023103705f0
└   accuracy = 1.0
┌ Info: Epoch: 73
│   loss = 0.022514593f0
└   accuracy = 1.0
┌ Info: Epoch: 74
│   loss = 0.021952922f0
└   accuracy = 1.0
┌ Info: Epoch: 75
│   loss = 0.021417053f0
└   accuracy = 1.0
┌ Info: Epoch: 76
│   loss = 0.020906389f0
└   accuracy = 1.0
┌ Info: Epoch: 77
│   loss = 0.0204159f0
└   accuracy = 1.0
┌ Info: Epoch: 78
│   loss = 0.01994732f0
└   accuracy = 1.0
┌ Info: Epoch: 79
│   loss = 0.01949887f0
└   accuracy = 1.0
┌ Info: Epoch: 80
│   loss = 0.01906871f0
└   accuracy = 1.0
┌ Info: Epoch: 81
│   loss = 0.018656129f0
└   accuracy = 1.0
┌ Info: Epoch: 82
│   loss = 0.018260362f0
└   accuracy = 1.0
┌ Info: Epoch: 83
│   loss = 0.017879806f0
└   accuracy = 1.0
┌ Info: Epoch: 84
│   loss = 0.017513612f0
└   accuracy = 1.0
┌ Info: Epoch: 85
│   loss = 0.017161498f0
└   accuracy = 1.0
┌ Info: Epoch: 86
│   loss = 0.01682241f0
└   accuracy = 1.0
┌ Info: Epoch: 87
│   loss = 0.016495718f0
└   accuracy = 1.0
┌ Info: Epoch: 88
│   loss = 0.016181245f0
└   accuracy = 1.0
┌ Info: Epoch: 89
│   loss = 0.015877243f0
└   accuracy = 1.0
┌ Info: Epoch: 90
│   loss = 0.0155781405f0
└   accuracy = 1.0
┌ Info: Epoch: 91
│   loss = 0.01528422f0
└   accuracy = 1.0
┌ Info: Epoch: 92
│   loss = 0.014997441f0
└   accuracy = 1.0
┌ Info: Epoch: 93
│   loss = 0.014718127f0
└   accuracy = 1.0
┌ Info: Epoch: 94
│   loss = 0.014446221f0
└   accuracy = 1.0
┌ Info: Epoch: 95
│   loss = 0.014181806f0
└   accuracy = 1.0
┌ Info: Epoch: 96
│   loss = 0.013925277f0
└   accuracy = 1.0
┌ Info: Epoch: 97
│   loss = 0.013677116f0
└   accuracy = 1.0
┌ Info: Epoch: 98
│   loss = 0.013437184f0
└   accuracy = 1.0
┌ Info: Epoch: 99
│   loss = 0.013204632f0
└   accuracy = 1.0
┌ Info: Epoch: 100
│   loss = 0.012979296f0
└   accuracy = 1.0

@mashu
Copy link

mashu commented May 16, 2024

@CarloLucibello this gradient_ez is very useful. Thanks! Would it be possible to have also option to run Enzyme from Zygote? Or an example similar to that one with gradient_ez how to add Zygote.@adjoint such that for one custom Flux layer instead of Zygote, Enzyme is used, but the rest is still Zygote?

I am thinking of some way, we could smoothly transition without switching to one completely?

@gdalle
Copy link

gdalle commented Jun 18, 2024

The most recent attempt was supposed to be DI.jl, but the choice to focus on arrays and single inputs means we can't use it.

@darsnack I'd actually love to revisit the dream of DI + Flux one of these days.

  • For multiple inputs, I think I see a way to support additional constant inputs without too much pain (Multiple arguments / activities? JuliaDiff/DifferentiationInterface.jl#311). Apparently it's what you need for e.g. X and y in training.
  • For array-only, the trouble is not supporting general structs, it's testing them. We've had this discussion together, and I don't want to commit to something that a) doesn't work for every backend and b) will probably be undertested because arbtrary structs can be, well, arbitrary. In my view, non-arrays cannot be in the DI API because there will be plenty of cases that fail, and it's very hard to say which ones ahead of time.

To me the best option would be a Flux.gradient (and Flux.withgradient) that uses ADTypes.jl (only to avoid further fragmentation). Alternatively, a small package that wraps Enzyme.autodiff + make_zero in a Zygote-like interface (similar to what's above).

Why not create a package named DifferentiationInterfaceForFlux or something, which relies on DI but tests compatibility with Flux layers and makes it part of its API? In other words, if I change something in DI that removes compatibility with Flux layers, the glue package could still be frozen to its current version until it gets resolved.

@wsmoses wsmoses closed this as completed Jun 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

6 participants