Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): support qualified urn names in simple urn constructors #12426

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 57 additions & 21 deletions metadata-ingestion/scripts/avro_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
"""

for aspect in aspects:
className = f'{aspect["name"]}Class'
className = f"{aspect['name']}Class"
aspectName = aspect["Aspect"]["name"]
class_def_original = f"class {className}(DictWrapper):"

Expand All @@ -299,9 +299,9 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
schema_classes_lines[empty_line] = "\n"
schema_classes_lines[empty_line] += f"\n ASPECT_NAME = '{aspectName}'"
if "type" in aspect["Aspect"]:
schema_classes_lines[
empty_line
] += f"\n ASPECT_TYPE = '{aspect['Aspect']['type']}'"
schema_classes_lines[empty_line] += (
f"\n ASPECT_TYPE = '{aspect['Aspect']['type']}'"
)

aspect_info = {
k: v for k, v in aspect["Aspect"].items() if k not in {"name", "type"}
Expand All @@ -315,7 +315,7 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
schema_classes_lines.append(
f"""
ASPECT_CLASSES: List[Type[_Aspect]] = [
{f',{newline} '.join(f"{aspect['name']}Class" for aspect in aspects)}
{f",{newline} ".join(f"{aspect['name']}Class" for aspect in aspects)}
]

