diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 254d8f94a6f..de8f54f7b39 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -58,7 +58,7 @@ impl<'context> Elaborator<'context> { } ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span), ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span), - ExpressionKind::If(if_) => self.elaborate_if(*if_), + ExpressionKind::If(if_) => self.elaborate_if(*if_, target_type), ExpressionKind::Match(match_) => self.elaborate_match(*match_), ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple, target_type), @@ -911,10 +911,15 @@ impl<'context> Elaborator<'context> { } } - fn elaborate_if(&mut self, if_expr: IfExpression) -> (HirExpression, Type) { + fn elaborate_if( + &mut self, + if_expr: IfExpression, + target_type: Option<&Type>, + ) -> (HirExpression, Type) { let expr_span = if_expr.condition.span; let (condition, cond_type) = self.elaborate_expression(if_expr.condition); - let (consequence, mut ret_type) = self.elaborate_expression(if_expr.consequence); + let (consequence, mut ret_type) = + self.elaborate_expression_with_target_type(if_expr.consequence, target_type); self.unify(&cond_type, &Type::Bool, || TypeCheckError::TypeMismatch { expected_typ: Type::Bool.to_string(), @@ -924,7 +929,8 @@ impl<'context> Elaborator<'context> { let alternative = if_expr.alternative.map(|alternative| { let expr_span = alternative.span; - let (else_, else_type) = self.elaborate_expression(alternative); + let (else_, else_type) = + self.elaborate_expression_with_target_type(alternative, target_type); self.unify(&ret_type, &else_type, || { let err = TypeCheckError::TypeMismatch { diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index e7147f582e9..ed6321dbe50 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -4108,6 +4108,27 @@ fn infers_lambda_argument_from_function_return_type_multiple_statements() { assert_no_errors(src); } +#[test] +fn infers_lambda_argument_from_function_return_type_when_inside_if() { + let src = r#" + pub struct Foo { + value: Field, + } + + pub fn func() -> fn(Foo) -> Field { + if true { + |foo| foo.value + } else { + |foo| foo.value + } + } + + fn main() { + } + "#; + assert_no_errors(src); +} + #[test] fn infers_lambda_argument_from_variable_type() { let src = r#"