Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Allow structural recursion without triggering edge cycle limiting #29294

Closed
wants to merge 1 commit into from

Conversation

Keno
Copy link
Member

@Keno Keno commented Sep 21, 2018

This attempts to fix inference for the case in #29293 (the one
returning Any). It does not fix the cache poisoning part of that
issue, which is a separate concern. The idea here is that we avoid
applying limiting if the argtypes of the frame become strictly simpler
(thus guaranteeing eventual termination). It is important that the
complexity relation be transitive and anti-reflexive.

return 0
end

function argtypes_strictly_less_comples(@nospecialize(patypes), @nospecialize(catypes))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argtypes_strictly_less_comples --> argtypes_strictly_less_complex?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's the second-person singular present active indicative of the Latin "compleō".

Copy link
Member

@vtjnash vtjnash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the right way to handle this. But it does seem to be an interesting idea: the conceptual idea here appears that it should give a proof that tail-recursion is convergent, thus permitting a whole new class of problems to be inferred! There's just some remaining effort to be done to demonstrate the expected convergence rates and fix the implementation.

@@ -207,6 +207,40 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector
return result
end

# Returns -1 when patype is less complex than catype, 0 when they are equal, 1 when patype is more complex
# May overapproximate and return 0 (e.g. when the result is indeterminate because the types are not comparable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 isn't an overapproximation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, you caught that, did you ;). You're right of course. That was just a quick way to fix the case I was interested in. -1 is the correct approximate fallback.

@@ -261,8 +295,10 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
if !argtypes_strictly_less_comples(get_argtypes(parent.result), get_argtypes(sv.result))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_argtypes is very slow, and thus is not permitted to be used. It's also only an approximation, so again, not permitted to be used.

Copy link
Member

@jrevels jrevels Sep 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the result of actually computing get_argtypes should already be cached here, yeah? I don't think we ever construct an InferenceResult instance without immediately computing the argtypes. So it should basically just be a field access by this point.

I did some refactoring in #28955 to make things like this less confusing - after that gets merged, get_argtypes(parent.result) here would just be parent.result.argtypes.

@Keno
Copy link
Member Author

Keno commented Sep 21, 2018

the conceptual idea here appears that it should give a proof that tail-recursion is convergent, thus permitting a whole new class of problems to be inferred!

Correct, basically the simple version of this: http://adam.chlipala.net/cpdt/html/GeneralRec.html (for one particular fixed relation, but for the same reasons).

@Keno Keno force-pushed the kf/inferencecomplexity branch from 0a7dba3 to 03fbdf4 Compare October 9, 2018 01:28
@Keno
Copy link
Member Author

Keno commented Oct 9, 2018

For my own future reference, a simple testcase that requires this:

Base.Broadcast._bcs($((Base.OneTo(2), Base.OneTo(14), Base.OneTo(512), Base.OneTo(1))), $((Base.OneTo(2), Base.OneTo(1), Base.OneTo(512), Base.OneTo(1))))

This attempts to fix inference for the case in #29293 (the one
returning `Any`). It does not fix the cache poisoning part of that
issue, which is a separate concern. The idea here is that we avoid
applying limiting if the argtypes of the frame become strictly simpler
(thus guaranteeing eventual termination). It is important that the
complexity relation be transitive and anti-reflexive.
@Keno Keno force-pushed the kf/inferencecomplexity branch from 03fbdf4 to 5a1eb24 Compare October 9, 2018 04:27
@vtjnash
Copy link
Member

vtjnash commented Oct 9, 2018

That result is only different because this PR is incorrect. That particular pattern just runs into intentional heuristics that decide when we should stop running constant-propagation (for inference performance, we disable all constant-propagation when we encounter recursion).

@Keno
Copy link
Member Author

Keno commented Oct 9, 2018

Well, I would like it not do that, because I need that function to infer ;)

@Keno
Copy link
Member Author

Keno commented Oct 9, 2018

