Skip to content

Commit

Permalink
Fix sample_without_replacement using guvectorize
Browse files Browse the repository at this point in the history
  • Loading branch information
oyamad committed Mar 16, 2019
1 parent 2402749 commit d3c8b3c
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions quantecon/random/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import numpy as np
from numba import jit, guvectorize, generated_jit, types
from numba import guvectorize, generated_jit, types

from ..util import check_random_state, searchsorted

Expand Down Expand Up @@ -98,7 +98,6 @@ def _probvec(r, out):
)(_probvec)


@jit
def sample_without_replacement(n, k, num_trials=None, random_state=None):
"""
Randomly choose k integers without replacement from 0, ..., n-1.
Expand Down Expand Up @@ -144,26 +143,30 @@ def sample_without_replacement(n, k, num_trials=None, random_state=None):
if k > n:
raise ValueError('k must be smaller than or equal to n')

m = 1 if num_trials is None else num_trials
size = k if num_trials is None else (num_trials, k)

random_state = check_random_state(random_state)
r = random_state.random_sample(size=(m, k))
r = random_state.random_sample(size=size)
result = _sample_without_replacement(n, r)

return result


@guvectorize(['(i8, f8[:], i8[:])'], '(),(k)->(k)', nopython=True, cache=True)
def _sample_without_replacement(n, r, out):
"""
Main body of `sample_without_replacement`. To be complied as a ufunc
by guvectorize of Numba.
"""
k = r.shape[0]

# Logic taken from random.sample in the standard library
result = np.empty((m, k), dtype=int)
pool = np.empty(n, dtype=int)
for i in range(m):
for j in range(n):
pool[j] = j
for j in range(k):
idx = int(np.floor(r[i, j] * (n-j))) # np.floor returns a float
result[i, j] = pool[idx]
pool[idx] = pool[n-j-1]

if num_trials is None:
return result[0]
else:
return result
pool = np.arange(n)
for j in range(k):
idx = int(np.floor(r[j] * (n-j))) # np.floor returns a float
out[j] = pool[idx]
pool[idx] = pool[n-j-1]


@generated_jit(nopython=True)
Expand Down

0 comments on commit d3c8b3c

Please sign in to comment.