Skip to content

Commit

Permalink
Improve ParquetDatasetMetadata and PydalaDatasetMetadata init. Refact…
Browse files Browse the repository at this point in the history
…or imports and update dependencies in pyproject.toml and requirements.lock
  • Loading branch information
Volker Lorrmann committed Nov 22, 2024
1 parent e81ee52 commit 479e81d
Show file tree
Hide file tree
Showing 10 changed files with 435 additions and 232 deletions.
3 changes: 2 additions & 1 deletion pydala/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from .dataset import CsvDataset, JsonDataset, ParquetDataset, PyarrowDataset
from .filesystem import FileSystem
from .helpers.misc import delattr_rec, get_nested_keys, getattr_rec, setattr_rec
from .helpers.misc import (delattr_rec, get_nested_keys, getattr_rec,
setattr_rec)
from .helpers.sql import get_table_names
from .table import PydalaTable

Expand Down
102 changes: 39 additions & 63 deletions pydala/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
from .helpers.polars import pl as _pl
from .io import Writer
from .metadata import ParquetDatasetMetadata, PydalaDatasetMetadata
from .schema import (
replace_schema, # from .optimize import Optimize
shrink_large_string,
)
from .schema import replace_schema # from .optimize import Optimize
from .schema import shrink_large_string
from .table import PydalaTable


