Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper Circular Reference Handling #416

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Dec 12, 2024

This PR targets finishing off #204

This seems to be blocked by JuliaLang/julia#56775 currently -- I make use of deepcopy inside of TestUtils.test_rule_correctness, and it falls over for circular references involving Arrays currently. Per the linked issue, I'm not 100% certain whether it's a bug in Julia 1.11.2, or expected behaviour (it feels like a bug though).

I'm not 100% sure who's attention we should attempt to get over JuliaLang/julia#56775 -- any idea @yebai @mhauru @penelopeysm ?

I'll also have a think about whether I can avoid the use of deepcopy -- I'm reasonably confident that I need it, but I might be wrong.

edit: definitely need deepcopy, because I need to be able to tell whether a rule has successfully returned inputs to their initial state. In order to make this comparison, I need to keep the initial state around, which requires the use of deepcopy. The only workaround would be to roll my own version of deepcopy, which I'm not keen on.

edit2: deepcopy seems to be fine on 1.10 though, so I could just do development on 1.10, and only run tests involving circular references on 1.10 for now.

@willtebbutt willtebbutt marked this pull request as draft December 12, 2024 11:46
@mhauru
Copy link
Contributor

mhauru commented Dec 12, 2024

If the proper fix will take some time, do you think it makes sense to build a workaround, or should we disable Mooncake tests in Turing.jl for now? Currently CI is failing on Turing.jl master due to the stack overflows.

@yebai
Copy link
Contributor

yebai commented Dec 12, 2024

How about using Serialization (i.e. serialize then deserialize) as an alternative for deepcopy? Serialization might be less efficient but could be a reasonable workaround until deepcopy issue is fixed on Julia 1.11

@penelopeysm
Copy link
Contributor

penelopeysm commented Dec 13, 2024

Currently CI is failing on Turing.jl master due to the stack overflows.

Do any of you know what changed that made the tests start to fail? Comparing the last success with the first failure I see very little change:

(last success -> first failure)
DifferentiationInterface 0.6.23 -> 0.6.24
SciMLBase 2.65.0 -> 2.65.1
UnsafeAtomicsLLVM 0.2.1 -> 0.2.2

Both are running Julia 1.11.1.

I've just made a PR pinning the old deps just DifferentiationInterface which seems to work: TuringLang/Turing.jl#2437 so this seems like a sensible workaround to decouple Turing.jl CI from this issue.

@willtebbutt
Copy link
Member Author

So, what's going on is previously DI wasn't exposed to the Mooncake.set_to_zero!! function. Since I exposed the prepare_pullback_cache!! functionality, this has changed -- I set_to_zero!! existing memory, rather than calling zero_tangent and allocating new memory. In principle this ought to do exactly the same thing, but due to un-finished implementation, it happens to be the case that set_to_zero!! doesn't properly support circular referencing, while zero_tangent does. It turns out that in the test suite of Turing.jl, you wind up with a circular reference in one of the args to AD, I think due to boxing associated to constructing a Turing.jl model inside of a testset, and some weirdness there (usually, we don't have circular references in Turing.jl models, thankfully).

Anyway, I'm going to make a tiny version of this PR in which set_to_zero!! is fixed, and get that merged in the next hour or two -- I'd rather we didn't have to pin the version of DI in Turing.jl.

This does raise a separate question though: why are things getting boxed, and generating circular references, in the Turing.jl tests? It's possible that this kind of thing will have performance implications for the test suite, so it might actually be good to look into what's going on there if we want to get the test suite to run more quickly 🤷

@willtebbutt
Copy link
Member Author

willtebbutt commented Dec 13, 2024

@devmotion kindly pointed me towards this discussion from slack. TLDR, I should try and find a way to avoid using deepcopy here. I'll look into this.

edit: another reason that I need to be able to produce an independent version of the inputs to a function is to ensure that the state of everything after running the function is the same as what is produced by running the forwards-pass of AD. I really don't know how I would do this without something deepcopy-like in nature. I think we might have to roll our own, but I want to get my head around the issues discussed in the julia issue linked above before doing so.

