From f03e5812439bdf9d1aedc69debdc50ba5dba2049 Mon Sep 17 00:00:00 2001 From: jfecher Date: Thu, 7 Dec 2023 09:49:52 -0600 Subject: [PATCH] fix: `try_unify` no longer binds types on failure (#3697) # Description ## Problem\* Resolves #3089. This issue also affected impl search which uses `try_unify` to compare against every impl candidate. In some programs this could cause invalid type bindings, eventually leading to various mismatched/unexpected types panics during SSA. ## Summary\* I've refactored `try_unify` to instead return a set of type bindings which must be applied afterward if the function was successful. Each caller of `try_unify` must do this manually now. Additionally, `lookup_trait_implementation` must now juggle several sets of type bindings and only commit to those that are not instantiation bindings (otherwise the generic trait would be permanently bound to another type), and even then only commit to them on success. ## Additional Context ## Documentation\* Check one: - [ ] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: kevaundray --- .../noirc_evaluator/src/ssa/ssa_gen/mod.rs | 1 - .../src/hir/def_collector/dc_crate.rs | 20 +- .../src/hir/resolution/functions.rs | 4 +- .../src/hir/resolution/resolver.rs | 9 +- .../src/hir/resolution/traits.rs | 23 +- .../noirc_frontend/src/hir/type_check/expr.rs | 27 +- .../noirc_frontend/src/hir/type_check/stmt.rs | 5 +- compiler/noirc_frontend/src/hir_def/traits.rs | 41 +- compiler/noirc_frontend/src/hir_def/types.rs | 396 +++++++++++------- .../src/monomorphization/mod.rs | 33 +- compiler/noirc_frontend/src/node_interner.rs | 69 ++- .../method_call_regression/Nargo.toml | 7 + .../method_call_regression/src/main.nr | 25 ++ 13 files changed, 414 insertions(+), 246 deletions(-) create mode 100644 test_programs/compile_success_empty/method_call_regression/Nargo.toml create mode 100644 test_programs/compile_success_empty/method_call_regression/src/main.nr diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index 41327c988d2..d7e6b8b0a3d 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -591,7 +591,6 @@ impl<'a> FunctionContext<'a> { } self.codegen_intrinsic_call_checks(function, &arguments, call.location); - Ok(self.insert_call(function, arguments, &call.return_type, call.location)) } diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 86122530cde..0806a8eb757 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -415,7 +415,7 @@ pub(crate) fn check_methods_signatures( let self_type = resolver.get_self_type().expect("trait impl must have a Self type"); // Temporarily bind the trait's Self type to self_type so we can type check - let _ = the_trait.self_type_typevar.borrow_mut().bind_to(self_type.clone(), the_trait.span); + the_trait.self_type_typevar.bind(self_type.clone()); for (file_id, func_id) in impl_methods { let impl_method = resolver.interner.function_meta(func_id); @@ -433,7 +433,10 @@ pub(crate) fn check_methods_signatures( let impl_method_generic_count = impl_method.typ.generic_count() - trait_impl_generic_count; - let trait_method_generic_count = trait_method.generics.len(); + + // We subtract 1 here to account for the implicit generic `Self` type that is on all + // traits (and thus trait methods) but is not required (or allowed) for users to specify. + let trait_method_generic_count = trait_method.generics().len() - 1; if impl_method_generic_count != trait_method_generic_count { let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics { @@ -447,9 +450,9 @@ pub(crate) fn check_methods_signatures( } if let Type::Function(impl_params, _, _) = impl_function_type.0 { - if trait_method.arguments.len() == impl_params.len() { + if trait_method.arguments().len() == impl_params.len() { // Check the parameters of the impl method against the parameters of the trait method - let args = trait_method.arguments.iter(); + let args = trait_method.arguments().iter(); let args_and_params = args.zip(&impl_params).zip(&impl_method.parameters.0); for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in @@ -468,7 +471,7 @@ pub(crate) fn check_methods_signatures( } else { let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters { actual_num_parameters: impl_method.parameters.0.len(), - expected_num_parameters: trait_method.arguments.len(), + expected_num_parameters: trait_method.arguments().len(), trait_name: the_trait.name.to_string(), method_name: func_name.to_string(), span: impl_method.location.span, @@ -481,11 +484,12 @@ pub(crate) fn check_methods_signatures( let resolved_return_type = resolver.resolve_type(impl_method.return_type.get_type().into_owned()); - trait_method.return_type.unify(&resolved_return_type, &mut typecheck_errors, || { + // TODO: This is not right since it may bind generic return types + trait_method.return_type().unify(&resolved_return_type, &mut typecheck_errors, || { let ret_type_span = impl_method.return_type.get_type().span; let expr_span = ret_type_span.expect("return type must always have a span"); - let expected_typ = trait_method.return_type.to_string(); + let expected_typ = trait_method.return_type().to_string(); let expr_typ = impl_method.return_type().to_string(); TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span } }); @@ -494,5 +498,5 @@ pub(crate) fn check_methods_signatures( } } - the_trait.self_type_typevar.borrow_mut().unbind(the_trait.self_type_typevar_id); + the_trait.self_type_typevar.unbind(the_trait.self_type_typevar_id); } diff --git a/compiler/noirc_frontend/src/hir/resolution/functions.rs b/compiler/noirc_frontend/src/hir/resolution/functions.rs index 387f94e129c..e63de9b9173 100644 --- a/compiler/noirc_frontend/src/hir/resolution/functions.rs +++ b/compiler/noirc_frontend/src/hir/resolution/functions.rs @@ -11,7 +11,7 @@ use crate::{ def_map::{CrateDefMap, ModuleId}, }, node_interner::{FuncId, NodeInterner, TraitImplId}, - Shared, Type, TypeBinding, + Type, TypeVariable, }; use super::{path_resolver::StandardPathResolver, resolver::Resolver}; @@ -24,7 +24,7 @@ pub(crate) fn resolve_function_set( mut unresolved_functions: UnresolvedFunctions, self_type: Option, trait_impl_id: Option, - impl_generics: Vec<(Rc, Shared, Span)>, + impl_generics: Vec<(Rc, TypeVariable, Span)>, errors: &mut Vec<(CompilationError, FileId)>, ) -> Vec<(FileId, FuncId)> { let file_id = unresolved_functions.file_id; diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 3a3e082bd5e..544e9618856 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -572,7 +572,7 @@ impl<'a> Resolver<'a> { match length { None => { let id = self.interner.next_type_variable_id(); - let typevar = Shared::new(TypeBinding::Unbound(id)); + let typevar = TypeVariable::unbound(id); new_variables.push((id, typevar.clone())); // 'Named'Generic is a bit of a misnomer here, we want a type variable that @@ -682,7 +682,7 @@ impl<'a> Resolver<'a> { vecmap(generics, |generic| { // Map the generic to a fresh type variable let id = self.interner.next_type_variable_id(); - let typevar = Shared::new(TypeBinding::Unbound(id)); + let typevar = TypeVariable::unbound(id); let span = generic.0.span(); // Check for name collisions of this generic @@ -925,10 +925,7 @@ impl<'a> Resolver<'a> { found.into_iter().collect() } - fn find_numeric_generics_in_type( - typ: &Type, - found: &mut BTreeMap>, - ) { + fn find_numeric_generics_in_type(typ: &Type, found: &mut BTreeMap) { match typ { Type::FieldElement | Type::Integer(_, _) diff --git a/compiler/noirc_frontend/src/hir/resolution/traits.rs b/compiler/noirc_frontend/src/hir/resolution/traits.rs index 702e96362a6..7a6cbccb081 100644 --- a/compiler/noirc_frontend/src/hir/resolution/traits.rs +++ b/compiler/noirc_frontend/src/hir/resolution/traits.rs @@ -18,7 +18,7 @@ use crate::{ }, hir_def::traits::{Trait, TraitConstant, TraitFunction, TraitImpl, TraitType}, node_interner::{FuncId, NodeInterner, TraitId}, - Path, Shared, TraitItem, Type, TypeVariableKind, + Path, Shared, TraitItem, Type, TypeBinding, TypeVariableKind, }; use super::{ @@ -111,8 +111,17 @@ fn resolve_trait_methods( resolver.set_self_type(Some(self_type)); let arguments = vecmap(parameters, |param| resolver.resolve_type(param.1.clone())); - let resolved_return_type = resolver.resolve_type(return_type.get_type().into_owned()); - let generics = resolver.get_generics().to_vec(); + let return_type = resolver.resolve_type(return_type.get_type().into_owned()); + + let mut generics = vecmap(resolver.get_generics(), |(_, type_var, _)| match &*type_var + .borrow() + { + TypeBinding::Unbound(id) => (*id, type_var.clone()), + TypeBinding::Bound(binding) => unreachable!("Trait generic was bound to {binding}"), + }); + + // Ensure the trait is generic over the Self type as well + generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar)); let name = name.clone(); let span: Span = name.span(); @@ -128,11 +137,13 @@ fn resolve_trait_methods( None }; + let no_environment = Box::new(Type::Unit); + let function_type = Type::Function(arguments, Box::new(return_type), no_environment); + let typ = Type::Forall(generics, Box::new(function_type)); + let f = TraitFunction { name, - generics, - arguments, - return_type: resolved_return_type, + typ, span, default_impl, default_impl_file_id: unresolved_trait.file_id, diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index c720b87043c..5263434a358 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -11,7 +11,7 @@ use crate::{ types::Type, }, node_interner::{DefinitionKind, ExprId, FuncId, TraitId, TraitMethodId}, - BinaryOpKind, Signedness, TypeBinding, TypeVariableKind, UnaryOp, + BinaryOpKind, Signedness, TypeBinding, TypeBindings, TypeVariableKind, UnaryOp, }; use super::{errors::TypeCheckError, TypeChecker}; @@ -289,14 +289,7 @@ impl<'interner> TypeChecker<'interner> { } HirExpression::TraitMethodReference(method) => { let the_trait = self.interner.get_trait(method.trait_id); - let method = &the_trait.methods[method.method_index]; - - let typ = Type::Function( - method.arguments.clone(), - Box::new(method.return_type.clone()), - Box::new(Type::Unit), - ); - + let typ = &the_trait.methods[method.method_index].typ; let (typ, bindings) = typ.instantiate(self.interner); self.interner.store_instantiation_bindings(*expr_id, bindings); typ @@ -546,7 +539,7 @@ impl<'interner> TypeChecker<'interner> { HirMethodReference::TraitMethodId(method) => { let the_trait = self.interner.get_trait(method.trait_id); let method = &the_trait.methods[method.method_index]; - (method.get_type(), method.arguments.len()) + (method.typ.clone(), method.arguments().len()) } }; @@ -778,7 +771,11 @@ impl<'interner> TypeChecker<'interner> { })); } - if other.try_bind_to_polymorphic_int(int).is_ok() || other == &Type::Error { + let mut bindings = TypeBindings::new(); + if other.try_bind_to_polymorphic_int(int, &mut bindings).is_ok() + || other == &Type::Error + { + Type::apply_type_bindings(bindings); Ok(Bool) } else { Err(TypeCheckError::TypeMismatchWithSource { @@ -1009,7 +1006,7 @@ impl<'interner> TypeChecker<'interner> { let env_type = self.interner.next_type_variable(); let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type)); - if let Err(error) = binding.borrow_mut().bind_to(expected, span) { + if let Err(error) = binding.try_bind(expected, span) { self.errors.push(error); } ret @@ -1077,7 +1074,11 @@ impl<'interner> TypeChecker<'interner> { })); } - if other.try_bind_to_polymorphic_int(int).is_ok() || other == &Type::Error { + let mut bindings = TypeBindings::new(); + if other.try_bind_to_polymorphic_int(int, &mut bindings).is_ok() + || other == &Type::Error + { + Type::apply_type_bindings(bindings); Ok(other.clone()) } else { Err(TypeCheckError::TypeMismatchWithSource { diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index aa2a947c961..0cbfafb4c79 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -8,7 +8,6 @@ use crate::hir_def::stmt::{ }; use crate::hir_def::types::Type; use crate::node_interner::{DefinitionId, ExprId, StmtId}; -use crate::{Shared, TypeBinding, TypeVariableKind}; use super::errors::{Source, TypeCheckError}; use super::TypeChecker; @@ -71,9 +70,7 @@ impl<'interner> TypeChecker<'interner> { expr_span: range_span, }); - let fresh_id = self.interner.next_type_variable_id(); - let type_variable = Shared::new(TypeBinding::Unbound(fresh_id)); - let expected_type = Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField); + let expected_type = Type::polymorphic_integer(self.interner); self.unify(&start_range_type, &expected_type, || { TypeCheckError::TypeCannotBeUsed { diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 5f0bf49ca0f..e6c46a46073 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{ graph::CrateId, node_interner::{FuncId, TraitId, TraitMethodId}, @@ -11,9 +9,7 @@ use noirc_errors::Span; #[derive(Clone, Debug, PartialEq, Eq)] pub struct TraitFunction { pub name: Ident, - pub generics: Vec<(Rc, TypeVariable, Span)>, - pub arguments: Vec, - pub return_type: Type, + pub typ: Type, pub span: Span, pub default_impl: Option>, pub default_impl_file_id: fm::FileId, @@ -145,12 +141,33 @@ impl std::fmt::Display for Trait { } impl TraitFunction { - pub fn get_type(&self) -> Type { - Type::Function( - self.arguments.clone(), - Box::new(self.return_type.clone()), - Box::new(Type::Unit), - ) - .generalize() + pub fn arguments(&self) -> &[Type] { + match &self.typ { + Type::Function(args, _, _) => args, + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(args, _, _) => args, + _ => unreachable!("Trait function does not have a function type"), + }, + _ => unreachable!("Trait function does not have a function type"), + } + } + + pub fn generics(&self) -> &[(TypeVariableId, TypeVariable)] { + match &self.typ { + Type::Function(..) => &[], + Type::Forall(generics, _) => generics, + _ => unreachable!("Trait function does not have a function type"), + } + } + + pub fn return_type(&self) -> &Type { + match &self.typ { + Type::Function(_, return_type, _) => return_type, + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(_, return_type, _) => return_type, + _ => unreachable!("Trait function does not have a function type"), + }, + _ => unreachable!("Trait function does not have a function type"), + } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 46818626a16..9634a1e9d88 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -75,7 +75,7 @@ pub enum Type { /// the environment should be `Unit` by default, /// for closures it should contain a `Tuple` type with the captured /// variable types. - Function(Vec, Box, Box), + Function(Vec, /*return_type:*/ Box, /*environment:*/ Box), /// &mut T MutableReference(Box), @@ -298,7 +298,7 @@ impl std::fmt::Display for TypeAliasType { write!(f, "{}", self.name)?; if !self.generics.is_empty() { - let generics = vecmap(&self.generics, |(_, binding)| binding.borrow().to_string()); + let generics = vecmap(&self.generics, |(_, binding)| binding.0.borrow().to_string()); write!(f, "{}", generics.join(", "))?; } @@ -413,7 +413,66 @@ pub enum TypeVariableKind { /// A TypeVariable is a mutable reference that is either /// bound to some type, or unbound with a given TypeVariableId. -pub type TypeVariable = Shared; +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct TypeVariable(Shared); + +impl TypeVariable { + pub fn unbound(id: TypeVariableId) -> Self { + TypeVariable(Shared::new(TypeBinding::Unbound(id))) + } + + /// Bind this type variable to a value. + /// + /// Panics if this TypeVariable is already Bound. + /// Also Panics if the ID of this TypeVariable occurs within the given + /// binding, as that would cause an infinitely recursive type. + pub fn bind(&self, typ: Type) { + let id = match &*self.0.borrow() { + TypeBinding::Bound(binding) => { + unreachable!("TypeVariable::bind, cannot bind bound var {} to {}", binding, typ) + } + TypeBinding::Unbound(id) => *id, + }; + + assert!(!typ.occurs(id)); + *self.0.borrow_mut() = TypeBinding::Bound(typ); + } + + pub fn try_bind(&self, binding: Type, span: Span) -> Result<(), TypeCheckError> { + let id = match &*self.0.borrow() { + TypeBinding::Bound(binding) => { + unreachable!("Expected unbound, found bound to {binding}") + } + TypeBinding::Unbound(id) => *id, + }; + + if binding.occurs(id) { + Err(TypeCheckError::TypeAnnotationsNeeded { span }) + } else { + *self.0.borrow_mut() = TypeBinding::Bound(binding); + Ok(()) + } + } + + /// Borrows this TypeVariable to (e.g.) manually match on the inner TypeBinding. + pub fn borrow(&self) -> std::cell::Ref { + self.0.borrow() + } + + /// Unbind this type variable, setting it to Unbound(id). + /// + /// This is generally a logic error to use outside of monomorphization. + pub fn unbind(&self, id: TypeVariableId) { + *self.0.borrow_mut() = TypeBinding::Unbound(id); + } + + /// Forcibly bind a type variable to a new type - even if the type + /// variable is already bound to a different type. This generally + /// a logic error to use outside of monomorphization. + pub fn force_bind(&self, typ: Type) { + *self.0.borrow_mut() = TypeBinding::Bound(typ); + } +} /// TypeBindings are the mutable insides of a TypeVariable. /// They are either bound to some type, or are unbound. @@ -427,24 +486,6 @@ impl TypeBinding { pub fn is_unbound(&self) -> bool { matches!(self, TypeBinding::Unbound(_)) } - - pub fn bind_to(&mut self, binding: Type, span: Span) -> Result<(), TypeCheckError> { - match self { - TypeBinding::Bound(_) => panic!("Tried to bind an already bound type variable!"), - TypeBinding::Unbound(id) => { - if binding.occurs(*id) { - Err(TypeCheckError::TypeAnnotationsNeeded { span }) - } else { - *self = TypeBinding::Bound(binding); - Ok(()) - } - } - } - } - - pub fn unbind(&mut self, id: TypeVariableId) { - *self = TypeBinding::Unbound(id); - } } /// A unique ID used to differentiate different type variables @@ -461,7 +502,8 @@ impl Type { } pub fn type_variable(id: TypeVariableId) -> Type { - Type::TypeVariable(Shared::new(TypeBinding::Unbound(id)), TypeVariableKind::Normal) + let var = TypeVariable(Shared::new(TypeBinding::Unbound(id))); + Type::TypeVariable(var, TypeVariableKind::Normal) } /// Returns a TypeVariable(_, TypeVariableKind::Constant(length)) to bind to @@ -469,13 +511,15 @@ impl Type { pub fn constant_variable(length: u64, interner: &mut NodeInterner) -> Type { let id = interner.next_type_variable_id(); let kind = TypeVariableKind::Constant(length); - Type::TypeVariable(Shared::new(TypeBinding::Unbound(id)), kind) + let var = TypeVariable(Shared::new(TypeBinding::Unbound(id))); + Type::TypeVariable(var, kind) } pub fn polymorphic_integer(interner: &mut NodeInterner) -> Type { let id = interner.next_type_variable_id(); let kind = TypeVariableKind::IntegerOrField; - Type::TypeVariable(Shared::new(TypeBinding::Unbound(id)), kind) + let var = TypeVariable(Shared::new(TypeBinding::Unbound(id))); + Type::TypeVariable(var, kind) } /// A bit of an awkward name for this function - this function returns @@ -484,7 +528,7 @@ impl Type { /// they shouldn't be bound over until monomorphization. pub fn is_bindable(&self) -> bool { match self { - Type::TypeVariable(binding, _) => match &*binding.borrow() { + Type::TypeVariable(binding, _) => match &*binding.0.borrow() { TypeBinding::Bound(binding) => binding.is_bindable(), TypeBinding::Unbound(_) => true, }, @@ -508,7 +552,7 @@ impl Type { // True if the given type is a NamedGeneric with the target_id let named_generic_id_matches_target = |typ: &Type| { if let Type::NamedGeneric(type_variable, _) = typ { - match &*type_variable.borrow() { + match &*type_variable.0.borrow() { TypeBinding::Bound(_) => { unreachable!("Named generics should not be bound until monomorphization") } @@ -608,7 +652,7 @@ impl Type { match self { Type::Forall(generics, _) => generics.len(), Type::TypeVariable(type_variable, _) | Type::NamedGeneric(type_variable, _) => { - match &*type_variable.borrow() { + match &*type_variable.0.borrow() { TypeBinding::Bound(binding) => binding.generic_count(), TypeBinding::Unbound(_) => 0, } @@ -617,31 +661,12 @@ impl Type { } } - /// Takes a monomorphic type and generalizes it over each of the given type variables. - pub(crate) fn generalize_from_variables( - self, - type_vars: HashMap, - ) -> Type { - let polymorphic_type_vars = vecmap(type_vars, |type_var| type_var); - Type::Forall(polymorphic_type_vars, Box::new(self)) - } - /// Takes a monomorphic type and generalizes it over each of the type variables in the /// given type bindings, ignoring what each type variable is bound to in the TypeBindings. pub(crate) fn generalize_from_substitutions(self, type_bindings: TypeBindings) -> Type { let polymorphic_type_vars = vecmap(type_bindings, |(id, (type_var, _))| (id, type_var)); Type::Forall(polymorphic_type_vars, Box::new(self)) } - - /// Takes a monomorphic type and generalizes it over each type variable found within. - /// - /// Note that Noir's type system assumes any Type::Forall are only present at top-level, - /// and thus all type variable's within a type are free. - pub(crate) fn generalize(self) -> Type { - let mut type_variables = HashMap::new(); - self.find_all_unbound_type_variables(&mut type_variables); - self.generalize_from_variables(type_variables) - } } impl std::fmt::Display for Type { @@ -661,23 +686,23 @@ impl std::fmt::Display for Type { Signedness::Signed => write!(f, "i{num_bits}"), Signedness::Unsigned => write!(f, "u{num_bits}"), }, - Type::TypeVariable(id, TypeVariableKind::Normal) => write!(f, "{}", id.borrow()), + Type::TypeVariable(var, TypeVariableKind::Normal) => write!(f, "{}", var.0.borrow()), Type::TypeVariable(binding, TypeVariableKind::IntegerOrField) => { - if let TypeBinding::Unbound(_) = &*binding.borrow() { + if let TypeBinding::Unbound(_) = &*binding.0.borrow() { // Show a Field by default if this TypeVariableKind::IntegerOrField is unbound, since that is // what they bind to by default anyway. It is less confusing than displaying it // as a generic. write!(f, "Field") } else { - write!(f, "{}", binding.borrow()) + write!(f, "{}", binding.0.borrow()) } } Type::TypeVariable(binding, TypeVariableKind::Constant(n)) => { - if let TypeBinding::Unbound(_) = &*binding.borrow() { + if let TypeBinding::Unbound(_) = &*binding.0.borrow() { // TypeVariableKind::Constant(n) binds to Type::Constant(n) by default, so just show that. write!(f, "{n}") } else { - write!(f, "{}", binding.borrow()) + write!(f, "{}", binding.0.borrow()) } } Type::Struct(s, args) => { @@ -702,7 +727,7 @@ impl std::fmt::Display for Type { } Type::Unit => write!(f, "()"), Type::Error => write!(f, "error"), - Type::NamedGeneric(binding, name) => match &*binding.borrow() { + Type::NamedGeneric(binding, name) => match &*binding.0.borrow() { TypeBinding::Bound(binding) => binding.fmt(f), TypeBinding::Unbound(_) if name.is_empty() => write!(f, "_"), TypeBinding::Unbound(_) => write!(f, "{name}"), @@ -761,58 +786,63 @@ pub struct UnificationError; impl Type { /// Try to bind a MaybeConstant variable to self, succeeding if self is a Constant, - /// MaybeConstant, or type variable. - pub fn try_bind_to_maybe_constant( + /// MaybeConstant, or type variable. If successful, the binding is placed in the + /// given TypeBindings map rather than linked immediately. + fn try_bind_to_maybe_constant( &self, var: &TypeVariable, target_length: u64, + bindings: &mut TypeBindings, ) -> Result<(), UnificationError> { - let target_id = match &*var.borrow() { + let target_id = match &*var.0.borrow() { TypeBinding::Bound(_) => unreachable!(), TypeBinding::Unbound(id) => *id, }; - match self { + let this = self.substitute(bindings); + + match &this { Type::Constant(length) if *length == target_length => { - *var.borrow_mut() = TypeBinding::Bound(self.clone()); + bindings.insert(target_id, (var.clone(), this)); Ok(()) } Type::NotConstant => { - *var.borrow_mut() = TypeBinding::Bound(Type::NotConstant); + bindings.insert(target_id, (var.clone(), Type::NotConstant)); Ok(()) } - Type::TypeVariable(binding, kind) => { - let borrow = binding.borrow(); + // A TypeVariable is less specific than a MaybeConstant, so we bind + // to the other type variable instead. + Type::TypeVariable(new_var, kind) => { + let borrow = new_var.0.borrow(); match &*borrow { - TypeBinding::Bound(typ) => typ.try_bind_to_maybe_constant(var, target_length), + TypeBinding::Bound(typ) => { + typ.try_bind_to_maybe_constant(var, target_length, bindings) + } // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), - TypeBinding::Unbound(_) => match kind { + TypeBinding::Unbound(new_target_id) => match kind { TypeVariableKind::Normal => { - drop(borrow); let clone = Type::TypeVariable( var.clone(), TypeVariableKind::Constant(target_length), ); - *binding.borrow_mut() = TypeBinding::Bound(clone); + bindings.insert(*new_target_id, (new_var.clone(), clone)); Ok(()) } TypeVariableKind::Constant(length) if *length == target_length => { - drop(borrow); let clone = Type::TypeVariable( var.clone(), TypeVariableKind::Constant(target_length), ); - *binding.borrow_mut() = TypeBinding::Bound(clone); + bindings.insert(*new_target_id, (new_var.clone(), clone)); Ok(()) } // The lengths don't match, but neither are set in stone so we can // just set them both to NotConstant. See issue 2370 TypeVariableKind::Constant(_) => { // *length != target_length - drop(borrow); - *var.borrow_mut() = TypeBinding::Bound(Type::NotConstant); - *binding.borrow_mut() = TypeBinding::Bound(Type::NotConstant); + bindings.insert(target_id, (var.clone(), Type::NotConstant)); + bindings.insert(*new_target_id, (new_var.clone(), Type::NotConstant)); Ok(()) } TypeVariableKind::IntegerOrField => Err(UnificationError), @@ -824,44 +854,49 @@ impl Type { } /// Try to bind a PolymorphicInt variable to self, succeeding if self is an integer, field, - /// other PolymorphicInt type, or type variable. - pub fn try_bind_to_polymorphic_int(&self, var: &TypeVariable) -> Result<(), UnificationError> { - let target_id = match &*var.borrow() { + /// other PolymorphicInt type, or type variable. If successful, the binding is placed in the + /// given TypeBindings map rather than linked immediately. + pub fn try_bind_to_polymorphic_int( + &self, + var: &TypeVariable, + bindings: &mut TypeBindings, + ) -> Result<(), UnificationError> { + let target_id = match &*var.0.borrow() { TypeBinding::Bound(_) => unreachable!(), TypeBinding::Unbound(id) => *id, }; - match self { + let this = self.substitute(bindings); + + match &this { Type::FieldElement | Type::Integer(..) => { - *var.borrow_mut() = TypeBinding::Bound(self.clone()); + bindings.insert(target_id, (var.clone(), this)); Ok(()) } Type::TypeVariable(self_var, TypeVariableKind::IntegerOrField) => { - let borrow = self_var.borrow(); + let borrow = self_var.0.borrow(); match &*borrow { - TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic_int(var), + TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic_int(var, bindings), // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), TypeBinding::Unbound(_) => { - drop(borrow); - *var.borrow_mut() = TypeBinding::Bound(self.clone()); + bindings.insert(target_id, (var.clone(), this.clone())); Ok(()) } } } Type::TypeVariable(binding, TypeVariableKind::Normal) => { - let borrow = binding.borrow(); + let borrow = binding.0.borrow(); match &*borrow { - TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic_int(var), + TypeBinding::Bound(typ) => typ.try_bind_to_polymorphic_int(var, bindings), // Avoid infinitely recursive bindings TypeBinding::Unbound(id) if *id == target_id => Ok(()), - TypeBinding::Unbound(_) => { - drop(borrow); - // PolymorphicInt is more specific than TypeVariable so we bind the type - // variable to PolymorphicInt instead. + TypeBinding::Unbound(new_target_id) => { + // IntegerOrField is more specific than TypeVariable so we bind the type + // variable to IntegerOrField instead. let clone = Type::TypeVariable(var.clone(), TypeVariableKind::IntegerOrField); - *binding.borrow_mut() = TypeBinding::Bound(clone); + bindings.insert(*new_target_id, (binding.clone(), clone)); Ok(()) } } @@ -870,102 +905,113 @@ impl Type { } } - pub fn try_bind_to(&self, var: &TypeVariable) -> Result<(), UnificationError> { - let target_id = match &*var.borrow() { + /// Try to bind the given type variable to self. Although the given type variable + /// is expected to be of TypeVariableKind::Normal, this binding can still fail + /// if the given type variable occurs within `self` as that would create a recursive type. + /// + /// If successful, the binding is placed in the + /// given TypeBindings map rather than linked immediately. + fn try_bind_to( + &self, + var: &TypeVariable, + bindings: &mut TypeBindings, + ) -> Result<(), UnificationError> { + let target_id = match &*var.0.borrow() { TypeBinding::Bound(_) => unreachable!(), TypeBinding::Unbound(id) => *id, }; - if let Some(binding) = self.get_inner_type_variable() { + let this = self.substitute(bindings); + + if let Some(binding) = this.get_inner_type_variable() { match &*binding.borrow() { - TypeBinding::Bound(typ) => return typ.try_bind_to(var), + TypeBinding::Bound(typ) => return typ.try_bind_to(var, bindings), // Don't recursively bind the same id to itself TypeBinding::Unbound(id) if *id == target_id => return Ok(()), _ => (), } } - // Check if the target id occurs within self before binding. Otherwise this could + // Check if the target id occurs within `this` before binding. Otherwise this could // cause infinitely recursive types - if self.occurs(target_id) { + if this.occurs(target_id) { Err(UnificationError) } else { - *var.borrow_mut() = TypeBinding::Bound(self.clone()); + bindings.insert(target_id, (var.clone(), this.clone())); Ok(()) } } fn get_inner_type_variable(&self) -> Option> { match self { - Type::TypeVariable(var, _) | Type::NamedGeneric(var, _) => Some(var.clone()), + Type::TypeVariable(var, _) | Type::NamedGeneric(var, _) => Some(var.0.clone()), _ => None, } } /// Try to unify this type with another, setting any type variables found - /// equal to the other type in the process. Unification is more strict - /// than sub-typing but less strict than Eq. Returns true if the unification - /// succeeded. Note that any bindings performed in a failed unification are - /// not undone. This may cause further type errors later on. + /// equal to the other type in the process. When comparing types, unification + /// (including try_unify) are almost always preferred over Type::eq as unification + /// will correctly handle generic types. pub fn unify( &self, expected: &Type, errors: &mut Vec, make_error: impl FnOnce() -> TypeCheckError, ) { - if let Err(UnificationError) = self.try_unify(expected) { - errors.push(make_error()); + let mut bindings = TypeBindings::new(); + + match self.try_unify(expected, &mut bindings) { + Ok(()) => { + // Commit any type bindings on success + Self::apply_type_bindings(bindings); + } + Err(UnificationError) => errors.push(make_error()), } } /// `try_unify` is a bit of a misnomer since although errors are not committed, /// any unified bindings are on success. - pub fn try_unify(&self, other: &Type) -> Result<(), UnificationError> { + pub fn try_unify( + &self, + other: &Type, + bindings: &mut TypeBindings, + ) -> Result<(), UnificationError> { use Type::*; use TypeVariableKind as Kind; match (self, other) { (Error, _) | (_, Error) => Ok(()), - (TypeVariable(binding, Kind::IntegerOrField), other) - | (other, TypeVariable(binding, Kind::IntegerOrField)) => { - // If it is already bound, unify against what it is bound to - if let TypeBinding::Bound(link) = &*binding.borrow() { - return link.try_unify(other); - } - - // Otherwise, check it is unified against an integer and bind it - other.try_bind_to_polymorphic_int(binding) + (TypeVariable(var, Kind::IntegerOrField), other) + | (other, TypeVariable(var, Kind::IntegerOrField)) => { + other.try_unify_to_type_variable(var, bindings, |bindings| { + other.try_bind_to_polymorphic_int(var, bindings) + }) } - (TypeVariable(binding, Kind::Normal), other) - | (other, TypeVariable(binding, Kind::Normal)) => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - return link.try_unify(other); - } - - other.try_bind_to(binding) + (TypeVariable(var, Kind::Normal), other) | (other, TypeVariable(var, Kind::Normal)) => { + other.try_unify_to_type_variable(var, bindings, |bindings| { + other.try_bind_to(var, bindings) + }) } - (TypeVariable(binding, Kind::Constant(length)), other) - | (other, TypeVariable(binding, Kind::Constant(length))) => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - return link.try_unify(other); - } - - other.try_bind_to_maybe_constant(binding, *length) - } + (TypeVariable(var, Kind::Constant(length)), other) + | (other, TypeVariable(var, Kind::Constant(length))) => other + .try_unify_to_type_variable(var, bindings, |bindings| { + other.try_bind_to_maybe_constant(var, *length, bindings) + }), (Array(len_a, elem_a), Array(len_b, elem_b)) => { - len_a.try_unify(len_b)?; - elem_a.try_unify(elem_b) + len_a.try_unify(len_b, bindings)?; + elem_a.try_unify(elem_b, bindings) } - (String(len_a), String(len_b)) => len_a.try_unify(len_b), + (String(len_a), String(len_b)) => len_a.try_unify(len_b, bindings), (FmtString(len_a, elements_a), FmtString(len_b, elements_b)) => { - len_a.try_unify(len_b)?; - elements_a.try_unify(elements_b) + len_a.try_unify(len_b, bindings)?; + elements_a.try_unify(elements_b, bindings) } (Tuple(elements_a), Tuple(elements_b)) => { @@ -973,7 +1019,7 @@ impl Type { Err(UnificationError) } else { for (a, b) in elements_a.iter().zip(elements_b) { - a.try_unify(b)?; + a.try_unify(b, bindings)?; } Ok(()) } @@ -985,7 +1031,7 @@ impl Type { (Struct(id_a, args_a), Struct(id_b, args_b)) => { if id_a == id_b && args_a.len() == args_b.len() { for (a, b) in args_a.iter().zip(args_b) { - a.try_unify(b)?; + a.try_unify(b, bindings)?; } Ok(()) } else { @@ -993,26 +1039,20 @@ impl Type { } } - (NamedGeneric(binding, _), other) if !binding.borrow().is_unbound() => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - link.try_unify(other) - } else { - unreachable!("If guard ensures binding is bound") - } - } - - (other, NamedGeneric(binding, _)) if !binding.borrow().is_unbound() => { - if let TypeBinding::Bound(link) = &*binding.borrow() { - other.try_unify(link) + (NamedGeneric(binding, _), other) | (other, NamedGeneric(binding, _)) + if !binding.0.borrow().is_unbound() => + { + if let TypeBinding::Bound(link) = &*binding.0.borrow() { + link.try_unify(other, bindings) } else { unreachable!("If guard ensures binding is bound") } } (NamedGeneric(binding_a, name_a), NamedGeneric(binding_b, name_b)) => { - // Unbound NamedGenerics are caught by the checks above - assert!(binding_a.borrow().is_unbound()); - assert!(binding_b.borrow().is_unbound()); + // Bound NamedGenerics are caught by the check above + assert!(binding_a.0.borrow().is_unbound()); + assert!(binding_b.0.borrow().is_unbound()); if name_a == name_b { Ok(()) @@ -1024,17 +1064,19 @@ impl Type { (Function(params_a, ret_a, env_a), Function(params_b, ret_b, env_b)) => { if params_a.len() == params_b.len() { for (a, b) in params_a.iter().zip(params_b.iter()) { - a.try_unify(b)?; + a.try_unify(b, bindings)?; } - env_a.try_unify(env_b)?; - ret_b.try_unify(ret_a) + env_a.try_unify(env_b, bindings)?; + ret_b.try_unify(ret_a, bindings) } else { Err(UnificationError) } } - (MutableReference(elem_a), MutableReference(elem_b)) => elem_a.try_unify(elem_b), + (MutableReference(elem_a), MutableReference(elem_b)) => { + elem_a.try_unify(elem_b, bindings) + } (other_a, other_b) => { if other_a == other_b { @@ -1046,6 +1088,34 @@ impl Type { } } + /// Try to unify a type variable to `self`. + /// This is a helper function factored out from try_unify. + fn try_unify_to_type_variable( + &self, + type_variable: &TypeVariable, + bindings: &mut TypeBindings, + + // Bind the type variable to a type. This is factored out since depending on the + // TypeVariableKind, there are different methods to check whether the variable can + // bind to the given type or not. + bind_variable: impl FnOnce(&mut TypeBindings) -> Result<(), UnificationError>, + ) -> Result<(), UnificationError> { + match &*type_variable.0.borrow() { + // If it is already bound, unify against what it is bound to + TypeBinding::Bound(link) => link.try_unify(self, bindings), + TypeBinding::Unbound(id) => { + // We may have already "bound" this type variable in this call to + // try_unify, so check those bindings as well. + match bindings.get(id) { + Some((_, binding)) => binding.clone().try_unify(self, bindings), + + // Otherwise, bind it + None => bind_variable(bindings), + } + } + } + } + /// Similar to `unify` but if the check fails this will attempt to coerce the /// argument to the target type. When this happens, the given expression is wrapped in /// a new expression to convert its type. E.g. `array` -> `array.as_slice()` @@ -1059,10 +1129,14 @@ impl Type { errors: &mut Vec, make_error: impl FnOnce() -> TypeCheckError, ) { - if let Err(UnificationError) = self.try_unify(expected) { + let mut bindings = TypeBindings::new(); + + if let Err(UnificationError) = self.try_unify(expected, &mut bindings) { if !self.try_array_to_slice_coercion(expected, expression, interner) { errors.push(make_error()); } + } else { + Type::apply_type_bindings(bindings); } } @@ -1085,8 +1159,10 @@ impl Type { if matches!(size1, Type::Constant(_)) && matches!(size2, Type::NotConstant) { // Still have to ensure the element types match. // Don't need to issue an error here if not, it will be done in unify_with_coercions - if element1.try_unify(element2).is_ok() { + let mut bindings = TypeBindings::new(); + if element1.try_unify(element2, &mut bindings).is_ok() { convert_array_expression_to_slice(expression, this, target, interner); + Self::apply_type_bindings(bindings); return true; } } @@ -1094,6 +1170,14 @@ impl Type { false } + /// Apply the given type bindings, making them permanently visible for each + /// clone of each type variable bound. + pub fn apply_type_bindings(bindings: TypeBindings) { + for (type_variable, binding) in bindings.values() { + type_variable.bind(binding.clone()); + } + } + /// If this type is a Type::Constant (used in array lengths), or is bound /// to a Type::Constant, return the constant as a u64. pub fn evaluate_to_u64(&self) -> Option { @@ -1229,7 +1313,7 @@ impl Type { } Type::Forall(_, typ) => typ.find_all_unbound_type_variables(type_variables), Type::TypeVariable(type_variable, _) | Type::NamedGeneric(type_variable, _) => { - match &*type_variable.borrow() { + match &*type_variable.0.borrow() { TypeBinding::Bound(binding) => { binding.find_all_unbound_type_variables(type_variables); } @@ -1251,7 +1335,7 @@ impl Type { return self.clone(); } - let substitute_binding = |binding: &TypeVariable| match &*binding.borrow() { + let substitute_binding = |binding: &TypeVariable| match &*binding.0.borrow() { TypeBinding::Bound(binding) => binding.substitute(type_bindings), TypeBinding::Unbound(id) => match type_bindings.get(id) { Some((_, binding)) => binding.clone(), @@ -1331,7 +1415,7 @@ impl Type { Type::Struct(_, generic_args) => generic_args.iter().any(|arg| arg.occurs(target_id)), Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), Type::NamedGeneric(binding, _) | Type::TypeVariable(binding, _) => { - match &*binding.borrow() { + match &*binding.0.borrow() { TypeBinding::Bound(binding) => binding.occurs(target_id), TypeBinding::Unbound(id) => *id == target_id, } @@ -1380,7 +1464,7 @@ impl Type { } Tuple(args) => Tuple(vecmap(args, |arg| arg.follow_bindings())), TypeVariable(var, _) | NamedGeneric(var, _) => { - if let TypeBinding::Bound(typ) = &*var.borrow() { + if let TypeBinding::Bound(typ) = &*var.0.borrow() { return typ.follow_bindings(); } self.clone() @@ -1485,7 +1569,7 @@ impl From<&Type> for PrintableType { Signedness::Signed => PrintableType::SignedInteger { width: *bit_width }, }, Type::TypeVariable(binding, TypeVariableKind::IntegerOrField) => { - match &*binding.borrow() { + match &*binding.0.borrow() { TypeBinding::Bound(typ) => typ.into(), TypeBinding::Unbound(_) => Type::default_int_type().into(), } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 2abb939e448..52ed0c746e1 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -221,7 +221,7 @@ impl<'interner> Monomorphizer<'interner> { fn function(&mut self, f: node_interner::FuncId, id: FuncId) { if let Some((self_type, trait_id)) = self.interner.get_function_trait(&f) { let the_trait = self.interner.get_trait(trait_id); - *the_trait.self_type_typevar.borrow_mut() = TypeBinding::Bound(self_type); + the_trait.self_type_typevar.force_bind(self_type); } let meta = self.interner.function_meta(&f); @@ -731,11 +731,7 @@ impl<'interner> Monomorphizer<'interner> { // Default any remaining unbound type variables. // This should only happen if the variable in question is unused // and within a larger generic type. - // NOTE: Make sure to review this if there is ever type-directed dispatch, - // like automatic solving of traits. It should be fine since it is strictly - // after type checking, but care should be taken that it doesn't change which - // impls are chosen. - *binding.borrow_mut() = TypeBinding::Bound(HirType::default_int_type()); + binding.bind(HirType::default_int_type()); ast::Type::Field } @@ -747,10 +743,6 @@ impl<'interner> Monomorphizer<'interner> { // Default any remaining unbound type variables. // This should only happen if the variable in question is unused // and within a larger generic type. - // NOTE: Make sure to review this if there is ever type-directed dispatch, - // like automatic solving of traits. It should be fine since it is strictly - // after type checking, but care should be taken that it doesn't change which - // impls are chosen. let default = if self.is_range_loop && matches!(kind, TypeVariableKind::IntegerOrField) { Type::default_range_loop_type() @@ -759,7 +751,7 @@ impl<'interner> Monomorphizer<'interner> { }; let monomorphized_default = self.convert_type(&default); - *binding.borrow_mut() = TypeBinding::Bound(default); + binding.bind(default); monomorphized_default } @@ -894,7 +886,6 @@ impl<'interner> Monomorphizer<'interner> { let original_func = Box::new(self.expr(call.func)); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); - let func: Box; let return_type = self.interner.id_type(id); let return_type = self.convert_type(&return_type); @@ -915,7 +906,8 @@ impl<'interner> Monomorphizer<'interner> { let func_type = self.interner.id_type(call.func); let func_type = self.convert_type(&func_type); let is_closure = self.is_function_closure(func_type); - if is_closure { + + let func = if is_closure { let local_id = self.next_local_id(); // store the function in a temporary variable before calling it @@ -937,14 +929,13 @@ impl<'interner> Monomorphizer<'interner> { typ: self.convert_type(&self.interner.id_type(call.func)), }); - func = Box::new(ast::Expression::ExtractTupleField( - Box::new(extracted_func.clone()), - 1usize, - )); - let env_argument = ast::Expression::ExtractTupleField(Box::new(extracted_func), 0usize); + let env_argument = + ast::Expression::ExtractTupleField(Box::new(extracted_func.clone()), 0usize); arguments.insert(0, env_argument); + + Box::new(ast::Expression::ExtractTupleField(Box::new(extracted_func), 1usize)) } else { - func = original_func.clone(); + original_func.clone() }; let call = self @@ -1450,12 +1441,12 @@ fn unwrap_struct_type(typ: &HirType) -> Vec<(String, HirType)> { fn perform_instantiation_bindings(bindings: &TypeBindings) { for (var, binding) in bindings.values() { - *var.borrow_mut() = TypeBinding::Bound(binding.clone()); + var.force_bind(binding.clone()); } } fn undo_instantiation_bindings(bindings: TypeBindings) { for (id, (var, _)) in bindings { - *var.borrow_mut() = TypeBinding::Unbound(id); + var.unbind(id); } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 72bed209715..0be1c93f478 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -22,7 +22,7 @@ use crate::hir_def::{ use crate::token::{Attributes, SecondaryAttribute}; use crate::{ ContractFunctionType, FunctionDefinition, FunctionVisibility, Generics, Shared, TypeAliasType, - TypeBinding, TypeBindings, TypeVariable, TypeVariableId, TypeVariableKind, + TypeBindings, TypeVariable, TypeVariableId, TypeVariableKind, }; /// An arbitrary number to limit the recursion depth when searching for trait impls. @@ -490,7 +490,6 @@ impl NodeInterner { pub fn push_empty_trait(&mut self, type_id: TraitId, typ: &UnresolvedTrait) { let self_type_typevar_id = self.next_type_variable_id(); - let self_type_typevar = Shared::new(TypeBinding::Unbound(self_type_typevar_id)); self.traits.insert( type_id, @@ -505,10 +504,10 @@ impl NodeInterner { // can refer to it with generic arguments before the generic parameters themselves // are resolved. let id = TypeVariableId(0); - (id, Shared::new(TypeBinding::Unbound(id))) + (id, TypeVariable::unbound(id)) }), self_type_typevar_id, - self_type_typevar, + TypeVariable::unbound(self_type_typevar_id), ), ); } @@ -530,7 +529,7 @@ impl NodeInterner { // can refer to it with generic arguments before the generic parameters themselves // are resolved. let id = TypeVariableId(0); - (id, Shared::new(TypeBinding::Unbound(id))) + (id, TypeVariable::unbound(id)) }); let new_struct = StructType::new(struct_id, name, typ.struct_def.span, no_fields, generics); @@ -549,7 +548,7 @@ impl NodeInterner { Type::Error, vecmap(&typ.type_alias_def.generics, |_| { let id = TypeVariableId(0); - (id, Shared::new(TypeBinding::Unbound(id))) + (id, TypeVariable::unbound(id)) }), )); @@ -1000,13 +999,32 @@ impl NodeInterner { object_type: &Type, trait_id: TraitId, ) -> Result> { - self.lookup_trait_implementation_helper(object_type, trait_id, IMPL_SEARCH_RECURSION_LIMIT) + let (impl_kind, bindings) = self.try_lookup_trait_implementation(object_type, trait_id)?; + Type::apply_type_bindings(bindings); + Ok(impl_kind) + } + + /// Similar to `lookup_trait_implementation` but does not apply any type bindings on success. + pub fn try_lookup_trait_implementation( + &self, + object_type: &Type, + trait_id: TraitId, + ) -> Result<(TraitImplKind, TypeBindings), Vec> { + let mut bindings = TypeBindings::new(); + let impl_kind = self.lookup_trait_implementation_helper( + object_type, + trait_id, + &mut bindings, + IMPL_SEARCH_RECURSION_LIMIT, + )?; + Ok((impl_kind, bindings)) } fn lookup_trait_implementation_helper( &self, object_type: &Type, trait_id: TraitId, + type_bindings: &mut TypeBindings, recursion_limit: u32, ) -> Result> { let make_constraint = || TraitConstraint::new(object_type.clone(), trait_id); @@ -1016,20 +1034,29 @@ impl NodeInterner { return Err(vec![make_constraint()]); } + let object_type = object_type.substitute(type_bindings); + let impls = self.trait_implementation_map.get(&trait_id).ok_or_else(|| vec![make_constraint()])?; for (existing_object_type, impl_kind) in impls { - let (existing_object_type, type_bindings) = existing_object_type.instantiate(self); + let (existing_object_type, instantiation_bindings) = + existing_object_type.instantiate(self); + + let mut fresh_bindings = TypeBindings::new(); + + if object_type.try_unify(&existing_object_type, &mut fresh_bindings).is_ok() { + // The unification was successful so we can append fresh_bindings to our bindings list + type_bindings.extend(fresh_bindings); - if object_type.try_unify(&existing_object_type).is_ok() { if let TraitImplKind::Normal(impl_id) = impl_kind { let trait_impl = self.get_trait_implementation(*impl_id); let trait_impl = trait_impl.borrow(); if let Err(mut errors) = self.validate_where_clause( &trait_impl.where_clause, - &type_bindings, + type_bindings, + &instantiation_bindings, recursion_limit, ) { errors.push(make_constraint()); @@ -1049,14 +1076,20 @@ impl NodeInterner { fn validate_where_clause( &self, where_clause: &[TraitConstraint], - type_bindings: &TypeBindings, + type_bindings: &mut TypeBindings, + instantiation_bindings: &TypeBindings, recursion_limit: u32, ) -> Result<(), Vec> { for constraint in where_clause { - let constraint_type = constraint.typ.substitute(type_bindings); + let constraint_type = constraint.typ.substitute(instantiation_bindings); + let constraint_type = constraint_type.substitute(type_bindings); + self.lookup_trait_implementation_helper( &constraint_type, constraint.trait_id, + // Use a fresh set of type bindings here since the constraint_type originates from + // our impl list, which we don't want to bind to. + &mut TypeBindings::new(), recursion_limit - 1, )?; } @@ -1077,7 +1110,7 @@ impl NodeInterner { trait_id: TraitId, ) -> bool { // Make sure there are no overlapping impls - if self.lookup_trait_implementation(&object_type, trait_id).is_ok() { + if self.try_lookup_trait_implementation(&object_type, trait_id).is_ok() { return false; } @@ -1105,8 +1138,8 @@ impl NodeInterner { let (instantiated_object_type, substitutions) = object_type.instantiate_type_variables(self); - if let Ok(TraitImplKind::Normal(existing)) = - self.lookup_trait_implementation(&instantiated_object_type, trait_id) + if let Ok((TraitImplKind::Normal(existing), _)) = + self.try_lookup_trait_implementation(&instantiated_object_type, trait_id) { let existing_impl = self.get_trait_implementation(existing); let existing_impl = existing_impl.borrow(); @@ -1285,8 +1318,10 @@ impl Methods { match interner.function_meta(&method).typ.instantiate(interner).0 { Type::Function(args, _, _) => { if let Some(object) = args.get(0) { - // TODO #3089: This is dangerous! try_unify may commit type bindings even on failure - if object.try_unify(typ).is_ok() { + let mut bindings = TypeBindings::new(); + + if object.try_unify(typ, &mut bindings).is_ok() { + Type::apply_type_bindings(bindings); return Some(method); } } diff --git a/test_programs/compile_success_empty/method_call_regression/Nargo.toml b/test_programs/compile_success_empty/method_call_regression/Nargo.toml new file mode 100644 index 00000000000..92c9b942008 --- /dev/null +++ b/test_programs/compile_success_empty/method_call_regression/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "short" +type = "bin" +authors = [""] +compiler_version = ">=0.19.4" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/method_call_regression/src/main.nr b/test_programs/compile_success_empty/method_call_regression/src/main.nr new file mode 100644 index 00000000000..8bb7ebcac45 --- /dev/null +++ b/test_programs/compile_success_empty/method_call_regression/src/main.nr @@ -0,0 +1,25 @@ +use dep::std; + +fn main() { + // s: Struct + let s = Struct { b: () }; + // Regression for #3089 + s.foo(); +} + +struct Struct { b: B } + +// Before the fix, this candidate is searched first, binding ? to `u8` permanently. +impl Struct { + fn foo(self) {} +} + +// Then this candidate would be searched next but would not be a valid +// candidate since `Struct` != `Struct`. +// +// With the fix, the type of `s` correctly no longer changes until a +// method is actually selected. So this candidate is now valid since +// `Struct` unifies with `Struct` with `? = u32`. +impl Struct { + fn foo(self) {} +}