Skip to content

Commit

Permalink
TST: use with where possible instead of manual close (pandas-dev#…
Browse files Browse the repository at this point in the history
…48931)

Coincidentally fixes some StataReaders being left open in tests.
  • Loading branch information
akx authored and noatamir committed Nov 9, 2022
1 parent 531fded commit 8d29f57
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 151 deletions.
10 changes: 3 additions & 7 deletions pandas/tests/io/parser/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,9 @@ def test_utf16_bom_skiprows(all_parsers, sep, encoding):
with open(path, "wb") as f:
f.write(bytes_data)

bytes_buffer = BytesIO(data.encode(utf8))
bytes_buffer = TextIOWrapper(bytes_buffer, encoding=utf8)

result = parser.read_csv(path, encoding=encoding, **kwargs)
expected = parser.read_csv(bytes_buffer, encoding=utf8, **kwargs)

bytes_buffer.close()
with TextIOWrapper(BytesIO(data.encode(utf8)), encoding=utf8) as bytes_buffer:
result = parser.read_csv(path, encoding=encoding, **kwargs)
expected = parser.read_csv(bytes_buffer, encoding=utf8, **kwargs)
tm.assert_frame_equal(result, expected)


Expand Down
23 changes: 10 additions & 13 deletions pandas/tests/io/pytables/test_file_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def test_mode(setup_path, tmp_path, mode):
HDFStore(path, mode=mode)

else:
store = HDFStore(path, mode=mode)
assert store._handle.mode == mode
store.close()
with HDFStore(path, mode=mode) as store:
assert store._handle.mode == mode

path = tmp_path / setup_path

Expand Down Expand Up @@ -253,16 +252,14 @@ def test_complibs(tmp_path, setup_path):
result = read_hdf(tmpfile, gname)
tm.assert_frame_equal(result, df)

# Open file and check metadata
# for correct amount of compression
h5table = tables.open_file(tmpfile, mode="r")
for node in h5table.walk_nodes(where="/" + gname, classname="Leaf"):
assert node.filters.complevel == lvl
if lvl == 0:
assert node.filters.complib is None
else:
assert node.filters.complib == lib
h5table.close()
# Open file and check metadata for correct amount of compression
with tables.open_file(tmpfile, mode="r") as h5table:
for node in h5table.walk_nodes(where="/" + gname, classname="Leaf"):
assert node.filters.complevel == lvl
if lvl == 0:
assert node.filters.complib is None
else:
assert node.filters.complib == lib


@pytest.mark.skipif(
Expand Down
18 changes: 9 additions & 9 deletions pandas/tests/io/pytables/test_read.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import closing
from pathlib import Path
import re

Expand Down Expand Up @@ -207,11 +208,10 @@ def test_read_hdf_open_store(tmp_path, setup_path):
path = tmp_path / setup_path
df.to_hdf(path, "df", mode="w")
direct = read_hdf(path, "df")
store = HDFStore(path, mode="r")
indirect = read_hdf(store, "df")
tm.assert_frame_equal(direct, indirect)
assert store.is_open
store.close()
with HDFStore(path, mode="r") as store:
indirect = read_hdf(store, "df")
tm.assert_frame_equal(direct, indirect)
assert store.is_open


def test_read_hdf_iterator(tmp_path, setup_path):
Expand All @@ -223,10 +223,10 @@ def test_read_hdf_iterator(tmp_path, setup_path):
df.to_hdf(path, "df", mode="w", format="t")
direct = read_hdf(path, "df")
iterator = read_hdf(path, "df", iterator=True)
assert isinstance(iterator, TableIterator)
indirect = next(iterator.__iter__())
tm.assert_frame_equal(direct, indirect)
iterator.store.close()
with closing(iterator.store):
assert isinstance(iterator, TableIterator)
indirect = next(iterator.__iter__())
tm.assert_frame_equal(direct, indirect)


def test_read_nokey(tmp_path, setup_path):
Expand Down
28 changes: 12 additions & 16 deletions pandas/tests/io/pytables/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,9 @@ def test_frame_select_complex2(tmp_path):

# scope with list like
l0 = selection.index.tolist() # noqa:F841
store = HDFStore(hh)
result = store.select("df", where="l1=l0")
tm.assert_frame_equal(result, expected)
store.close()
with HDFStore(hh) as store:
result = store.select("df", where="l1=l0")
tm.assert_frame_equal(result, expected)

result = read_hdf(hh, "df", where="l1=l0")
tm.assert_frame_equal(result, expected)
Expand All @@ -705,21 +704,18 @@ def test_frame_select_complex2(tmp_path):
tm.assert_frame_equal(result, expected)

# scope with index
store = HDFStore(hh)

result = store.select("df", where="l1=index")
tm.assert_frame_equal(result, expected)

result = store.select("df", where="l1=selection.index")
tm.assert_frame_equal(result, expected)
with HDFStore(hh) as store:
result = store.select("df", where="l1=index")
tm.assert_frame_equal(result, expected)

result = store.select("df", where="l1=selection.index.tolist()")
tm.assert_frame_equal(result, expected)
result = store.select("df", where="l1=selection.index")
tm.assert_frame_equal(result, expected)

result = store.select("df", where="l1=list(selection.index)")
tm.assert_frame_equal(result, expected)
result = store.select("df", where="l1=selection.index.tolist()")
tm.assert_frame_equal(result, expected)

store.close()
result = store.select("df", where="l1=list(selection.index)")
tm.assert_frame_equal(result, expected)


def test_invalid_filtering(setup_path):
Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/io/pytables/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,9 +917,8 @@ def do_copy(f, new_f=None, keys=None, propindexes=True, **kwargs):
df = tm.makeDataFrame()

with tm.ensure_clean() as path:
st = HDFStore(path)
st.append("df", df, data_columns=["A"])
st.close()
with HDFStore(path) as st:
st.append("df", df, data_columns=["A"])
do_copy(f=path)
do_copy(f=path, propindexes=False)

Expand Down
10 changes: 5 additions & 5 deletions pandas/tests/io/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def test_get_handle_with_path(self, path_type):
assert os.path.expanduser(filename) == handles.handle.name

def test_get_handle_with_buffer(self):
input_buffer = StringIO()
with icom.get_handle(input_buffer, "r") as handles:
assert handles.handle == input_buffer
assert not input_buffer.closed
input_buffer.close()
with StringIO() as input_buffer:
with icom.get_handle(input_buffer, "r") as handles:
assert handles.handle == input_buffer
assert not input_buffer.closed
assert input_buffer.closed

# Test that BytesIOWrapper(get_handle) returns correct amount of bytes every time
def test_bytesiowrapper_returns_correct_bytes(self):
Expand Down
11 changes: 5 additions & 6 deletions pandas/tests/io/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,17 @@ def test_bzip_compression_level(obj, method):
)
def test_empty_archive_zip(suffix, archive):
with tm.ensure_clean(filename=suffix) as path:
file = archive(path, "w")
file.close()
with archive(path, "w"):
pass
with pytest.raises(ValueError, match="Zero files found"):
pd.read_csv(path)


