diff --git a/metadata-ingestion/scripts/avro_codegen.py b/metadata-ingestion/scripts/avro_codegen.py
index 0fe79a2c6a8e4..7e75cba983381 100644
--- a/metadata-ingestion/scripts/avro_codegen.py
+++ b/metadata-ingestion/scripts/avro_codegen.py
@@ -277,7 +277,7 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
 """
 
     for aspect in aspects:
-        className = f'{aspect["name"]}Class'
+        className = f"{aspect['name']}Class"
         aspectName = aspect["Aspect"]["name"]
         class_def_original = f"class {className}(DictWrapper):"
 
@@ -299,9 +299,9 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
         schema_classes_lines[empty_line] = "\n"
         schema_classes_lines[empty_line] += f"\n    ASPECT_NAME = '{aspectName}'"
         if "type" in aspect["Aspect"]:
-            schema_classes_lines[
-                empty_line
-            ] += f"\n    ASPECT_TYPE = '{aspect['Aspect']['type']}'"
+            schema_classes_lines[empty_line] += (
+                f"\n    ASPECT_TYPE = '{aspect['Aspect']['type']}'"
+            )
 
         aspect_info = {
             k: v for k, v in aspect["Aspect"].items() if k not in {"name", "type"}
@@ -315,7 +315,7 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
     schema_classes_lines.append(
         f"""
 ASPECT_CLASSES: List[Type[_Aspect]] = [
-    {f',{newline}    '.join(f"{aspect['name']}Class" for aspect in aspects)}
+    {f",{newline}    ".join(f"{aspect['name']}Class" for aspect in aspects)}
 ]
 
 ASPECT_NAME_MAP: Dict[str, Type[_Aspect]] = {{
@@ -326,11 +326,11 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None:
 from typing_extensions import TypedDict
 
 class AspectBag(TypedDict, total=False):
-    {f'{newline}    '.join(f"{aspect['Aspect']['name']}: {aspect['name']}Class" for aspect in aspects)}
+    {f"{newline}    ".join(f"{aspect['Aspect']['name']}: {aspect['name']}Class" for aspect in aspects)}
 
 
 KEY_ASPECTS: Dict[str, Type[_Aspect]] = {{
-    {f',{newline}    '.join(f"'{aspect['Aspect']['keyForEntity']}': {aspect['name']}Class" for aspect in aspects if aspect['Aspect'].get('keyForEntity'))}
+    {f",{newline}    ".join(f"'{aspect['Aspect']['keyForEntity']}': {aspect['name']}Class" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))}
 }}
 """
     )
@@ -546,6 +546,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
         assert field_name(fields[0]) == "guid"
         assert fields[0]["type"] == ["null", "string"]
         fields[0]["type"] = "string"
+    arg_count = len(fields)
 
     field_urn_type_classes = {}
     for field in fields:
@@ -561,6 +562,12 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
                 field_urn_type_class = "Urn"
 
         field_urn_type_classes[field_name(field)] = field_urn_type_class
+    if arg_count == 1:
+        field = fields[0]
+
+        if field_urn_type_classes[field_name(field)] is None:
+            # All single-arg urn types should accept themselves.
+            field_urn_type_classes[field_name(field)] = class_name
 
     _init_arg_parts: List[str] = []
     for field in fields:
@@ -579,7 +586,6 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
 
     super_init_args = ", ".join(field_name(field) for field in fields)
 
-    arg_count = len(fields)
     parse_ids_mapping = ", ".join(
         f"{field_name(field)}=entity_ids[{i}]" for i, field in enumerate(fields)
     )
@@ -601,8 +607,26 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
 
         # Generalized mechanism for validating embedded urns.
         field_urn_type_class = field_urn_type_classes[field_name(field)]
-        if field_urn_type_class:
-            init_validation += f"{field_name(field)} = str({field_name(field)})\n"
+        if field_urn_type_class and field_urn_type_class == class_name:
+            # First, we need to extract the main piece from the urn type.
+            init_validation += (
+                f"if isinstance({field_name(field)}, {field_urn_type_class}):\n"
+                f"    {field_name(field)} = {field_name(field)}.{field_name(field)}\n"
+            )
+
+            # If it's still an urn type, that's a problem.
+            init_validation += (
+                f"elif isinstance({field_name(field)}, Urn):\n"
+                f"    raise InvalidUrnError(f'Expecting a {field_urn_type_class} but got {{{field_name(field)}}}')\n"
+            )
+
+            # Then, we do character validation as normal.
+            init_validation += (
+                f"if UrnEncoder.contains_reserved_char({field_name(field)}):\n"
+                f"    raise InvalidUrnError(f'{class_name} {field_name(field)} contains reserved characters')\n"
+            )
+        elif field_urn_type_class:
+            init_validation += f"{field_name(field)} = str({field_name(field)})  # convert urn type to str\n"
             init_validation += (
                 f"assert {field_urn_type_class}.from_string({field_name(field)})\n"
             )
@@ -611,17 +635,29 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str:
                 f"if UrnEncoder.contains_reserved_char({field_name(field)}):\n"
                 f"    raise InvalidUrnError(f'{class_name} {field_name(field)} contains reserved characters')\n"
             )
+        # TODO add ALL_ENV_TYPES validation
 
+        # Field coercion logic.
         if field_name(field) == "env":
             init_coercion += "env = env.upper()\n"
-        # TODO add ALL_ENV_TYPES validation
-        elif entity_type == "dataPlatform" and field_name(field) == "platform_name":
-            init_coercion += 'if platform_name.startswith("urn:li:dataPlatform:"):\n'
-            init_coercion += "    platform_name = DataPlatformUrn.from_string(platform_name).platform_name\n"
-
-        if field_name(field) == "platform":
-            init_coercion += "platform = platform.urn() if isinstance(platform, DataPlatformUrn) else DataPlatformUrn(platform).urn()\n"
-        elif field_urn_type_class is None:
+        elif field_name(field) == "platform":
+            # For platform names in particular, we also qualify them when they don't have the prefix.
+            # We can rely on the DataPlatformUrn constructor to do this prefixing.
+            init_coercion += "platform = DataPlatformUrn(platform).urn()\n"
+        elif field_urn_type_class is not None:
+            # For urn types, we need to parse them into urn types where appropriate.
+            # Otherwise, we just need to encode special characters.
+            init_coercion += (
+                f"if isinstance({field_name(field)}, str):\n"
+                f"    if {field_name(field)}.startswith('urn:li:'):\n"
+                f"        try:\n"
+                f"            {field_name(field)} = {field_urn_type_class}.from_string({field_name(field)})\n"
+                f"        except InvalidUrnError:\n"
+                f"            raise InvalidUrnError(f'Expecting a {field_urn_type_class} but got {{{field_name(field)}}}')\n"
+                f"    else:\n"
+                f"        {field_name(field)} = UrnEncoder.encode_string({field_name(field)})\n"
+            )
+        else:
             # For all non-urns, run the value through the UrnEncoder.
             init_coercion += (
                 f"{field_name(field)} = UrnEncoder.encode_string({field_name(field)})\n"
@@ -642,10 +678,10 @@ class {class_name}(_SpecificUrn):
     def __init__(self, {init_args}, *, _allow_coercion: bool = True) -> None:
         if _allow_coercion:
             # Field coercion logic (if any is required).
-{textwrap.indent(init_coercion.strip(), prefix=" "*4*3)}
+{textwrap.indent(init_coercion.strip(), prefix=" " * 4 * 3)}
 
         # Validation logic.
-{textwrap.indent(init_validation.strip(), prefix=" "*4*2)}
+{textwrap.indent(init_validation.strip(), prefix=" " * 4 * 2)}
 
         super().__init__(self.ENTITY_TYPE, [{super_init_args}])
 
@@ -729,7 +765,7 @@ def generate(
             and aspect["Aspect"]["keyForEntity"] != entity.name
         ):
             raise ValueError(
-                f'Entity key {entity.keyAspect} is used by {aspect["Aspect"]["keyForEntity"]} and {entity.name}'
+                f"Entity key {entity.keyAspect} is used by {aspect['Aspect']['keyForEntity']} and {entity.name}"
             )
 
         # Also require that the aspect list is deduplicated.
diff --git a/metadata-ingestion/tests/unit/urns/test_urn.py b/metadata-ingestion/tests/unit/urns/test_urn.py
index 8490364326d94..f9ed2b387f078 100644
--- a/metadata-ingestion/tests/unit/urns/test_urn.py
+++ b/metadata-ingestion/tests/unit/urns/test_urn.py
@@ -10,6 +10,7 @@
     DataPlatformUrn,
     DatasetUrn,
     SchemaFieldUrn,
+    TagUrn,
     Urn,
 )
 from datahub.testing.doctest import assert_doctest
@@ -71,9 +72,18 @@ def test_urn_coercion() -> None:
 def test_urns_in_init() -> None:
     platform = DataPlatformUrn("abc")
     assert platform.urn() == "urn:li:dataPlatform:abc"
+    assert platform == DataPlatformUrn(platform)
+    assert platform == DataPlatformUrn(platform.urn())
 
     dataset_urn = DatasetUrn(platform, "def", "PROD")
     assert dataset_urn.urn() == "urn:li:dataset:(urn:li:dataPlatform:abc,def,PROD)"
+    assert dataset_urn == DatasetUrn(platform.urn(), "def", "PROD")
+    assert dataset_urn == DatasetUrn(platform.platform_name, "def", "PROD")
+
+    with pytest.raises(
+        InvalidUrnError, match="Expecting a DataPlatformUrn but got .*dataset.*"
+    ):
+        assert dataset_urn == DatasetUrn(dataset_urn, "def", "PROD")  # type: ignore
 
     schema_field = SchemaFieldUrn(dataset_urn, "foo")
     assert (
@@ -101,6 +111,11 @@ def test_urn_type_dispatch_2() -> None:
     with pytest.raises(InvalidUrnError, match="Passed an urn of type dataJob"):
         CorpUserUrn.from_string(urn)
 
+    with pytest.raises(
+        InvalidUrnError, match="Expecting a CorpUserUrn but got .*dataJob.*"
+    ):
+        CorpUserUrn(urn)  # type: ignore
+
 
 def test_urn_type_dispatch_3() -> None:
     # Creating a "generic" Urn.
@@ -133,6 +148,24 @@ def test_urn_type_dispatch_4() -> None:
     assert urn2.urn() == urn_str
 
 
+def test_urn_from_urn_simple() -> None:
+    # This capability is also tested by a bunch of other tests above.
+
+    tag_str = "urn:li:tag:legacy"
+    tag = TagUrn.from_string(tag_str)
+    assert tag_str == tag.urn()
+    assert tag.name == "legacy"
+    assert tag == TagUrn(tag)
+    assert tag == TagUrn(tag.urn())
+
+
+def test_urn_from_urn_tricky() -> None:
+    tag_str = "urn:li:tag:urn:li:tag:legacy"
+    tag = TagUrn(tag_str)
+    assert tag.urn() == tag_str
+    assert tag.name == "urn:li:tag:legacy"
+
+
 def test_urn_doctest() -> None:
     assert_doctest(datahub.utilities.urns._urn_base)