Skip to content

Commit

Permalink
Improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jan 28, 2025
1 parent daf9286 commit 6df788c
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4578,15 +4578,15 @@ def test_pow_1_rewrite(shape):
x = pt.tensor("x", shape=shape)
z = 1**x

f1 = pytensor.function([x], z, mode=get_default_mode().excluding("canonicalize"))
assert debugprint(f1, file="str").count("Pow") == 1

x_val = np.random.random(shape).astype(config.floatX)
z_val_1 = f1(x_val)

f2 = pytensor.function([x], z)
assert debugprint(f2, file="str").count("Pow") == 0
assert isinstance(z.owner.op, Elemwise) and isinstance(
z.owner.op.scalar_op, ps.basic.Pow
)

z_val_2 = f2(x_val)
f = pytensor.function([x], z)
assert not any(
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.basic.Pow)
for node in f.maker.fgraph.toposort()
)

np.testing.assert_allclose(z_val_1, z_val_2)
x_val = np.random.random(shape).astype(config.floatX)
np.testing.assert_allclose(z.eval({x: x_val}), f(x_val))

0 comments on commit 6df788c

Please sign in to comment.