forked from dstndstn/tractor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathctest.py
82 lines (62 loc) · 2.07 KB
/
ctest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import print_function
if __name__ == '__main__':
import matplotlib
matplotlib.use('Agg')
import numpy as np
import pylab as plt
from tractor import *
from astrometry.util.plotutils import *
ps = PlotSequence('ctest')
class MyTractor(Tractor):
def setParams(self, p):
print('MyTractor.setParams', p)
super(MyTractor, self).setParams(p)
tim = self.getImage(0)
data = tim.getImage()
mod = self.getModelImage(0)
mx = max(data.max(), mod.max())
mn = min(data.min(), mod.min())
ima = dict(interpolation='nearest', origin='lower', vmin=mn, vmax=mx)
plt.clf()
plt.subplot(2,2,1)
plt.imshow(mod, **ima)
plt.subplot(2,2,2)
plt.imshow(data, **ima)
plt.subplot(2,2,4)
plt.imshow(-(data - mod)*tim.getInvError(), interpolation='nearest',
origin='lower', vmin=-5, vmax=5, cmap='RdBu')
ps.savefig()
H,W = 10,10
img = np.zeros((H,W), np.float32)
sig1 = 0.1
tim = Image(data=img, invvar=np.zeros_like(img) + 1./sig1**2,
psf=NCircularGaussianPSF([1.], [1.]),
wcs=NullWCS(), photocal=LinearPhotoCal(1.),
sky=ConstantSky(0.),
name='Test', domask=False)
src = PointSource(PixPos(W/2, H/2), Flux(100.))
tractor = Tractor([tim], [src])
mod = tractor.getModelImage(0)
tim.data = mod + np.random.normal(scale=sig1, size=mod.shape)
src.brightness = Flux(10.)
src.pos = PixPos(W/2 - 1, H/2 - 1)
print('All params:')
tractor.printThawedParams()
tim.freezeParams('psf', 'photocal')
print('Thawed param:')
tractor.printThawedParams()
print('Params:', tractor.getParams())
lnp0 = tractor.getLogProb()
print('Logprob:', lnp0)
#print 'Testing _getOneImageDerivs...'
#tractor._getOneImageDerivs(0)
print('Calling ceres optimizer...')
from tractor.ceres_optimizer import CeresOptimizer
tractor.optimizer = CeresOptimizer()
tractor.optimize()
print('Ceres opt finished')
print('Params:', tractor.getParams())
lnp1 = tractor.getLogProb()
print('Logprob:', lnp1)
print('Thawed param:')
tractor.printThawedParams()