From 966d8a607436f371079f2e7fa3774046fb9785a5 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 21 Jan 2025 12:38:14 -0300 Subject: [PATCH] fix: allow calling trait impl method from struct if multiple impls exist (#7124) --- compiler/noirc_frontend/src/ast/statement.rs | 7 ++ compiler/noirc_frontend/src/elaborator/mod.rs | 1 + .../src/elaborator/path_resolution.rs | 24 +--- .../noirc_frontend/src/elaborator/types.rs | 103 +++++++++++++++--- compiler/noirc_frontend/src/hir_def/traits.rs | 16 +-- compiler/noirc_frontend/src/tests/traits.rs | 31 +++++- 6 files changed, 133 insertions(+), 49 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 57572e80d1e..3a70ff33a35 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -12,6 +12,7 @@ use super::{ }; use crate::ast::UnresolvedTypeData; use crate::elaborator::types::SELF_TYPE_NAME; +use crate::elaborator::Turbofish; use crate::lexer::token::SpannedToken; use crate::node_interner::{ InternedExpressionKind, InternedPattern, InternedStatementKind, NodeInterner, @@ -535,6 +536,12 @@ impl PathSegment { pub fn turbofish_span(&self) -> Span { Span::from(self.ident.span().end()..self.span.end()) } + + pub fn turbofish(&self) -> Option { + self.generics + .as_ref() + .map(|generics| Turbofish { span: self.turbofish_span(), generics: generics.clone() }) + } } impl From for PathSegment { diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 79f6be444ce..0e8850b6543 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -61,6 +61,7 @@ mod unquote; use fm::FileId; use iter_extended::vecmap; use noirc_errors::{Location, Span, Spanned}; +pub use path_resolution::Turbofish; use path_resolution::{PathResolution, PathResolutionItem}; use types::bind_ordered_generics; diff --git a/compiler/noirc_frontend/src/elaborator/path_resolution.rs b/compiler/noirc_frontend/src/elaborator/path_resolution.rs index 0d0b153b6b6..120ead52883 100644 --- a/compiler/noirc_frontend/src/elaborator/path_resolution.rs +++ b/compiler/noirc_frontend/src/elaborator/path_resolution.rs @@ -227,13 +227,7 @@ impl<'context> Elaborator<'context> { ModuleDefId::TypeId(id) => ( id.module_id(), true, - IntermediatePathResolutionItem::Struct( - id, - last_segment_generics.as_ref().map(|generics| Turbofish { - generics: generics.clone(), - span: last_segment.turbofish_span(), - }), - ), + IntermediatePathResolutionItem::Struct(id, last_segment.turbofish()), ), ModuleDefId::TypeAliasId(id) => { let type_alias = self.interner.get_type_alias(id); @@ -244,25 +238,13 @@ impl<'context> Elaborator<'context> { ( module_id, true, - IntermediatePathResolutionItem::TypeAlias( - id, - last_segment_generics.as_ref().map(|generics| Turbofish { - generics: generics.clone(), - span: last_segment.turbofish_span(), - }), - ), + IntermediatePathResolutionItem::TypeAlias(id, last_segment.turbofish()), ) } ModuleDefId::TraitId(id) => ( id.0, false, - IntermediatePathResolutionItem::Trait( - id, - last_segment_generics.as_ref().map(|generics| Turbofish { - generics: generics.clone(), - span: last_segment.turbofish_span(), - }), - ), + IntermediatePathResolutionItem::Trait(id, last_segment.turbofish()), ), ModuleDefId::FunctionId(_) => panic!("functions cannot be in the type namespace"), ModuleDefId::GlobalId(_) => panic!("globals cannot be in the type namespace"), diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index a1b63910a3e..36427997843 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -684,6 +684,60 @@ impl<'context> Elaborator<'context> { None } + /// This resolves a method in the form `Struct::method` where `method` is a trait method + fn resolve_struct_trait_method(&mut self, path: &Path) -> Option { + if path.segments.len() < 2 { + return None; + } + + let mut path = path.clone(); + let span = path.span(); + let last_segment = path.pop(); + let before_last_segment = path.last_segment(); + + let path_resolution = self.resolve_path(path).ok()?; + let PathResolutionItem::Struct(struct_id) = path_resolution.item else { + return None; + }; + + let struct_type = self.get_struct(struct_id); + let generics = struct_type.borrow().instantiate(self.interner); + let typ = Type::Struct(struct_type, generics); + let method_name = &last_segment.ident.0.contents; + + // If we can find a method on the struct, this is definitely not a trait method + if self.interner.lookup_direct_method(&typ, method_name, false).is_some() { + return None; + } + + let trait_methods = self.interner.lookup_trait_methods(&typ, method_name, false); + if trait_methods.is_empty() { + return None; + } + + let (hir_method_reference, error) = + self.get_trait_method_in_scope(&trait_methods, method_name, last_segment.span); + let hir_method_reference = hir_method_reference?; + let func_id = hir_method_reference.func_id(self.interner)?; + let HirMethodReference::TraitMethodId(trait_method_id, _, _) = hir_method_reference else { + return None; + }; + + let trait_id = trait_method_id.trait_id; + let trait_ = self.interner.get_trait(trait_id); + let mut constraint = trait_.as_constraint(span); + constraint.typ = typ; + + let method = TraitMethod { method_id: trait_method_id, constraint, assumed: false }; + let turbofish = before_last_segment.turbofish(); + let item = PathResolutionItem::TraitFunction(trait_id, turbofish, func_id); + let mut errors = path_resolution.errors; + if let Some(error) = error { + errors.push(error); + } + Some(TraitPathResolution { method, item: Some(item), errors }) + } + // Try to resolve the given trait method path. // // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not @@ -695,6 +749,7 @@ impl<'context> Elaborator<'context> { self.resolve_trait_static_method_by_self(path) .or_else(|| self.resolve_trait_static_method(path)) .or_else(|| self.resolve_trait_method_by_named_generic(path)) + .or_else(|| self.resolve_struct_trait_method(path)) } pub(super) fn unify( @@ -1456,6 +1511,19 @@ impl<'context> Elaborator<'context> { method_name: &str, span: Span, ) -> Option { + let (method, error) = self.get_trait_method_in_scope(trait_methods, method_name, span); + if let Some(error) = error { + self.push_err(error); + } + method + } + + fn get_trait_method_in_scope( + &mut self, + trait_methods: &[(FuncId, TraitId)], + method_name: &str, + span: Span, + ) -> (Option, Option) { let module_id = self.module_id(); let module_data = self.get_module(module_id); @@ -1489,28 +1557,24 @@ impl<'context> Elaborator<'context> { let trait_id = *traits.iter().next().unwrap(); let trait_ = self.interner.get_trait(trait_id); let trait_name = self.fully_qualified_trait_path(trait_); - - self.push_err(PathResolutionError::TraitMethodNotInScope { + let method = + self.trait_hir_method_reference(trait_id, trait_methods, method_name, span); + let error = PathResolutionError::TraitMethodNotInScope { ident: Ident::new(method_name.into(), span), trait_name, - }); - - return Some(self.trait_hir_method_reference( - trait_id, - trait_methods, - method_name, - span, - )); + }; + return (Some(method), Some(error)); } else { let traits = vecmap(traits, |trait_id| { let trait_ = self.interner.get_trait(trait_id); self.fully_qualified_trait_path(trait_) }); - self.push_err(PathResolutionError::UnresolvedWithPossibleTraitsToImport { + let method = None; + let error = PathResolutionError::UnresolvedWithPossibleTraitsToImport { ident: Ident::new(method_name.into(), span), traits, - }); - return None; + }; + return (method, Some(error)); } } @@ -1519,15 +1583,18 @@ impl<'context> Elaborator<'context> { let trait_ = self.interner.get_trait(trait_id); self.fully_qualified_trait_path(trait_) }); - self.push_err(PathResolutionError::MultipleTraitsInScope { + let method = None; + let error = PathResolutionError::MultipleTraitsInScope { ident: Ident::new(method_name.into(), span), traits, - }); - return None; + }; + return (method, Some(error)); } let trait_id = traits_in_scope[0].0; - Some(self.trait_hir_method_reference(trait_id, trait_methods, method_name, span)) + let method = self.trait_hir_method_reference(trait_id, trait_methods, method_name, span); + let error = None; + (Some(method), error) } fn trait_hir_method_reference( @@ -1545,7 +1612,7 @@ impl<'context> Elaborator<'context> { // Return a TraitMethodId with unbound generics. These will later be bound by the type-checker. let trait_ = self.interner.get_trait(trait_id); - let generics = trait_.as_constraint(span).trait_bound.trait_generics; + let generics = trait_.get_trait_generics(span); let trait_method_id = trait_.find_method(method_name).unwrap(); HirMethodReference::TraitMethodId(trait_method_id, generics, false) } diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index ff0cac027b1..a80c25492a3 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -186,22 +186,22 @@ impl Trait { (ordered, named) } - /// Returns a TraitConstraint for this trait using Self as the object - /// type and the uninstantiated generics for any trait generics. - pub fn as_constraint(&self, span: Span) -> TraitConstraint { + pub fn get_trait_generics(&self, span: Span) -> TraitGenerics { let ordered = vecmap(&self.generics, |generic| generic.clone().as_named_generic()); let named = vecmap(&self.associated_types, |generic| { let name = Ident::new(generic.name.to_string(), span); NamedType { name, typ: generic.clone().as_named_generic() } }); + TraitGenerics { ordered, named } + } + /// Returns a TraitConstraint for this trait using Self as the object + /// type and the uninstantiated generics for any trait generics. + pub fn as_constraint(&self, span: Span) -> TraitConstraint { + let trait_generics = self.get_trait_generics(span); TraitConstraint { typ: Type::TypeVariable(self.self_type_typevar.clone()), - trait_bound: ResolvedTraitBound { - trait_generics: TraitGenerics { ordered, named }, - trait_id: self.id, - span, - }, + trait_bound: ResolvedTraitBound { trait_generics, trait_id: self.id, span }, } } } diff --git a/compiler/noirc_frontend/src/tests/traits.rs b/compiler/noirc_frontend/src/tests/traits.rs index 11cdb95391f..eba9889a8e1 100644 --- a/compiler/noirc_frontend/src/tests/traits.rs +++ b/compiler/noirc_frontend/src/tests/traits.rs @@ -1341,9 +1341,7 @@ fn regression_6530() { assert_eq!(errors.len(), 0); } -// See https://github.com/noir-lang/noir/issues/7090 #[test] -#[should_panic] fn calls_trait_method_using_struct_name_when_multiple_impls_exist() { let src = r#" trait From2 { @@ -1367,3 +1365,32 @@ fn calls_trait_method_using_struct_name_when_multiple_impls_exist() { "#; assert_no_errors(src); } + +#[test] +fn calls_trait_method_using_struct_name_when_multiple_impls_exist_and_errors_turbofish() { + let src = r#" + trait From2 { + fn from2(input: T) -> Self; + } + struct U60Repr {} + impl From2<[Field; 3]> for U60Repr { + fn from2(_: [Field; 3]) -> Self { + U60Repr {} + } + } + impl From2 for U60Repr { + fn from2(_: Field) -> Self { + U60Repr {} + } + } + fn main() { + let _ = U60Repr::::from2([1, 2, 3]); + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + errors[0].0, + CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) + )); +}