-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Allow structural recursion without triggering edge cycle limiting #29294
Conversation
return 0 | ||
end | ||
|
||
function argtypes_strictly_less_comples(@nospecialize(patypes), @nospecialize(catypes)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argtypes_strictly_less_comples
--> argtypes_strictly_less_complex
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's the second-person singular present active indicative of the Latin "compleō".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't the right way to handle this. But it does seem to be an interesting idea: the conceptual idea here appears that it should give a proof that tail-recursion is convergent, thus permitting a whole new class of problems to be inferred! There's just some remaining effort to be done to demonstrate the expected convergence rates and fix the implementation.
@@ -207,6 +207,40 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector | |||
return result | |||
end | |||
|
|||
# Returns -1 when patype is less complex than catype, 0 when they are equal, 1 when patype is more complex | |||
# May overapproximate and return 0 (e.g. when the result is indeterminate because the types are not comparable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0
isn't an overapproximation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heh, you caught that, did you ;). You're right of course. That was just a quick way to fix the case I was interested in. -1
is the correct approximate fallback.
@@ -261,8 +295,10 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl | |||
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match | |||
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing} | |||
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 | |||
topmost = infstate | |||
edgecycle = true | |||
if !argtypes_strictly_less_comples(get_argtypes(parent.result), get_argtypes(sv.result)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_argtypes
is very slow, and thus is not permitted to be used. It's also only an approximation, so again, not permitted to be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC the result of actually computing get_argtypes
should already be cached here, yeah? I don't think we ever construct an InferenceResult
instance without immediately computing the argtypes. So it should basically just be a field access by this point.
I did some refactoring in #28955 to make things like this less confusing - after that gets merged, get_argtypes(parent.result)
here would just be parent.result.argtypes
.
Correct, basically the simple version of this: http://adam.chlipala.net/cpdt/html/GeneralRec.html (for one particular fixed relation, but for the same reasons). |
0a7dba3
to
03fbdf4
Compare
For my own future reference, a simple testcase that requires this:
|
This attempts to fix inference for the case in #29293 (the one returning `Any`). It does not fix the cache poisoning part of that issue, which is a separate concern. The idea here is that we avoid applying limiting if the argtypes of the frame become strictly simpler (thus guaranteeing eventual termination). It is important that the complexity relation be transitive and anti-reflexive.
03fbdf4
to
5a1eb24
Compare
That result is only different because this PR is incorrect. That particular pattern just runs into intentional heuristics that decide when we should stop running constant-propagation (for inference performance, we disable all constant-propagation when we encounter recursion). |
Well, I would like it not do that, because I need that function to infer ;) |
Also, this was with aggressive const prop enabled, so that may or may not have made a difference. In general, I'd like an inference mode that infers as much as possible as long as termination is guaranteed. If that needs to be a flag, that'd be ok with me. |
The inference enhancements in #29294 work quite well to prevent limiting on many kinds of code. However, targetting TPUs, one code pattern it struggeled with was a fairly large broadcast fusion in Flux: λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) The reason #29294 is because the make_makeargs function used by the implementation of Broadcast.flatten (which the TPU backend uses) had a non-decreasing first argument (passing the return value of a previous invocation of make_makeargs back in as the first argument). However, that's not a fundamental limitation of the operation, but rather an implementation choice. This PR switches that function's recursion pattern to be purely structural, allowing inference to infer through it (with the changes in #29294). As a result, ResNet50 infers properly.
The inference enhancements in #29294 work quite well to prevent limiting on many kinds of code. However, targetting TPUs, one code pattern it struggeled with was a fairly large broadcast fusion in Flux: λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) The reason #29294 is because the make_makeargs function used by the implementation of Broadcast.flatten (which the TPU backend uses) had a non-decreasing first argument (passing the return value of a previous invocation of make_makeargs back in as the first argument). However, that's not a fundamental limitation of the operation, but rather an implementation choice. This PR switches that function's recursion pattern to be purely structural, allowing inference to infer through it (with the changes in #29294). As a result, ResNet50 infers properly.
* Make broadcast recursion in `flatten` structural The inference enhancements in #29294 work quite well to prevent limiting on many kinds of code. However, targetting TPUs, one code pattern it struggeled with was a fairly large broadcast fusion in Flux: λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) The reason #29294 is because the make_makeargs function used by the implementation of Broadcast.flatten (which the TPU backend uses) had a non-decreasing first argument (passing the return value of a previous invocation of make_makeargs back in as the first argument). However, that's not a fundamental limitation of the operation, but rather an implementation choice. This PR switches that function's recursion pattern to be purely structural, allowing inference to infer through it (with the changes in #29294). As a result, ResNet50 infers properly. * Comment spelling fix Co-Authored-By: mbauman <mbauman@gmail.com>
I wonder if what we need to do is assign each type a complexity score (e.g. recursively sum up the number of parameters leaves) and then simply assert that the total complexity score of the signature must decrease. That would capture various shifting around of complexity between arguments. |
That seems like a pretty solid approach. I guess the trick is designing the complexity score. |
Maybe something like this: Non-parametric types: 1 Not sure about the 1 + parts, maybe should just be a bare sum. |
Yes, I think we call that size and/or depth. I wrote a couple of blog posts about it. |
Sure, but I'm saying combine them into one number for the whole signature |
It's basically the number of nodes in the "signature tree": top-level nodes are the call arguments, parametric types, tuples and unions are internal nodes, and non-parametric types are leaves. |
non-parametric types and non-types, but yes |
Ideally, we'd also cache this value in type object, so we don't have to recursively walk it every time. |
Alright, great. We could even call it size and/or depth: Lines 1017 to 1036 in b89e88e
|
Yes, but that's not quite the correct notion yet. We can probably extend it. |
The term "depth" for a tree has a well-established meaning—and this is not it. The "size" of a tree is closer to the correct term, but calling this the "size" of a type would be extremely confusing, so please don't use either of those terms for this. Calling this the "tree size" of a type would be ok. |
The depth of a type's tree also does not have the right property at all: if you take a bunch of types with depth |
Correct. |
Unclear to me if we want it to be super-additive or exactly additive. Super-additive would make it legal to take a structure apart but not construct a structure from individual pieces whereas exactly additive would make both legal. It seems like what one wants is some leeway to take |
Alternatively, count total argument complexity at the start and bail out if it ever gets larger than a fixed amount above that OR if it takes more than some upper limit on the number of steps. |
Weight sounds nicely additive.
Pure sum of tuple elements makes What about: store lowest weight seen so far, and how many steps back this was. When we get something lighter, update lowest_seen and last_seen counter. Bail if we ever see something that is Maybe count positive integers with |
Good point: tuples need to have weight one more than their parts. Probably similar for others. |
This attempts to fix inference for the case in #29293 (the one
returning
Any
). It does not fix the cache poisoning part of thatissue, which is a separate concern. The idea here is that we avoid
applying limiting if the argtypes of the frame become strictly simpler
(thus guaranteeing eventual termination). It is important that the
complexity relation be transitive and anti-reflexive.