From a7b8cc08f09dfad0bd30ded79852ae45cc24a6c1 Mon Sep 17 00:00:00 2001
From: Alex Waygood <Alex.Waygood@Gmail.com>
Date: Tue, 10 Sep 2024 18:41:45 -0400
Subject: [PATCH] [red-knot] Fix `.to_instance()` for union types (#13319)

---
 crates/red_knot_python_semantic/src/types.rs  | 20 ++++++++++++++--
 .../src/types/infer.rs                        | 24 ++++++++++---------
 2 files changed, 31 insertions(+), 13 deletions(-)

diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs
index e61a0f4843fee0..093dd205ecb0d2 100644
--- a/crates/red_knot_python_semantic/src/types.rs
+++ b/crates/red_knot_python_semantic/src/types.rs
@@ -445,12 +445,28 @@ impl<'db> Type<'db> {
     }
 
     #[must_use]
-    pub fn to_instance(&self) -> Type<'db> {
+    pub fn to_instance(&self, db: &'db dyn Db) -> Type<'db> {
         match self {
             Type::Any => Type::Any,
             Type::Unknown => Type::Unknown,
+            Type::Unbound => Type::Unknown,
+            Type::Never => Type::Never,
             Type::Class(class) => Type::Instance(*class),
-            _ => Type::Unknown, // TODO type errors
+            Type::Union(union) => union.map(db, |element| element.to_instance(db)),
+            // TODO: we can probably do better here: --Alex
+            Type::Intersection(_) => Type::Unknown,
+            // TODO: calling `.to_instance()` on any of these should result in a diagnostic,
+            // since they already indicate that the object is an instance of some kind:
+            Type::BooleanLiteral(_)
+            | Type::BytesLiteral(_)
+            | Type::Function(_)
+            | Type::Instance(_)
+            | Type::Module(_)
+            | Type::IntLiteral(_)
+            | Type::StringLiteral(_)
+            | Type::Tuple(_)
+            | Type::LiteralString
+            | Type::None => Type::Unknown,
         }
     }
 
diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs
index 3dd1a4d1ff757c..30afbafc3c06bd 100644
--- a/crates/red_knot_python_semantic/src/types/infer.rs
+++ b/crates/red_knot_python_semantic/src/types/infer.rs
@@ -1457,9 +1457,11 @@ impl<'db> TypeInferenceBuilder<'db> {
             ast::Number::Int(n) => n
                 .as_i64()
                 .map(Type::IntLiteral)
-                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
-            ast::Number::Float(_) => builtins_symbol_ty(self.db, "float").to_instance(),
-            ast::Number::Complex { .. } => builtins_symbol_ty(self.db, "complex").to_instance(),
+                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)),
+            ast::Number::Float(_) => builtins_symbol_ty(self.db, "float").to_instance(self.db),
+            ast::Number::Complex { .. } => {
+                builtins_symbol_ty(self.db, "complex").to_instance(self.db)
+            }
         }
     }
 
@@ -1573,7 +1575,7 @@ impl<'db> TypeInferenceBuilder<'db> {
         }
 
         // TODO generic
-        builtins_symbol_ty(self.db, "list").to_instance()
+        builtins_symbol_ty(self.db, "list").to_instance(self.db)
     }
 
     fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> {
@@ -1584,7 +1586,7 @@ impl<'db> TypeInferenceBuilder<'db> {
         }
 
         // TODO generic
-        builtins_symbol_ty(self.db, "set").to_instance()
+        builtins_symbol_ty(self.db, "set").to_instance(self.db)
     }
 
     fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> {
@@ -1596,7 +1598,7 @@ impl<'db> TypeInferenceBuilder<'db> {
         }
 
         // TODO generic
-        builtins_symbol_ty(self.db, "dict").to_instance()
+        builtins_symbol_ty(self.db, "dict").to_instance(self.db)
     }
 
     /// Infer the type of the `iter` expression of the first comprehension.
@@ -2067,22 +2069,22 @@ impl<'db> TypeInferenceBuilder<'db> {
             (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n
                 .checked_add(m)
                 .map(Type::IntLiteral)
-                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
+                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)),
 
             (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n
                 .checked_sub(m)
                 .map(Type::IntLiteral)
-                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
+                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)),
 
             (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n
                 .checked_mul(m)
                 .map(Type::IntLiteral)
-                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
+                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)),
 
             (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Div) => n
                 .checked_div(m)
                 .map(Type::IntLiteral)
-                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
+                .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)),
 
             (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n
                 .checked_rem(m)
@@ -2311,7 +2313,7 @@ impl<'db> TypeInferenceBuilder<'db> {
                     name.ctx
                 );
 
-                self.infer_name_expression(name).to_instance()
+                self.infer_name_expression(name).to_instance(self.db)
             }
 
             ast::Expr::NoneLiteral(_literal) => Type::None,