Skip to content

Commit

Permalink
ENH: Allow tuple of Warnings in tm.assert_produces_warning()
Browse files Browse the repository at this point in the history
  • Loading branch information
kernc committed Feb 23, 2018
1 parent cefb3c2 commit 5b7d0f9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
9 changes: 6 additions & 3 deletions pandas/tests/sparse/frame/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,23 +449,26 @@ def test_set_value(self):

# ok, as the index gets converted to object
frame = self.frame.copy()
with tm.assert_produces_warning(FutureWarning,
with tm.assert_produces_warning((FutureWarning,
PerformanceWarning),
check_stacklevel=False):
res = frame.set_value('foobar', 'B', 1.5)
assert res.index.dtype == 'object'

res = self.frame
res.index = res.index.astype(object)

with tm.assert_produces_warning(FutureWarning,
with tm.assert_produces_warning((FutureWarning,
PerformanceWarning),
check_stacklevel=False):
res = self.frame.set_value('foobar', 'B', 1.5)
assert res.index[-1] == 'foobar'
with tm.assert_produces_warning(FutureWarning,
check_stacklevel=False):
assert res.get_value('foobar', 'B') == 1.5

with tm.assert_produces_warning(FutureWarning,
with tm.assert_produces_warning((FutureWarning,
PerformanceWarning),
check_stacklevel=False):
res2 = res.set_value('foobar', 'qux', 1.5)
tm.assert_index_equal(res2.columns,
Expand Down
6 changes: 4 additions & 2 deletions pandas/tests/sparse/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,14 @@ def test_get_get_value(self):
def test_set_value(self):

idx = self.btseries.index[7]
with tm.assert_produces_warning(FutureWarning,
with tm.assert_produces_warning((FutureWarning,
PerformanceWarning),
check_stacklevel=False):
self.btseries.set_value(idx, 0)
assert self.btseries[idx] == 0

with tm.assert_produces_warning(FutureWarning,
with tm.assert_produces_warning((FutureWarning,
PerformanceWarning),
check_stacklevel=False):
self.iseries.set_value('foobar', 0)
assert self.iseries.index[-1] == 'foobar'
Expand Down
23 changes: 16 additions & 7 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,9 +2392,10 @@ def assert_produces_warning(expected_warning=Warning, filter_level="always",
Parameters
----------
expected_warning : {Warning, False, None}, default Warning
expected_warning : {Warning, tuple, False, None}, default Warning
The type of Exception raised. ``exception.Warning`` is the base
class for all warnings. To check that no warning is returned,
class for all warnings. To expect mltiple warnings, pass a tuple
of warning classes. To check that no warning is returned,
specify ``False`` or ``None``.
filter_level : str, default "always"
Specifies whether warnings are ignored, displayed, or turned
Expand Down Expand Up @@ -2441,6 +2442,11 @@ class for all warnings. To check that no warning is returned,
..warn:: This is *not* thread-safe.
"""
# Ensure a tuple
if (isinstance(expected_warning, type) and
issubclass(expected_warning, Warning)):
expected_warning = (expected_warning,)

with warnings.catch_warnings(record=True) as w:

if clear is not None:
Expand All @@ -2455,15 +2461,17 @@ class for all warnings. To check that no warning is returned,
except Exception:
pass

saw_warning = False
saw_warning = set()
warnings.simplefilter(filter_level)
yield w
extra_warnings = []

for actual_warning in w:
if (expected_warning and issubclass(actual_warning.category,
expected_warning)):
saw_warning = True
saw_warning.add(
next(w for w in expected_warning
if issubclass(actual_warning.category, w)))

if check_stacklevel and issubclass(actual_warning.category,
(FutureWarning,
Expand All @@ -2480,9 +2488,10 @@ class for all warnings. To check that no warning is returned,
else:
extra_warnings.append(actual_warning.category.__name__)
if expected_warning:
msg = "Did not see expected warning of class {name!r}.".format(
name=expected_warning.__name__)
assert saw_warning, msg
unseen = set(expected_warning) - saw_warning
msg = ("Did not see expected warning(s) of class: " +
', '.join(w.__name__ for w in unseen))
assert not unseen, msg
assert not extra_warnings, ("Caused unexpected warning(s): {extra!r}."
).format(extra=extra_warnings)

Expand Down

0 comments on commit 5b7d0f9

Please sign in to comment.