Skip to content

Commit

Permalink
refactor: Reorder alt.datasets module
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Nov 9, 2024
1 parent e6dd27e commit 2a7bc4f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 72 deletions.
52 changes: 26 additions & 26 deletions altair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,6 @@ class Loader(Generic[IntoDataFrameT, IntoFrameT]):

_reader: _Reader[IntoDataFrameT, IntoFrameT]

# TODO: docs (parameters, examples)
def url(
self,
name: DatasetName | LiteralString,
suffix: Extension | None = None,
/,
tag: VersionTag | None = None,
) -> str:
"""Return the address of a remote dataset."""
return self._reader.url(name, suffix, tag=tag)

# TODO: docs (parameters, examples)
def __call__(
self,
name: DatasetName | LiteralString,
suffix: Extension | None = None,
/,
tag: VersionTag | None = None,
**kwds: Any,
) -> IntoDataFrameT:
"""Get a remote dataset and load as tabular data."""
return self._reader.dataset(name, suffix, tag=tag, **kwds)

def __repr__(self) -> str:
return f"{type(self).__name__}[{self._reader._name}]"

@overload
@classmethod
def with_backend(
Expand Down Expand Up @@ -157,6 +131,29 @@ def with_backend(cls, backend: _Backend, /) -> Loader[Any, Any]:
obj._reader = get_backend(backend)
return obj

# TODO: docs (parameters, examples)
def __call__(
self,
name: DatasetName | LiteralString,
suffix: Extension | None = None,
/,
tag: VersionTag | None = None,
**kwds: Any,
) -> IntoDataFrameT:
"""Get a remote dataset and load as tabular data."""
return self._reader.dataset(name, suffix, tag=tag, **kwds)

# TODO: docs (parameters, examples)
def url(
self,
name: DatasetName | LiteralString,
suffix: Extension | None = None,
/,
tag: VersionTag | None = None,
) -> str:
"""Return the address of a remote dataset."""
return self._reader.url(name, suffix, tag=tag)

@property
def cache_dir(self) -> Path | None:
"""
Expand Down Expand Up @@ -186,6 +183,9 @@ def cache_dir(self, source: StrPath, /) -> None:

os.environ[self._reader._ENV_VAR] = str(source)

def __repr__(self) -> str:
return f"{type(self).__name__}[{self._reader._name}]"


def __getattr__(name):
if name == "data":
Expand Down
92 changes: 46 additions & 46 deletions altair/datasets/_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,10 @@ class _Reader(Generic[IntoDataFrameT, IntoFrameT], Protocol):
_read_fn: dict[Extension, Callable[..., IntoDataFrameT]]
_scan_fn: dict[_ExtensionScan, Callable[..., IntoFrameT]]
_name: LiteralString
_opener: ClassVar[OpenerDirector] = urllib.request.build_opener()
_ENV_VAR: ClassVar[LiteralString] = "ALTAIR_DATASETS_DIR"
_opener: ClassVar[OpenerDirector] = urllib.request.build_opener()
_metadata: Path = Path(__file__).parent / "_metadata" / "metadata.parquet"

@property
def _cache(self) -> Path | None: # type: ignore[return]
"""
Returns path to datasets cache, if possible.
Requires opt-in via environment variable::
Reader._ENV_VAR
"""
if _dir := os.environ.get(self._ENV_VAR):
cache_dir = Path(_dir)
cache_dir.mkdir(exist_ok=True)
return cache_dir

def reader_from(self, source: StrPath, /) -> Callable[..., IntoDataFrameT]:
suffix = validate_suffix(source, is_ext_supported)
return self._read_fn[suffix]
Expand All @@ -112,21 +98,6 @@ def scanner_from(self, source: StrPath, /) -> Callable[..., IntoFrameT]:
suffix = validate_suffix(source, is_ext_scan)
return self._scan_fn[suffix]

def url(
self,
name: DatasetName | LiteralString,
suffix: Extension | None = None,
/,
tag: VersionTag | None = None,
) -> str:
df = self._query(**validate_constraints(name, suffix, tag))
url = df.item(0, "url_npm")
if isinstance(url, str):
return url
else:
msg = f"Expected 'str' but got {type(url).__name__!r} from {url!r}."
raise TypeError(msg)

def dataset(
self,
name: DatasetName | LiteralString,
Expand All @@ -145,7 +116,7 @@ def dataset(
**kwds
Arguments passed to the underlying read function.
"""
df = self._query(**validate_constraints(name, suffix, tag))
df = self.query(**validate_constraints(name, suffix, tag))
it = islice(df.iter_rows(named=True), 1)
result = cast("Metadata", next(it))
url = result["url_npm"]
Expand All @@ -164,7 +135,22 @@ def dataset(
with self._opener.open(url) as f:
return fn(f.read(), **kwds)

def _query(
def url(
self,
name: DatasetName | LiteralString,
suffix: Extension | None = None,
/,
tag: VersionTag | None = None,
) -> str:
df = self.query(**validate_constraints(name, suffix, tag))
url = df.item(0, "url_npm")
if isinstance(url, str):
return url
else:
msg = f"Expected 'str' but got {type(url).__name__!r} from {url!r}."
raise TypeError(msg)

def query(
self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata]
) -> nw.DataFrame[IntoDataFrameT]:
r"""
Expand Down Expand Up @@ -192,6 +178,20 @@ def _query(
msg = f"Found no results for:\n{terms}"
raise NotImplementedError(msg)

@property
def _cache(self) -> Path | None: # type: ignore[return]
"""
Returns path to datasets cache, if possible.
Requires opt-in via environment variable::
Reader._ENV_VAR
"""
if _dir := os.environ.get(self._ENV_VAR):
cache_dir = Path(_dir)
cache_dir.mkdir(exist_ok=True)
return cache_dir

def _import(self, name: str, /) -> Any:
if spec := find_spec(name):
return import_module(spec.name)
Expand All @@ -205,6 +205,20 @@ def __repr__(self) -> str:
def __init__(self, name: LiteralString, /) -> None: ...


class _PandasReader(_Reader["pd.DataFrame", "pd.DataFrame"]):
def __init__(self, name: _Pandas, /) -> None:
self._name = _requirements(name)
if not TYPE_CHECKING:
pd = self._import(self._name)
self._read_fn = {
".csv": pd.read_csv,
".json": pd.read_json,
".tsv": cast(partial["pd.DataFrame"], partial(pd.read_csv, sep="\t")),
".arrow": pd.read_feather,
}
self._scan_fn = {".parquet": pd.read_parquet}


class _PandasPyArrowReader(_Reader["pd.DataFrame", "pd.DataFrame"]):
def __init__(self, name: Literal["pandas[pyarrow]"], /) -> None:
_pd, _pa = _requirements(name)
Expand All @@ -229,20 +243,6 @@ def __init__(self, name: Literal["pandas[pyarrow]"], /) -> None:
self._scan_fn = {".parquet": partial(pd.read_parquet, dtype_backend="pyarrow")}


class _PandasReader(_Reader["pd.DataFrame", "pd.DataFrame"]):
def __init__(self, name: _Pandas, /) -> None:
self._name = _requirements(name)
if not TYPE_CHECKING:
pd = self._import(self._name)
self._read_fn = {
".csv": pd.read_csv,
".json": pd.read_json,
".tsv": cast(partial["pd.DataFrame"], partial(pd.read_csv, sep="\t")),
".arrow": pd.read_feather,
}
self._scan_fn = {".parquet": pd.read_parquet}


class _PolarsReader(_Reader["pl.DataFrame", "pl.LazyFrame"]):
def __init__(self, name: _Polars, /) -> None:
self._name = _requirements(name)
Expand Down

0 comments on commit 2a7bc4f

Please sign in to comment.