def test_ambiguous_archive_zip():
with tm.ensure_clean(filename=".zip") as path:
file = zipfile.ZipFile(path, "w")
file.writestr("a.csv", "foo,bar")
file.writestr("b.csv", "foo,bar")
file.close()
with zipfile.ZipFile(path, "w") as file:
file.writestr("a.csv", "foo,bar")
file.writestr("b.csv", "foo,bar")
with pytest.raises(ValueError, match="Multiple files found in ZIP file"):
pd.read_csv(path)

Expand Down
24 changes: 10 additions & 14 deletions pandas/tests/io/test_fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,18 @@ def test_to_csv_fsspec_object(cleared_fs, binary_mode, df1):

path = "memory://test/test.csv"
mode = "wb" if binary_mode else "w"
fsspec_object = fsspec.open(path, mode=mode).open()

df1.to_csv(fsspec_object, index=True)
assert not fsspec_object.closed
fsspec_object.close()
with fsspec.open(path, mode=mode).open() as fsspec_object:
df1.to_csv(fsspec_object, index=True)
assert not fsspec_object.closed

mode = mode.replace("w", "r")
fsspec_object = fsspec.open(path, mode=mode).open()

df2 = read_csv(
fsspec_object,
parse_dates=["dt"],
index_col=0,
)
assert not fsspec_object.closed
fsspec_object.close()
with fsspec.open(path, mode=mode) as fsspec_object:
df2 = read_csv(
fsspec_object,
parse_dates=["dt"],
index_col=0,
)
assert not fsspec_object.closed

tm.assert_frame_equal(df1, df2)

Expand Down
40 changes: 19 additions & 21 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
from __future__ import annotations

from contextlib import closing
import csv
from datetime import (
date,
Expand Down Expand Up @@ -455,9 +456,8 @@ def sqlite_iris_conn(sqlite_iris_engine):

@pytest.fixture
def sqlite_buildin():
conn = sqlite3.connect(":memory:")
yield conn
conn.close()
with sqlite3.connect(":memory:") as conn:
yield conn


@pytest.fixture
Expand Down Expand Up @@ -1532,13 +1532,14 @@ def test_sql_open_close(self, test_frame3):

with tm.ensure_clean() as name:

conn = self.connect(name)
assert sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False) == 4
conn.close()
with closing(self.connect(name)) as conn:
assert (
sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False)
== 4
)

conn = self.connect(name)
result = sql.read_sql_query("SELECT * FROM test_frame3_legacy;", conn)
conn.close()
with closing(self.connect(name)) as conn:
result = sql.read_sql_query("SELECT * FROM test_frame3_legacy;", conn)

tm.assert_frame_equal(test_frame3, result)

Expand Down Expand Up @@ -2371,18 +2372,15 @@ class Test(BaseModel):

BaseModel.metadata.create_all(self.conn)
Session = sessionmaker(bind=self.conn)
session = Session()

df = DataFrame({"id": [0, 1], "foo": ["hello", "world"]})
assert (
df.to_sql("test_frame", con=self.conn, index=False, if_exists="replace")
== 2
)

session.commit()
foo = session.query(Test.id, Test.foo)
df = DataFrame(foo)
session.close()
with Session() as session:
df = DataFrame({"id": [0, 1], "foo": ["hello", "world"]})
assert (
df.to_sql("test_frame", con=self.conn, index=False, if_exists="replace")
== 2
)
session.commit()
foo = session.query(Test.id, Test.foo)
df = DataFrame(foo)

assert list(df.columns) == ["id", "foo"]

Expand Down
Loading

0 comments on commit 8d29f57

Please sign in to comment.