Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: infer lambda parameter types from return type and let type #7267

Merged
merged 11 commits into from
Feb 4, 2025
105 changes: 81 additions & 24 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,17 @@ use super::{Elaborator, LambdaContext, UnsafeBlockStatus};

impl<'context> Elaborator<'context> {
pub(crate) fn elaborate_expression(&mut self, expr: Expression) -> (ExprId, Type) {
self.elaborate_expression_with_target_type(expr, None)
}

pub(crate) fn elaborate_expression_with_target_type(
&mut self,
expr: Expression,
target_type: Option<&Type>,
) -> (ExprId, Type) {
let (hir_expr, typ) = match expr.kind {
ExpressionKind::Literal(literal) => self.elaborate_literal(literal, expr.span),
ExpressionKind::Block(block) => self.elaborate_block(block),
ExpressionKind::Block(block) => self.elaborate_block(block, target_type),
ExpressionKind::Prefix(prefix) => return self.elaborate_prefix(*prefix, expr.span),
ExpressionKind::Index(index) => self.elaborate_index(*index),
ExpressionKind::Call(call) => self.elaborate_call(*call, expr.span),
Expand All @@ -50,18 +58,22 @@ impl<'context> Elaborator<'context> {
}
ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span),
ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span),
ExpressionKind::If(if_) => self.elaborate_if(*if_),
ExpressionKind::If(if_) => self.elaborate_if(*if_, target_type),
ExpressionKind::Match(match_) => self.elaborate_match(*match_),
ExpressionKind::Variable(variable) => return self.elaborate_variable(variable),
ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None),
ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr),
ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple, target_type),
ExpressionKind::Lambda(lambda) => {
self.elaborate_lambda_with_target_type(*lambda, target_type)
}
ExpressionKind::Parenthesized(expr) => {
return self.elaborate_expression_with_target_type(*expr, target_type)
}
ExpressionKind::Quote(quote) => self.elaborate_quote(quote, expr.span),
ExpressionKind::Comptime(comptime, _) => {
return self.elaborate_comptime_block(comptime, expr.span)
return self.elaborate_comptime_block(comptime, expr.span, target_type)
}
ExpressionKind::Unsafe(block_expression, span) => {
self.elaborate_unsafe_block(block_expression, span)
self.elaborate_unsafe_block(block_expression, span, target_type)
}
ExpressionKind::Resolved(id) => return (id, self.interner.id_type(id)),
ExpressionKind::Interned(id) => {
Expand Down Expand Up @@ -112,18 +124,29 @@ impl<'context> Elaborator<'context> {
}
}

