-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rewrite for 1 ** x = 1
#1179
Conversation
pytensor/tensor/rewriting/math.py
Outdated
|
||
Parameters | ||
---------- | ||
fgraph: FunctionGraph | ||
Full function graph being rewritten | ||
node: Apply | ||
Specific node being rewritten | ||
|
||
Returns | ||
------- | ||
rewritten_output: list[TensorVariable] | None | ||
Rewritten output of node, or None if no rewrite is possible | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should add docstrings for rewrites, just adds lines to the codebase. Nobody will be calling this function manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose I agree w.r.t Parameters and Returns, but there should at least be a small explainer of what the rewrite does.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1179 +/- ##
=======================================
Coverage 82.27% 82.27%
=======================================
Files 186 186
Lines 48000 48009 +9
Branches 8621 8624 +3
=======================================
+ Hits 39490 39499 +9
Misses 6353 6353
Partials 2157 2157
|
pytensor/tensor/rewriting/math.py
Outdated
node.inputs[0], only_process_constants=True, raise_not_constant=False | ||
) | ||
if cst == 1: | ||
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could make an infinite recursion if the ShapeOpt
is not running, as it will default to alloc(1, *pow(1, x).shape)
.
You can do pt.broadcast_arrays(*node.inputs)[0]
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but does that mean we should also change local_canonicalize_pow
? Because I copied the return from there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines to that rewrite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
found it yes. Why don't you combine your changes with that rewrite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check how it looks now? I had to set the dtype to the output as well, not sure if there's a better way
pytensor/tensor/rewriting/math.py
Outdated
node.inputs[0], only_process_constants=True, raise_not_constant=False | ||
) | ||
if cst_base == 1: | ||
return [broadcast_arrays(*node.inputs)[0].astype(node.outputs[0].dtype)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can clean up a bit by storing inp_idx
in the 3 branches and doing something like:
inp_idx = None
if case1:
inp_idx = 1
elif case 2
inp_idx = 0
...
if inp_idx is None:
return None
new_out = broadcast_arrays(*node.inputs)[inp_idx]
if new out.dtype != node.out.dtype:
new_out = cast(...)
return [new_out]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my concern about alloc_like applies to the old code as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it can be as clean as this, because in the x ** 0 case we're not using either input value in the output. But I tried to make it look more like this.
Description
Small rewrite that simplifies powers of base 1, since
1 ** x = 1
for any x.Related Issue
1 ** x
to1
#1177Checklist
Type of change