Skip to content

Commit

Permalink
update schema
Browse files Browse the repository at this point in the history
  • Loading branch information
cody-scott committed Sep 19, 2024
1 parent 1fbf8be commit 27c0af8
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 6 deletions.
42 changes: 39 additions & 3 deletions dagster_mssql_bcp_tests/bcp_core/test_asset_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def test_resolve_name(self):
def test_rename_columns(self):
schema = asset_schema.AssetSchema(
[
{"name": "a", "alias": "b"},
{"name": "c", "alias": "d"},
{"name": "e"},
{"name": "a", "alias": "b", "type": "NVARCHAR"},
{"name": "c", "alias": "d", "type": "NVARCHAR"},
{"name": "e", "type": "NVARCHAR"},
]
)
result = schema.get_rename_dict()
Expand Down Expand Up @@ -312,3 +312,39 @@ def test_get_identity_columns(self):
)

assert schema.get_identity_columns() == ["b"]

def test_validate_schema(self):
good = [
{"name": "a", "type": "BIGINT"},
{"name": "b", "type": "BIGINT", "identity": True},
]
asset_schema.AssetSchema(good)

duplicate_columns = [
{"name": "a", "type": "BIGINT"},
{"name": "a", "type": "BIGINT"},
]
with pytest.raises(ValueError) as ae:
asset_schema.AssetSchema(duplicate_columns)
assert "Duplicate column name: a" in str(ae.value)

missing_name = [
{"type": "BIGINT"},
]
with pytest.raises(ValueError) as ae:
asset_schema.AssetSchema(missing_name)
assert "Column name not provided for column: {'type': 'BIGINT'}" in str(ae.value)

missing_type = [
{"name": "a"},
]
with pytest.raises(ValueError) as ae:
asset_schema.AssetSchema(missing_type)
assert "Column type not provided for column: a" in str(ae.value)

invalid_type = [
{"name": "a", "type": "INVALID"},
]
with pytest.raises(ValueError) as ae:
asset_schema.AssetSchema(invalid_type)
assert "Invalid data type: INVALID" in str(ae.value)
50 changes: 49 additions & 1 deletion dagster_mssql_bcp_tests/bcp_pandas/test_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,52 @@ def my_asset(context):
materialize(
assets=[my_asset],
resources={"io_manager": io_manager},
)
)

# def test_geo(self):

# schema = 'dbo'
# table = 'geo_table'
# drop = f"""DROP TABLE IF EXISTS {schema}.{table}"""

# with self.connect_mssql() as connection:
# connection.execute(text(drop))

# io_manager = self.io()

# asset_schema = [
# {"name": "a", "alias": "a", "type": "INT", "identity": True},
# {"name": "b", "type": "VARBINARY"},
# ]

# @asset(
# name=table,
# metadata={
# "asset_schema": asset_schema,
# 'add_row_hash': False,
# 'add_load_datetime': False,
# 'add_load_uuid': False
# },
# )
# def my_asset(context):
# import geopandas as gpd
# from shapely.geometry import Point
# d = {'geo': ['name1', 'name2'], 'geometry': [Point(1, 2), Point(2, 1)]}
# gdf = gpd.GeoDataFrame(d, crs="EPSG:4326")
# gdf['b'] = gdf['geometry'].to_wkb(True)
# return gdf

# @asset(
# deps=[my_asset]
# )
# def convert_geo(context):
# with self.connect_mssql() as conn:
# sql = f'SELECT b, geography::STGeomFromWKB(b, 4326) FROM {schema}.{table}'
# print(sql)
# conn.exec_driver_sql(sql)


# materialize(
# assets=[my_asset, convert_geo],
# resources={"io_manager": io_manager},
# )
32 changes: 30 additions & 2 deletions src/dagster_mssql_bcp/bcp_core/asset_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class AssetSchema:
decimal_column_types = ["NUMERIC", "DECIMAL"]
money_column_types = ["MONEY"]
xml_column_types = ["XML"]
binary_column_types = ['BINARY', 'VARBINARY']

number_column_types = (
int_column_types
Expand All @@ -79,10 +80,12 @@ class AssetSchema:
+ decimal_column_types
+ money_column_types
+ xml_column_types
+ binary_column_types
)

def __init__(self, schema: list[dict]):
self.schema = schema
self.validate_asset_schema()

def __eq__(self, value: "AssetSchema") -> bool:
for _ in self.schema:
Expand All @@ -94,6 +97,27 @@ def __eq__(self, value: "AssetSchema") -> bool:
return False
return True

def validate_asset_schema(self):
columns = {}
for column in self.schema:
column_name = column.get("name")
if column_name is None:
raise ValueError(f"Column name not provided for column: {column}")

column_counter = columns.get(column["name"], 0)
if column_counter > 0:
raise ValueError(f"Duplicate column name: {column['name']}")

columns[column_name] = 1

column_type = column.get("type", None)
if column_type is None:
raise ValueError(f"Column type not provided for column: {column['name']}")


if column_type not in self.allowed_types:
raise ValueError(f"Invalid data type: {column['type']}")

@staticmethod
def _resolve_name(column: dict):
return column.get("alias", column.get("name"))
Expand Down Expand Up @@ -158,7 +182,6 @@ def get_identity_columns(self) -> list[str]:
if column.get("identity", False) is True
]

def validate_asset_schema(self): ...

def get_sql_columns(self) -> list[str]:
columns = []
Expand All @@ -172,7 +195,7 @@ def get_sql_columns(self) -> list[str]:
if data_type not in self.allowed_types:
raise ValueError(f"Invalid data type: {data_type}")

if data_type in self.text_column_types:
if data_type in (self.text_column_types + self.binary_column_types):
length = data.get("length", "MAX")
to_add = f"{column_name} {data_type}({length})"
elif data_type in self.decimal_column_types:
Expand Down Expand Up @@ -235,6 +258,11 @@ def get_asset_schema_from_db(
base_result["precision"] = precision
if scale is not None:
base_result["scale"] = scale
elif data_type in AssetSchema.binary_column_types:
if str_length is not None:
base_result["length"] = str_length
else:
base_result["length"] = 'MAX'

result_schema.append(base_result)

Expand Down

0 comments on commit 27c0af8

Please sign in to comment.