From 1f35d9d68e8bd11deb4cb9c74762c00d99d4a5dc Mon Sep 17 00:00:00 2001 From: mrava87 Date: Tue, 19 Jul 2022 16:49:36 +0300 Subject: [PATCH] bug: fix ista/fista for cupy arrays --- pylops/optimization/sparsity.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pylops/optimization/sparsity.py b/pylops/optimization/sparsity.py index 82a34f8e..70ddb04a 100644 --- a/pylops/optimization/sparsity.py +++ b/pylops/optimization/sparsity.py @@ -906,7 +906,7 @@ def ISTA( Op1, niter=eigsiter, tol=eigstol, dtype=Op1.dtype, backend="cupy" )[0] ) - alpha = 1.0 / maxeig + alpha = 1.0 / float(maxeig) # define threshold thresh = eps * alpha * 0.5 @@ -959,8 +959,8 @@ def ISTA( normresold = normres # compute gradient - grad = alpha * Op.H @ res - + grad = alpha * (Op.H @ res) + # update inverted model xinv_unthesh = xinv + grad if SOp is not None: @@ -1211,7 +1211,7 @@ def FISTA( Op1, niter=eigsiter, tol=eigstol, dtype=Op1.dtype, backend="cupy" )[0] ) - alpha = 1.0 / maxeig + alpha = 1.0 / float(maxeig) # define threshold thresh = eps * alpha * 0.5 @@ -1254,7 +1254,7 @@ def FISTA( resz = data - Op @ zinv # compute gradient - grad = alpha * Op.H @ resz + grad = alpha * (Op.H @ resz) # update inverted model xinv_unthesh = zinv + grad