diff --git a/compiler/noirc_frontend/src/elaborator/lints.rs b/compiler/noirc_frontend/src/elaborator/lints.rs index c0a18d219b7..57db2359772 100644 --- a/compiler/noirc_frontend/src/elaborator/lints.rs +++ b/compiler/noirc_frontend/src/elaborator/lints.rs @@ -6,11 +6,13 @@ use crate::{ type_check::TypeCheckError, }, hir_def::{ - expr::{HirExpression, HirIdent, HirLiteral}, + expr::{HirBlockExpression, HirExpression, HirIdent, HirLiteral}, function::FuncMeta, + stmt::HirStatement, + }, + node_interner::{ + DefinitionId, DefinitionKind, ExprId, FuncId, FunctionModifiers, NodeInterner, }, - node_interner::NodeInterner, - node_interner::{DefinitionKind, ExprId, FuncId, FunctionModifiers}, Type, }; @@ -264,3 +266,77 @@ pub(crate) fn overflowing_int( fn func_meta_name_ident(func: &FuncMeta, modifiers: &FunctionModifiers) -> Ident { Ident(Spanned::from(func.name.location.span, modifiers.name.clone())) } + +/// Check that a recursive function *can* return without endlessly calling itself. +pub(crate) fn unbounded_recursion<'a>( + interner: &'a NodeInterner, + func_id: FuncId, + func_name: impl FnOnce() -> &'a str, + func_span: Span, + body_id: ExprId, +) -> Option { + if !can_return_without_recursing(interner, func_id, body_id) { + Some(ResolverError::UnconditionalRecursion { + name: func_name().to_string(), + span: func_span, + }) + } else { + None + } +} + +/// Check if an expression will end up calling a specific function. +fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_id: ExprId) -> bool { + let check = |e| can_return_without_recursing(interner, func_id, e); + + let check_block = |block: HirBlockExpression| { + block.statements.iter().all(|stmt_id| match interner.statement(stmt_id) { + HirStatement::Let(s) => check(s.expression), + HirStatement::Assign(s) => check(s.expression), + HirStatement::Expression(e) => check(e), + HirStatement::Semi(e) => check(e), + // Rust doesn't seem to check the for loop body (it's bounds might mean it's never called). + HirStatement::For(e) => check(e.start_range) && check(e.end_range), + HirStatement::Constrain(_) + | HirStatement::Comptime(_) + | HirStatement::Break + | HirStatement::Continue + | HirStatement::Error => true, + }) + }; + + match interner.expression(&expr_id) { + HirExpression::Ident(ident, _) => { + if ident.id == DefinitionId::dummy_id() { + return true; + } + let definition = interner.definition(ident.id); + if let DefinitionKind::Function(id) = definition.kind { + func_id != id + } else { + true + } + } + HirExpression::Block(b) => check_block(b), + HirExpression::Prefix(e) => check(e.rhs), + HirExpression::Infix(e) => check(e.lhs) && check(e.rhs), + HirExpression::Index(e) => check(e.collection) && check(e.index), + HirExpression::MemberAccess(e) => check(e.lhs), + HirExpression::Call(e) => check(e.func) && e.arguments.iter().cloned().all(check), + HirExpression::MethodCall(e) => check(e.object) && e.arguments.iter().cloned().all(check), + HirExpression::Cast(e) => check(e.lhs), + HirExpression::If(e) => { + check(e.condition) && (check(e.consequence) || e.alternative.map(check).unwrap_or(true)) + } + HirExpression::Tuple(e) => e.iter().cloned().all(check), + HirExpression::Unsafe(b) => check_block(b), + // Rust doesn't check the lambda body (it might not be called). + HirExpression::Lambda(_) + | HirExpression::Literal(_) + | HirExpression::Constructor(_) + | HirExpression::Quote(_) + | HirExpression::Unquote(_) + | HirExpression::Comptime(_) + | HirExpression::Error => true, + } +} diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index aef0771c486..641e2d1d57e 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -470,6 +470,20 @@ impl<'context> Elaborator<'context> { self.check_for_unused_variables_in_scope_tree(func_scope_tree); } + // Check that the body can return without calling the function. + if let FunctionKind::Normal | FunctionKind::Recursive = kind { + self.run_lint(|elaborator| { + lints::unbounded_recursion( + elaborator.interner, + id, + || elaborator.interner.definition_name(func_meta.name.id), + func_meta.name.location.span, + hir_func.as_expr(), + ) + .map(Into::into) + }); + } + let meta = self .interner .func_meta diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index 4f9907d6a16..3c4022b58bb 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -29,6 +29,8 @@ pub enum ResolverError { UnusedVariable { ident: Ident }, #[error("Unused {}", item.item_type())] UnusedItem { ident: Ident, item: UnusedItem }, + #[error("Unconditional recursion")] + UnconditionalRecursion { name: String, span: Span }, #[error("Could not find variable in this scope")] VariableNotDeclared { name: String, span: Span }, #[error("path is not an identifier")] @@ -213,6 +215,15 @@ impl<'a> From<&'a ResolverError> for Diagnostic { diagnostic.unnecessary = true; diagnostic } + ResolverError::UnconditionalRecursion { name, span} => { + let mut diagnostic = Diagnostic::simple_warning( + format!("function `{name}` cannot return without recursing"), + "function cannot return without recursing".to_string(), + *span, + ); + diagnostic.unnecessary = true; + diagnostic + } ResolverError::VariableNotDeclared { name, span } => Diagnostic::simple_error( format!("cannot find `{name}` in this scope "), "not found in this scope".to_string(), diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 86a486a0de5..ce4ad4d1bb9 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -122,7 +122,7 @@ pub(crate) fn get_program_errors(src: &str) -> Vec<(CompilationError, FileId)> { fn assert_no_errors(src: &str) { let errors = get_program_errors(src); if !errors.is_empty() { - panic!("Expected no errors, got: {:?}", errors); + panic!("Expected no errors, got: {:?}; src = {src}", errors); } } @@ -3390,6 +3390,136 @@ fn arithmetic_generics_rounding_fail() { assert_eq!(errors.len(), 1); } +#[test] +fn unconditional_recursion_fail() { + let srcs = vec![ + r#" + fn main() { + main() + } + "#, + r#" + fn main() -> pub bool { + if main() { true } else { false } + } + "#, + r#" + fn main() -> pub bool { + if true { main() } else { main() } + } + "#, + r#" + fn main() -> pub u64 { + main() + main() + } + "#, + r#" + fn main() -> pub u64 { + 1 + main() + } + "#, + r#" + fn main() -> pub bool { + let _ = main(); + true + } + "#, + r#" + fn main(a: u64, b: u64) -> pub u64 { + main(a + b, main(a, b)) + } + "#, + r#" + fn main() -> pub u64 { + foo(1, main()) + } + fn foo(a: u64, b: u64) -> u64 { + a + b + } + "#, + r#" + fn main() -> pub u64 { + let (a, b) = (main(), main()); + a + b + } + "#, + r#" + fn main() -> pub u64 { + let mut sum = 0; + for i in 0 .. main() { + sum += i; + } + sum + } + "#, + ]; + + for src in srcs { + let errors = get_program_errors(src); + assert!( + !errors.is_empty(), + "expected 'unconditional recursion' error, got nothing; src = {src}" + ); + + for (error, _) in errors { + let CompilationError::ResolverError(ResolverError::UnconditionalRecursion { .. }) = + error + else { + panic!("Expected an 'unconditional recursion' error, got {:?}; src = {src}", error); + }; + } + } +} + +#[test] +fn unconditional_recursion_pass() { + let srcs = vec![ + r#" + fn main() { + if false { main(); } + } + "#, + r#" + fn main(i: u64) -> pub u64 { + if i == 0 { 0 } else { i + main(i-1) } + } + "#, + // Only immediate self-recursion is detected. + r#" + fn main() { + foo(); + } + fn foo() { + bar(); + } + fn bar() { + foo(); + } + "#, + // For loop bodies are not checked. + r#" + fn main() -> pub u64 { + let mut sum = 0; + for _ in 0 .. 10 { + sum += main(); + } + sum + } + "#, + // Lambda bodies are not checked. + r#" + fn main() { + let foo = || main(); + foo(); + } + "#, + ]; + + for src in srcs { + assert_no_errors(src); + } +} + #[test] fn uses_self_in_import() { let src = r#" diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index ff662b878ec..5102f0cf1fd 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -259,6 +259,8 @@ mod tests { let _: MyOtherStruct = MyOtherStruct { my_other_field: 2 }; let _ = derive_do_nothing(crate::panic::panic(f"")); let _ = derive_do_nothing_alt(crate::panic::panic(f"")); - remove_unused_warnings(); + if false { + remove_unused_warnings(); + } } }