Skip to content

Commit

Permalink
Use broadcast_arrays in rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jan 28, 2025
1 parent d873bff commit daf9286
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,7 +1918,7 @@ def local_pow_canonicalize(fgraph, node):
node.inputs[0], only_process_constants=True, raise_not_constant=False
)
if cst_base == 1:
return [alloc_like(1, node.outputs[0], fgraph)]
return [broadcast_arrays(*node.inputs)[0].astype(node.outputs[0].dtype)]

cst_exponent = get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
Expand Down

0 comments on commit daf9286

Please sign in to comment.