Skip to content

Commit

Permalink
further updates for lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
cody-scott committed Nov 1, 2024
1 parent ea34202 commit a3273b9
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 28 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ The user running the dagster pipeline must have the necessary permissions to loa

## Polars

Polars processes as a `LazyFrame`. Either a `DataFrame` or `LazyFrame` can be provided as an output of your asset before its cast automatically to `lazy`

```python
from dagster import asset, Definitions
from dagster_mssql_bcp import PolarsBCPIOManager
Expand Down Expand Up @@ -69,8 +71,19 @@ def my_polars_asset(context):
return pl.DataFrame({"id": [1, 2, 3]})


@asset(
metadata={
"asset_schema": [
{"name": "id", "type": "INT"},
],
"schema": "my_schema",
}
)
def my_polars_asset_lazy(context):
return pl.LazyFrame({"id": [1, 2, 3]})

defs = Definitions(
assets=[my_polars_asset],
assets=[my_polars_asset, my_polars_asset_lazy],
io_managers={
"io_manager": io_manager,
},
Expand Down Expand Up @@ -105,7 +118,7 @@ io_manager = PandasBCPIOManager(
"schema": "my_schema",
}
)
def my_polars_asset(context):
def my_pandas_asset(context):
return pd.DataFrame({"id": [1, 2, 3]})


Expand Down
2 changes: 1 addition & 1 deletion dagster_mssql_bcp_tests/bcp_polars/test_bcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def test_save_csv(self, polars_io):
pendulum.now(tz="America/Toronto"),
],
}
),
).lazy(),
dir,
"test.csv",
)
Expand Down
5 changes: 4 additions & 1 deletion src/dagster_mssql_bcp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from dagster_mssql_bcp.bcp_polars import PolarsBCPIOManager, PolarsBCP
from dagster_mssql_bcp.bcp_pandas import PandasBCPIOManager, PandasBCP

from dagster_mssql_bcp.bcp_core import AssetSchema

__all__ = [
"PolarsBCP",
"PolarsBCPIOManager",
"PandasBCP",
"PandasBCPIOManager",
]
"AssetSchema",
]
34 changes: 29 additions & 5 deletions src/dagster_mssql_bcp/bcp_core/bcp_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,28 @@ def __init__(
self.load_uuid_column_name = load_uuid_column_name
self.load_datetime_column_name = load_datetime_column_name

@property
def config(self):
return dict(
host=self.host,
port=self.port,
database=self.database,
username=self.username,
password=self.password,
add_row_hash=self.add_row_hash,
add_load_datetime=self.add_load_datetime,
add_load_uuid=self.add_load_uuid,
driver=self.driver,
query_props=self.query_props,
bcp_arguments=self.bcp_arguments,
bcp_path=self.bcp_path,
process_datetime=self.process_datetime,
process_replacements=self.process_replacements,
row_hash_column_name=self.row_hash_column_name,
load_uuid_column_name=self.load_uuid_column_name,
load_datetime_column_name=self.load_datetime_column_name,
)

@property
def connection_config(self):
"""
Expand Down Expand Up @@ -204,9 +226,10 @@ def load_bcp(
schema, table, asset_schema, staging_table, connection
)

data = self._pre_bcp_stage_pre_start_hook(
data = self._pre_start_hook(
data
)

data, schema_deltas = self._pre_bcp_stage(
connection,
data,
Expand All @@ -220,6 +243,7 @@ def load_bcp(
process_datetime,
process_replacements,
)

data = self._pre_bcp_stage_completed_hook(
data
)
Expand Down Expand Up @@ -321,7 +345,7 @@ def _post_bcp_stage(
with connect_mssql(connection_config_dict) as con:
# Validate loads (counts of tables match)
new_line_count = self._validate_bcp_load(
con, schema, staging_table, len(data)
con, schema, staging_table, None
)

if process_replacements:
Expand All @@ -343,7 +367,7 @@ def _post_bcp_stage(
def _pre_bcp_stage_completed_hook(self, dataframe):
return dataframe

def _pre_bcp_stage_pre_start_hook(self, dataframe):
def _pre_start_hook(self, dataframe):
return dataframe

def _parse_asset_schema(self, schema, table, asset_schema):
Expand Down Expand Up @@ -840,7 +864,7 @@ def _validate_bcp_load(
connection: Connection,
schema: str,
bcp_table: str,
row_count: int,
row_count: int | None = None,
):
"""
Validates the BCP load by comparing the row count in the specified table with the expected row count.
Expand All @@ -862,7 +886,7 @@ def _validate_bcp_load(
result = cursor.fetchone()
if result is None:
raise ValueError("No result from validation")
if result[0] != row_count:
if row_count is not None and result[0] != row_count:
raise ValueError("Validation failed")
return result[0]

Expand Down
38 changes: 25 additions & 13 deletions src/dagster_mssql_bcp/bcp_core/bcp_io_manager_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from dagster import ConfigurableIOManager, InputContext, OutputContext, get_dagster_logger

from abc import abstractmethod, ABC
from .asset_schema import AssetSchema
from .mssql_connection import connect_mssql
from .utils import get_cleanup_statement, get_select_statement

from .bcp_core import BCPCore

class BCPIOManagerCore(ConfigurableIOManager):
class BCPIOManagerCore(ConfigurableIOManager, ABC):
host: str
port: str
database: str
Expand All @@ -31,32 +32,41 @@ class BCPIOManagerCore(ConfigurableIOManager):
load_uuid_column_name: str = "load_uuid"
load_datetime_column_name: str = "load_datetime"

def load_input(self, context: InputContext):
raise NotImplementedError

def handle_output(self, context: OutputContext, obj):
if obj is None:
get_dagster_logger().info("No data to load")
return

bcp_manager = self.get_bcp(
@property
def config(self):
return dict(
host=self.host,
port=self.port,
database=self.database,
username=self.username,
password=self.password,
driver=self.driver,
bcp_arguments=self.bcp_arguments,
query_props=self.query_props,
add_row_hash=self.add_row_hash,
add_load_datetime=self.add_load_datetime,
add_load_uuid=self.add_load_uuid,
driver=self.driver,
query_props=self.query_props,
bcp_arguments=self.bcp_arguments,
bcp_path=self.bcp_path,
process_datetime=self.process_datetime,
process_replacements=self.process_replacements,
row_hash_column_name=self.row_hash_column_name,
load_uuid_column_name=self.load_uuid_column_name,
load_datetime_column_name=self.load_datetime_column_name,
)

@abstractmethod
def load_input(self, context: InputContext):
raise NotImplementedError

def handle_output(self, context: OutputContext, obj):
if obj is None:
get_dagster_logger().info("No data to load")
return

bcp_manager = self.get_bcp(
**self.config
)

metadata = (
context.definition_metadata
if context.definition_metadata is not None
Expand Down Expand Up @@ -155,10 +165,12 @@ def handle_output(self, context: OutputContext, obj):
| deltas
)

@abstractmethod
def check_empty(self, obj) -> bool:
"""Checks if frame is empty"""
raise NotImplementedError

@abstractmethod
def get_bcp(self, *args, **kwargs) -> BCPCore:
"""Returns an instance of the BCP class for the given connection details."""
raise NotImplementedError
9 changes: 3 additions & 6 deletions src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def _reorder_columns(self, data: pl.LazyFrame, column_list: list[str]):

def _save_csv(self, data: pl.LazyFrame, path: Path, file_name: str):
path = Path(path)
data.write_csv(
file=path / file_name,
data.sink_csv(
path=path / file_name,
line_terminator="\n",
separator="\t",
)
Expand All @@ -159,8 +159,5 @@ def _add_identity_columns(
data = data.with_columns([pl.lit(None).alias(_) for _ in missing_idents])
return data

def _pre_bcp_stage_completed_hook(self, data: pl.LazyFrame):
return data.collect()

def _pre_bcp_stage_pre_start_hook(self, data: pl.DataFrame):
def _pre_start_hook(self, data: pl.DataFrame):
return data.lazy()

0 comments on commit a3273b9

Please sign in to comment.