From a49de788b69a5453049dd7de5637886f4fa7b850 Mon Sep 17 00:00:00 2001
From: Ivan Levkivskyi <levkivskyi@gmail.com>
Date: Sat, 21 Jan 2023 21:00:49 +0000
Subject: [PATCH] Fix crash in astdiff and clean it up

---
 mypy/server/astdiff.py           | 30 ++++++++++++++++++------------
 mypy/server/update.py            | 10 +++++++---
 test-data/unit/fine-grained.test | 28 ++++++++++++++++++++++++++++
 3 files changed, 53 insertions(+), 15 deletions(-)

diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py
index 97f811384d37..e58174d49ff5 100644
--- a/mypy/server/astdiff.py
+++ b/mypy/server/astdiff.py
@@ -52,7 +52,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
 
 from __future__ import annotations
 
-from typing import Sequence, Tuple, cast
+from typing import Sequence, Tuple, Union, cast
 from typing_extensions import TypeAlias as _TypeAlias
 
 from mypy.expandtype import expand_type
@@ -109,11 +109,17 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
 # snapshots are immutable).
 #
 # For example, the snapshot of the 'int' type is ('Instance', 'builtins.int', ()).
-SnapshotItem: _TypeAlias = Tuple[object, ...]
+
+# Type snapshots are strict, they must be hashable and ordered (e.g. for Unions).
+Primitive: _TypeAlias = Union[str, float, int, bool]  # float is for Literal[3.14] support.
+SnapshotItem: _TypeAlias = Tuple[Union[Primitive, "SnapshotItem"], ...]
+
+# Symbol snapshots can be more lenient.
+SymbolSnapshot: _TypeAlias = Tuple[object, ...]
 
 
 def compare_symbol_table_snapshots(
-    name_prefix: str, snapshot1: dict[str, SnapshotItem], snapshot2: dict[str, SnapshotItem]
+    name_prefix: str, snapshot1: dict[str, SymbolSnapshot], snapshot2: dict[str, SymbolSnapshot]
 ) -> set[str]:
     """Return names that are different in two snapshots of a symbol table.
 
@@ -155,7 +161,7 @@ def compare_symbol_table_snapshots(
     return triggers
 
 
-def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SnapshotItem]:
+def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SymbolSnapshot]:
     """Create a snapshot description that represents the state of a symbol table.
 
     The snapshot has a representation based on nested tuples and dicts
@@ -165,7 +171,7 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sna
     things defined in other modules are represented just by the names of
     the targets.
     """
-    result: dict[str, SnapshotItem] = {}
+    result: dict[str, SymbolSnapshot] = {}
     for name, symbol in table.items():
         node = symbol.node
         # TODO: cross_ref?
@@ -206,7 +212,7 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sna
     return result
 
 
-def snapshot_definition(node: SymbolNode | None, common: tuple[object, ...]) -> tuple[object, ...]:
+def snapshot_definition(node: SymbolNode | None, common: SymbolSnapshot) -> SymbolSnapshot:
     """Create a snapshot description of a symbol table node.
 
     The representation is nested tuples and dicts. Only externally
@@ -289,11 +295,11 @@ def snapshot_type(typ: Type) -> SnapshotItem:
     return typ.accept(SnapshotTypeVisitor())
 
 
-def snapshot_optional_type(typ: Type | None) -> SnapshotItem | None:
+def snapshot_optional_type(typ: Type | None) -> SnapshotItem:
     if typ:
         return snapshot_type(typ)
     else:
-        return None
+        return ("<not set>",)
 
 
 def snapshot_types(types: Sequence[Type]) -> SnapshotItem:
@@ -395,7 +401,7 @@ def visit_parameters(self, typ: Parameters) -> SnapshotItem:
             "Parameters",
             snapshot_types(typ.arg_types),
             tuple(encode_optional_str(name) for name in typ.arg_names),
-            tuple(typ.arg_kinds),
+            tuple(k.value for k in typ.arg_kinds),
         )
 
     def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
