-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcoupled_tv_huber_crlb_2.py
143 lines (105 loc) · 4.39 KB
/
coupled_tv_huber_crlb_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""This formulation solves the model
min_x ||Ax - b||_W^2 + ||grad(x)||_*
where A is the ray transform, || . ||_W is the weighted l2 norm,
grad is the gradient and || . ||_* is the nuclear norm.
"""
import odl
import numpy as np
import scipy.linalg as spl
from util import cov_matrix, load_data, load_fan_data, inverse_sqrt_matrix
"""
class MyOperator(odl.Operator):
def _call(self, x):
result = self.range.zero()
for k in range(self.domain.size):
for i in range(self.domain[0].size):
for j in range(self.domain[0].size):
result += x[k][i] * x[k][j]
return result.ufunc.sqrt()
def derivative(self, x):
result = sum(x, x[0].space.zero())
result /= self(x)
return odl.PointwiseSum(self.domain[0]) * odl.PointwiseSum(self.domain) * self.domain.element([result] * self.domain.size)
class MyOperatorTransp(odl.Operator):
def _call(self, x):
result = self.range.zero()
for k in range(self.domain[0].size):
for i in range(self.domain.size):
for j in range(self.domain.size):
result += x[j][k] * x[j][k]
return result.ufunc.sqrt()
def derivative(self, x):
result = self.domain.zero()
sx = self(x)
for i in range(self.domain[0].size):
result[0] += x[0][i] / sx
for i in range(1, self.domain.size):
result[i].assign(result[0])
return odl.PointwiseSum(self.domain[0]) * odl.PointwiseSum(self.domain) * result
"""
class MyOperatorTrace(odl.Operator):
def _call(self, x):
result = self.range.zero()
for i in range(self.domain.size):
for j in range(self.domain[0].size):
result += x[i][j] * x[i][j]
return result.ufunc.sqrt()
def derivative(self, x):
result = x.copy()
sx = self(x)
for xi in result:
for xii in xi:
xii /= sx
return odl.PointwiseSum(self.domain[0]) * odl.PointwiseSum(self.domain) * result
class LamOp(odl.Operator):
def __init__(self, space, arr):
self.arr = arr
odl.Operator.__init__(self, space, space, True)
def _call(self, x):
result = self.range.zero()
for k in range(self.domain.size):
for i in range(self.domain.size):
for j in range(self.domain[0].size):
result[i][j] += self.arr[k][i] * x[k][j]
return result
@property
def adjoint(self):
return LamOp(self.domain, self.arr.T)
data, geometry, crlb = load_fan_data(return_crlb=True)
space = odl.uniform_discr([-129, -129], [129, 129], [400, 400])
ray_trafo = odl.tomo.RayTransform(space, geometry, impl='astra_cuda')
A = odl.DiagonalOperator(ray_trafo, 2)
grad = odl.Gradient(space)
grad_vec = odl.DiagonalOperator(grad, 2)
cross_terms = True
c = 0.5
if not cross_terms:
crlb[1, 0, ...] = 0
crlb[0, 1, ...] = 0
mat_sqrt_inv = inverse_sqrt_matrix(crlb)
re = ray_trafo.range.element
W = odl.ProductSpaceOperator([[odl.MultiplyOperator(re(mat_sqrt_inv[0, 0])), odl.MultiplyOperator(re(mat_sqrt_inv[0, 1]))],
[odl.MultiplyOperator(re(mat_sqrt_inv[1, 0])), odl.MultiplyOperator(re(mat_sqrt_inv[1, 1]))]])
op = W * A
rhs = W(data)
data_discrepancy = odl.solvers.L2NormSquared(A.range).translated(rhs)
l1_norm = odl.solvers.L1Norm(space)
huber = 2.5 * odl.solvers.MoreauEnvelope(l1_norm, sigma=0.01)
my_op = MyOperatorTrace(domain=grad_vec.range, range=space, linear=False)
spc_cov_matrix = [[1, -c],
[-c, 1]]
spc_cov_matrix_inv_sqrt = inverse_sqrt_matrix(spc_cov_matrix)
Lam = LamOp(grad_vec.range, arr=spc_cov_matrix_inv_sqrt)
func = data_discrepancy * op + huber * my_op * Lam * grad_vec
fbp_op = odl.tomo.fbp_op(ray_trafo,
filter_type='Hann', frequency_scaling=0.7)
x = A.domain.element([fbp_op(data[0]), fbp_op(data[1])])
x.show(clim=[0.9, 1.1])
callback = (odl.solvers.CallbackShow(step=5) &
odl.solvers.CallbackShow(step=5, clim=[-0.1, 0.1]) &
odl.solvers.CallbackShow(step=5, clim=[0.9, 1.1]) &
odl.solvers.CallbackPrintIteration())
opnorm = odl.power_method_opnorm(op)
hessinv_estimate = odl.ScalingOperator(func.domain, 1 / opnorm ** 2)
odl.solvers.bfgs_method(func, x, line_search=1.0, maxiter=1000, num_store=10,
callback=callback, hessinv_estimate=hessinv_estimate)