Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Account for the expected kind when resolving turbofish generics #5448

Merged
merged 12 commits into from
Jul 9, 2024
Merged
49 changes: 26 additions & 23 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
},
node_interner::{DefinitionKind, ExprId, FuncId, ReferenceId},
token::Tokens,
Kind, QuotedType, Shared, StructType, Type,
QuotedType, Shared, StructType, Type,
};

use super::Elaborator;
Expand All @@ -51,23 +51,7 @@ impl<'context> Elaborator<'context> {
ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span),
ExpressionKind::If(if_) => self.elaborate_if(*if_),
ExpressionKind::Variable(variable, generics) => {
let generics = generics.map(|option_inner| {
option_inner
.into_iter()
.map(|generic| {
// All type expressions should resolve to a `Type::Constant`
if generic.is_type_expression() {
self.resolve_type_inner(
generic,
&Kind::Numeric(Box::new(Type::default_int_type())),
)
} else {
self.resolve_type(generic)
}
})
.collect()
});
return self.elaborate_variable(variable, generics);
return self.elaborate_variable(variable, generics)
}
ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda),
Expand Down Expand Up @@ -342,14 +326,36 @@ impl<'context> Elaborator<'context> {
}
};

if func_id != FuncId::dummy_id() {
// Perform any check that required information from the interned function.
let generics = if func_id != FuncId::dummy_id() {
let function_type = self.interner.function_meta(&func_id).typ.clone();
self.try_add_mutable_reference_to_object(
&function_type,
&mut object_type,
&mut object,
);
}

// Resolve generics using the expected kinds of the function we are calling
let direct_generics =
self.interner.function_meta(&func_id).direct_generics.clone();

method_call.generics.map(|option_inner| {
if option_inner.len() != direct_generics.len() {
let type_check_err = TypeCheckError::IncorrectTurbofishGenericCount {
expected_count: direct_generics.len(),
actual_count: option_inner.len(),
span,
};
self.push_err(type_check_err);
}
let generics_with_types = direct_generics.iter().zip(option_inner);
vecmap(generics_with_types, |(generic, unresolved_type)| {
self.resolve_type_inner(unresolved_type, &generic.kind)
})
})
} else {
None
};

