From 8abd2612f1306c4d28839eb1d535e921a7a7e58d Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 10 Jan 2025 12:06:05 +0400 Subject: [PATCH] feat(python): Support writing to file objects from `write_excel` (#20638) --- py-polars/polars/dataframe/frame.py | 4 +-- .../polars/io/spreadsheet/_write_utils.py | 22 +++++++++---- py-polars/tests/unit/io/test_spreadsheet.py | 33 ++++++++++++++++++- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 7014b6e74554..66a7f7d1472f 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -3106,8 +3106,8 @@ def write_excel( Parameters ---------- workbook : {str, Workbook} - String name or path of the workbook to create, BytesIO object to write - into, or an open `xlsxwriter.Workbook` object that has not been closed. + String name or path of the workbook to create, BytesIO object, file opened + in binary-mode, or an `xlsxwriter.Workbook` object that has not been closed. If None, writes to a `dataframe.xlsx` workbook in the working directory. worksheet : {str, Worksheet} Name of target worksheet or an `xlsxwriter.Worksheet` object (in which diff --git a/py-polars/polars/io/spreadsheet/_write_utils.py b/py-polars/polars/io/spreadsheet/_write_utils.py index 8489b359a416..a471c340b192 100644 --- a/py-polars/polars/io/spreadsheet/_write_utils.py +++ b/py-polars/polars/io/spreadsheet/_write_utils.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from io import BytesIO +from os import PathLike from pathlib import Path from typing import TYPE_CHECKING, Any, overload @@ -594,13 +595,20 @@ def _xl_setup_workbook( if isinstance(workbook, BytesIO): wb, ws, can_close = Workbook(workbook, workbook_options), None, True else: - file = Path("dataframe.xlsx" if workbook is None else workbook) - wb = Workbook( - (file if file.suffix else file.with_suffix(".xlsx")) - .expanduser() - .resolve(strict=False), - workbook_options, - ) + if workbook is None: + file = Path("dataframe.xlsx") + elif isinstance(workbook, str): + file = Path(workbook) + else: + file = workbook + + if isinstance(file, PathLike): + file = ( + (file if file.suffix else file.with_suffix(".xlsx")) + .expanduser() + .resolve(strict=False) + ) + wb = Workbook(file, workbook_options) ws, can_close = None, True if ws is None: diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index c1838351d86d..4d0fc511f76c 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -4,6 +4,7 @@ from collections import OrderedDict from datetime import date, datetime from io import BytesIO +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable import pytest @@ -16,7 +17,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from pathlib import Path from polars._typing import ExcelSpreadsheetEngine, SchemaDict, SelectorType @@ -859,6 +859,37 @@ def test_excel_write_compound_types(engine: ExcelSpreadsheetEngine) -> None: ] +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) +def test_excel_write_to_file_object( + engine: ExcelSpreadsheetEngine, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"x": ["aaa", "bbb", "ccc"], "y": [123, 456, 789]}) + + # write to bytesio + xls = BytesIO() + df.write_excel(xls, worksheet="data") + assert_frame_equal(df, pl.read_excel(xls, engine=engine)) + + # write to file path + path = Path(tmp_path).joinpath("test_write_path.xlsx") + df.write_excel(path, worksheet="data") + assert_frame_equal(df, pl.read_excel(xls, engine=engine)) + + # write to file path (as string) + path = Path(tmp_path).joinpath("test_write_path_str.xlsx") + df.write_excel(str(path), worksheet="data") + assert_frame_equal(df, pl.read_excel(xls, engine=engine)) + + # write to file object + path = Path(tmp_path).joinpath("test_write_file_object.xlsx") + with path.open("wb") as tgt: + df.write_excel(tgt, worksheet="data") + with path.open("rb") as src: + assert_frame_equal(df, pl.read_excel(src, engine=engine)) + + @pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) def test_excel_read_no_headers(engine: ExcelSpreadsheetEngine) -> None: df = pl.DataFrame(