diff --git a/compiler/noirc_frontend/src/elaborator/lints.rs b/compiler/noirc_frontend/src/elaborator/lints.rs index c0a18d219b7..57db2359772 100644 --- a/compiler/noirc_frontend/src/elaborator/lints.rs +++ b/compiler/noirc_frontend/src/elaborator/lints.rs @@ -6,11 +6,13 @@ use crate::{ type_check::TypeCheckError, }, hir_def::{ - expr::{HirExpression, HirIdent, HirLiteral}, + expr::{HirBlockExpression, HirExpression, HirIdent, HirLiteral}, function::FuncMeta, + stmt::HirStatement, + }, + node_interner::{ + DefinitionId, DefinitionKind, ExprId, FuncId, FunctionModifiers, NodeInterner, }, - node_interner::NodeInterner, - node_interner::{DefinitionKind, ExprId, FuncId, FunctionModifiers}, Type, }; @@ -264,3 +266,77 @@ pub(crate) fn overflowing_int( fn func_meta_name_ident(func: &FuncMeta, modifiers: &FunctionModifiers) -> Ident { Ident(Spanned::from(func.name.location.span, modifiers.name.clone())) } + +/// Check that a recursive function *can* return without endlessly calling itself. +pub(crate) fn unbounded_recursion<'a>( + interner: &'a NodeInterner, + func_id: FuncId, + func_name: impl FnOnce() -> &'a str, + func_span: Span, + body_id: ExprId, +) -> Option { + if !can_return_without_recursing(interner, func_id, body_id) { + Some(ResolverError::UnconditionalRecursion { + name: func_name().to_string(), + span: func_span, + }) + } else { + None + } +} + +/// Check if an expression will end up calling a specific function. +fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_id: ExprId) -> bool { + let check = |e| can_return_without_recursing(interner, func_id, e); + + let check_block = |block: HirBlockExpression| { + block.statements.iter().all(|stmt_id| match interner.statement(stmt_id) { + HirStatement::Let(s) => check(s.expression), + HirStatement::Assign(s) => check(s.expression), + HirStatement::Expression(e) => check(e), + HirStatement::Semi(e) => check(e), + // Rust doesn't seem to check the for loop body (it's bounds might mean it's never called). + HirStatement::For(e) => check(e.start_range) && check(e.end_range), + HirStatement::Constrain(_) + | HirStatement::Comptime(_) + | HirStatement::Break + | HirStatement::Continue + | HirStatement::Error => true, + }) + }; + + match interner.expression(&expr_id) { + HirExpression::Ident(ident, _) => { + if ident.id == DefinitionId::dummy_id() { + return true; + } + let definition = interner.definition(ident.id); + if let DefinitionKind::Function(id) = definition.kind { + func_id != id + } else { + true + } + } + HirExpression::Block(b) => check_block(b), + HirExpression::Prefix(e) => check(e.rhs), + HirExpression::Infix(e) => check(e.lhs) && check(e.rhs), + HirExpression::Index(e) => check(e.collection) && check(e.index), + HirExpression::MemberAccess(e) => check(e.lhs), + HirExpression::Call(e) => check(e.func) && e.arguments.iter().cloned().all(check), + HirExpression::MethodCall(e) => check(e.object) && e.arguments.iter().cloned().all(check), + HirExpression::Cast(e) => check(e.lhs), + HirExpression::If(e) => { + check(e.condition) && (check(e.consequence) || e.alternative.map(check).unwrap_or(true)) + } + HirExpression::Tuple(e) => e.iter().cloned().all(check), + HirExpression::Unsafe(b) => check_block(b), + // Rust doesn't check the lambda body (it might not be called). + HirExpression::Lambda(_) + | HirExpression::Literal(_) + | HirExpression::Constructor(_) + | HirExpression::Quote(_) + | HirExpression::Unquote(_) + | HirExpression::Comptime(_) + | HirExpression::Error => true, + } +} diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 971054e0334..bd385dc455f 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -4,14 +4,7 @@ use std::{ }; use crate::{ - ast::ItemVisibility, - hir_def::{ - expr::{HirBlockExpression, HirExpression}, - stmt::HirStatement, - traits::ResolvedTraitBound, - }, - node_interner::DefinitionId, - StructField, StructType, TypeBindings, + ast::ItemVisibility, hir_def::traits::ResolvedTraitBound, StructField, StructType, TypeBindings, }; use crate::{ ast::{ @@ -479,12 +472,16 @@ impl<'context> Elaborator<'context> { // Check that the body can return without calling the function. if let FunctionKind::Normal | FunctionKind::Recursive = kind { - self.check_for_unbounded_recursion( - id, - self.interner.definition_name(func_meta.name.id).to_string(), - func_meta.name.location.span, - hir_func.as_expr(), - ); + self.run_lint(|e| { + lints::unbounded_recursion( + &e.interner, + id, + || e.interner.definition_name(func_meta.name.id), + func_meta.name.location.span, + hir_func.as_expr(), + ) + .map(Into::into) + }); } let meta = self @@ -1709,79 +1706,4 @@ impl<'context> Elaborator<'context> { _ => true, }) } - - /// Check that a recursive function *can* return without endlessly calling itself. - fn check_for_unbounded_recursion( - &mut self, - func_id: FuncId, - func_name: String, - func_span: Span, - body_id: ExprId, - ) { - if !self.can_return_without_recursing(func_id, body_id) { - self.push_err(CompilationError::ResolverError(ResolverError::UnconditionalRecursion { - name: func_name, - span: func_span, - })); - } - } - - /// Check if an expression will end up calling a specific function. - fn can_return_without_recursing(&self, func_id: FuncId, expr_id: ExprId) -> bool { - let check = |e| self.can_return_without_recursing(func_id, e); - - let check_block = |block: HirBlockExpression| { - block.statements.iter().all(|stmt_id| match self.interner.statement(stmt_id) { - HirStatement::Let(s) => check(s.expression), - HirStatement::Assign(s) => check(s.expression), - HirStatement::Expression(e) => check(e), - HirStatement::Semi(e) => check(e), - // Rust doesn't seem to check the for loop body (it's bounds might mean it's never called). - HirStatement::For(e) => check(e.start_range) && check(e.end_range), - HirStatement::Constrain(_) - | HirStatement::Comptime(_) - | HirStatement::Break - | HirStatement::Continue - | HirStatement::Error => true, - }) - }; - - match self.interner.expression(&expr_id) { - HirExpression::Ident(ident, _) => { - if ident.id == DefinitionId::dummy_id() { - return true; - } - let definition = self.interner.definition(ident.id); - if let DefinitionKind::Function(id) = definition.kind { - func_id != id - } else { - true - } - } - HirExpression::Block(b) => check_block(b), - HirExpression::Prefix(e) => check(e.rhs), - HirExpression::Infix(e) => check(e.lhs) && check(e.rhs), - HirExpression::Index(e) => check(e.collection) && check(e.index), - HirExpression::MemberAccess(e) => check(e.lhs), - HirExpression::Call(e) => check(e.func) && e.arguments.iter().cloned().all(check), - HirExpression::MethodCall(e) => { - check(e.object) && e.arguments.iter().cloned().all(check) - } - HirExpression::Cast(e) => check(e.lhs), - HirExpression::If(e) => { - check(e.condition) - && (check(e.consequence) || e.alternative.map(check).unwrap_or(true)) - } - HirExpression::Tuple(e) => e.iter().cloned().all(check), - HirExpression::Unsafe(b) => check_block(b), - // Rust doesn't check the lambda body (it might not be called). - HirExpression::Lambda(_) - | HirExpression::Literal(_) - | HirExpression::Constructor(_) - | HirExpression::Quote(_) - | HirExpression::Unquote(_) - | HirExpression::Comptime(_) - | HirExpression::Error => true, - } - } }