// These arguments will be given to the desugared function call.
// Compared to the method arguments, they also contain the object.
Expand All @@ -367,9 +373,6 @@ impl<'context> Elaborator<'context> {

let location = Location::new(span, self.file);
let method = method_call.method_name;
let generics = method_call.generics.map(|option_inner| {
option_inner.into_iter().map(|generic| self.resolve_type(generic)).collect()
});
let turbofish_generics = generics.clone();
let method_call =
HirMethodCallExpression { method, object, arguments, location, generics };
Expand Down
3 changes: 1 addition & 2 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,7 @@ impl<'context> Elaborator<'context> {

let direct_generics = func.def.generics.iter();
let direct_generics = direct_generics
.filter_map(|generic| self.find_generic(&generic.ident().0.contents))
.map(|ResolvedGeneric { name, type_var, .. }| (name.clone(), type_var.clone()))
.filter_map(|generic| self.find_generic(&generic.ident().0.contents).cloned())
.collect();

let statements = std::mem::take(&mut func.def.body.statements);
Expand Down
37 changes: 35 additions & 2 deletions compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use noirc_errors::{Location, Span};
use rustc_hash::FxHashSet as HashSet;

use crate::{
ast::ERROR_IDENT,
ast::{UnresolvedType, ERROR_IDENT},
hir::{
comptime::Interpreter,
def_collector::dc_crate::CompilationError,
Expand Down Expand Up @@ -401,12 +401,45 @@ impl<'context> Elaborator<'context> {
pub(super) fn elaborate_variable(
&mut self,
variable: Path,
generics: Option<Vec<Type>>,
unresolved_turbofish: Option<Vec<UnresolvedType>>,
) -> (ExprId, Type) {
let span = variable.span;
let expr = self.resolve_variable(variable);
let definition_id = expr.id;

let definition = self.interner.try_definition(definition_id);

// Resolve any generics if we the variable we have resolved is a function
// and if the turbofish operator was used.
let generics = if let Some(definition) = definition {
match &definition.kind {
DefinitionKind::Function(function) => {
// Resolve generics using the expected kinds of the function we are calling
let direct_generics =
self.interner.function_meta(function).direct_generics.clone();

unresolved_turbofish.map(|option_inner| {
if option_inner.len() != direct_generics.len() {
let type_check_err = TypeCheckError::IncorrectTurbofishGenericCount {
expected_count: direct_generics.len(),
actual_count: option_inner.len(),
span,
};
self.push_err(type_check_err);
}

let generics_with_types = direct_generics.iter().zip(option_inner);
vecmap(generics_with_types, |(generic, unresolved_type)| {
self.resolve_type_inner(unresolved_type, &generic.kind)
})
})
}
_ => None,
}
} else {
None
};

let id = self.interner.push_expr(HirExpression::Ident(expr.clone(), generics.clone()));

self.interner.push_expr_location(id, span, self.file);
Expand Down
3 changes: 1 addition & 2 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1096,8 +1096,7 @@ impl<'a> Resolver<'a> {

let direct_generics = func.def.generics.iter();
let direct_generics = direct_generics
.filter_map(|generic| self.find_generic(&generic.ident().0.contents))
.map(|ResolvedGeneric { name, type_var, .. }| (name.clone(), type_var.clone()))
.filter_map(|generic| self.find_generic(&generic.ident().0.contents).cloned())
.collect();

FuncMeta {
Expand Down
8 changes: 5 additions & 3 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
traits::TraitConstraint,
},
node_interner::{ExprId, FuncId, GlobalId, NodeInterner},
Kind, Type, TypeBindings,
Kind, ResolvedGeneric, Type, TypeBindings,
};

pub use self::errors::Source;
Expand Down Expand Up @@ -281,8 +281,10 @@ pub(crate) fn check_trait_impl_method_matches_declaration(
}

// Substitute each generic on the trait function with the corresponding generic on the impl function
for ((_, trait_fn_generic), (name, impl_fn_generic)) in
trait_fn_meta.direct_generics.iter().zip(&meta.direct_generics)
for (
ResolvedGeneric { type_var: trait_fn_generic, .. },
ResolvedGeneric { name, type_var: impl_fn_generic, .. },
) in trait_fn_meta.direct_generics.iter().zip(&meta.direct_generics)
{
let arg = Type::NamedGeneric(impl_fn_generic.clone(), name.clone(), Kind::Normal);
bindings.insert(trait_fn_generic.id(), (trait_fn_generic.clone(), arg));
Expand Down
6 changes: 2 additions & 4 deletions compiler/noirc_frontend/src/hir_def/function.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use iter_extended::vecmap;
use noirc_errors::{Location, Span};

use std::rc::Rc;

use super::expr::{HirBlockExpression, HirExpression, HirIdent};
use super::stmt::HirPattern;
use super::traits::TraitConstraint;
use crate::ast::{FunctionKind, FunctionReturnType, Visibility};
use crate::graph::CrateId;
use crate::macros_api::BlockExpression;
use crate::node_interner::{ExprId, NodeInterner, TraitImplId};
use crate::{ResolvedGeneric, Type, TypeVariable};
use crate::{ResolvedGeneric, Type};

/// A Hir function is a block expression
/// with a list of statements
Expand Down Expand Up @@ -113,7 +111,7 @@ pub struct FuncMeta {
/// This does not include generics from an outer scope, like those introduced by
/// an `impl<T>` block. This also does not include implicit generics added by the compiler
/// such as a trait's `Self` type variable.
pub direct_generics: Vec<(Rc<String>, TypeVariable)>,
pub direct_generics: Vec<ResolvedGeneric>,

/// All the generics used by this function, which includes any implicit generics or generics
/// from outer scopes, such as those introduced by an impl.
Expand Down
127 changes: 127 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,79 @@ fn specify_method_types_with_turbofish() {
assert_eq!(errors.len(), 0);
}

#[test]
fn incorrect_turbofish_count_function_call() {
let src = r#"
trait Default {
fn default() -> Self;
}

impl Default for Field {
fn default() -> Self { 0 }
}

impl Default for u64 {
fn default() -> Self { 0 }
}

// Need the above as we don't have access to the stdlib here.
// We also need to construct a concrete value of `U` without giving away its type
// as otherwise the unspecified type is ignored.

fn generic_func<T, U>() -> (T, U) where T: Default, U: Default {
(T::default(), U::default())
}

fn main() {
let _ = generic_func::<u64, Field, Field>();
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
assert!(matches!(
errors[0].0,
CompilationError::TypeError(TypeCheckError::IncorrectTurbofishGenericCount { .. }),
));
}

#[test]
fn incorrect_turbofish_count_method_call() {
let src = r#"
trait Default {
fn default() -> Self;
}

impl Default for Field {
fn default() -> Self { 0 }
}

// Need the above as we don't have access to the stdlib here.
// We also need to construct a concrete value of `U` without giving away its type
// as otherwise the unspecified type is ignored.

struct Foo<T> {
inner: T
}

impl<T> Foo<T> {
fn generic_method<U>(_self: Self) -> U where U: Default {
U::default()
}
}

fn main() {
let foo: Foo<Field> = Foo { inner: 1 };
let _ = foo.generic_method::<Field, u32>();
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);
assert!(matches!(
errors[0].0,
CompilationError::TypeError(TypeCheckError::IncorrectTurbofishGenericCount { .. }),
));
}

#[test]
fn struct_numeric_generic_in_function() {
let src = r#"
Expand Down Expand Up @@ -1983,3 +2056,57 @@ fn underflowing_i8() {
panic!("Expected OverflowingAssignment error, got {:?}", errors[0].0);
}
}

#[test]
fn turbofish_numeric_generic_nested_call() {
// Check for turbofish numeric generics used with function calls
let src = r#"
fn foo<let N: u32>() -> [u8; N] {
[0; N]
}

fn bar<let N: u32>() -> [u8; N] {
foo::<N>()
}

global M: u32 = 3;

fn main() {
let _ = bar::<M>();
}
"#;
let errors = get_program_errors(src);
assert!(errors.is_empty());

// Check for turbofish numeric generics used with method calls
let src = r#"
struct Foo<T> {
a: T
}

impl<T> Foo<T> {
fn static_method<let N: u32>() -> [u8; N] {
[0; N]
}

fn impl_method<let N: u32>(self) -> [T; N] {
[self.a; N]
}
}

fn bar<let N: u32>() -> [u8; N] {
let _ = Foo::static_method::<N>();
let x: Foo<u8> = Foo { a: 0 };
x.impl_method::<N>()
}

global M: u32 = 3;

fn main() {
let _ = bar::<M>();
}
"#;
let errors = get_program_errors(src);
dbg!(errors.clone());
assert!(errors.is_empty());
}
Loading