diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index bace90b0653c91..9c8c4a8cb27eca 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -292,6 +292,16 @@ impl<'a> PreorderVisitor<'a> for SymbolTableBuilder { ast::visitor::preorder::walk_stmt(self, stmt); self.pop_scope(); } + ast::Stmt::Import(ast::StmtImport { names, .. }) => { + for alias in names { + self.add_symbol(alias.name.id.split('.').next().unwrap()); + } + } + ast::Stmt::ImportFrom(ast::StmtImportFrom { names, .. }) => { + for alias in names { + self.add_symbol(&alias.name.id); + } + } _ => { ast::visitor::preorder::walk_stmt(self, stmt); } @@ -338,6 +348,24 @@ mod tests { assert_eq!(names(table.root_symbols()), vec!["int", "x"]); } + #[test] + fn import() { + let table = build("import foo"); + assert_eq!(names(table.root_symbols()), vec!["foo"]); + } + + #[test] + fn import_sub() { + let table = build("import foo.bar"); + assert_eq!(names(table.root_symbols()), vec!["foo"]); + } + + #[test] + fn import_from() { + let table = build("from bar import foo"); + assert_eq!(names(table.root_symbols()), vec!["foo"]); + } + #[test] fn class_scope() { let table = build(