Skip to content

Commit

Permalink
FIX: random.draw: Add =None
Browse files Browse the repository at this point in the history
Fix #704
  • Loading branch information
oyamad committed May 23, 2023
1 parent 535a82f commit beec748
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
11 changes: 11 additions & 0 deletions quantecon/random/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,14 @@ def test_lln(self):
pmf_computed = hist * np.diff(bin_edges)
atol = 1e-2
assert_allclose(pmf_computed, self.pmf, atol=atol)


@njit
def draw_jitted_w_o_size(n):
cdf = np.linspace(1/n, 1, n)
return draw(cdf)


def test_draw_jitted_w_o_size():
n = 3
assert_(draw_jitted_w_o_size(n) in range(n))
6 changes: 3 additions & 3 deletions quantecon/random/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,16 @@ def draw(cdf, size=None):

# Overload for the `draw` function
@overload(draw)
def ol_draw(cdf, size):
def ol_draw(cdf, size=None):
if isinstance(size, types.Integer):
def draw_impl(cdf, size):
def draw_impl(cdf, size=None):
rs = np.random.random(size)
out = np.empty(size, dtype=np.int_)
for i in range(size):
out[i] = searchsorted(cdf, rs[i])
return out
else:
def draw_impl(cdf, size):
def draw_impl(cdf, size=None):
r = np.random.random()
return searchsorted(cdf, r)
return draw_impl

0 comments on commit beec748

Please sign in to comment.