Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

allow Paths in save() #83

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions altair_saver/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from http import client
import io
import os
import pathlib
import socket
import subprocess
import sys
Expand Down Expand Up @@ -133,9 +134,10 @@ def temporary_filename(


@contextlib.contextmanager
def maybe_open(fp: Union[IO, str], mode: str = "w") -> Iterator[IO]:
"""Context manager to write to a file specified by filename or file-like object"""
if isinstance(fp, str):
def maybe_open(fp: Union[IO, str, pathlib.PurePath], mode: str = "w") -> Iterator[IO]:
"""Context manager to write to a file specified by filename, Path or
file-like object"""
if isinstance(fp, str) or isinstance(fp, pathlib.PurePath):
with open(fp, mode) as f:
yield f
elif isinstance(fp, io.TextIOBase) and "b" in mode:
Expand All @@ -150,11 +152,13 @@ def maybe_open(fp: Union[IO, str], mode: str = "w") -> Iterator[IO]:
yield fp


def extract_format(fp: Union[IO, str]) -> str:
"""Extract the altair_saver output format from a file or filename."""
def extract_format(fp: Union[IO, str, pathlib.PurePath]) -> str:
"""Extract the altair_saver output format from a file, filename or Path."""
filename: Optional[str]
if isinstance(fp, str):
filename = fp
elif isinstance(fp, pathlib.PurePath):
filename = str(fp)
else:
filename = getattr(fp, "name", None)
if filename is None:
Expand Down
9 changes: 7 additions & 2 deletions altair_saver/savers/_selenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,16 @@ def _serve(cls, content: str, js_resources: Dict[str, str]) -> str:
if cls._provider is None:
cls._provider = Provider()
resource = cls._provider.create(
content=content, route="", headers={"Access-Control-Allow-Origin": "*"},
content=content,
route="",
headers={"Access-Control-Allow-Origin": "*"},
)
cls._resources[resource.url] = resource
for route, content in js_resources.items():
cls._resources[route] = cls._provider.create(content=content, route=route,)
cls._resources[route] = cls._provider.create(
content=content,
route=route,
)
return resource.url

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion altair_saver/savers/tests/test_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def test_html_save(
@pytest.mark.parametrize("embed_options", [None, {"theme": "dark"}])
@pytest.mark.parametrize("case, data", get_testcases())
def test_html_mimebundle(
case: str, data: Dict[str, Any], embed_options: Optional[dict],
case: str,
data: Dict[str, Any],
embed_options: Optional[dict],
) -> None:
saver = HTMLSaver(data["vega-lite"], embed_options=embed_options)
bundle = saver.mimebundle("html")
Expand Down
4 changes: 3 additions & 1 deletion altair_saver/savers/tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def exec_path(name: str) -> str:

@pytest.mark.parametrize("suppress_warnings", [True, False])
def test_stderr_suppression(
interactive_spec: JSONDict, suppress_warnings: bool, capsys: SysCapture,
interactive_spec: JSONDict,
suppress_warnings: bool,
capsys: SysCapture,
) -> None:
message = "WARN Can not resolve event source: window"

Expand Down
38 changes: 23 additions & 15 deletions altair_saver/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import http
import io
import pathlib
import socket
import subprocess
import tempfile
Expand Down Expand Up @@ -43,32 +44,39 @@ def request(*args: Any, **kwargs: Any) -> None:
("vl.json", "vega-lite"),
],
)
@pytest.mark.parametrize("use_filename", [True, False])
def test_extract_format(ext: str, fmt: str, use_filename: bool) -> None:
if use_filename:
@pytest.mark.parametrize("fp_type", ["string", "path", "pointer", "stream"])
def test_extract_format(ext: str, fmt: str, fp_type: str) -> None:
if fp_type == "string":
filename = f"chart.{ext}"
assert extract_format(filename) == fmt
else:
elif fp_type == "path":
filepath = pathlib.Path(f"chart.{ext}")
assert extract_format(filepath) == fmt
elif fp_type == "pointer":
with tempfile.NamedTemporaryFile(suffix=f".{ext}") as fp:
assert extract_format(fp) == fmt


def test_extract_format_failure() -> None:
fp = io.StringIO()
with pytest.raises(ValueError) as err:
extract_format(fp)
assert f"Cannot infer format from {fp}" in str(err.value)
elif fp_type == "stream":
string_io = io.StringIO()
with pytest.raises(ValueError) as err:
extract_format(string_io)
assert f"Cannot infer format from {string_io}" in str(err.value)


@pytest.mark.parametrize("mode", ["w", "wb"])
def test_maybe_open_filename(mode: str) -> None:
content_raw = "testing maybe_open with filename\n"
@pytest.mark.parametrize("fp_type", ["string", "path"])
def test_maybe_open_filename(mode: str, fp_type: str) -> None:
content_raw = "testing maybe_open with filename or path\n"
content = content_raw.encode() if "b" in mode else content_raw

with temporary_filename() as filename:
with maybe_open(filename, mode) as f:
if fp_type == "path":
fp = pathlib.Path(filename)
elif fp_type == "string":
fp = filename

with maybe_open(fp, mode) as f:
f.write(content)
with open(filename, "rb" if "b" in mode else "r") as f:
with open(fp, "rb" if "b" in mode else "r") as f:
assert f.read() == content


Expand Down