Skip to content

Commit

Permalink
latest changes
Browse files Browse the repository at this point in the history
  • Loading branch information
KarenUllrich committed Jun 4, 2019
1 parent 93c80de commit 7c2568e
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 117 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# pytorch-backprojection
This code accompanies "Differentiable probabilistic models of scientific imaging with the Fourier slice theorem", UAI 2019
#

5 changes: 3 additions & 2 deletions observation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn as nn
from torch import distributions

from operators import base_grid_generator3d, translate, SliceExctractor
from operators import base_grid_generator3d, Translate, SliceExctractor


class ScientificImagingObservationModel(nn.Module):
Expand Down Expand Up @@ -43,7 +43,8 @@ def __init__(self, D=128, batch_size=1, std_noise=0.):
self.base_slice = nn.Parameter(base_grid_generator3d((self.batch_size, 2, self.D, self.D, self.D)),
requires_grad=False)
self.extract_slice = SliceExctractor(limit=self.D, batch_size=self.batch_size)
self.translate = translate(self.batch_size, self.D)
self.translate = Translate(self.batch_size, self.D)


def forward(self, protein_samples, rotation_samples, translation_samples):
r"""Module forward method.
Expand Down
243 changes: 243 additions & 0 deletions operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
""" Geometric operators for the Fourier Domain.
Karen Ullrich, May 2019
"""

import numpy as np

import torch
import torch.nn as nn


def base_grid_generator3d(size):
"""Compute grid for the center slice
"""
N, C, H, W, D = size
x = np.linspace(-H / 2, H / 2 - 1, H)
y = np.linspace(-H / 2, H / 2 - 1, H)
base_grid = np.vstack(np.meshgrid(x, y)).reshape(2, -1).T
base_grid = np.hstack([base_grid, np.zeros((H * W, 1))])
base_grid = np.expand_dims(base_grid.reshape(H, W, 1, 3), 0)
base_grid = base_grid.repeat(N, 0)
return nn.Parameter(torch.Tensor(base_grid), requires_grad=False)


def base_grid_generator2d(size):
"""Compute grid for the center slice
"""
N, C, H, W = size
x = np.linspace(-H / 2, H / 2 - 1, H) / (H / 2)
y = np.linspace(-H / 2, H / 2 - 1, H) / (H / 2)
base_grid = np.vstack(np.meshgrid(x, y)).reshape(2, -1).T
base_grid = np.expand_dims(base_grid.reshape(H, W, 2), 0)
base_grid = base_grid.repeat(N, 0)
return nn.Parameter(torch.Tensor(base_grid), requires_grad=False)


class Translate(nn.Module):
def __init__(self, batch_size, N):
super(Translate, self).__init__()

self.image_base_grid = base_grid_generator2d((batch_size, 2, N, N))

self.realidx = nn.Parameter(
torch.LongTensor([0]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).repeat(batch_size, 1, N, N),
requires_grad=False)
self.imagidx = nn.Parameter(
torch.LongTensor([1]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).repeat(batch_size, 1, N, N),
requires_grad=False)

def __call__(self, projection, x):
device = projection.device
preal = torch.gather(projection, dim=1, index=self.realidx)
pimag = torch.gather(projection, dim=1, index=self.imagidx)

x = x.unsqueeze(1).unsqueeze(1)
kx = torch.sum(self.image_base_grid * x, dim=-1).unsqueeze(1)

coskx = torch.cos(2 * np.pi * kx).to(device)
sinkx = torch.sin(2 * np.pi * kx).to(device)

outreal = coskx * preal + sinkx * pimag
outimag = coskx * pimag - sinkx * preal

return torch.cat([outreal, outimag], dim=1)


class SliceExctractor(nn.Module):
def __init__(self, limit, batch_size):
super(SliceExctractor, self).__init__()

self.limit = limit
self.batch_size = batch_size
batch_idx = torch.Tensor(np.arange(self.batch_size)).repeat(self.limit ** 2, 1, 2, 1).permute(0, 3, 2, 1).long()
c0_idx = torch.Tensor(np.zeros(self.batch_size)).repeat(self.limit ** 2, 1, 1, 1).permute(0, 3, 2, 1).long()
c1_idx = torch.Tensor(np.ones(self.batch_size)).repeat(self.limit ** 2, 1, 1, 1).permute(0, 3, 2, 1).long()
self.idxer = nn.Parameter(torch.cat([batch_idx, torch.cat([c0_idx, c1_idx], dim=-2)], dim=-1),
requires_grad=False)

def save_get(self, volume, idx, boundary_mode="periodic"):

if boundary_mode == "periodic":
idx = (idx % (self.limit - 1))
elif boundary_mode == "continious":
idx = torch.clamp(idx, 0, self.limit - 1)

idx = idx.permute(1, 2, 0, 3).view(self.limit * self.limit, self.batch_size, 3)
idx = idx.unsqueeze(-2).repeat(1, 1, 2, 1)
idx = torch.cat([self.idxer, idx.long()], dim=-1).view(self.limit * self.limit * self.batch_size * 2, 5)

return volume[torch.unbind(idx, dim=-1)].view(self.limit, self.limit, self.batch_size, 2).permute(2, 3, 0, 1)

def forward(self, volume, grid):
ix = grid[:, :, :, 0]
iy = grid[:, :, :, 1]
iz = grid[:, :, :, 2]

px_0 = torch.floor(ix)
py_0 = torch.floor(iy)
pz_0 = torch.floor(iz)
px_1 = torch.ceil(ix)
py_1 = torch.ceil(iy)
pz_1 = torch.ceil(iz)

dx = (ix - px_0).unsqueeze(1)
dy = (iy - py_0).unsqueeze(1)
dz = (iz - pz_0).unsqueeze(1)

c_000 = self.save_get(volume, idx=torch.stack([py_0, px_0, pz_0], dim=-1))
c_100 = self.save_get(volume, idx=torch.stack([py_0, px_1, pz_0], dim=-1))
c_00 = c_000 * (1. - dx) + c_100 * (dx)
del c_000, c_100

c_010 = self.save_get(volume, idx=torch.stack([py_1, px_0, pz_0], dim=-1))
c_110 = self.save_get(volume, idx=torch.stack([py_1, px_1, pz_0], dim=-1))
c_10 = c_010 * (1. - dx) + c_110 * (dx)
del c_010, c_110

c_0 = c_00 * (1. - dy) + c_10 * (dy)
del c_00, c_10

c_001 = self.save_get(volume, idx=torch.stack([py_0, px_0, pz_1], dim=-1))
c_101 = self.save_get(volume, idx=torch.stack([py_0, px_1, pz_1], dim=-1))
c_01 = c_001 * (1. - dx) + c_101 * (dx)
del c_001, c_101

c_011 = self.save_get(volume, idx=torch.stack([py_1, px_0, pz_1], dim=-1))
c_111 = self.save_get(volume, idx=torch.stack([py_1, px_1, pz_1], dim=-1))
c_11 = c_011 * (1. - dx) + c_111 * (dx)
del c_011, c_111

c_1 = c_01 * (1. - dy) + c_11 * (dy)
del c_11, c_01

return c_0 * (1. - dz) + c_1 * (dz)


# compute Euler Angles based rotation matrix

component_1_x = torch.FloatTensor([[1, 0, 0], [0, 0, 0], [0, 0, 0]])
component_cos_x = torch.FloatTensor([[0, 0, 0, 0, 1, 0, 0, 0, 1]])
component_sin_x = torch.FloatTensor([[0, 0, 0, 0, 0, -1, 0, 1, 0]])

component_1_z = torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 0, 1]])
component_cos_z = torch.FloatTensor([[1, 0, 0, 0, 1, 0, 0, 0, 0]])
component_sin_z = torch.FloatTensor([[0, -1, 0, 1, 0, 0, 0, 0, 0]])

component_1_y = torch.FloatTensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]])
component_cos_y = torch.FloatTensor([[1, 0, 0, 0, 0, 0, 0, 0, 1]])
component_sin_y = torch.FloatTensor([[0, 0, 1, 0, 0, 0, -1, 0, 0]])


def cosinefy_x(x, device):
batch_size = len(x)
y = torch.mm(x, component_cos_x.to(device))
y = y.resize(batch_size * 3, 3)
return y


def sinefy_x(x, device):
batch_size = len(x)
y = torch.mm(x, component_sin_x.to(device))
y = y.resize(batch_size * 3, 3)
return y


def cosinefy_z(x, device):
batch_size = len(x)
y = torch.mm(x, component_cos_z.to(device))
y = y.resize(batch_size * 3, 3)
return y


def sinefy_z(x, device):
batch_size = len(x)
y = torch.mm(x, component_sin_z.to(device))
y = y.resize(batch_size * 3, 3)
return y


def cosinefy_y(x, device):
batch_size = len(x)
y = torch.mm(x, component_cos_y.to(device))
y = y.resize(batch_size * 3, 3)
return y


def sinefy_y(x, device):
batch_size = len(x)
y = torch.mm(x, component_sin_y.to(device))
y = y.resize(batch_size * 3, 3)
return y


def R_x(g, device):
"""
Compute the Euler Angles R_x
"""
cos_angles = cosinefy_x(torch.cos(g), device)
sin_angles = sinefy_x(torch.sin(g), device)

out = sin_angles + cos_angles
return out.resize(len(g), 3, 3) + component_1_x.to(device)


def R_z(g, device):
"""
Compute the Euler Angles R_z
"""
cos_angles = cosinefy_z(torch.cos(g), device)
sin_angles = sinefy_z(torch.sin(g), device)

out = sin_angles + cos_angles
return out.resize(len(g), 3, 3) + component_1_z.to(device)


def R_y(g, device):
"""
Compute the Euler Angles R_y
"""
cos_angles = cosinefy_y(torch.cos(g), device)
sin_angles = sinefy_y(torch.sin(g), device)

out = sin_angles + cos_angles
return out.resize(len(g), 3, 3) + component_1_y.to(device)


def rotmat3D_EA(g):
"""
Generates a rotation matrix from Z-Y-Z Euler angles. This rotation matrix
maps from image coordinates (x,y,0) to view coordinates.
"""
device = g.device
R_phi = R_z(g[:, 0].view(-1, 1), device)
R_theta = R_y(g[:, 1].view(-1, 1), device)
R_psi = R_z(g[:, 2].view(-1, 1), device)

R = torch.bmm(R_phi, R_theta)
R = torch.bmm(R, R_psi)

return R.resize(len(g), 3, 3)
Loading

0 comments on commit 7c2568e

Please sign in to comment.