@yebai
Copy link
Contributor

yebai commented Dec 16, 2024

another reason that I need to be able to produce an independent version of the inputs to a function is to ensure that the state of everything after running the function is the same as what is produced by running the forwards-pass of AD.

Here is another reason to find a good solution to this problem. The operation above is very similar to copying tapes in Libtask. During tape copying, one needs to recursively duplicate all data (e.g. input arguments to all instructions) on the tape, producing a completely independent new tape that can be continued.

@willtebbutt
Copy link
Member Author

This is blocked by JuliaLang/julia#56775

@willtebbutt
Copy link
Member Author

A fix for JuliaLang/julia#56775 has now been merged, and should appear in 1.11.3.

Copy link
Contributor

github-actions bot commented Jan 22, 2025

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬─────────┬─────────────┬─────────┐
│                      Label │ Mooncake │  Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │  String │      String │  String │
├────────────────────────────┼──────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │     83.2 │     1.1 │         5.7 │    8.11 │
│                  _sum_1000 │     6.75 │  1710.0 │        34.4 │    1.09 │
│               sum_sin_1000 │     2.26 │    1.64 │        10.7 │    1.97 │
│              _sum_sin_1000 │     2.61 │   309.0 │        13.0 │    2.41 │
│                   kron_sum │     59.7 │    3.68 │       212.0 │    9.27 │
│              kron_view_sum │     62.5 │    9.14 │       227.0 │   101.0 │
│      naive_map_sin_cos_exp │     2.53 │ missing │        7.48 │    2.33 │
│            map_sin_cos_exp │     2.69 │    1.43 │         6.0 │    3.13 │
│      broadcast_sin_cos_exp │     2.57 │    3.75 │        1.48 │    2.26 │
│                 simple_mlp │     5.11 │    3.19 │        5.03 │     1.5 │
│                     gp_lml │     14.3 │    7.06 │     missing │    11.6 │
│ turing_broadcast_benchmark │     3.16 │ missing │        26.0 │ missing │
│         large_single_block │     4.42 │  4550.0 │        30.3 │    2.18 │
└────────────────────────────┴──────────┴─────────┴─────────────┴─────────┘

Copy link

codecov bot commented Jan 22, 2025

Codecov Report

Attention: Patch coverage is 93.00412% with 34 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/tangents.jl 88.43% 17 Missing ⚠️
src/rrules/twice_precision.jl 0.00% 7 Missing ⚠️
src/test_utils.jl 96.00% 3 Missing ⚠️
src/fwds_rvs_data.jl 90.90% 2 Missing ⚠️
src/rrules/iddict.jl 96.42% 2 Missing ⚠️
src/rrules/tasks.jl 90.00% 2 Missing ⚠️
src/rrules/memory.jl 98.73% 1 Missing ⚠️
Files with missing lines Coverage Δ
ext/MooncakeCUDAExt.jl 95.00% <100.00%> (+4.37%) ⬆️
src/rrules/array_legacy.jl 48.37% <100.00%> (-6.58%) ⬇️
src/rrules/function_wrappers.jl 98.98% <100.00%> (+0.02%) ⬆️
src/rrules/memory.jl 96.75% <98.73%> (+0.34%) ⬆️
src/fwds_rvs_data.jl 95.46% <90.90%> (-0.47%) ⬇️
src/rrules/iddict.jl 96.29% <96.42%> (-0.65%) ⬇️
src/rrules/tasks.jl 79.68% <90.00%> (+0.05%) ⬆️
src/test_utils.jl 92.42% <96.00%> (-0.69%) ⬇️
src/rrules/twice_precision.jl 0.00% <0.00%> (-96.93%) ⬇️
src/tangents.jl 84.36% <88.43%> (-0.25%) ⬇️

... and 5 files with indirect coverage changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants