diff --git a/quantecon/random/utilities.py b/quantecon/random/utilities.py index c2d68d23d..3bc12b87e 100644 --- a/quantecon/random/utilities.py +++ b/quantecon/random/utilities.py @@ -4,7 +4,7 @@ """ import numpy as np -from numba import jit, guvectorize, generated_jit, types +from numba import guvectorize, generated_jit, types from ..util import check_random_state, searchsorted @@ -98,7 +98,6 @@ def _probvec(r, out): )(_probvec) -@jit def sample_without_replacement(n, k, num_trials=None, random_state=None): """ Randomly choose k integers without replacement from 0, ..., n-1. @@ -144,26 +143,30 @@ def sample_without_replacement(n, k, num_trials=None, random_state=None): if k > n: raise ValueError('k must be smaller than or equal to n') - m = 1 if num_trials is None else num_trials + size = k if num_trials is None else (num_trials, k) random_state = check_random_state(random_state) - r = random_state.random_sample(size=(m, k)) + r = random_state.random_sample(size=size) + result = _sample_without_replacement(n, r) + + return result + + +@guvectorize(['(i8, f8[:], i8[:])'], '(),(k)->(k)', nopython=True, cache=True) +def _sample_without_replacement(n, r, out): + """ + Main body of `sample_without_replacement`. To be complied as a ufunc + by guvectorize of Numba. + + """ + k = r.shape[0] # Logic taken from random.sample in the standard library - result = np.empty((m, k), dtype=int) - pool = np.empty(n, dtype=int) - for i in range(m): - for j in range(n): - pool[j] = j - for j in range(k): - idx = int(np.floor(r[i, j] * (n-j))) # np.floor returns a float - result[i, j] = pool[idx] - pool[idx] = pool[n-j-1] - - if num_trials is None: - return result[0] - else: - return result + pool = np.arange(n) + for j in range(k): + idx = int(np.floor(r[j] * (n-j))) # np.floor returns a float + out[j] = pool[idx] + pool[idx] = pool[n-j-1] @generated_jit(nopython=True)