@@ -406,7 +412,7 @@ def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
             snapshot_types(typ.arg_types),
             snapshot_type(typ.ret_type),
             tuple(encode_optional_str(name) for name in typ.arg_names),
-            tuple(typ.arg_kinds),
+            tuple(k.value for k in typ.arg_kinds),
             typ.is_type_obj(),
             typ.is_ellipsis_args,
             snapshot_types(typ.variables),
@@ -463,7 +469,7 @@ def visit_type_alias_type(self, typ: TypeAliasType) -> SnapshotItem:
         return ("TypeAliasType", typ.alias.fullname, snapshot_types(typ.args))
 
 
-def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> tuple[object, ...]:
+def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> SymbolSnapshot:
     """Create a snapshot of the signature of a function that has no explicit signature.
 
     If the arguments to a function without signature change, it must be
@@ -475,7 +481,7 @@ def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> tuple[obje
     if isinstance(func, FuncItem):
         return (tuple(func.arg_names), tuple(func.arg_kinds))
     else:
-        result = []
+        result: list[SymbolSnapshot] = []
         for item in func.items:
             if isinstance(item, Decorator):
                 if item.var.type:
diff --git a/mypy/server/update.py b/mypy/server/update.py
index 83cce22873a1..00b823c99dfd 100644
--- a/mypy/server/update.py
+++ b/mypy/server/update.py
@@ -151,7 +151,11 @@
     semantic_analysis_for_scc,
     semantic_analysis_for_targets,
 )
-from mypy.server.astdiff import SnapshotItem, compare_symbol_table_snapshots, snapshot_symbol_table
+from mypy.server.astdiff import (
+    SymbolSnapshot,
+    compare_symbol_table_snapshots,
+    snapshot_symbol_table,
+)
 from mypy.server.astmerge import merge_asts
 from mypy.server.aststrip import SavedAttributes, strip_target
 from mypy.server.deps import get_dependencies_of_target, merge_dependencies
@@ -417,7 +421,7 @@ def update_module(
 
         t0 = time.time()
         # Record symbol table snapshot of old version the changed module.
-        old_snapshots: dict[str, dict[str, SnapshotItem]] = {}
+        old_snapshots: dict[str, dict[str, SymbolSnapshot]] = {}
         if module in manager.modules:
             snapshot = snapshot_symbol_table(module, manager.modules[module].names)
             old_snapshots[module] = snapshot
@@ -751,7 +755,7 @@ def get_sources(
 
 def calculate_active_triggers(
     manager: BuildManager,
-    old_snapshots: dict[str, dict[str, SnapshotItem]],
+    old_snapshots: dict[str, dict[str, SymbolSnapshot]],
     new_modules: dict[str, MypyFile | None],
 ) -> set[str]:
     """Determine activated triggers by comparing old and new symbol tables.
diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test
index d4b2d3469871..bb3e53c6c244 100644
--- a/test-data/unit/fine-grained.test
+++ b/test-data/unit/fine-grained.test
@@ -10315,3 +10315,31 @@ a.py:3: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#varia
 a.py:4: note: Revealed type is "A?"
 ==
 a.py:4: note: Revealed type is "Union[builtins.str, builtins.int]"
+
+[case testUnionOfSimilarCallablesCrash]
+import b
+
+[file b.py]
+from a import x
+
+[file m.py]
+from typing import Union, TypeVar
+
+T = TypeVar("T")
+S = TypeVar("S")
+def foo(x: T, y: S) -> Union[T, S]: ...
+def f(x: int) -> int: ...
+def g(*x: int) -> int: ...
+
+[file a.py]
+from m import f, g, foo
+x = foo(f, g)
+
+[file a.py.2]
+from m import f, g, foo
+x = foo(f, g)
+reveal_type(x)
+[builtins fixtures/tuple.pyi]
+[out]
+==
+a.py:3: note: Revealed type is "Union[def (x: builtins.int) -> builtins.int, def (*x: builtins.int) -> builtins.int]"