diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index da8a88030b09e..c9b698a13ba04 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,8 +1,9 @@ use ruff_db::files::File; -use ruff_python_ast::name::Name; +use ruff_python_ast as ast; use crate::builtins::builtins_scope; -use crate::semantic_index::definition::Definition; +use crate::semantic_index::ast_ids::HasScopedAstId; +use crate::semantic_index::definition::{Definition, DefinitionKind}; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId}; use crate::semantic_index::{ global_scope, semantic_index, symbol_table, use_def_map, DefinitionWithConstraints, @@ -14,7 +15,8 @@ use crate::{Db, FxOrderSet}; pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder}; pub(crate) use self::diagnostic::TypeCheckDiagnostics; pub(crate) use self::infer::{ - infer_definition_types, infer_expression_types, infer_scope_types, TypeInference, + infer_deferred_types, infer_definition_types, infer_expression_types, infer_scope_types, + TypeInference, }; mod builder; @@ -88,6 +90,24 @@ pub(crate) fn definition_ty<'db>(db: &'db dyn Db, definition: Definition<'db>) - inference.definition_ty(definition) } +/// Infer the type of a (possibly deferred) sub-expression of a [`Definition`]. +/// +/// ## Panics +/// If the given expression is not a sub-expression of the given [`Definition`]. +pub(crate) fn definition_expression_ty<'db>( + db: &'db dyn Db, + definition: Definition<'db>, + expression: &ast::Expr, +) -> Type<'db> { + let expr_id = expression.scoped_ast_id(db, definition.scope(db)); + let inference = infer_definition_types(db, definition); + if let Some(ty) = inference.try_expression_ty(expr_id) { + ty + } else { + infer_deferred_types(db, definition).expression_ty(expr_id) + } +} + /// Infer the combined type of an array of [`Definition`]s, plus one optional "unbound type". /// /// Will return a union if there is more than one definition, or at least one plus an unbound @@ -243,7 +263,7 @@ impl<'db> Type<'db> { /// us to explicitly consider whether to handle an error or propagate /// it up the call stack. #[must_use] - pub fn member(&self, db: &'db dyn Db, name: &Name) -> Type<'db> { + pub fn member(&self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> { match self { Type::Any => Type::Any, Type::Never => { @@ -314,7 +334,7 @@ impl<'db> Type<'db> { #[salsa::interned] pub struct FunctionType<'db> { /// name of the function at definition - pub name: Name, + pub name: ast::name::Name, /// types of all decorators on this function decorators: Vec>, @@ -329,19 +349,33 @@ impl<'db> FunctionType<'db> { #[salsa::interned] pub struct ClassType<'db> { /// Name of the class at definition - pub name: Name, + pub name: ast::name::Name, - /// Types of all class bases - bases: Vec>, + definition: Definition<'db>, body_scope: ScopeId<'db>, } impl<'db> ClassType<'db> { + /// Return an iterator over the types of this class's bases. + /// + /// # Panics: + /// If `definition` is not a `DefinitionKind::Class`. + pub fn bases(&self, db: &'db dyn Db) -> impl Iterator> { + let definition = self.definition(db); + let DefinitionKind::Class(class_stmt_node) = definition.node(db) else { + panic!("Class type definition must have DefinitionKind::Class"); + }; + class_stmt_node + .bases() + .iter() + .map(move |base_expr| definition_expression_ty(db, definition, base_expr)) + } + /// Returns the class member of this class named `name`. /// /// The member resolves to a member of the class itself or any of its bases. - pub fn class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> { + pub fn class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> { let member = self.own_class_member(db, name); if !member.is_unbound() { return member; @@ -351,12 +385,12 @@ impl<'db> ClassType<'db> { } /// Returns the inferred type of the class member named `name`. - pub fn own_class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> { + pub fn own_class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> { let scope = self.body_scope(db); symbol_ty_by_name(db, scope, name) } - pub fn inherited_class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> { + pub fn inherited_class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> { for base in self.bases(db) { let member = base.member(db, name); if !member.is_unbound() { diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 6ab0ec82c9408..c0e2df24a5334 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1,4 +1,4 @@ -//! We have three Salsa queries for inferring types at three different granularities: scope-level, +//! We have Salsa queries for inferring types at three different granularities: scope-level, //! definition-level, and expression-level. //! //! Scope-level inference is for when we are actually checking a file, and need to check types for @@ -11,15 +11,21 @@ //! allows us to handle import cycles without getting into a cycle of scope-level inference //! queries. //! -//! The expression-level inference query is needed in only a few cases. Since an assignment -//! statement can have multiple targets (via `x = y = z` or unpacking `(x, y) = z`, it can be -//! associated with multiple definitions. In order to avoid inferring the type of the right-hand -//! side once per definition, we infer it as a standalone query, so its result will be cached by -//! Salsa. We also need the expression-level query for inferring types in type guard expressions -//! (e.g. the test clause of an `if` statement.) +//! The expression-level inference query is needed in only a few cases. Since some assignments can +//! have multiple targets (via `x = y = z` or unpacking `(x, y) = z`, they can be associated with +//! multiple definitions (one per assigned symbol). In order to avoid inferring the type of the +//! right-hand side once per definition, we infer it as a standalone query, so its result will be +//! cached by Salsa. We also need the expression-level query for inferring types in type guard +//! expressions (e.g. the test clause of an `if` statement.) //! //! Inferring types at any of the three region granularities returns a [`TypeInference`], which //! holds types for every [`Definition`] and expression within the inferred region. +//! +//! Some type expressions can require deferred evaluation. This includes all type expressions in +//! stub files, or annotation expressions in modules with `from __future__ import annotations`, or +//! stringified annotations. We have a fourth Salsa query for inferring the deferred types +//! associated with a particular definition. Scope-level inference infers deferred types for all +//! definitions once the rest of the types in the scope have been inferred. use std::num::NonZeroU32; use rustc_hash::FxHashMap; @@ -28,8 +34,7 @@ use salsa::plumbing::AsId; use ruff_db::files::File; use ruff_db::parsed::parsed_module; -use ruff_python_ast::{self as ast, UnaryOp}; -use ruff_python_ast::{AnyNodeRef, ExprContext}; +use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp}; use ruff_text_size::Ranged; use crate::builtins::builtins_scope; @@ -43,8 +48,8 @@ use crate::semantic_index::symbol::{FileScopeId, NodeWithScopeKind, NodeWithScop use crate::semantic_index::SemanticIndex; use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics}; use crate::types::{ - builtins_symbol_ty_by_name, definitions_ty, global_symbol_ty_by_name, BytesLiteralType, - ClassType, FunctionType, Name, StringLiteralType, Type, UnionBuilder, + builtins_symbol_ty_by_name, definitions_ty, global_symbol_ty_by_name, symbol_ty, + BytesLiteralType, ClassType, FunctionType, StringLiteralType, Type, UnionBuilder, }; use crate::Db; @@ -97,6 +102,28 @@ pub(crate) fn infer_definition_types<'db>( TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index).finish() } +/// Infer types for all deferred type expressions in a [`Definition`]. +/// +/// Deferred expressions are type expressions (annotations, base classes, aliases...) in a stub +/// file, or in a file with `from __future__ import annotations`, or stringified annotations. +#[salsa::tracked(return_ref)] +pub(crate) fn infer_deferred_types<'db>( + db: &'db dyn Db, + definition: Definition<'db>, +) -> TypeInference<'db> { + let file = definition.file(db); + let _span = tracing::trace_span!( + "infer_deferred_types", + definition = ?definition.as_id(), + file = %file.path(db) + ) + .entered(); + + let index = semantic_index(db, file); + + TypeInferenceBuilder::new(db, InferenceRegion::Deferred(definition), index).finish() +} + /// Infer all types for an [`Expression`] (including sub-expressions). /// Use rarely; only for cases where we'd otherwise risk double-inferring an expression: RHS of an /// assignment, which might be unpacking/multi-target and thus part of multiple definitions, or a @@ -119,8 +146,13 @@ pub(crate) fn infer_expression_types<'db>( /// A region within which we can infer types. pub(crate) enum InferenceRegion<'db> { + /// infer types for a standalone [`Expression`] Expression(Expression<'db>), + /// infer types for a [`Definition`] Definition(Definition<'db>), + /// infer deferred types for a [`Definition`] + Deferred(Definition<'db>), + /// infer types for an entire [`ScopeId`] Scope(ScopeId<'db>), } @@ -135,6 +167,9 @@ pub(crate) struct TypeInference<'db> { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, + + /// Are there deferred type expressions in this region? + has_deferred: bool, } impl<'db> TypeInference<'db> { @@ -142,6 +177,10 @@ impl<'db> TypeInference<'db> { self.expressions[&expression] } + pub(crate) fn try_expression_ty(&self, expression: ScopedExpressionId) -> Option> { + self.expressions.get(&expression).copied() + } + pub(crate) fn definition_ty(&self, definition: Definition<'db>) -> Type<'db> { self.definitions[&definition] } @@ -231,7 +270,9 @@ impl<'db> TypeInferenceBuilder<'db> { ) -> Self { let (file, scope) = match region { InferenceRegion::Expression(expression) => (expression.file(db), expression.scope(db)), - InferenceRegion::Definition(definition) => (definition.file(db), definition.scope(db)), + InferenceRegion::Definition(definition) | InferenceRegion::Deferred(definition) => { + (definition.file(db), definition.scope(db)) + } InferenceRegion::Scope(scope) => (scope.file(db), scope), }; @@ -251,6 +292,17 @@ impl<'db> TypeInferenceBuilder<'db> { self.types.definitions.extend(inference.definitions.iter()); self.types.expressions.extend(inference.expressions.iter()); self.types.diagnostics.extend(&inference.diagnostics); + self.types.has_deferred |= inference.has_deferred; + } + + /// Are we currently inferring types in a stub file? + fn is_stub(&self) -> bool { + self.file.is_stub(self.db.upcast()) + } + + /// Are we currently inferred deferred types? + fn is_deferred(&self) -> bool { + matches!(self.region, InferenceRegion::Deferred(_)) } /// Infers types in the given [`InferenceRegion`]. @@ -258,6 +310,7 @@ impl<'db> TypeInferenceBuilder<'db> { match self.region { InferenceRegion::Scope(scope) => self.infer_region_scope(scope), InferenceRegion::Definition(definition) => self.infer_region_definition(definition), + InferenceRegion::Deferred(definition) => self.infer_region_deferred(definition), InferenceRegion::Expression(expression) => self.infer_region_expression(expression), } } @@ -291,6 +344,20 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_generator_expression_scope(generator.node()); } } + + if self.types.has_deferred { + let mut deferred_expression_types: FxHashMap> = + FxHashMap::default(); + for definition in self.types.definitions.keys() { + if infer_definition_types(self.db, *definition).has_deferred { + let deferred = infer_deferred_types(self.db, *definition); + deferred_expression_types.extend(deferred.expressions.iter()); + } + } + self.types + .expressions + .extend(deferred_expression_types.iter()); + } } fn infer_region_definition(&mut self, definition: Definition<'db>) { @@ -351,6 +418,19 @@ impl<'db> TypeInferenceBuilder<'db> { } } + fn infer_region_deferred(&mut self, definition: Definition<'db>) { + match definition.node(self.db) { + DefinitionKind::Function(_function) => { + // TODO self.infer_function_deferred(function.node()); + } + DefinitionKind::Class(class) => self.infer_class_deferred(class.node()), + DefinitionKind::AnnotatedAssignment(_annotated_assignment) => { + // TODO self.infer_annotated_assignment_deferred(annotated_assignment.node()); + } + _ => {} + } + } + fn infer_region_expression(&mut self, expression: Expression<'db>) { self.infer_expression(expression.node_ref(self.db)); } @@ -555,7 +635,7 @@ impl<'db> TypeInferenceBuilder<'db> { name, type_params: _, decorator_list, - arguments, + arguments: _, body: _, } = class; @@ -563,21 +643,40 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_decorator(decorator); } - // TODO if there are type params, the bases should be inferred inside that scope (only) - - let bases = arguments - .as_deref() - .map(|arguments| self.infer_arguments(arguments)) - .unwrap_or(Vec::new()); - let body_scope = self .index .node_scope(NodeWithScopeRef::Class(class)) .to_scope_id(self.db, self.file); - let class_ty = Type::Class(ClassType::new(self.db, name.id.clone(), bases, body_scope)); + let class_ty = Type::Class(ClassType::new( + self.db, + name.id.clone(), + definition, + body_scope, + )); self.types.definitions.insert(definition, class_ty); + + for keyword in class.keywords() { + self.infer_expression(&keyword.value); + } + + // inference of bases deferred in stubs + // TODO also defer stringified generic type parameters + if !self.is_stub() { + for base in class.bases() { + self.infer_expression(base); + } + } + } + + fn infer_class_deferred(&mut self, class: &ast::StmtClassDef) { + if self.is_stub() { + self.types.has_deferred = true; + for base in class.bases() { + self.infer_expression(base); + } + } } fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) { @@ -1123,7 +1222,7 @@ impl<'db> TypeInferenceBuilder<'db> { asname: _, } = alias; - let member_ty = module_ty.member(self.db, &Name::new(&name.id)); + let member_ty = module_ty.member(self.db, &ast::name::Name::new(&name.id)); // TODO: What if it's a union where one of the elements is `Unbound`? if member_ty.is_unbound() { @@ -1668,10 +1767,19 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { let ast::ExprName { range: _, id, ctx } = name; + let file_scope_id = self.scope.file_scope_id(self.db); + + // if we're inferring types of deferred expressions, always treat them as public symbols + if self.is_deferred() { + let symbols = self.index.symbol_table(file_scope_id); + let symbol = symbols + .symbol_id_by_name(id) + .expect("Expected the symbol table to create a symbol for every Name node"); + return symbol_ty(self.db, self.scope, symbol); + } match ctx { ExprContext::Load => { - let file_scope_id = self.scope.file_scope_id(self.db); let use_def = self.index.use_def_map(file_scope_id); let use_id = name.scoped_use_id(self.db, self.scope); let may_be_unbound = use_def.use_may_be_unbound(use_id); @@ -1721,7 +1829,7 @@ impl<'db> TypeInferenceBuilder<'db> { } = attribute; let value_ty = self.infer_expression(value); - let member_ty = value_ty.member(self.db, &Name::new(&attr.id)); + let member_ty = value_ty.member(self.db, &ast::name::Name::new(&attr.id)); match ctx { ExprContext::Load => member_ty, @@ -2259,7 +2367,6 @@ mod tests { let base_names: Vec<_> = class .bases(&db) - .iter() .map(|base_ty| format!("{}", base_ty.display(&db))) .collect(); @@ -2814,14 +2921,14 @@ mod tests { let Type::Class(c_class) = c_ty else { panic!("C is not a Class") }; - let c_bases = c_class.bases(&db); - let b_ty = c_bases.first().unwrap(); + let mut c_bases = c_class.bases(&db); + let b_ty = c_bases.next().unwrap(); let Type::Class(b_class) = b_ty else { panic!("B is not a Class") }; assert_eq!(b_class.name(&db), "B"); - let b_bases = b_class.bases(&db); - let a_ty = b_bases.first().unwrap(); + let mut b_bases = b_class.bases(&db); + let a_ty = b_bases.next().unwrap(); let Type::Class(a_class) = a_ty else { panic!("A is not a Class") }; @@ -3053,6 +3160,24 @@ mod tests { Ok(()) } + /// A class's bases can be self-referential; this looks silly but a slightly more complex + /// version of it actually occurs in typeshed: `class str(Sequence[str]): ...` + #[test] + fn cyclical_class_pyi_definition() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_file("/src/a.pyi", "class C(C): ...")?; + assert_public_ty(&db, "/src/a.pyi", "C", "Literal[C]"); + Ok(()) + } + + #[test] + fn str_builtin() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_file("/src/a.py", "x = str")?; + assert_public_ty(&db, "/src/a.py", "x", "Literal[str]"); + Ok(()) + } + #[test] fn narrow_not_none() -> anyhow::Result<()> { let mut db = setup_db(); diff --git a/crates/ruff_db/src/files.rs b/crates/ruff_db/src/files.rs index c21a26474b325..ec1f6939207c2 100644 --- a/crates/ruff_db/src/files.rs +++ b/crates/ruff_db/src/files.rs @@ -8,6 +8,7 @@ use salsa::{Durability, Setter}; pub use file_root::{FileRoot, FileRootKind}; pub use path::FilePath; use ruff_notebook::{Notebook, NotebookError}; +use ruff_python_ast::PySourceType; use crate::file_revision::FileRevision; use crate::files::file_root::FileRoots; @@ -424,6 +425,13 @@ impl File { pub fn exists(self, db: &dyn Db) -> bool { self.status(db) == FileStatus::Exists } + + /// Returns `true` if the file should be analyzed as a type stub. + pub fn is_stub(self, db: &dyn Db) -> bool { + self.path(db) + .extension() + .is_some_and(|extension| PySourceType::from_extension(extension).is_stub()) + } } /// A virtual file that doesn't exist on the file system.