From 6add19272d65b8d1cf32f37dc8702bf7d728ea8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Esteban=20K=C3=BCber?= Date: Tue, 6 Feb 2024 03:30:16 +0000 Subject: [PATCH] Properly handle `async` blocks and `fn`s in `if` exprs without `else` When encountering a tail expression in the then arm of an `if` expression without an `else` arm, account for `async fn` and `async` blocks to suggest `return`ing the value and pointing at the return type of the `async fn`. We now also account for AFIT when looking for the return type to point at. Fix #115405. --- compiler/rustc_hir_typeck/src/coercion.rs | 38 +++++-- .../rustc_hir_typeck/src/fn_ctxt/_impl.rs | 37 +++++-- .../rustc_hir_typeck/src/fn_ctxt/checks.rs | 5 +- .../src/fn_ctxt/suggestions.rs | 102 +++++++++++++----- compiler/rustc_middle/src/hir/map/mod.rs | 2 +- .../rustc_parse/src/parser/diagnostics.rs | 2 +- .../missing-return-in-async-block.fixed | 22 ++++ .../missing-return-in-async-block.rs | 22 ++++ .../missing-return-in-async-block.stderr | 35 ++++++ .../in-trait/default-body-type-err-2.stderr | 2 + .../ui/loops/dont-suggest-break-thru-item.rs | 2 + .../loops/dont-suggest-break-thru-item.stderr | 16 ++- 12 files changed, 237 insertions(+), 48 deletions(-) create mode 100644 tests/ui/async-await/missing-return-in-async-block.fixed create mode 100644 tests/ui/async-await/missing-return-in-async-block.rs create mode 100644 tests/ui/async-await/missing-return-in-async-block.stderr diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs index ca636ebcade04..1493abbde765e 100644 --- a/compiler/rustc_hir_typeck/src/coercion.rs +++ b/compiler/rustc_hir_typeck/src/coercion.rs @@ -92,14 +92,16 @@ impl<'a, 'tcx> Deref for Coerce<'a, 'tcx> { type CoerceResult<'tcx> = InferResult<'tcx, (Vec>, Ty<'tcx>)>; -struct CollectRetsVisitor<'tcx> { - ret_exprs: Vec<&'tcx hir::Expr<'tcx>>, +pub struct CollectRetsVisitor<'tcx> { + pub ret_exprs: Vec<&'tcx hir::Expr<'tcx>>, } impl<'tcx> Visitor<'tcx> for CollectRetsVisitor<'tcx> { fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { - if let hir::ExprKind::Ret(_) = expr.kind { - self.ret_exprs.push(expr); + match expr.kind { + hir::ExprKind::Ret(_) => self.ret_exprs.push(expr), + hir::ExprKind::Closure(_) => return, + _ => {} } intravisit::walk_expr(self, expr); } @@ -1856,13 +1858,31 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> { } let parent_id = fcx.tcx.hir().get_parent_item(id); - let parent_item = fcx.tcx.hir_node_by_def_id(parent_id.def_id); + let mut parent_item = fcx.tcx.hir_node_by_def_id(parent_id.def_id); + // When suggesting return, we need to account for closures and async blocks, not just items. + for (_, node) in fcx.tcx.hir().parent_iter(id) { + match node { + hir::Node::Expr(&hir::Expr { + kind: hir::ExprKind::Closure(hir::Closure { .. }), + .. + }) => { + parent_item = node; + break; + } + hir::Node::Item(_) | hir::Node::TraitItem(_) | hir::Node::ImplItem(_) => break, + _ => {} + } + } - if let (Some(expr), Some(_), Some((fn_id, fn_decl, _, _))) = - (expression, blk_id, fcx.get_node_fn_decl(parent_item)) - { + if let (Some(expr), Some(_), Some(fn_decl)) = (expression, blk_id, parent_item.fn_decl()) { fcx.suggest_missing_break_or_return_expr( - &mut err, expr, fn_decl, expected, found, id, fn_id, + &mut err, + expr, + fn_decl, + expected, + found, + id, + parent_id.into(), ); } diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs index 60eb40bd8fe66..2ad97c34562cf 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs @@ -955,14 +955,35 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { owner_id, .. }) => Some((hir::HirId::make_owner(owner_id.def_id), &sig.decl, ident, false)), - Node::Expr(&hir::Expr { hir_id, kind: hir::ExprKind::Closure(..), .. }) - if let Some(Node::Item(&hir::Item { - ident, - kind: hir::ItemKind::Fn(ref sig, ..), - owner_id, - .. - })) = self.tcx.hir().find_parent(hir_id) => - { + Node::Expr(&hir::Expr { + hir_id, + kind: + hir::ExprKind::Closure(hir::Closure { + kind: hir::ClosureKind::Coroutine(..), .. + }), + .. + }) => { + let (ident, sig, owner_id) = match self.tcx.hir().find_parent(hir_id) { + Some(Node::Item(&hir::Item { + ident, + kind: hir::ItemKind::Fn(ref sig, ..), + owner_id, + .. + })) => (ident, sig, owner_id), + Some(Node::TraitItem(&hir::TraitItem { + ident, + kind: hir::TraitItemKind::Fn(ref sig, ..), + owner_id, + .. + })) => (ident, sig, owner_id), + Some(Node::ImplItem(&hir::ImplItem { + ident, + kind: hir::ImplItemKind::Fn(ref sig, ..), + owner_id, + .. + })) => (ident, sig, owner_id), + _ => return None, + }; Some(( hir::HirId::make_owner(owner_id.def_id), &sig.decl, diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index d30c7a4fb3899..94dbab9526ece 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -1726,7 +1726,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } /// Given a function block's `HirId`, returns its `FnDecl` if it exists, or `None` otherwise. - fn get_parent_fn_decl(&self, blk_id: hir::HirId) -> Option<(&'tcx hir::FnDecl<'tcx>, Ident)> { + pub(crate) fn get_parent_fn_decl( + &self, + blk_id: hir::HirId, + ) -> Option<(&'tcx hir::FnDecl<'tcx>, Ident)> { let parent = self.tcx.hir_node_by_def_id(self.tcx.hir().get_parent_item(blk_id).def_id); self.get_node_fn_decl(parent).map(|(_, fn_decl, ident, _)| (fn_decl, ident)) } diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs index 95c1139e43e44..afe5d0d488026 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs @@ -1,5 +1,6 @@ use super::FnCtxt; +use crate::coercion::CollectRetsVisitor; use crate::errors; use crate::fluent_generated as fluent; use crate::fn_ctxt::rustc_span::BytePos; @@ -16,6 +17,7 @@ use rustc_errors::{Applicability, Diagnostic, MultiSpan}; use rustc_hir as hir; use rustc_hir::def::Res; use rustc_hir::def::{CtorKind, CtorOf, DefKind}; +use rustc_hir::intravisit::{Map, Visitor}; use rustc_hir::lang_items::LangItem; use rustc_hir::{ CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node, @@ -826,6 +828,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } hir::FnRetTy::Return(hir_ty) => { if let hir::TyKind::OpaqueDef(item_id, ..) = hir_ty.kind + // FIXME: account for RPITIT. && let hir::Node::Item(hir::Item { kind: hir::ItemKind::OpaqueTy(op_ty), .. }) = self.tcx.hir_node(item_id.hir_id()) @@ -1037,33 +1040,82 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { return; } - if let hir::FnRetTy::Return(ty) = fn_decl.output { - let ty = self.astconv().ast_ty_to_ty(ty); - let bound_vars = self.tcx.late_bound_vars(fn_id); - let ty = self - .tcx - .instantiate_bound_regions_with_erased(Binder::bind_with_vars(ty, bound_vars)); - let ty = match self.tcx.asyncness(fn_id.owner) { - ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| { - span_bug!(fn_decl.output.span(), "failed to get output type of async function") - }), - ty::Asyncness::No => ty, - }; - let ty = self.normalize(expr.span, ty); - if self.can_coerce(found, ty) { - if let Some(node) = self.tcx.opt_hir_node(fn_id) - && let Some(owner_node) = node.as_owner() - && let Some(span) = expr.span.find_ancestor_inside(owner_node.span()) + let in_closure = matches!( + self.tcx + .hir() + .parent_iter(id) + .filter(|(_, node)| { + matches!( + node, + Node::Expr(Expr { kind: ExprKind::Closure(..), .. }) + | Node::Item(_) + | Node::TraitItem(_) + | Node::ImplItem(_) + ) + }) + .next(), + Some((_, Node::Expr(Expr { kind: ExprKind::Closure(..), .. }))) + ); + + let can_return = match fn_decl.output { + hir::FnRetTy::Return(ty) => { + let ty = self.astconv().ast_ty_to_ty(ty); + let bound_vars = self.tcx.late_bound_vars(fn_id); + let ty = self + .tcx + .instantiate_bound_regions_with_erased(Binder::bind_with_vars(ty, bound_vars)); + let ty = match self.tcx.asyncness(fn_id.owner) { + ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| { + span_bug!( + fn_decl.output.span(), + "failed to get output type of async function" + ) + }), + ty::Asyncness::No => ty, + }; + let ty = self.normalize(expr.span, ty); + self.can_coerce(found, ty) + } + hir::FnRetTy::DefaultReturn(_) if in_closure => { + let mut rets = vec![]; + if let Some(ret_coercion) = self.ret_coercion.as_ref() { + let ret_ty = ret_coercion.borrow().expected_ty(); + rets.push(ret_ty); + } + let mut visitor = CollectRetsVisitor { ret_exprs: vec![] }; + if let Some(item) = self.tcx.hir().find(id) + && let Node::Expr(expr) = item { - err.multipart_suggestion( - "you might have meant to return this value", - vec![ - (span.shrink_to_lo(), "return ".to_string()), - (span.shrink_to_hi(), ";".to_string()), - ], - Applicability::MaybeIncorrect, - ); + visitor.visit_expr(expr); + for expr in visitor.ret_exprs { + if let Some(ty) = self.typeck_results.borrow().node_type_opt(expr.hir_id) { + rets.push(ty); + } + } + if let hir::ExprKind::Block(hir::Block { expr: Some(expr), .. }, _) = expr.kind + { + if let Some(ty) = self.typeck_results.borrow().node_type_opt(expr.hir_id) { + rets.push(ty); + } + } } + rets.into_iter().all(|ty| self.can_coerce(found, ty)) + } + _ => false, + }; + if can_return { + if let Some(node) = self.tcx.opt_hir_node(fn_id) + && let Some(owner_node) = node.as_owner() + && let Some(span) = expr.span.find_ancestor_inside(owner_node.span()) + { + err.multipart_suggestion( + "you might have meant to return this value", + vec![ + (span.shrink_to_lo(), "return ".to_string()), + (span.shrink_to_hi(), ";".to_string()), + ], + Applicability::MaybeIncorrect, + ); } } } diff --git a/compiler/rustc_middle/src/hir/map/mod.rs b/compiler/rustc_middle/src/hir/map/mod.rs index ba1ae46626b22..7dfb222a11708 100644 --- a/compiler/rustc_middle/src/hir/map/mod.rs +++ b/compiler/rustc_middle/src/hir/map/mod.rs @@ -644,7 +644,7 @@ impl<'hir> Map<'hir> { Node::Item(_) | Node::ForeignItem(_) | Node::TraitItem(_) - | Node::Expr(Expr { kind: ExprKind::Closure { .. }, .. }) + | Node::Expr(Expr { kind: ExprKind::Closure(_), .. }) | Node::ImplItem(_) // The input node `id` must be enclosed in the method's body as opposed // to some other place such as its return type (fixes #114918). diff --git a/compiler/rustc_parse/src/parser/diagnostics.rs b/compiler/rustc_parse/src/parser/diagnostics.rs index 7a24b819b5f56..445d5b2ce790c 100644 --- a/compiler/rustc_parse/src/parser/diagnostics.rs +++ b/compiler/rustc_parse/src/parser/diagnostics.rs @@ -900,7 +900,7 @@ impl<'a> Parser<'a> { // fn foo() -> Foo { // field: value, // } - info!(?maybe_struct_name, ?self.token); + debug!(?maybe_struct_name, ?self.token); let mut snapshot = self.create_snapshot_for_diagnostic(); let path = Path { segments: ThinVec::new(), diff --git a/tests/ui/async-await/missing-return-in-async-block.fixed b/tests/ui/async-await/missing-return-in-async-block.fixed new file mode 100644 index 0000000000000..3dbac7945b6e1 --- /dev/null +++ b/tests/ui/async-await/missing-return-in-async-block.fixed @@ -0,0 +1,22 @@ +// run-rustfix +// edition:2021 +use std::future::Future; +use std::pin::Pin; +pub struct S; +pub fn foo() { + let _ = Box::pin(async move { + if true { + return Ok(S); //~ ERROR mismatched types + } + Err(()) + }); +} +pub fn bar() -> Pin> + 'static>> { + Box::pin(async move { + if true { + return Ok(S); //~ ERROR mismatched types + } + Err(()) + }) +} +fn main() {} diff --git a/tests/ui/async-await/missing-return-in-async-block.rs b/tests/ui/async-await/missing-return-in-async-block.rs new file mode 100644 index 0000000000000..7d04e0e0fad14 --- /dev/null +++ b/tests/ui/async-await/missing-return-in-async-block.rs @@ -0,0 +1,22 @@ +// run-rustfix +// edition:2021 +use std::future::Future; +use std::pin::Pin; +pub struct S; +pub fn foo() { + let _ = Box::pin(async move { + if true { + Ok(S) //~ ERROR mismatched types + } + Err(()) + }); +} +pub fn bar() -> Pin> + 'static>> { + Box::pin(async move { + if true { + Ok(S) //~ ERROR mismatched types + } + Err(()) + }) +} +fn main() {} diff --git a/tests/ui/async-await/missing-return-in-async-block.stderr b/tests/ui/async-await/missing-return-in-async-block.stderr new file mode 100644 index 0000000000000..5ea76e5f7bf93 --- /dev/null +++ b/tests/ui/async-await/missing-return-in-async-block.stderr @@ -0,0 +1,35 @@ +error[E0308]: mismatched types + --> $DIR/missing-return-in-async-block.rs:9:13 + | +LL | / if true { +LL | | Ok(S) + | | ^^^^^ expected `()`, found `Result` +LL | | } + | |_________- expected this to be `()` + | + = note: expected unit type `()` + found enum `Result` +help: you might have meant to return this value + | +LL | return Ok(S); + | ++++++ + + +error[E0308]: mismatched types + --> $DIR/missing-return-in-async-block.rs:17:13 + | +LL | / if true { +LL | | Ok(S) + | | ^^^^^ expected `()`, found `Result` +LL | | } + | |_________- expected this to be `()` + | + = note: expected unit type `()` + found enum `Result` +help: you might have meant to return this value + | +LL | return Ok(S); + | ++++++ + + +error: aborting due to 2 previous errors + +For more information about this error, try `rustc --explain E0308`. diff --git a/tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr b/tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr index 77f6945f064cc..9fa73d817ca91 100644 --- a/tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr +++ b/tests/ui/impl-trait/in-trait/default-body-type-err-2.stderr @@ -1,6 +1,8 @@ error[E0308]: mismatched types --> $DIR/default-body-type-err-2.rs:7:9 | +LL | async fn woopsie_async(&self) -> String { + | ------ expected `String` because of return type LL | 42 | ^^- help: try using a conversion method: `.to_string()` | | diff --git a/tests/ui/loops/dont-suggest-break-thru-item.rs b/tests/ui/loops/dont-suggest-break-thru-item.rs index b46ba89e81d7f..308101115e521 100644 --- a/tests/ui/loops/dont-suggest-break-thru-item.rs +++ b/tests/ui/loops/dont-suggest-break-thru-item.rs @@ -8,6 +8,7 @@ fn closure() { if true { Err(1) //~^ ERROR mismatched types + //~| HELP you might have meant to return this value } Ok(()) @@ -21,6 +22,7 @@ fn async_block() { if true { Err(1) //~^ ERROR mismatched types + //~| HELP you might have meant to return this value } Ok(()) diff --git a/tests/ui/loops/dont-suggest-break-thru-item.stderr b/tests/ui/loops/dont-suggest-break-thru-item.stderr index 4fce471511904..c84a98198f55a 100644 --- a/tests/ui/loops/dont-suggest-break-thru-item.stderr +++ b/tests/ui/loops/dont-suggest-break-thru-item.stderr @@ -5,27 +5,37 @@ LL | / if true { LL | | Err(1) | | ^^^^^^ expected `()`, found `Result<_, {integer}>` LL | | +LL | | LL | | } | |_____________- expected this to be `()` | = note: expected unit type `()` found enum `Result<_, {integer}>` +help: you might have meant to return this value + | +LL | return Err(1); + | ++++++ + error[E0308]: mismatched types - --> $DIR/dont-suggest-break-thru-item.rs:22:17 + --> $DIR/dont-suggest-break-thru-item.rs:23:17 | LL | / if true { LL | | Err(1) | | ^^^^^^ expected `()`, found `Result<_, {integer}>` LL | | +LL | | LL | | } | |_____________- expected this to be `()` | = note: expected unit type `()` found enum `Result<_, {integer}>` +help: you might have meant to return this value + | +LL | return Err(1); + | ++++++ + error[E0308]: mismatched types - --> $DIR/dont-suggest-break-thru-item.rs:35:17 + --> $DIR/dont-suggest-break-thru-item.rs:37:17 | LL | / if true { LL | | Err(1) @@ -38,7 +48,7 @@ LL | | } found enum `Result<_, {integer}>` error[E0308]: mismatched types - --> $DIR/dont-suggest-break-thru-item.rs:47:17 + --> $DIR/dont-suggest-break-thru-item.rs:49:17 | LL | / if true { LL | | Err(1)