Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite for 1 ** x = 1 #1179

Merged
merged 6 commits into from
Jan 28, 2025
Merged

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 27, 2025

Description

Small rewrite that simplifies powers of base 1, since 1 ** x = 1 for any x.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Comment on lines 1925 to 1937

Parameters
----------
fgraph: FunctionGraph
Full function graph being rewritten
node: Apply
Specific node being rewritten

Returns
-------
rewritten_output: list[TensorVariable] | None
Rewritten output of node, or None if no rewrite is possible
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should add docstrings for rewrites, just adds lines to the codebase. Nobody will be calling this function manually

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose I agree w.r.t Parameters and Returns, but there should at least be a small explainer of what the rewrite does.

Copy link

codecov bot commented Jan 27, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.27%. Comparing base (4ea4259) to head (c4f662a).
Report is 2 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1179   +/-   ##
=======================================
  Coverage   82.27%   82.27%           
=======================================
  Files         186      186           
  Lines       48000    48009    +9     
  Branches     8621     8624    +3     
=======================================
+ Hits        39490    39499    +9     
  Misses       6353     6353           
  Partials     2157     2157           
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/math.py 88.98% <100.00%> (+0.06%) ⬆️

node.inputs[0], only_process_constants=True, raise_not_constant=False
)
if cst == 1:
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could make an infinite recursion if the ShapeOpt is not running, as it will default to alloc(1, *pow(1, x).shape).

You can do pt.broadcast_arrays(*node.inputs)[0] instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but does that mean we should also change local_canonicalize_pow ? Because I copied the return from there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines to that rewrite?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found it yes. Why don't you combine your changes with that rewrite?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check how it looks now? I had to set the dtype to the output as well, not sure if there's a better way

tests/tensor/rewriting/test_math.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_math.py Outdated Show resolved Hide resolved
node.inputs[0], only_process_constants=True, raise_not_constant=False
)
if cst_base == 1:
return [broadcast_arrays(*node.inputs)[0].astype(node.outputs[0].dtype)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can clean up a bit by storing inp_idx in the 3 branches and doing something like:

inp_idx = None
if case1:
  inp_idx = 1
elif case 2
  inp_idx = 0
...

if inp_idx is None:
  return None

new_out = broadcast_arrays(*node.inputs)[inp_idx]

if new out.dtype != node.out.dtype:
  new_out = cast(...)

return [new_out]

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my concern about alloc_like applies to the old code as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it can be as clean as this, because in the x ** 0 case we're not using either input value in the output. But I tried to make it look more like this.

@ricardoV94 ricardoV94 added the enhancement New feature or request label Jan 28, 2025
@ricardoV94 ricardoV94 merged commit b065112 into pymc-devs:main Jan 28, 2025
64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rewrite 1 ** x to 1
2 participants