From daf9286581e36bcdd35934add3358714d69711a8 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 28 Jan 2025 15:33:13 +0800 Subject: [PATCH] Use `broadcast_arrays` in rewrite --- pytensor/tensor/rewriting/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 8416ebc8b2..e10df6e877 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1918,7 +1918,7 @@ def local_pow_canonicalize(fgraph, node): node.inputs[0], only_process_constants=True, raise_not_constant=False ) if cst_base == 1: - return [alloc_like(1, node.outputs[0], fgraph)] + return [broadcast_arrays(*node.inputs)[0].astype(node.outputs[0].dtype)] cst_exponent = get_underlying_scalar_constant_value( node.inputs[1], only_process_constants=True, raise_not_constant=False