diff --git a/pytorch_wavelets/dtcwt/lowlevel.py b/pytorch_wavelets/dtcwt/lowlevel.py index 0112719..88fa267 100644 --- a/pytorch_wavelets/dtcwt/lowlevel.py +++ b/pytorch_wavelets/dtcwt/lowlevel.py @@ -59,8 +59,7 @@ def prep_filt(h, c, transpose=False): """ Prepares an array to be of the correct format for pytorch. Can also specify whether to make it a row filter (set tranpose=True)""" h = _as_col_vector(h)[::-1] - #h = np.reshape(h, [1, 1, *h.shape]) - h = np.expand_dims(h, (0,1)) + h = h[None, None, :] h = np.repeat(h, repeats=c, axis=0) if transpose: h = h.transpose((0,1,3,2))