Also, this was with aggressive const prop enabled, so that may or may not have made a difference. In general, I'd like an inference mode that infers as much as possible as long as termination is guaranteed. If that needs to be a flag, that'd be ok with me.

Keno added a commit that referenced this pull request Oct 26, 2018
The inference enhancements in #29294 work quite well to prevent limiting
on many kinds of code. However, targetting TPUs, one code pattern it
struggeled with was a fairly large broadcast fusion in Flux:

     λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))

The reason #29294 is because the make_makeargs function used by the
implementation of Broadcast.flatten (which the TPU backend uses) had
a non-decreasing first argument (passing the return value of a previous
invocation of make_makeargs back in as the first argument). However,
that's not a fundamental limitation of the operation, but rather an
implementation choice. This PR switches that function's recursion pattern
to be purely structural, allowing inference to infer through it (with
the changes in #29294). As a result, ResNet50 infers properly.
Keno added a commit that referenced this pull request Oct 26, 2018
The inference enhancements in #29294 work quite well to prevent limiting
on many kinds of code. However, targetting TPUs, one code pattern it
struggeled with was a fairly large broadcast fusion in Flux:

     λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))

The reason #29294 is because the make_makeargs function used by the
implementation of Broadcast.flatten (which the TPU backend uses) had
a non-decreasing first argument (passing the return value of a previous
invocation of make_makeargs back in as the first argument). However,
that's not a fundamental limitation of the operation, but rather an
implementation choice. This PR switches that function's recursion pattern
to be purely structural, allowing inference to infer through it (with
the changes in #29294). As a result, ResNet50 infers properly.
Keno added a commit that referenced this pull request Oct 28, 2018
* Make broadcast recursion in `flatten` structural

The inference enhancements in #29294 work quite well to prevent limiting
on many kinds of code. However, targetting TPUs, one code pattern it
struggeled with was a fairly large broadcast fusion in Flux:

     λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))

