diff --git a/dagster_mssql_bcp_tests/bcp_polars/test_bcp.py b/dagster_mssql_bcp_tests/bcp_polars/test_bcp.py index 5664167..ad08f53 100644 --- a/dagster_mssql_bcp_tests/bcp_polars/test_bcp.py +++ b/dagster_mssql_bcp_tests/bcp_polars/test_bcp.py @@ -249,7 +249,7 @@ def test_replace_values(self, polars_io): {"a": ["1,000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"]} ) expected = df = pl.DataFrame( - {"a": ["1000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"]} + {"a": ["1000", "2", "3"], "b": [4000, 5, 6], "c": ["a", "b", "c"], 'should_process_replacements': [0, 0, 0]} ) schema = polars_mssql_bcp.AssetSchema( [ @@ -264,7 +264,7 @@ def test_replace_values(self, polars_io): df = pl.DataFrame( {"c": ["nan", "NAN", "c", "abc\tdef", "abc\t\ndef", "abc\ndef", "nan", "somenanthing"]} ) - expected = df = pl.DataFrame( + expected = pl.DataFrame( { "c": [ "", @@ -275,6 +275,9 @@ def test_replace_values(self, polars_io): "abc__NEWLINE__def", "", "somenanthing" + ], + 'should_process_replacements': [ + 0, 0, 0, 1, 1, 1, 0, 0 ] } ) @@ -304,6 +307,7 @@ def test_replace_values(self, polars_io): # "2021-01-01 00:00:00-05:00", ], "d": ["2021-01-01 00:00:00-05:00", "2021-01-01 00:00:00-05:00"], + "should_process_replacements": [0, 0] } ) diff --git a/src/dagster_mssql_bcp/bcp_core/bcp_core.py b/src/dagster_mssql_bcp/bcp_core/bcp_core.py index 63ec336..0904d53 100644 --- a/src/dagster_mssql_bcp/bcp_core/bcp_core.py +++ b/src/dagster_mssql_bcp/bcp_core/bcp_core.py @@ -303,11 +303,12 @@ def _pre_bcp_stage( frame_columns, asset_schema.get_columns(), sql_structure ) - # Filter columns that are not in the json schema (evolution) + # Filter columns that are not in the json schema (evolution) data = self._filter_columns(data, asset_schema.get_columns(True)) sql_structure = sql_structure or frame_columns data = self._reorder_columns(data, sql_structure) + data = self._add_replacement_flag_column(data) if process_replacements: data = self._replace_values(data, asset_schema) if process_datetime: @@ -466,7 +467,7 @@ def _create_target_tables( connection, schema, staging_table, - asset_schema.get_sql_columns(True), + asset_schema.get_sql_columns(True) + ["should_process_replacements BIT"], ) @abstractmethod @@ -922,6 +923,8 @@ def _replace_temporary_tab_newline( UPDATE {schema}.{table} SET {set_columns} + WHERE + should_process_replacements = 1 """ update_sql_str = update_sql.format( @@ -1006,3 +1009,11 @@ def _calculate_row_hash( connection.execute(text(update_sql)) # endregion + + @abstractmethod + def _add_replacement_flag_column(self, data): + """ + Adds a bit column, `should_replace`, to indicate if that row should have the REPLACE applied. + Replace is applied for tabs and new lines only + """ + raise NotImplementedError \ No newline at end of file diff --git a/src/dagster_mssql_bcp/bcp_pandas/pandas_mssql_bcp.py b/src/dagster_mssql_bcp/bcp_pandas/pandas_mssql_bcp.py index 96228ba..4789b29 100644 --- a/src/dagster_mssql_bcp/bcp_pandas/pandas_mssql_bcp.py +++ b/src/dagster_mssql_bcp/bcp_pandas/pandas_mssql_bcp.py @@ -12,7 +12,6 @@ from dagster_mssql_bcp.bcp_core import AssetSchema, BCPCore - class PandasBCP(BCPCore): def _add_meta_columns( self, @@ -122,13 +121,17 @@ def _filter_columns(self, data: pd.DataFrame, columns: list[str]): def _rename_columns(self, data: pd.DataFrame, columns: dict) -> pd.DataFrame: return data.rename(columns=columns) - - def _add_identity_columns(self, data: pd.DataFrame, asset_schema: AssetSchema) -> pd.DataFrame: + def _add_identity_columns( + self, data: pd.DataFrame, asset_schema: AssetSchema + ) -> pd.DataFrame: ident_cols = asset_schema.get_identity_columns() - missing_idents = [ - _ for _ in ident_cols if _ not in data.columns - ] + missing_idents = [_ for _ in ident_cols if _ not in data.columns] for _ in missing_idents: data[_] = None - - return data \ No newline at end of file + + return data + + def _add_replacement_flag_column(self, data: pd.DataFrame): + # we just set this to 1 to force all rows to participate + data["should_process_replacements"] = 1 + return data diff --git a/src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py b/src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py index 787857a..b84550d 100644 --- a/src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py +++ b/src/dagster_mssql_bcp/bcp_polars/polars_mssql_bcp.py @@ -1,6 +1,5 @@ from pathlib import Path - import pendulum try: @@ -52,6 +51,27 @@ def _replace_values(self, data: pl.LazyFrame, asset_schema: AssetSchema): if _ in asset_schema.get_numeric_columns() ] + string_cols = data.select(cs.by_dtype(pl.String)).collect_schema().names() + + if len(string_cols) > 0: + # calculates only the rows that have replacements + data = data.with_columns( + [ + pl.col(_) + .str.contains("(\t)|(\n)") + .alias(f"{_}__bcp__has_replacement_values") + for _ in string_cols + ] + ) + + data = data.with_columns( + pl.any_horizontal( + [f"{_}__bcp__has_replacement_values" for _ in string_cols] + ).alias("should_process_replacements") + ) + + data = data.drop([f"{_}__bcp__has_replacement_values" for _ in string_cols]) + data = data.with_columns( [ pl.col(_) @@ -59,7 +79,7 @@ def _replace_values(self, data: pl.LazyFrame, asset_schema: AssetSchema): .str.replace_all("\n", "__NEWLINE__") .str.replace_all("^nan$", "") .str.replace_all("^NAN$", "") - for _ in data.select(cs.by_dtype(pl.String)).collect_schema().names() + for _ in string_cols if _ not in number_columns_that_are_strings ] + [ @@ -69,7 +89,10 @@ def _replace_values(self, data: pl.LazyFrame, asset_schema: AssetSchema): .str.replace_all("^NAN$", "") for _ in number_columns_that_are_strings ] - + [pl.col(_).cast(pl.Int64) for _ in data.select(cs.boolean()).collect_schema().names()] + + [ + pl.col(_).cast(pl.Int64) + for _ in data.select(cs.boolean()).collect_schema().names() + ] ) return data @@ -129,7 +152,9 @@ def _process_datetime( def _reorder_columns(self, data: pl.LazyFrame, column_list: list[str]): """Reorder the data frame to match the order of the columns in the SQL table.""" - column_list = [column for column in column_list if column in data.collect_schema().names()] + column_list = [ + column for column in column_list if column in data.collect_schema().names() + ] return data.select(column_list) def _save_csv(self, data: pl.LazyFrame, path: Path, file_name: str): @@ -155,9 +180,15 @@ def _add_identity_columns( self, data: pl.LazyFrame, asset_schema: AssetSchema ) -> pl.LazyFrame: ident_cols = asset_schema.get_identity_columns() - missing_idents = [_ for _ in ident_cols if _ not in data.collect_schema().names()] + missing_idents = [ + _ for _ in ident_cols if _ not in data.collect_schema().names() + ] data = data.with_columns([pl.lit(None).alias(_) for _ in missing_idents]) return data def _pre_start_hook(self, data: pl.DataFrame): return data.lazy() + + def _add_replacement_flag_column(self, data: pl.DataFrame): + data = data.with_columns(pl.lit(0).alias("should_process_replacements")) + return data