-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathjanky_stuff.py
77 lines (66 loc) · 2.3 KB
/
janky_stuff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
class JankySampler:
def __init__(self, arr, seed=None):
self.arr = arr
self.nprng = np.random.RandomState(seed)
self.reset()
def reset(self):
self.drawn = 0
self.idx = self.nprng.permutation(len(self.arr))
def draw(self, n):
'''
Shuffle the array if it's exhausted and draw `n` samples without
replacement.
'''
if n > len(self.arr):
raise ValueError("It looks like you tried to draw more than there are in the list")
if self.drawn + n > len(self.arr):
self.reset()
end = self.drawn + n
retval = self.arr[self.idx[self.drawn:end]]
self.drawn = end
return retval
class JankySubsampler:
'''
To be used with iter_data_mpi. This class reports it has the same number
of examples as `arrays[0]`, but returns a mixed slice of examples from all
`arrays`.
'''
def __init__(self, arrays, pmf, seed=None):
assert len(pmf) == len(arrays)
self.pmf = pmf
self.arrays = arrays
self.samplers = [JankySampler(arr, seed=seed) for arr in arrays]
self.idxs = np.arange(len(self.pmf))
self.nprng = np.random.RandomState(seed)
for arr in arrays[1:]:
assert arrays[0].shape[1:] == arr.shape[1:]
@property
def shape(self):
return self.arrays[0].shape
def __getitem__(self, key):
n = len(key)
arr = self.nprng.choice(self.idxs, size=n, p=self.pmf)
ret = np.concatenate(
[sampler.draw((arr == idx).sum())
for idx, sampler in zip(self.idxs, self.samplers)],
axis=0)
return ret
if __name__ == "__main__":
# Test 1
arr = np.arange(9)
sampler = JankySampler(arr, seed=42)
for _ in range(4):
print(sampler.draw(8))
# Test 2
pmf = [0.5, 0.5]
arr1 = np.arange(4 * 2).reshape(4, 2) + 1
arr2 = -np.arange(8 * 2).reshape(8, 2)
subsampler = JankySubsampler((arr1, arr2), pmf, seed=42)
aux_arr1 = np.ones((4,))
aux_arr2 = np.zeros((8,))
aux_subsampler = JankySubsampler((aux_arr1, aux_arr2), pmf, seed=42)
dummy_indices = np.arange(4) # Draw 4 arrays at a time
for _ in range(10):
print(subsampler[dummy_indices])
print(aux_subsampler[dummy_indices])