-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathio.py
403 lines (351 loc) · 14.3 KB
/
io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import datetime as dt
import os
import time
import uuid
import duckdb
import pandas as pd
import polars.selectors as cs
import pyarrow as pa
import pyarrow.dataset as pds
# import pyarrow.dataset as pds
import pyarrow.parquet as pq
from fsspec import AbstractFileSystem
from fsspec import filesystem as fsspec_filesystem
from .filesystem import clear_cache
from .helpers.datetime import get_timestamp_column
from .helpers.polars import pl
from .schema import convert_timestamp, replace_schema, shrink_large_string
from .table import PydalaTable
def write_table(
table: pa.Table,
path: str,
filesystem: AbstractFileSystem | None = None,
row_group_size: int | None = None,
compression: str = "zstd",
**kwargs,
) -> tuple[str, pq.FileMetaData]:
"""
Writes a PyArrow table to Parquet format.
Args:
table (pa.Table): The PyArrow table to write.
path (str): The path to write the Parquet file to.
filesystem (AbstractFileSystem | None, optional): The filesystem to use for writing the file. Defaults to None.
row_group_size (int | None, optional): The size of each row group in the Parquet file. Defaults to None.
compression (str, optional): The compression algorithm to use. Defaults to "zstd".
**kwargs: Additional keyword arguments to pass to `pq.write_table`.
Returns:
tuple[str, pq.FileMetaData]: A tuple containing the file path and the metadata of the written Parquet file.
"""
if not filesystem.exists(os.path.dirname(path)):
try:
filesystem.makedirs(os.path.dirname(path), exist_ok=True)
except Exception:
pass
if filesystem is None:
filesystem = fsspec_filesystem("file")
filesystem.invalidate_cache()
metadata = []
pq.write_table(
table,
path,
filesystem=filesystem,
row_group_size=row_group_size,
compression=compression,
metadata_collector=metadata,
allow_truncated_timestamps=True,
**kwargs,
)
metadata = metadata[0]
# metadata.set_file_path(path)
return path, metadata
class Writer:
def __init__(
self,
data: (
pa.Table
| pa.RecordBatch
| pl.DataFrame
| pl.LazyFrame
| pd.DataFrame
| duckdb.DuckDBPyRelation
),
path: str,
schema: pa.Schema | None,
filesystem: AbstractFileSystem | None = None,
):
"""
Initialize the object with the given data, path, schema, and filesystem.
Parameters:
data (pa.Table | pl.DataFrame | pl.LazyFrame | pd.DataFrame | duckdb.DuckDBPyRelation):
The input data, which can be one of the following types: pa.Table, pl.DataFrame, pl.LazyFrame,
pd.DataFrame, duckdb.DuckDBPyRelation.
path (str): The path of the data.
schema (pa.Schema | None): The schema of the data, if available.
filesystem (AbstractFileSystem | None, optional): The filesystem to use, defaults to None.
Returns:
None
"""
self.schema = schema
self.data = (
data
if not isinstance(data, pa.RecordBatch)
else pa.Table.from_batches([data])
)
self.base_path = path
self.path = None
self._filesystem = filesystem
def _to_polars(self):
"""
Convert the data attribute to a Polars DataFrame.
This function checks the type of the data attribute and converts it to a Polars DataFrame if it is not
already one.
It supports conversion from Arrow tables, Pandas DataFrames, and DuckDBPyRelations.
"""
if isinstance(self.data, pa.Table):
self.data = pl.from_arrow(self.data)
elif isinstance(self.data, pd.DataFrame):
self.data = pl.from_pandas(self.data)
elif isinstance(self.data, duckdb.DuckDBPyRelation):
self.data = self.data.pl()
def _to_arrow(self):
"""
Convert the data in the DataFrame to Arrow format.
This method checks the type of the data and converts it to Arrow format accordingly.
It supports conversion from Polars DataFrames, Polars LazyFrames, Pandas DataFrames, and DuckDBPyRelations.
"""
if isinstance(self.data, pl.DataFrame):
self.data = self.data.to_arrow()
elif isinstance(self.data, pl.LazyFrame):
self.data = self.data.collect().to_arrow()
elif isinstance(self.data, pd.DataFrame):
self.data = pa.Table.from_pandas(self.data)
elif isinstance(self.data, duckdb.DuckDBPyRelation):
self.data = self.data.arrow()
def _set_schema(self):
"""
Sets the schema of the DataFrame.
This private method is called internally to set the schema of the DataFrame. It first converts the DataFrame
to an Arrow table using the `_to_arrow()` method. Then, it checks if a schema has already been set for the
DataFrame.
If not, it assigns the schema of the DataFrame's underlying data to the `schema` attribute.
"""
self._to_arrow()
self.schema = self.schema or self.data.schema
def sort_data(self, by: str | list[str] | list[tuple[str, str]] | None = None):
"""
Sorts the data in the PydalaTable object based on the specified column(s).
Args:
by (str | list[str] | list[tuple[str, str]] | None): The column(s) to sort by.
If a single string is provided, the data will be sorted in ascending order based on that column.
If a list of strings is provided, the data will be sorted in ascending order based on each
column in the list.
If a list of tuples is provided, each tuple should contain a column name and a sort order
("ascending" or "descending").
If None is provided, the data will not be sorted.
Returns:
None
"""
if by is not None:
self._to_arrow()
by = PydalaTable._get_sort_by(by, type_="pyarrow")
self.data = self.data.sort_by(**by)
def unique(self, columns: bool | str | list[str] = False):
"""
Generates a unique subset of the DataFrame based on the specified columns.
Args:
columns (bool | str | list[str], optional): The columns to use for determining uniqueness.
If set to False, uniqueness is determined based on all columns.
If set to a string, uniqueness is determined based on the specified column.
If set to a list of strings, uniqueness is determined based on the specified columns.
Defaults to False.
"""
if columns is not None:
self._to_polars()
self.data = self.data.with_columns(cs.by_dtype(pl.Null()).cast(pl.Int64()))
if isinstance(columns, bool):
columns = None
self.data = self.data.unique(columns, maintain_order=True)
def add_datepart_columns(
self, columns: list[str], timestamp_column: str | None = None
):
"""
Adds datepart columns to the data.
Args:
timestamp_column (str): The name of the timestamp column.
columns (list[str]): Date part columns to add. The available options are: "year",
"month", "week", "yearday", monthday", "weekday".
Returns:
None
"""
if columns is None:
columns = []
if isinstance(columns, str):
columns = [columns]
if timestamp_column is None:
timestamp_column = get_timestamp_column(self.data)
timestamp_column = timestamp_column[0] if len(timestamp_column) else None
if timestamp_column is not None:
self._set_schema()
self._to_polars()
datepart_columns = {
col: True
for col in self.schema.names + columns
if col
in [
"year",
"month",
"week",
"yearday",
"monthday",
"weekday",
"day",
"hour",
"minute",
]
}
self.data = self.data.with_datepart_columns(
timestamp_column=timestamp_column, **datepart_columns
)
self._to_arrow()
for col in datepart_columns:
if col not in self.schema.names:
if col == "weekday":
self.schema.append(pa.field(col, pa.string()))
self.schema = self.schema.append(pa.field(col, pa.int32()))
def cast_schema(
self,
use_large_string: bool = False,
tz: str = None,
ts_unit: str = None,
remove_tz: bool = False,
alter_schema: bool = False,
):
"""
Casts the schema of the data object based on the specified parameters.
Args:
use_large_string (bool, optional): Whether to use large string type. Defaults to False.
tz (str, optional): Timezone to convert timestamps to. Defaults to None.
ts_unit (str, optional): Unit to convert timestamps to. Defaults to None.
remove_tz (bool, optional): Whether to remove timezone from timestamps. Defaults to False.
alter_schema (bool, optional): Whether to alter the schema. Defaults to False.
"""
self._to_arrow()
self._set_schema()
self._use_large_string = use_large_string
if not use_large_string:
self.schema = shrink_large_string(self.schema)
if tz is not None or ts_unit is not None or remove_tz:
self.schema = convert_timestamp(
self.schema,
tz=tz,
unit=ts_unit,
remove_tz=remove_tz,
)
self.data = replace_schema(
self.data,
self.schema,
alter_schema=alter_schema,
)
self.schema = self.data.schema
def delta(
self,
other: pl.DataFrame | pl.LazyFrame,
subset: str | list[str] | None = None,
):
"""
Computes the difference between the current DataFrame and another DataFrame or LazyFrame.
Parameters:
other (DataFrame | LazyFrame): The DataFrame or LazyFrame to compute the difference with.
subset (str | list[str] | None, optional): The column(s) to compute the difference on. If `None`,
the difference is computed on all columns. Defaults to `None`.
"""
self._to_polars()
self.data = self.data.delta(other, subset=subset)
@property
def shape(self):
if self.data is None:
return 0
if isinstance(self.data, pl.LazyFrame):
self.data = self.data.collect()
return self.data.shape
def write_to_dataset(
self,
row_group_size: int | None = None,
compression: str = "zstd",
partitioning_columns: list[str] | None = None,
partitioning_flavor: str = "hive",
max_rows_per_file: int | None = None,
create_dir: bool = False,
basename: str | None = None,
**kwargs,
):
"""
Writes the data to a dataset in the Parquet format.
Args:
row_group_size (int | None, optional): The number of rows per row group. Defaults to None.
compression (str, optional): The compression algorithm to use. Defaults to "zstd".
partitioning_columns (list[str] | None, optional): The columns to use for partitioning the dataset.
Defaults to None.
partitioning_flavor (str, optional): The partitioning flavor to use. Defaults to "hive".
max_rows_per_file (int | None, optional): The maximum number of rows per file. Defaults to None.
create_dir (bool, optional): Whether to create directories for the dataset. Defaults to False.
basename (str | None, optional): The base name for the output files. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
None
"""
self._to_arrow()
if basename is None:
basename_template = (
"data-"
f"{dt.datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3]}-{uuid.uuid4().hex[:16]}-{{i}}.parquet"
)
else:
basename_template = f"{basename}-{{i}}.parquet"
file_options = pds.ParquetFileFormat().make_write_options(
compression=compression
)
if hasattr(self._filesystem, "fs"):
if "local" in self._filesystem.fs.protocol:
create_dir = True
else:
if "local" in self._filesystem.protocol:
create_dir = True
retries = 0
while retries < 2:
try:
pds.write_dataset(
self.data,
base_dir=self.base_path,
filesystem=self._filesystem,
file_options=file_options,
partitioning=partitioning_columns,
partitioning_flavor=partitioning_flavor,
basename_template=basename_template,
min_rows_per_group=row_group_size,
max_rows_per_group=row_group_size,
max_rows_per_file=max_rows_per_file,
existing_data_behavior="overwrite_or_ignore",
create_dir=create_dir,
format="parquet",
**kwargs,
)
break
except Exception as e:
retries += 1
if retries == 2:
raise e
self.clear_cache()
time.sleep(0.5)
def clear_cache(self) -> None:
"""
Clears the cache for the dataset's filesystem and base filesystem.
This method clears the cache for the dataset's filesystem and base filesystem,
which can be useful if the dataset has been modified and the cache needs to be
updated accordingly.
Returns:
None
"""
if hasattr(self._filesystem, "fs"):
clear_cache(self._filesystem.fs)
clear_cache(self._filesystem)