ASPECT_NAME_MAP: Dict[str, Type[_Aspect]] = {{
Expand All @@ -326,11 +326,11 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
from typing_extensions import TypedDict

class AspectBag(TypedDict, total=False):
{f'{newline} '.join(f"{aspect['Aspect']['name']}: {aspect['name']}Class" for aspect in aspects)}
{f"{newline} ".join(f"{aspect['Aspect']['name']}: {aspect['name']}Class" for aspect in aspects)}


KEY_ASPECTS: Dict[str, Type[_Aspect]] = {{
{f',{newline} '.join(f"'{aspect['Aspect']['keyForEntity']}': {aspect['name']}Class" for aspect in aspects if aspect['Aspect'].get('keyForEntity'))}
{f",{newline} ".join(f"'{aspect['Aspect']['keyForEntity']}': {aspect['name']}Class" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))}
}}
"""
)
Expand Down Expand Up @@ -546,6 +546,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
assert field_name(fields[0]) == "guid"
assert fields[0]["type"] == ["null", "string"]
fields[0]["type"] = "string"
arg_count = len(fields)

field_urn_type_classes = {}
for field in fields:
Expand All @@ -561,6 +562,12 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
field_urn_type_class = "Urn"

field_urn_type_classes[field_name(field)] = field_urn_type_class
if arg_count == 1:
field = fields[0]

if field_urn_type_classes[field_name(field)] is None:
# All single-arg urn types should accept themselves.
field_urn_type_classes[field_name(field)] = class_name

_init_arg_parts: List[str] = []
for field in fields:
Expand All @@ -579,7 +586,6 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:

super_init_args = ", ".join(field_name(field) for field in fields)

arg_count = len(fields)
parse_ids_mapping = ", ".join(
f"{field_name(field)}=entity_ids[{i}]" for i, field in enumerate(fields)
)
Expand All @@ -601,8 +607,26 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:

# Generalized mechanism for validating embedded urns.
field_urn_type_class = field_urn_type_classes[field_name(field)]
if field_urn_type_class:
init_validation += f"{field_name(field)} = str({field_name(field)})\n"
if field_urn_type_class and field_urn_type_class == class_name:
# First, we need to extract the main piece from the urn type.
init_validation += (
f"if isinstance({field_name(field)}, {field_urn_type_class}):\n"
f" {field_name(field)} = {field_name(field)}.{field_name(field)}\n"
)

# If it's still an urn type, that's a problem.
init_validation += (
f"elif isinstance({field_name(field)}, Urn):\n"
f" raise InvalidUrnError(f'Expecting a {field_urn_type_class} but got {{{field_name(field)}}}')\n"
)

# Then, we do character validation as normal.
init_validation += (
f"if UrnEncoder.contains_reserved_char({field_name(field)}):\n"
f" raise InvalidUrnError(f'{class_name} {field_name(field)} contains reserved characters')\n"
)
elif field_urn_type_class:
init_validation += f"{field_name(field)} = str({field_name(field)}) # convert urn type to str\n"
init_validation += (
f"assert {field_urn_type_class}.from_string({field_name(field)})\n"
)
Expand All @@ -611,17 +635,29 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
f"if UrnEncoder.contains_reserved_char({field_name(field)}):\n"
f" raise InvalidUrnError(f'{class_name} {field_name(field)} contains reserved characters')\n"
)
# TODO add ALL_ENV_TYPES validation

# Field coercion logic.
if field_name(field) == "env":
init_coercion += "env = env.upper()\n"
# TODO add ALL_ENV_TYPES validation
elif entity_type == "dataPlatform" and field_name(field) == "platform_name":
init_coercion += 'if platform_name.startswith("urn:li:dataPlatform:"):\n'
init_coercion += " platform_name = DataPlatformUrn.from_string(platform_name).platform_name\n"

if field_name(field) == "platform":
init_coercion += "platform = platform.urn() if isinstance(platform, DataPlatformUrn) else DataPlatformUrn(platform).urn()\n"
elif field_urn_type_class is None:
elif field_name(field) == "platform":
# For platform names in particular, we also qualify them when they don't have the prefix.
# We can rely on the DataPlatformUrn constructor to do this prefixing.
init_coercion += "platform = DataPlatformUrn(platform).urn()\n"
elif field_urn_type_class is not None:
# For urn types, we need to parse them into urn types where appropriate.
# Otherwise, we just need to encode special characters.
init_coercion += (
f"if isinstance({field_name(field)}, str):\n"
f" if {field_name(field)}.startswith('urn:li:'):\n"
f" try:\n"
f" {field_name(field)} = {field_urn_type_class}.from_string({field_name(field)})\n"
f" except InvalidUrnError:\n"
f" raise InvalidUrnError(f'Expecting a {field_urn_type_class} but got {{{field_name(field)}}}')\n"
f" else:\n"
f" {field_name(field)} = UrnEncoder.encode_string({field_name(field)})\n"
)
else:
# For all non-urns, run the value through the UrnEncoder.
init_coercion += (
f"{field_name(field)} = UrnEncoder.encode_string({field_name(field)})\n"
Expand All @@ -642,10 +678,10 @@ class {class_name}(_SpecificUrn):
def __init__(self, {init_args}, *, _allow_coercion: bool = True) -> None:
if _allow_coercion:
# Field coercion logic (if any is required).
{textwrap.indent(init_coercion.strip(), prefix=" "*4*3)}
{textwrap.indent(init_coercion.strip(), prefix=" " * 4 * 3)}

# Validation logic.
{textwrap.indent(init_validation.strip(), prefix=" "*4*2)}
{textwrap.indent(init_validation.strip(), prefix=" " * 4 * 2)}

super().__init__(self.ENTITY_TYPE, [{super_init_args}])

Expand Down Expand Up @@ -729,7 +765,7 @@ def generate(
and aspect["Aspect"]["keyForEntity"] != entity.name
):
raise ValueError(
f'Entity key {entity.keyAspect} is used by {aspect["Aspect"]["keyForEntity"]} and {entity.name}'
f"Entity key {entity.keyAspect} is used by {aspect['Aspect']['keyForEntity']} and {entity.name}"
)

# Also require that the aspect list is deduplicated.
Expand Down
14 changes: 14 additions & 0 deletions metadata-ingestion/tests/unit/urns/test_urn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,18 @@ def test_urn_coercion() -> None:
def test_urns_in_init() -> None:
platform = DataPlatformUrn("abc")
assert platform.urn() == "urn:li:dataPlatform:abc"
assert platform == DataPlatformUrn(platform)
assert platform == DataPlatformUrn(platform.urn())

dataset_urn = DatasetUrn(platform, "def", "PROD")
assert dataset_urn.urn() == "urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD)"
assert dataset_urn == DatasetUrn(platform.urn(), "def", "PROD")
assert dataset_urn == DatasetUrn(platform.platform_name, "def", "PROD")

with pytest.raises(
InvalidUrnError, match="Expecting a DataPlatformUrn but got .*dataset.*"
):
assert dataset_urn == DatasetUrn(dataset_urn, "def", "PROD") # type: ignore

schema_field = SchemaFieldUrn(dataset_urn, "foo")
assert (
Expand Down Expand Up @@ -101,6 +110,11 @@ def test_urn_type_dispatch_2() -> None:
with pytest.raises(InvalidUrnError, match="Passed an urn of type dataJob"):
CorpUserUrn.from_string(urn)

with pytest.raises(
InvalidUrnError, match="Expecting a CorpUserUrn but got.*dataJob.*"
hsheth2 marked this conversation as resolved.
Show resolved Hide resolved
):
CorpUserUrn(urn) # type: ignore


def test_urn_type_dispatch_3() -> None:
# Creating a "generic" Urn.
Expand Down
Loading