pub(super) fn elaborate_block(&mut self, block: BlockExpression) -> (HirExpression, Type) {
let (block, typ) = self.elaborate_block_expression(block);
pub(super) fn elaborate_block(
&mut self,
block: BlockExpression,
target_type: Option<&Type>,
) -> (HirExpression, Type) {
let (block, typ) = self.elaborate_block_expression(block, target_type);
(HirExpression::Block(block), typ)
}

fn elaborate_block_expression(&mut self, block: BlockExpression) -> (HirBlockExpression, Type) {
fn elaborate_block_expression(
&mut self,
block: BlockExpression,
target_type: Option<&Type>,
) -> (HirBlockExpression, Type) {
self.push_scope();
let mut block_type = Type::Unit;
let mut statements = Vec::with_capacity(block.statements.len());
let statements_len = block.statements.len();
let mut statements = Vec::with_capacity(statements_len);

for (i, statement) in block.statements.into_iter().enumerate() {
let (id, stmt_type) = self.elaborate_statement(statement);
let statement_target_type = if i == statements_len - 1 { target_type } else { None };
let (id, stmt_type) =
self.elaborate_statement_with_target_type(statement, statement_target_type);
statements.push(id);

if let HirStatement::Semi(expr) = self.interner.statement(&id) {
Expand All @@ -149,6 +172,7 @@ impl<'context> Elaborator<'context> {
&mut self,
block: BlockExpression,
span: Span,
target_type: Option<&Type>,
) -> (HirExpression, Type) {
// Before entering the block we cache the old value of `in_unsafe_block` so it can be restored.
let old_in_unsafe_block = self.unsafe_block_status;
Expand All @@ -161,7 +185,7 @@ impl<'context> Elaborator<'context> {

self.unsafe_block_status = UnsafeBlockStatus::InUnsafeBlockWithoutUnconstrainedCalls;

let (hir_block_expression, typ) = self.elaborate_block_expression(block);
let (hir_block_expression, typ) = self.elaborate_block_expression(block, target_type);

if let UnsafeBlockStatus::InUnsafeBlockWithoutUnconstrainedCalls = self.unsafe_block_status
{
Expand Down Expand Up @@ -572,7 +596,7 @@ impl<'context> Elaborator<'context> {
let span = arg.span;
let type_hint =
if let Some(Type::Function(func_args, _, _, _)) = typ { Some(func_args) } else { None };
let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint);
let (hir_expr, typ) = self.elaborate_lambda_with_parameter_type_hints(*lambda, type_hint);
let id = self.interner.push_expr(hir_expr);
self.interner.push_expr_location(id, span, self.file);
self.interner.push_expr_type(id, typ.clone());
Expand Down Expand Up @@ -884,10 +908,15 @@ impl<'context> Elaborator<'context> {
}
}

fn elaborate_if(&mut self, if_expr: IfExpression) -> (HirExpression, Type) {
fn elaborate_if(
&mut self,
if_expr: IfExpression,
target_type: Option<&Type>,
) -> (HirExpression, Type) {
let expr_span = if_expr.condition.span;
let (condition, cond_type) = self.elaborate_expression(if_expr.condition);
let (consequence, mut ret_type) = self.elaborate_expression(if_expr.consequence);
let (consequence, mut ret_type) =
self.elaborate_expression_with_target_type(if_expr.consequence, target_type);

self.unify(&cond_type, &Type::Bool, || TypeCheckError::TypeMismatch {
expected_typ: Type::Bool.to_string(),
Expand All @@ -897,7 +926,8 @@ impl<'context> Elaborator<'context> {

let alternative = if_expr.alternative.map(|alternative| {
let expr_span = alternative.span;
let (else_, else_type) = self.elaborate_expression(alternative);
let (else_, else_type) =
self.elaborate_expression_with_target_type(alternative, target_type);

self.unify(&ret_type, &else_type, || {
let err = TypeCheckError::TypeMismatch {
Expand Down Expand Up @@ -931,23 +961,44 @@ impl<'context> Elaborator<'context> {
(HirExpression::Error, Type::Error)
}

fn elaborate_tuple(&mut self, tuple: Vec<Expression>) -> (HirExpression, Type) {
fn elaborate_tuple(
&mut self,
tuple: Vec<Expression>,
target_type: Option<&Type>,
) -> (HirExpression, Type) {
let mut element_ids = Vec::with_capacity(tuple.len());
let mut element_types = Vec::with_capacity(tuple.len());

for element in tuple {
let (id, typ) = self.elaborate_expression(element);
for (index, element) in tuple.into_iter().enumerate() {
let target_type = target_type.map(|typ| typ.follow_bindings());
let expr_target_type =
if let Some(Type::Tuple(types)) = &target_type { types.get(index) } else { None };
let (id, typ) = self.elaborate_expression_with_target_type(element, expr_target_type);
element_ids.push(id);
element_types.push(typ);
}

(HirExpression::Tuple(element_ids), Type::Tuple(element_types))
}

fn elaborate_lambda_with_target_type(
&mut self,
lambda: Lambda,
target_type: Option<&Type>,
) -> (HirExpression, Type) {
let target_type = target_type.map(|typ| typ.follow_bindings());

if let Some(Type::Function(args, _, _, _)) = target_type {
return self.elaborate_lambda_with_parameter_type_hints(lambda, Some(&args));
}

self.elaborate_lambda_with_parameter_type_hints(lambda, None)
}

/// For elaborating a lambda we might get `parameters_type_hints`. These come from a potential
/// call that has this lambda as the argument.
/// The parameter type hints will be the types of the function type corresponding to the lambda argument.
fn elaborate_lambda(
fn elaborate_lambda_with_parameter_type_hints(
&mut self,
lambda: Lambda,
parameters_type_hints: Option<&Vec<Type>>,
Expand Down Expand Up @@ -1013,9 +1064,15 @@ impl<'context> Elaborator<'context> {
}
}

fn elaborate_comptime_block(&mut self, block: BlockExpression, span: Span) -> (ExprId, Type) {
let (block, _typ) =
self.elaborate_in_comptime_context(|this| this.elaborate_block_expression(block));
fn elaborate_comptime_block(
&mut self,
block: BlockExpression,
span: Span,
target_type: Option<&Type>,
) -> (ExprId, Type) {
let (block, _typ) = self.elaborate_in_comptime_context(|this| {
this.elaborate_block_expression(block, target_type)
});

let mut interpreter = self.setup_interpreter();
let value = interpreter.evaluate_block(block);
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,8 @@ impl<'context> Elaborator<'context> {
| FunctionKind::Oracle
| FunctionKind::TraitFunctionWithoutBody => (HirFunction::empty(), Type::Error),
FunctionKind::Normal => {
let (block, body_type) = self.elaborate_block(body);
let return_type = func_meta.return_type();
let (block, body_type) = self.elaborate_block(body, Some(return_type));
let expr_id = self.intern_expr(block, body_span);
self.interner.push_expr_type(expr_id, body_type.clone());
(HirFunction::unchecked_from_expr(expr_id), body_type)
Expand Down
30 changes: 24 additions & 6 deletions compiler/noirc_frontend/src/elaborator/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ use super::{lints, Elaborator, Loop};

impl<'context> Elaborator<'context> {
fn elaborate_statement_value(&mut self, statement: Statement) -> (HirStatement, Type) {
self.elaborate_statement_value_with_target_type(statement, None)
}

fn elaborate_statement_value_with_target_type(
&mut self,
statement: Statement,
target_type: Option<&Type>,
) -> (HirStatement, Type) {
match statement.kind {
StatementKind::Let(let_stmt) => self.elaborate_local_let(let_stmt),
StatementKind::Constrain(constrain) => self.elaborate_constrain(constrain),
Expand All @@ -38,7 +46,7 @@ impl<'context> Elaborator<'context> {
StatementKind::Continue => self.elaborate_jump(false, statement.span),
StatementKind::Comptime(statement) => self.elaborate_comptime_statement(*statement),
StatementKind::Expression(expr) => {
let (expr, typ) = self.elaborate_expression(expr);
let (expr, typ) = self.elaborate_expression_with_target_type(expr, target_type);
(HirStatement::Expression(expr), typ)
}
StatementKind::Semi(expr) => {
Expand All @@ -48,15 +56,24 @@ impl<'context> Elaborator<'context> {
StatementKind::Interned(id) => {
let kind = self.interner.get_statement_kind(id);
let statement = Statement { kind: kind.clone(), span: statement.span };
self.elaborate_statement_value(statement)
self.elaborate_statement_value_with_target_type(statement, target_type)
}
StatementKind::Error => (HirStatement::Error, Type::Error),
}
}

pub(crate) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) {
self.elaborate_statement_with_target_type(statement, None)
}

pub(crate) fn elaborate_statement_with_target_type(
&mut self,
statement: Statement,
target_type: Option<&Type>,
) -> (StmtId, Type) {
let span = statement.span;
let (hir_statement, typ) = self.elaborate_statement_value(statement);
let (hir_statement, typ) =
self.elaborate_statement_value_with_target_type(statement, target_type);
let id = self.interner.push_stmt(hir_statement);
self.interner.push_stmt_location(id, span, self.file);
(id, typ)
Expand All @@ -75,12 +92,13 @@ impl<'context> Elaborator<'context> {
let_stmt: LetStatement,
global_id: Option<GlobalId>,
) -> (HirStatement, Type) {
let expr_span = let_stmt.expression.span;
let (expression, expr_type) = self.elaborate_expression(let_stmt.expression);

let type_contains_unspecified = let_stmt.r#type.contains_unspecified();
let annotated_type = self.resolve_inferred_type(let_stmt.r#type);

let expr_span = let_stmt.expression.span;
let (expression, expr_type) =
self.elaborate_expression_with_target_type(let_stmt.expression, Some(&annotated_type));

// Require the top-level of a global's type to be fully-specified
if type_contains_unspecified && global_id.is_some() {
let span = expr_span;
Expand Down
Loading
Loading