From ba0733516d3542d8cebbdf8ecfb178ce440127bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Sun, 10 Jan 2021 14:49:01 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Improve=20GMSD=20run-time?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- piqa/__init__.py | 2 +- piqa/gmsd.py | 9 ++++++--- piqa/ssim.py | 12 ++++++------ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/piqa/__init__.py b/piqa/__init__.py index 6bba34a..75114ad 100644 --- a/piqa/__init__.py +++ b/piqa/__init__.py @@ -5,4 +5,4 @@ specific image quality assessement metric. """ -__version__ = '1.0.8' +__version__ = '1.0.9' diff --git a/piqa/gmsd.py b/piqa/gmsd.py index 0a0c7c1..dc0d86b 100644 --- a/piqa/gmsd.py +++ b/piqa/gmsd.py @@ -69,12 +69,15 @@ def _gmsd( # Gradient magnitude similarity gms_num = (2. - alpha) * gm_xy + c - gms_den = gm_x ** 2 + gm_y ** 2 - alpha * gm_xy + c + gms_den = gm_x ** 2 + gm_y ** 2 + c + + if alpha > 0.: + gms_den = gms_den - alpha * gm_xy + gms = gms_num / gms_den # Gradient magnitude similarity deviation - gmsd = (gms - gms.mean((-1, -2), keepdim=True)) ** 2 - gmsd = torch.sqrt(gmsd.mean((-1, -2))) + gmsd = torch.std(gms, dim=(-1, -2)) return gmsd diff --git a/piqa/ssim.py b/piqa/ssim.py index 7a3467a..c2e3c7e 100644 --- a/piqa/ssim.py +++ b/piqa/ssim.py @@ -89,20 +89,20 @@ def ssim_per_channel( mu_x = filter2d(x, window) mu_y = filter2d(y, window) - mu_x_sq = mu_x ** 2 - mu_y_sq = mu_y ** 2 + mu_xx = mu_x ** 2 + mu_yy = mu_y ** 2 mu_xy = mu_x * mu_y # Variance (sigma) - sigma_x_sq = filter2d(x ** 2, window) - mu_x_sq - sigma_y_sq = filter2d(y ** 2, window) - mu_y_sq + sigma_xx = filter2d(x ** 2, window) - mu_xx + sigma_yy = filter2d(y ** 2, window) - mu_yy sigma_xy = filter2d(x * y, window) - mu_xy # Contrast sensitivity - cs = (2. * sigma_xy + c2) / (sigma_x_sq + sigma_y_sq + c2) + cs = (2. * sigma_xy + c2) / (sigma_xx + sigma_yy + c2) # Structural similarity - ss = (2. * mu_x * mu_y + c1) / (mu_x_sq + mu_y_sq + c1) * cs + ss = (2. * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs return ss.mean((-1, -2)), cs.mean((-1, -2))