|
1 | 1 | import datetime as dt
|
2 | 2 | import inspect
|
3 | 3 | import os
|
| 4 | +import asyncio |
4 | 5 | from datetime import datetime, timedelta
|
5 | 6 | from functools import wraps
|
6 | 7 | from pathlib import Path
|
|
13 | 14 | import pyarrow as pa
|
14 | 15 | import pyarrow.dataset as pds
|
15 | 16 | import pyarrow.parquet as pq
|
| 17 | +import pyarrow.fs as pfs |
16 | 18 | import s3fs
|
17 | 19 | from fsspec import AbstractFileSystem, filesystem
|
18 | 20 | from fsspec.implementations.cache_mapper import AbstractCacheMapper
|
|
26 | 28 | from .schema import shrink_large_string
|
27 | 29 |
|
28 | 30 |
|
| 31 | +def get_credentials_from_fssspec(fs: AbstractFileSystem) -> dict[str, str]: |
| 32 | + if "s3" in fs.protocol: |
| 33 | + credendials = fs.s3._get_credentials() |
| 34 | + return { |
| 35 | + "access_key": credendials.access_key, |
| 36 | + "secret_key": credendials.secret_key, |
| 37 | + "session_token": credendials.token, |
| 38 | + "endpoint_override": fs.s3._endpoint.host, |
| 39 | + } |
| 40 | + |
| 41 | + |
29 | 42 | def get_total_directory_size(directory: str):
|
30 | 43 | return sum(f.stat().st_size for f in Path(directory).glob("**/*") if f.is_file())
|
31 | 44 |
|
@@ -733,6 +746,12 @@ def sync_folder(
|
733 | 746 | self.cp(new_src, dst)
|
734 | 747 |
|
735 | 748 |
|
| 749 | +def list_files_recursive(self, path:str, format:str=""): |
| 750 | + bucket, prefix = path.split("/", maxsplit=1) |
| 751 | + return [f["Key"] for f in asyncio.run(self.s3.list_objects_v2(Bucket=bucket, Prefix=prefix))["Contents"] if f["Key"].endswith(format)] |
| 752 | + |
| 753 | + |
| 754 | + |
736 | 755 | AbstractFileSystem.read_parquet = read_parquet
|
737 | 756 | AbstractFileSystem.read_parquet_dataset = read_parquet_dataset
|
738 | 757 | AbstractFileSystem.write_parquet = write_parquet
|
@@ -763,6 +782,7 @@ def sync_folder(
|
763 | 782 | # AbstractFileSystem.parallel_mv = parallel_mv
|
764 | 783 | # AbstractFileSystem.parallel_rm = parallel_rm
|
765 | 784 | AbstractFileSystem.sync_folder = sync_folder
|
| 785 | +AbstractFileSystem.list_files_recursive = list_files_recursive |
766 | 786 |
|
767 | 787 |
|
768 | 788 | def FileSystem(
|
@@ -828,9 +848,104 @@ def FileSystem(
|
828 | 848 | same_names=same_names,
|
829 | 849 | **kwargs,
|
830 | 850 | )
|
| 851 | + |
| 852 | + return fs |
| 853 | + |
| 854 | + |
| 855 | +def PyArrowFileSystem( |
| 856 | + bucket: str | None = None, |
| 857 | + fs: AbstractFileSystem | None = None, |
| 858 | + access_key: str | None = None, |
| 859 | + secret_key: str | None = None, |
| 860 | + session_token: str | None = None, |
| 861 | + endpoint_override: str | None = None, |
| 862 | + protocol: str | None = None, |
| 863 | +) -> pfs.FileSystem: |
| 864 | + credentials = None |
| 865 | + if fs is not None: |
| 866 | + protocol = fs.protocol[0] if isinstance(fs.protocol, tuple) else fs.protocol |
| 867 | + |
| 868 | + if protocol == "dir": |
| 869 | + bucket = fs.path |
| 870 | + fs = fs.fs |
| 871 | + protocol = fs.protocol[0] if isinstance(fs.protocol, tuple) else fs.protocol |
| 872 | + |
| 873 | + if protocol == "s3": |
| 874 | + credentials = get_credentials_from_fssspec(fs) |
| 875 | + |
| 876 | + if credentials is None: |
| 877 | + credentials = { |
| 878 | + "access_key": access_key, |
| 879 | + "secret_key": secret_key, |
| 880 | + "session_token": session_token, |
| 881 | + "endpoint_override": endpoint_override, |
| 882 | + } |
| 883 | + if protocol == "s3": |
| 884 | + fs = pfs.S3FileSystem( |
| 885 | + **credentials, |
| 886 | + ) |
| 887 | + elif protocol in ("file", "local", None): |
| 888 | + fs = pfs.LocalFileSystem() |
| 889 | + |
| 890 | + else: |
| 891 | + fs = pfs.LocalFileSystem() |
| 892 | + |
| 893 | + if bucket is not None: |
| 894 | + if protocol in ["file", "local", "None"]: |
| 895 | + bucket = os.path.abspath(bucket) |
| 896 | + |
| 897 | + fs = pfs.SubTreeFileSystem(base_fs=fs, base_path=bucket) |
| 898 | + |
831 | 899 | return fs
|
832 | 900 |
|
833 | 901 |
|
| 902 | +# class FileSystem: |
| 903 | +# def __init__( |
| 904 | +# self, |
| 905 | +# bucket: str | None = None, |
| 906 | +# fs: AbstractFileSystem | None = None, |
| 907 | +# profile: str | None = None, |
| 908 | +# key: str | None = None, |
| 909 | +# endpoint_url: str | None = None, |
| 910 | +# secret: str | None = None, |
| 911 | +# token: str | None = None, |
| 912 | +# protocol: str | None = None, |
| 913 | +# cached: bool = False, |
| 914 | +# cache_storage="~/.tmp", |
| 915 | +# check_files: bool = False, |
| 916 | +# cache_check: int = 120, |
| 917 | +# expire_time: int = 24 * 60 * 60, |
| 918 | +# same_names: bool = False, |
| 919 | +# **kwargs, |
| 920 | +# ): |
| 921 | +# self._fsspec_fs = FsSpecFileSystem( |
| 922 | +# bucket=bucket, |
| 923 | +# fs=fs, |
| 924 | +# profile=profile, |
| 925 | +# key=key, |
| 926 | +# endpoint_url=endpoint_url, |
| 927 | +# secret=secret, |
| 928 | +# token=token, |
| 929 | +# protocol=protocol, |
| 930 | +# cached=cached, |
| 931 | +# cache_storage=cache_storage, |
| 932 | +# check_files=check_files, |
| 933 | +# cache_check=cache_check, |
| 934 | +# expire_time=expire_time, |
| 935 | +# same_names=same_names, |
| 936 | +# **kwargs, |
| 937 | +# ) |
| 938 | +# self._pfs_fs = PyArrowFileSystem( |
| 939 | +# bucket=bucket, |
| 940 | +# fs=fs, |
| 941 | +# access_key=key, |
| 942 | +# secret_key=secret, |
| 943 | +# session_token=token, |
| 944 | +# endpoint_override=endpoint_url, |
| 945 | +# protocol=protocol, |
| 946 | +# ) |
| 947 | + |
| 948 | + |
834 | 949 | def clear_cache(fs: AbstractFileSystem | None):
|
835 | 950 | if hasattr(fs, "dir_cache"):
|
836 | 951 | if fs is not None:
|
|
0 commit comments