diff --git a/altair_saver/_utils.py b/altair_saver/_utils.py index be5bb37..0ff76af 100644 --- a/altair_saver/_utils.py +++ b/altair_saver/_utils.py @@ -2,6 +2,7 @@ from http import client import io import os +import pathlib import socket import subprocess import sys @@ -133,9 +134,10 @@ def temporary_filename( @contextlib.contextmanager -def maybe_open(fp: Union[IO, str], mode: str = "w") -> Iterator[IO]: - """Context manager to write to a file specified by filename or file-like object""" - if isinstance(fp, str): +def maybe_open(fp: Union[IO, str, pathlib.PurePath], mode: str = "w") -> Iterator[IO]: + """Context manager to write to a file specified by filename, Path or + file-like object""" + if isinstance(fp, str) or isinstance(fp, pathlib.PurePath): with open(fp, mode) as f: yield f elif isinstance(fp, io.TextIOBase) and "b" in mode: @@ -150,11 +152,13 @@ def maybe_open(fp: Union[IO, str], mode: str = "w") -> Iterator[IO]: yield fp -def extract_format(fp: Union[IO, str]) -> str: - """Extract the altair_saver output format from a file or filename.""" +def extract_format(fp: Union[IO, str, pathlib.PurePath]) -> str: + """Extract the altair_saver output format from a file, filename or Path.""" filename: Optional[str] if isinstance(fp, str): filename = fp + elif isinstance(fp, pathlib.PurePath): + filename = str(fp) else: filename = getattr(fp, "name", None) if filename is None: diff --git a/altair_saver/savers/_selenium.py b/altair_saver/savers/_selenium.py index 0ac3208..6839000 100644 --- a/altair_saver/savers/_selenium.py +++ b/altair_saver/savers/_selenium.py @@ -215,11 +215,16 @@ def _serve(cls, content: str, js_resources: Dict[str, str]) -> str: if cls._provider is None: cls._provider = Provider() resource = cls._provider.create( - content=content, route="", headers={"Access-Control-Allow-Origin": "*"}, + content=content, + route="", + headers={"Access-Control-Allow-Origin": "*"}, ) cls._resources[resource.url] = resource for route, content in js_resources.items(): - cls._resources[route] = cls._provider.create(content=content, route=route,) + cls._resources[route] = cls._provider.create( + content=content, + route=route, + ) return resource.url @classmethod diff --git a/altair_saver/savers/tests/test_html.py b/altair_saver/savers/tests/test_html.py index 6899fec..b3c2183 100644 --- a/altair_saver/savers/tests/test_html.py +++ b/altair_saver/savers/tests/test_html.py @@ -75,7 +75,9 @@ def test_html_save( @pytest.mark.parametrize("embed_options", [None, {"theme": "dark"}]) @pytest.mark.parametrize("case, data", get_testcases()) def test_html_mimebundle( - case: str, data: Dict[str, Any], embed_options: Optional[dict], + case: str, + data: Dict[str, Any], + embed_options: Optional[dict], ) -> None: saver = HTMLSaver(data["vega-lite"], embed_options=embed_options) bundle = saver.mimebundle("html") diff --git a/altair_saver/savers/tests/test_node.py b/altair_saver/savers/tests/test_node.py index 568d4bc..997794d 100644 --- a/altair_saver/savers/tests/test_node.py +++ b/altair_saver/savers/tests/test_node.py @@ -112,7 +112,9 @@ def exec_path(name: str) -> str: @pytest.mark.parametrize("suppress_warnings", [True, False]) def test_stderr_suppression( - interactive_spec: JSONDict, suppress_warnings: bool, capsys: SysCapture, + interactive_spec: JSONDict, + suppress_warnings: bool, + capsys: SysCapture, ) -> None: message = "WARN Can not resolve event source: window" diff --git a/altair_saver/tests/test_utils.py b/altair_saver/tests/test_utils.py index ae5840b..1c79083 100644 --- a/altair_saver/tests/test_utils.py +++ b/altair_saver/tests/test_utils.py @@ -1,5 +1,6 @@ import http import io +import pathlib import socket import subprocess import tempfile @@ -43,32 +44,39 @@ def request(*args: Any, **kwargs: Any) -> None: ("vl.json", "vega-lite"), ], ) -@pytest.mark.parametrize("use_filename", [True, False]) -def test_extract_format(ext: str, fmt: str, use_filename: bool) -> None: - if use_filename: +@pytest.mark.parametrize("fp_type", ["string", "path", "pointer", "stream"]) +def test_extract_format(ext: str, fmt: str, fp_type: str) -> None: + if fp_type == "string": filename = f"chart.{ext}" assert extract_format(filename) == fmt - else: + elif fp_type == "path": + filepath = pathlib.Path(f"chart.{ext}") + assert extract_format(filepath) == fmt + elif fp_type == "pointer": with tempfile.NamedTemporaryFile(suffix=f".{ext}") as fp: assert extract_format(fp) == fmt - - -def test_extract_format_failure() -> None: - fp = io.StringIO() - with pytest.raises(ValueError) as err: - extract_format(fp) - assert f"Cannot infer format from {fp}" in str(err.value) + elif fp_type == "stream": + string_io = io.StringIO() + with pytest.raises(ValueError) as err: + extract_format(string_io) + assert f"Cannot infer format from {string_io}" in str(err.value) @pytest.mark.parametrize("mode", ["w", "wb"]) -def test_maybe_open_filename(mode: str) -> None: - content_raw = "testing maybe_open with filename\n" +@pytest.mark.parametrize("fp_type", ["string", "path"]) +def test_maybe_open_filename(mode: str, fp_type: str) -> None: + content_raw = "testing maybe_open with filename or path\n" content = content_raw.encode() if "b" in mode else content_raw with temporary_filename() as filename: - with maybe_open(filename, mode) as f: + if fp_type == "path": + fp = pathlib.Path(filename) + elif fp_type == "string": + fp = filename + + with maybe_open(fp, mode) as f: f.write(content) - with open(filename, "rb" if "b" in mode else "r") as f: + with open(fp, "rb" if "b" in mode else "r") as f: assert f.read() == content