Expand Down Expand Up @@ -62,21 +60,31 @@ def __init__(
# enable object caching for e.g. parquet metadata
self.ddb_con.execute(
f"""PRAGMA enable_object_cache;
SET THREADS={psutil.cpu_count()*2};"""
SET THREADS={psutil.cpu_count() * 2};"""
)
self._timestamp_column = timestamp_column

#self.load_files()
# self.load_files()

if self.has_files:
if partitioning == "ignore":
self._partitioning = None
elif partitioning is None and "=" in self._files[0]:
self._partitioning = "hive"
else:
self._partitioning = partitioning
# NOTE: Set partitioning manually, if not set, try to infer it
if partitioning is None:
# try to infer partitioning
if any(["=" in obj for obj in self.fs.ls(self._path)]):
partitioning = "hive"
else:
self._partitioning = partitioning
if partitioning == "ignore":
partitioning = None
self._partitioning = partitioning

# if self.has_files:
# if partitioning == "ignore":
# self._partitioning = None
# elif partitioning is None and "=" in self._files[0]:
# self._partitioning = "hive"
# else:
# self._partitioning = partitioning
# else:
# self._partitioning = partitioning

try:
self.load()
Expand All @@ -99,9 +107,7 @@ def load_files(self) -> None:
self._files = [
fn.replace(self._path, "").lstrip("/")
for fn in sorted(
self._filesystem.glob(
os.path.join(self._path, f"**/*.{self._format}")
)
self._filesystem.glob(os.path.join(self._path, f"**/*.{self._format}"))
)
]

Expand Down Expand Up @@ -155,15 +161,11 @@ def load(self):
format=self._format,
partitioning=self._partitioning,
)
self.table = PydalaTable(
result=self._arrow_dataset, ddb_con=self.ddb_con
)
self.table = PydalaTable(result=self._arrow_dataset, ddb_con=self.ddb_con)
# self.ddb_con.register("arrow__dataset", self._arrow_parquet_dataset)

if self._timestamp_column is None:
self._timestamp_columns = get_timestamp_column(
self.table.pl.head(10)
)
self._timestamp_columns = get_timestamp_column(self.table.pl.head(10))
if len(self._timestamp_columns) > 0:
self._timestamp_column = self._timestamp_columns[0]

Expand Down Expand Up @@ -345,9 +347,7 @@ def partition_names(self) -> list:
if not hasattr(self, "_partition_names") and hasattr(
self._arrow_dataset, "partitioning"
):
self._partition_names = (
self._arrow_dataset.partitioning.schema.names
)
self._partition_names = self._arrow_dataset.partitioning.schema.names

return self._partition_names

Expand Down Expand Up @@ -472,12 +472,7 @@ def filter(
the method will automatically use DuckDB for filtering.
"""
if any(
[
s in filter_expr
for s in ["%", "like", "similar to", "*", "(", ")"]
]
):
if any([s in filter_expr for s in ["%", "like", "similar to", "*", "(", ")"]]):
use = "duckdb"

if use == "auto":
Expand Down Expand Up @@ -505,9 +500,7 @@ def registered_tables(self) -> list[str]:
Returns:
list[str]: A list of table names.
"""
return (
self.ddb_con.sql("SHOW TABLES").arrow().column("name").to_pylist()
)
return self.ddb_con.sql("SHOW TABLES").arrow().column("name").to_pylist()

def interrupt_duckdb(self):
"""
Expand Down Expand Up @@ -569,9 +562,7 @@ def _get_delta_other_df(
# _pl.first(col).alias("max"), _pl.last(col).alias("min")
# )
# else:
max_min = df.select(
_pl.max(col).alias("max"), _pl.min(col).alias("min")
)
max_min = df.select(_pl.max(col).alias("max"), _pl.min(col).alias("min"))

if collect:
max_min = max_min.collect()
Expand Down Expand Up @@ -732,7 +723,7 @@ def write_to_dataset(
self.delete_files(del_files)

self.clear_cache()
#self.load_files()
# self.load_files()


class ParquetDataset(PydalaDatasetMetadata, BaseDataset):
Expand Down Expand Up @@ -855,14 +846,10 @@ def load(
filesystem=self._filesystem,
)

self.table = PydalaTable(
result=self._arrow_dataset, ddb_con=self.ddb_con
)
self.table = PydalaTable(result=self._arrow_dataset, ddb_con=self.ddb_con)

if self._timestamp_column is None:
self._timestamp_columns = get_timestamp_column(
self.table.pl.head(10)
)
self._timestamp_columns = get_timestamp_column(self.table.pl.head(10))
if len(self._timestamp_columns) > 0:
self._timestamp_column = self._timestamp_columns[0]
if self._timestamp_column is not None:
Expand Down Expand Up @@ -1187,17 +1174,13 @@ def load(self):
.opt_dtype(strict=False)
.to_arrow()
)
self.table = PydalaTable(
result=self._arrow_dataset, ddb_con=self.ddb_con
)
self.table = PydalaTable(result=self._arrow_dataset, ddb_con=self.ddb_con)

self.ddb_con.register(f"{self.name}", self._arrow_dataset)
# self.ddb_con.register("arrow__dataset", self._arrow_parquet_dataset)

if self._timestamp_column is None:
self._timestamp_columns = get_timestamp_column(
self.table.pl.head(10)
)
self._timestamp_columns = get_timestamp_column(self.table.pl.head(10))
if len(self._timestamp_columns) > 1:
self._timestamp_column = self._timestamp_columns[0]

Expand Down Expand Up @@ -1255,9 +1238,7 @@ def _compact_partition(
# else:
# num_rows = 0

batches = scan.to_batch_reader(
sort_by=sort_by, batch_size=max_rows_per_file
)
batches = scan.to_batch_reader(sort_by=sort_by, batch_size=max_rows_per_file)
for batch in batches:
self.write_to_dataset(
pa.table(batch),
Expand Down Expand Up @@ -1331,7 +1312,7 @@ def _compact_by_timeperiod(
if len(self.scan_files) == 1:
date_diff = (
self.metadata_table.filter(
f"file_path='{self.scan_files[0].replace(self._path,'').lstrip('/')}'"
f"file_path='{self.scan_files[0].replace(self._path, '').lstrip('/')}'"
)
.aggregate("max(AE_DATUM.max) - min(AE_DATUM.min)")
.fetchone()[0]
Expand Down Expand Up @@ -1403,9 +1384,7 @@ def compact_by_timeperiod(
end_dates = dates[1:]

files_to_delete = []
for start_date, end_date in tqdm.tqdm(
list(zip(start_dates, end_dates))
):
for start_date, end_date in tqdm.tqdm(list(zip(start_dates, end_dates))):
files_to_delete_ = self._compact_by_timeperiod(
start_date=start_date,
end_date=end_date,
Expand Down Expand Up @@ -1515,16 +1494,13 @@ def _optimize_dtypes(
[
field
for field in scan.arrow_dataset.schema
if field.name
not in scan.arrow_dataset.partitioning.schema.names
if field.name not in scan.arrow_dataset.partitioning.schema.names
]
)

if schema != optimized_schema:
table = replace_schema(
scan.pl.opt_dtype(
strict=strict, exclude=exclude, include=include
)
scan.pl.opt_dtype(strict=strict, exclude=exclude, include=include)
.collect(streaming=True)
.to_arrow(),
schema=optimized_schema,
Expand Down
62 changes: 54 additions & 8 deletions pydala/filesystem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime as dt
import inspect
import os
import asyncio
from datetime import datetime, timedelta
from functools import wraps
from pathlib import Path
Expand All @@ -13,13 +12,12 @@
import psutil
import pyarrow as pa
import pyarrow.dataset as pds
import pyarrow.parquet as pq
import pyarrow.fs as pfs
import pyarrow.parquet as pq
import s3fs
from fsspec import AbstractFileSystem, filesystem
from fsspec.implementations.cache_mapper import AbstractCacheMapper
from fsspec.implementations.cached import SimpleCacheFileSystem

# from fsspec.implementations import cached as cachedfs
from fsspec.implementations.dirfs import DirFileSystem
from loguru import logger
Expand Down Expand Up @@ -746,10 +744,58 @@ def sync_folder(
self.cp(new_src, dst)


def list_files_recursive(self, path:str, format:str=""):
bucket, prefix = path.split("/", maxsplit=1)
return [f["Key"] for f in asyncio.run(self.s3.list_objects_v2(Bucket=bucket, Prefix=prefix))["Contents"] if f["Key"].endswith(format)]

# NOTE: This is not working properly due to some event loop issues

# def list_files_recursive(self, path: str, format: str = ""):
# bucket, prefix = path.split("/", maxsplit=1)
# return [
# f["Key"]
# for f in asyncio.run(self.s3.list_objects_v2(Bucket=bucket, Prefix=prefix))[
# "Contents"
# ]
# if f["Key"].endswith(format)
# ]


# async def _list_files_recursive(
# self, path: str, format: str = "", max_items: int = 10000
# ):
# bucket, prefix = path.split("/", maxsplit=1)
# continuation_token = None
# files = []

# while True:
# if continuation_token:
# response = await self.s3.list_objects_v2(
# Bucket=bucket,
# Prefix=prefix,
# ContinuationToken=continuation_token,
# MaxKeys=max_items,
# )
# else:
# response = await self.s3.list_objects_v2(
# Bucket=bucket, Prefix=prefix, MaxKeys=max_items
# )

# if "Contents" in response:
# files.extend(
# [f["Key"] for f in response["Contents"] if f["Key"].endswith(format)]
# )

# if response.get("IsTruncated"): # Check if there are more objects to retrieve
# continuation_token = response.get("NextContinuationToken")
# else:
# break

# return files


# def list_files_recursive(self, path: str, format: str = "", max_items: int = 10000):
# loop = asyncio.get_event_loop()
# if loop.is_closed():
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# return loop.run_until_complete(_list_files_recursive(self, path, format, max_items))


AbstractFileSystem.read_parquet = read_parquet
Expand Down Expand Up @@ -782,7 +828,7 @@ def list_files_recursive(self, path:str, format:str=""):
# AbstractFileSystem.parallel_mv = parallel_mv
# AbstractFileSystem.parallel_rm = parallel_rm
AbstractFileSystem.sync_folder = sync_folder
AbstractFileSystem.list_files_recursive = list_files_recursive
# AbstractFileSystem.list_files_recursive = list_files_recursive


def FileSystem(
Expand Down
2 changes: 1 addition & 1 deletion pydala/helpers/datetime.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime as dt
import re
from functools import lru_cache

import pendulum as pdl
import polars as pl
import polars.selectors as cs
import pyarrow as pa
from functools import lru_cache


def get_timestamp_column(df: pl.DataFrame | pl.LazyFrame | pa.Table) -> str | list[str]:
Expand Down
1 change: 0 additions & 1 deletion pydala/helpers/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
from typing import Any

import pendulum as pdl
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
Expand Down
1 change: 0 additions & 1 deletion pydala/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import polars.selectors as cs
import pyarrow as pa
import pyarrow.dataset as pds

# import pyarrow.dataset as pds
import pyarrow.parquet as pq
from fsspec import AbstractFileSystem
Expand Down
Loading

0 comments on commit 479e81d

Please sign in to comment.