The reason #29294 is because the make_makeargs function used by the
implementation of Broadcast.flatten (which the TPU backend uses) had
a non-decreasing first argument (passing the return value of a previous
invocation of make_makeargs back in as the first argument). However,
that's not a fundamental limitation of the operation, but rather an
implementation choice. This PR switches that function's recursion pattern
to be purely structural, allowing inference to infer through it (with
the changes in #29294). As a result, ResNet50 infers properly.

* Comment spelling fix

Co-Authored-By: mbauman <mbauman@gmail.com>
@Keno
Copy link
Member Author

Keno commented Nov 9, 2018

I wonder if what we need to do is assign each type a complexity score (e.g. recursively sum up the number of parameters leaves) and then simply assert that the total complexity score of the signature must decrease. That would capture various shifting around of complexity between arguments.

@StefanKarpinski
Copy link
Member

That seems like a pretty solid approach. I guess the trick is designing the complexity score.

@StefanKarpinski
Copy link
Member

StefanKarpinski commented Nov 9, 2018

Maybe something like this:

Non-parametric types: 1
Parametric types: 1 + sum of parameters
Tuples: 1 + sum of tuple elements
Unions: 1 + sum of union elements

Not sure about the 1 + parts, maybe should just be a bare sum.

@vtjnash
Copy link
Member

vtjnash commented Nov 9, 2018

Yes, I think we call that size and/or depth. I wrote a couple of blog posts about it.

@Keno
Copy link
Member Author

Keno commented Nov 9, 2018

Sure, but I'm saying combine them into one number for the whole signature

@StefanKarpinski
Copy link
Member

It's basically the number of nodes in the "signature tree": top-level nodes are the call arguments, parametric types, tuples and unions are internal nodes, and non-parametric types are leaves.

@Keno
Copy link
Member Author

Keno commented Nov 9, 2018

non-parametric types and non-types, but yes

@Keno
Copy link
Member Author

Keno commented Nov 9, 2018

Ideally, we'd also cache this value in type object, so we don't have to recursively walk it every time.

@vtjnash
Copy link
Member

vtjnash commented Nov 9, 2018

Alright, great. We could even call it size and/or depth:

julia/src/jltypes.c

Lines 1017 to 1036 in b89e88e

static size_t jl_type_depth(jl_value_t *dt)
{
if (jl_is_uniontype(dt)) {
size_t ad = jl_type_depth(((jl_uniontype_t*)dt)->a);
size_t bd = jl_type_depth(((jl_uniontype_t*)dt)->b);
return ad > bd ? ad : bd;
}
else if (jl_is_unionall(dt)) {
jl_unionall_t *ua = (jl_unionall_t*)dt;
size_t bd = jl_type_depth(ua->body);
if (ua->var->ub == (jl_value_t*)jl_any_type)
return bd;
size_t vd = jl_type_depth(ua->var->ub);
return vd+1 > bd ? vd+1 : bd;
}
else if (jl_is_datatype(dt)) {
return ((jl_datatype_t*)dt)->depth;
}
return 0;
}

@Keno
Copy link
Member Author

Keno commented Nov 9, 2018

Yes, but that's not quite the correct notion yet. We can probably extend it.

@StefanKarpinski
Copy link
Member

The term "depth" for a tree has a well-established meaning—and this is not it. The "size" of a tree is closer to the correct term, but calling this the "size" of a type would be extremely confusing, so please don't use either of those terms for this. Calling this the "tree size" of a type would be ok.

@StefanKarpinski
Copy link
Member

The depth of a type's tree also does not have the right property at all: if you take a bunch of types with depth d1, ..., dn and you combine them into a single argument as type parameters or types of the slots in a tuple, then the depth is max(d1, ..., dn) + 1. We want a measure that has the property that it is additive in this case.

@Keno
Copy link
Member Author

Keno commented Nov 9, 2018

We want a measure that has the property that it is additive in this case.

Correct.

@StefanKarpinski
Copy link
Member

StefanKarpinski commented Nov 9, 2018

Unclear to me if we want it to be super-additive or exactly additive. Super-additive would make it legal to take a structure apart but not construct a structure from individual pieces whereas exactly additive would make both legal. It seems like what one wants is some leeway to take k steps at the same level of complexity, so you probably want a counter that resets to k every time the total argument complexity drops and which counts down every time it stays the same or grows, bailing out when the counter reaches zero. Exactly additive seems most permissive, allowing arbitrary juggling.

@StefanKarpinski
Copy link
Member

Alternatively, count total argument complexity at the start and bail out if it ever gets larger than a fixed amount above that OR if it takes more than some upper limit on the number of steps.

@chethega
Copy link
Contributor

chethega commented Nov 9, 2018

The term "depth" for a tree has a well-established meaning—and this is not it. The "size" of a tree is closer to the correct term, but calling this the "size" of a type would be extremely confusing, so please don't use either of those terms for this. Calling this the "tree size" of a type would be ok.

Weight sounds nicely additive.

Unclear to me if we want it to be super-additive or exactly additive.

Pure sum of tuple elements makes () and ((),), (((),),) etc have all the same weight, right? That feels wrong and probably allows Turing-completeness within the same weight.

What about: store lowest weight seen so far, and how many steps back this was. When we get something lighter, update lowest_seen and last_seen counter. Bail if we ever see something that is C_gain heavier than lowest_seen or if last_seen is more than N_speed steps ago. In other words, permit juggling that temporarily makes things moderately worse.

Maybe count positive integers with val, and with 1 otherwise; this nicely generalizes NTuples, makes Vector lighter/simpler than Matrix and, very importantly, makes Val{1} et al lighter than Val{2}.

@StefanKarpinski
Copy link
Member

Pure sum of tuple elements makes () and ((),), (((),),) etc have all the same weight, right? That feels wrong and probably allows Turing-completeness within the same weight.

Good point: tuples need to have weight one more than their parts. Probably similar for others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants