Skip to content

Commit

Permalink
plot
Browse files Browse the repository at this point in the history
  • Loading branch information
rfgzuid committed Jun 13, 2024
1 parent 8d71568 commit 8bfea3f
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 559 deletions.
711 changes: 235 additions & 476 deletions experiments/PIV Test.ipynb

Large diffs are not rendered by default.

Binary file added experiments/SAD.npy
Binary file not shown.
Binary file added experiments/optical.npy
Binary file not shown.
Binary file added experiments/refine.npy
Binary file not shown.
112 changes: 49 additions & 63 deletions src/SIV_library/advanced.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch
from torch.nn.functional import grid_sample, interpolate
from torchvision.transforms import Resize, InterpolationMode
from torch.utils.data import DataLoader

from src.SIV_library.lib import OpticalFlow, SIV
from src.SIV_library.optical_flow import optical_flow

from collections.abc import Generator
from tqdm import tqdm


class Warp(torch.nn.Module):
Expand All @@ -14,8 +17,6 @@ class Warp(torch.nn.Module):
def __init__(self, x, y, u, v):
super().__init__()
self.x, self.y, self.u, self.v = x, y, u, v
self.idx = 0

self.apply_to = ['a'] # apply this transform to ONLY the first image of the pair

def forward(self, image: torch.Tensor) -> torch.Tensor:
Expand All @@ -24,22 +25,20 @@ def forward(self, image: torch.Tensor) -> torch.Tensor:

x, y = self.x / ((cols - 1) / 2) - 1, self.y / ((rows - 1) / 2) - 1

grid = torch.stack((x[self.idx], y[self.idx]), dim=-1).to(x.device)
v_grid = grid + torch.stack((-self.u[self.idx] / (cols / 2), self.v[self.idx] / (rows / 2)), dim=-1)
grid = torch.stack((x, y), dim=-1).to(x.device)
v_grid = grid + torch.stack((-self.u / (cols / 2), self.v / (rows / 2)), dim=-1)

img_new = grid_sample(image.float(), v_grid[None, :, :, :], mode='bicubic').to(torch.uint8)
self.idx += 1
img_new = grid_sample(image.float(), v_grid, mode='bicubic').to(torch.uint8)
return img_new

def interpolate_field(self, img_shape) -> None:
if self.u.shape[-2:] == img_shape:
return

self.u = interpolate(self.u[None, :, :, :], img_shape, mode='bicubic').squeeze(dim=0)
self.v = interpolate(self.v[None, :, :, :], img_shape, mode='bicubic').squeeze(dim=0)
self.u = interpolate(self.u[None, None, :, :], img_shape, mode='bicubic').squeeze(dim=0)
self.v = interpolate(self.v[None, None, :, :], img_shape, mode='bicubic').squeeze(dim=0)

y, x = torch.meshgrid(torch.arange(0, img_shape[0], 1), torch.arange(0, img_shape[1], 1))
x, y = x.expand(self.x.shape[0], -1, -1), y.expand(self.y.shape[0], -1, -1)
self.x, self.y = x.to(self.x.device), y.to(self.y.device)


