Skip to content

Commit e81ee52

Browse files
author
Volker L.
committed
Merge remote-tracking branch 'refs/remotes/origin/main'
2 parents c95c7b5 + 1d3a430 commit e81ee52

9 files changed

+226
-25
lines changed

pydala/catalog2.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ class Table:
4141
write_args: dict = field(default_factory=dict)
4242

4343
def _load_parquet(self):
44-
45-
44+
pass
4645

4746

4847
@dataclass

pydala/dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
)
6767
self._timestamp_column = timestamp_column
6868

69-
self.load_files()
69+
#self.load_files()
7070

7171
if self.has_files:
7272
if partitioning == "ignore":
@@ -732,7 +732,7 @@ def write_to_dataset(
732732
self.delete_files(del_files)
733733

734734
self.clear_cache()
735-
self.load_files()
735+
#self.load_files()
736736

737737

738738
class ParquetDataset(PydalaDatasetMetadata, BaseDataset):

pydala/filesystem.py

+115
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime as dt
22
import inspect
33
import os
4+
import asyncio
45
from datetime import datetime, timedelta
56
from functools import wraps
67
from pathlib import Path
@@ -13,6 +14,7 @@
1314
import pyarrow as pa
1415
import pyarrow.dataset as pds
1516
import pyarrow.parquet as pq
17+
import pyarrow.fs as pfs
1618
import s3fs
1719
from fsspec import AbstractFileSystem, filesystem
1820
from fsspec.implementations.cache_mapper import AbstractCacheMapper
@@ -26,6 +28,17 @@
2628
from .schema import shrink_large_string
2729

2830

31+
def get_credentials_from_fssspec(fs: AbstractFileSystem) -> dict[str, str]:
32+
if "s3" in fs.protocol:
33+
credendials = fs.s3._get_credentials()
34+
return {
35+
"access_key": credendials.access_key,
36+
"secret_key": credendials.secret_key,
37+
"session_token": credendials.token,
38+
"endpoint_override": fs.s3._endpoint.host,
39+
}
40+
41+
2942
def get_total_directory_size(directory: str):
3043
return sum(f.stat().st_size for f in Path(directory).glob("**/*") if f.is_file())
3144

@@ -733,6 +746,12 @@ def sync_folder(
733746
self.cp(new_src, dst)
734747

735748

749+
def list_files_recursive(self, path:str, format:str=""):
750+
bucket, prefix = path.split("/", maxsplit=1)
751+
return [f["Key"] for f in asyncio.run(self.s3.list_objects_v2(Bucket=bucket, Prefix=prefix))["Contents"] if f["Key"].endswith(format)]
752+
753+
754+
736755
AbstractFileSystem.read_parquet = read_parquet
737756
AbstractFileSystem.read_parquet_dataset = read_parquet_dataset
738757
AbstractFileSystem.write_parquet = write_parquet
@@ -763,6 +782,7 @@ def sync_folder(
763782
# AbstractFileSystem.parallel_mv = parallel_mv
764783
# AbstractFileSystem.parallel_rm = parallel_rm
765784
AbstractFileSystem.sync_folder = sync_folder
785+
AbstractFileSystem.list_files_recursive = list_files_recursive
766786

767787

768788
def FileSystem(
@@ -828,9 +848,104 @@ def FileSystem(
828848
same_names=same_names,
829849
**kwargs,
830850
)
851+
852+
return fs
853+
854+
855+
def PyArrowFileSystem(
856+
bucket: str | None = None,
857+
fs: AbstractFileSystem | None = None,
858+
access_key: str | None = None,
859+
secret_key: str | None = None,
860+
session_token: str | None = None,
861+
endpoint_override: str | None = None,
862+
protocol: str | None = None,
863+
) -> pfs.FileSystem:
864+
credentials = None
865+
if fs is not None:
866+
protocol = fs.protocol[0] if isinstance(fs.protocol, tuple) else fs.protocol
867+
868+
if protocol == "dir":
869+
bucket = fs.path
870+
fs = fs.fs
871+
protocol = fs.protocol[0] if isinstance(fs.protocol, tuple) else fs.protocol
872+
873+
if protocol == "s3":
874+
credentials = get_credentials_from_fssspec(fs)
875+
876+
if credentials is None:
877+
credentials = {
878+
"access_key": access_key,
879+
"secret_key": secret_key,
880+
"session_token": session_token,
881+
"endpoint_override": endpoint_override,
882+
}
883+
if protocol == "s3":
884+
fs = pfs.S3FileSystem(
885+
**credentials,
886+
)
887+
elif protocol in ("file", "local", None):
888+
fs = pfs.LocalFileSystem()
889+
890+
else:
891+
fs = pfs.LocalFileSystem()
892+
893+
if bucket is not None:
894+
if protocol in ["file", "local", "None"]:
895+
bucket = os.path.abspath(bucket)
896+
897+
fs = pfs.SubTreeFileSystem(base_fs=fs, base_path=bucket)
898+
831899
return fs
832900

833901

902+
# class FileSystem:
903+
# def __init__(
904+
# self,
905+
# bucket: str | None = None,
906+
# fs: AbstractFileSystem | None = None,
907+
# profile: str | None = None,
908+
# key: str | None = None,
909+
# endpoint_url: str | None = None,
910+
# secret: str | None = None,
911+
# token: str | None = None,
912+
# protocol: str | None = None,
913+
# cached: bool = False,
914+
# cache_storage="~/.tmp",
915+
# check_files: bool = False,
916+
# cache_check: int = 120,
917+
# expire_time: int = 24 * 60 * 60,
918+
# same_names: bool = False,
919+
# **kwargs,
920+
# ):
921+
# self._fsspec_fs = FsSpecFileSystem(
922+
# bucket=bucket,
923+
# fs=fs,
924+
# profile=profile,
925+
# key=key,
926+
# endpoint_url=endpoint_url,
927+
# secret=secret,
928+
# token=token,
929+
# protocol=protocol,
930+
# cached=cached,
931+
# cache_storage=cache_storage,
932+
# check_files=check_files,
933+
# cache_check=cache_check,
934+
# expire_time=expire_time,
935+
# same_names=same_names,
936+
# **kwargs,
937+
# )
938+
# self._pfs_fs = PyArrowFileSystem(
939+
# bucket=bucket,
940+
# fs=fs,
941+
# access_key=key,
942+
# secret_key=secret,
943+
# session_token=token,
944+
# endpoint_override=endpoint_url,
945+
# protocol=protocol,
946+
# )
947+
948+
834949
def clear_cache(fs: AbstractFileSystem | None):
835950
if hasattr(fs, "dir_cache"):
836951
if fs is not None:

pydala/helpers/_fsspec_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _check_file(self, path):
106106
return fn
107107
logger.info(f"Downloading {self.protocol[0]}://{path}")
108108

109-
#def glob(self, path):
109+
# def glob(self, path):
110110
# return [self._strip_protocol(path)]
111111

112112
def size(self, path):

pydala/helpers/datetime.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def timestamp_from_string(
8686
datetime.datetime: The datetime object.
8787
"""
8888
# Extract the timezone from the string if not provided
89-
#tz = extract_timezone(timestamp) if tz is None else tz
90-
#timestamp = timestamp.replace(tz, "").strip() if tz else timestamp
89+
# tz = extract_timezone(timestamp) if tz is None else tz
90+
# timestamp = timestamp.replace(tz, "").strip() if tz else timestamp
9191

9292
pdl_timestamp = pdl.parse(timestamp, exact=exact, strict=strict)
9393

pydala/metadata.py

+46-9
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,40 @@ def remove_from_metadata(
107107

108108
return metadata
109109

110+
def get_file_paths(metadata: pq.FileMetaData,) -> list[str]:
111+
return [metadata.row_group(i).column(0).file_path for i in range(metadata.num_row_groups)]
112+
113+
class FileMetadata:
114+
def __init__(
115+
self,
116+
path:str,
117+
filesystem: AbstractFileSystem | pfs.FileSystem | None = None,
118+
bucket: str | None = None,
119+
cached: bool = False,
120+
**caching_options,
121+
**kwargs,
122+
):
123+
self._path = path
124+
self._bucket = bucket
125+
self._cached = cached
126+
self._base_filesystem = filesystem
127+
self._filesystem = FileSystem(
128+
bucket=bucket, fs=filesystem, cached=cached, **caching_options
129+
)
130+
131+
self._caching_options = caching_options
132+
133+
self.load_files()
134+
135+
def load_files(self):
136+
self._files = self._filesystem.list_files_recursive(self._path)
137+
138+
139+
140+
141+
@property
142+
def fs(self):
143+
return self._filesystem
110144

111145
class ParquetDatasetMetadata:
112146
def __init__(
@@ -140,7 +174,7 @@ def __init__(
140174
)
141175

142176
self._makedirs()
143-
self.load_files()
177+
#self.load_files()
144178

145179
self._caching_options = caching_options
146180

@@ -170,7 +204,10 @@ def _makedirs(self):
170204
self._filesystem.touch(os.path.join(self._path, "tmp.delete"))
171205
self._filesystem.rm(os.path.join(self._path, "tmp.delete"))
172206

173-
def load_files(self) -> None:
207+
def load_files(self)->None:
208+
self._files = get_file_paths(self._metadata)
209+
210+
def _ls_files(self) -> None:
174211
"""
175212
Reloads the list of files in the dataset directory. This method should be called
176213
after adding or removing files from the directory to ensure that the dataset object
@@ -180,7 +217,7 @@ def load_files(self) -> None:
180217
None
181218
"""
182219
self.clear_cache()
183-
self._files = [
220+
return [
184221
fn.replace(self._path, "").lstrip("/")
185222
for fn in sorted(
186223
self._filesystem.glob(os.path.join(self._path, "**/*.parquet"))
@@ -200,7 +237,7 @@ def _collect_file_metadata(self, files: list[str] | None = None, **kwargs) -> No
200237
None
201238
"""
202239
if files is None:
203-
files = self._files
240+
files = self._ls_files()
204241

205242
file_metadata = collect_parquet_metadata(
206243
files=files,
@@ -211,8 +248,8 @@ def _collect_file_metadata(self, files: list[str] | None = None, **kwargs) -> No
211248

212249
# if file_metadata:
213250
for f in file_metadata:
214-
file_metadata[f.replace(self._path, "")].set_file_path(
215-
f.split(self._path)[-1].lstrip("/")
251+
file_metadata[f].set_file_path(
252+
f
216253
)
217254

218255
if self.has_file_metadata:
@@ -250,14 +287,14 @@ def update_file_metadata(self, files: list[str] | None = None, **kwargs) -> None
250287
"""
251288

252289
# Add new files to file_metadata
253-
self.load_files()
290+
all_files = self._ls_files()
254291

255292
new_files = []
256293
rm_files = []
257294

258295
if self.has_file_metadata:
259-
new_files += sorted(set(self.files) - set(self._file_metadata.keys()))
260-
rm_files += sorted(set(self._file_metadata.keys()) - set(self.files))
296+
new_files += sorted(set(all_files) - set(self._file_metadata.keys()))
297+
rm_files += sorted(set(self._file_metadata.keys()) - set(all_files))
261298

262299
else:
263300
new_files += sorted(set(new_files + self._files))

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ dev-dependencies = [
5656
"datafusion>=42.0.0",
5757
"ibis>=3.3.0",
5858
"ibis-framework[duckdb,polars]>=9.5.0",
59-
"obstore>=0.2.0",
60-
"boto3>=1.35.49",
59+
"ruff>=0.7.1",
60+
"adlfs>=2024.7.0",
6161
]
6262
managed = true
6363

0 commit comments

Comments
 (0)