-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathduckpond.py
93 lines (75 loc) · 2.84 KB
/
duckpond.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from duckdb import connect
from dagster import IOManager
import pandas as pd
from sqlescapy import sqlescape
from string import Template
from typing import Mapping
class SQL:
def __init__(self, sql, **bindings):
self.sql = sql
self.bindings = bindings
def sql_to_string(s: SQL) -> str:
replacements = {}
for key, value in s.bindings.items():
if isinstance(value, pd.DataFrame):
replacements[key] = f"df_{id(value)}"
elif isinstance(value, SQL):
replacements[key] = f"({sql_to_string(value)})"
elif isinstance(value, str):
replacements[key] = f"'{sqlescape(value)}'"
elif isinstance(value, (int, float, bool)):
replacements[key] = str(value)
elif value is None:
replacements[key] = "null"
else:
raise ValueError(f"Invalid type for {key}")
return Template(s.sql).safe_substitute(replacements)
def collect_dataframes(s: SQL) -> Mapping[str, pd.DataFrame]:
dataframes = {}
for key, value in s.bindings.items():
if isinstance(value, pd.DataFrame):
dataframes[f"df_{id(value)}"] = value
elif isinstance(value, SQL):
dataframes.update(collect_dataframes(value))
return dataframes
class DuckDB:
def __init__(self, options=""):
self.options = options
def query(self, select_statement: SQL):
db = connect(":memory:")
db.query("install httpfs; load httpfs;")
db.query(self.options)
dataframes = collect_dataframes(select_statement)
for key, value in dataframes.items():
db.register(key, value)
result = db.query(sql_to_string(select_statement))
if result is None:
return
return result.df()
class DuckPondIOManager(IOManager):
def __init__(self, bucket_name: str, duckdb: DuckDB, prefix=""):
self.bucket_name = bucket_name
self.duckdb = duckdb
self.prefix = prefix
def _get_s3_url(self, context):
if context.has_asset_key:
id = context.get_asset_identifier()
else:
id = context.get_identifier()
return f"s3://{self.bucket_name}/{self.prefix}{'/'.join(id)}.parquet"
def handle_output(self, context, select_statement: SQL):
if select_statement is None:
return
if not isinstance(select_statement, SQL):
raise ValueError(
f"Expected asset to return a SQL; got {select_statement!r}"
)
self.duckdb.query(
SQL(
"copy $select_statement to $url (format parquet)",
select_statement=select_statement,
url=self._get_s3_url(context),
)
)
def load_input(self, context) -> SQL:
return SQL("select * from read_parquet($url)", url=self._get_s3_url(context))