Expand All @@ -49,80 +48,67 @@ class CTF:
https://www.ipol.im/pub/art/2013/20/article.pdf
"""
def __init__(self, optical: OpticalFlow, num_passes: int = 3, scale_factor: float = 1/2):
self.optical = optical
self.dataset = optical.dataset
self.num_passes, self.scale_factor = num_passes, scale_factor
self.alpha, self.num_iter, self.eps = optical.alpha, optical.num_iter, optical.eps

def __len__(self) -> int:
return len(self.optical.dataset)
return len(self.dataset)

def __call__(self) -> Generator:
img_shape = self.optical.dataset.img_shape
img_shape = self.dataset.img_shape
scales = [self.scale_factor ** (self.num_passes - p - 1) for p in range(self.num_passes)]
sizes = [(round(img_shape[0] * scale), round(img_shape[1] * scale)) for scale in scales]

for size in sizes:
# y, x = torch.meshgrid(torch.arange(0, size[0], 1), torch.arange(0, size[1], 1))
# resize = Resize(size, InterpolationMode.BICUBIC)
# resize.apply_to = ['a', 'b'] # apply the resize transform to both images in the pair
# warp = Warp(x, y, u, v)
#
# optical.dataset.img_shape = size
# optical.dataset.transforms = [resize, warp]
#
# for
pass


def ctf_optical(optical: OpticalFlow, num_passes: int = 3, scale_factor: float = 1/2):
"""
runs the optical flow algorithm in a coarse-to-fine pyramidal structure, allowing for larger displacements
https://www.ipol.im/pub/art/2013/20/article.pdf
"""
img_shape = optical.dataset.img_shape
scales = [scale_factor ** (num_passes - p - 1) for p in range(num_passes)]
sizes = [(round(img_shape[0] * scale), round(img_shape[1] * scale)) for scale in scales]
loader = DataLoader(self.dataset)
for a, b in tqdm(loader):
for idx, size in enumerate(sizes):
y, x = torch.meshgrid(torch.arange(0, size[0], 1), torch.arange(0, size[1], 1))
x, y = x.to(self.dataset.device), y.to(self.dataset.device)

u = torch.zeros((len(optical.dataset), *sizes[0]), device=optical.device)
v = torch.zeros((len(optical.dataset), *sizes[0]), device=optical.device)
resize = Resize(size, InterpolationMode.BICUBIC)
resize.apply_to = ['a', 'b'] # apply the resize transform to both images in the pair

for idx, size in enumerate(sizes):
y, x = torch.meshgrid(torch.arange(0, size[0], 1), torch.arange(0, size[1], 1))
x = x.expand(len(optical.dataset), -1, -1).to(optical.device)
y = y.expand(len(optical.dataset), -1, -1).to(optical.device)
transforms = [resize, Warp(x, y, u, v)] if idx != 0 else [resize]
for t in transforms:
a = t(a[None, :, :, :]).squeeze(0) if 'a' in t.apply_to else a
b = t(b[None, :, :, :]).squeeze(0) if 'b' in t.apply_to else b

resize = Resize(size, InterpolationMode.BICUBIC)
resize.apply_to = ['a', 'b'] # apply the resize transform to both images in the pair
warp = Warp(x, y, u, v)
du, dv = optical_flow(a, b, self.alpha, self.num_iter, self.eps)
u, v = (u + du, v + dv) if idx != 0 else (du, dv)

optical.dataset.img_shape = size
optical.dataset.transforms = [resize, warp]
if idx < self.num_passes - 1:
u = interpolate(u[None, None, :, :], sizes[idx + 1], mode='bicubic').squeeze()
v = interpolate(v[None, None, :, :], sizes[idx + 1], mode='bicubic').squeeze()

_, _, du, dv = optical.run()
u, v = u + du, v + dv
u, v = u / self.scale_factor, v / self.scale_factor
yield x, y, u, v

if idx < num_passes - 1:
u = interpolate(u[None, :, :, :], sizes[idx + 1], mode='bicubic').squeeze()
v = interpolate(v[None, :, :, :], sizes[idx + 1], mode='bicubic').squeeze()

u, v = u / scale_factor, v / scale_factor
return x, y, u, v


def match_refine(matching: SIV, optical: OpticalFlow):
class Refine:
"""
runs the matching algorithm and refines the result with optical flow
https://link-springer-com.tudelft.idm.oclc.org/article/10.1007/s00348-019-2820-4?fromPaywallRec=false
"""
img_shape = matching.dataset.img_shape
x, y, u, v = matching.run()
def __init__(self, match: SIV, optical: OpticalFlow):
self.match, self.optical = match, optical
self.dataset = match.dataset
self.alpha, self.num_iter, self.eps = optical.alpha, optical.num_iter, optical.eps

def __len__(self) -> int:
return len(self.dataset)

def __call__(self) -> Generator:
loader = DataLoader(self.dataset)
for a, b in tqdm(loader, total=len(loader)):
x, y, u, v = self.match.run(a, b)

warp = Warp(x, y, u, v)
warp = Warp(x, y, u, v)
a, b = warp(a[None, :, :, :]).squeeze(0), warp(b[None, :, :, :]).squeeze(0)

warp.interpolate_field(img_shape)
x, y, u, v = warp.x, warp.y, warp.u, warp.v
x, y, u, v = warp.x, warp.y, warp.u, warp.v

optical.dataset.transforms = [warp]
du, dv = self.optical.run(a, b)
u, v = u + du, v + dv

_, _, du, dv = optical.run()
u, v = u + du, v + dv
return x, y, u, v
yield x, y, u, v
44 changes: 24 additions & 20 deletions src/SIV_library/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,32 @@ def __len__(self) -> int:
return len(self.dataset)

def __call__(self) -> Generator:
img_shape = self.dataset.img_shape
scales = [self.scale_factor ** p for p in range(self.num_passes)]

loader = DataLoader(self.dataset)
for a, b in tqdm(loader, total=len(loader), desc="SAD" if self.mode == 1 else "Correlation"):
for i, scale in enumerate(scales):
window_size, overlap = int(self.window_size * scale), int(self.overlap * scale)
search_area = tuple(int(pad * scale) for pad in self.search_area)
yield self.multipass(a, b)

n_rows, n_cols = get_field_shape(img_shape, window_size, overlap)
xp, yp = get_x_y(img_shape, window_size, overlap)
xp, yp = xp.reshape(n_rows, n_cols).to(self.device), yp.reshape(n_rows, n_cols).to(self.device)
def run(self, a, b):
scales = [self.scale_factor ** p for p in range(self.num_passes)]
for i, scale in enumerate(scales):
window_size, overlap = int(self.window_size * scale), int(self.overlap * scale)
search_area = tuple(int(pad * scale) for pad in self.search_area)

if i == 0:
window = window_array(a, window_size, overlap)
area = window_array(b, window_size, overlap, area=search_area)
else:
shift = WindowShift(img_shape, window_size, overlap, search_area, self.device)
window, area, up, vp = shift.run(a, b, xp, yp, up, vp)
n_rows, n_cols = get_field_shape(self.dataset.img_shape, window_size, overlap)
xp, yp = get_x_y(self.dataset.img_shape, window_size, overlap)
xp, yp = xp.reshape(n_rows, n_cols).to(self.device), yp.reshape(n_rows, n_cols).to(self.device)

match = block_match(window, area, self.mode)
du, dv = correlation_to_displacement(match, search_area, n_rows, n_cols, self.mode)
if i == 0:
window = window_array(a, window_size, overlap)
area = window_array(b, window_size, overlap, area=search_area)
else:
shift = WindowShift(self.dataset.img_shape, window_size, overlap, search_area, self.device)
window, area, up, vp = shift.run(a, b, xp, yp, up, vp)

up, vp = (du, dv) if i == 0 else (up + du, vp + dv)
yield xp, yp, up, -vp
match = block_match(window, area, self.mode)
du, dv = correlation_to_displacement(match, search_area, n_rows, n_cols, self.mode)

up, vp = (du, dv) if i == 0 else (up + du, vp + dv)
return xp, yp, up, -vp


class OpticalFlow:
Expand All @@ -114,5 +115,8 @@ def __call__(self) -> Generator:

loader = DataLoader(self.dataset)
for a, b in tqdm(loader, total=len(loader), desc='Optical flow'):
du, dv = optical_flow(a, b, self.alpha, self.num_iter, self.eps)
du, dv = self.run(a, b)
yield x, y, du, -dv

def run(self, a, b):
return optical_flow(a, b, self.alpha, self.num_iter, self.eps)

0 comments on commit 8bfea3f

Please sign in to comment.