diff --git a/pin-project-internal/src/project.rs b/pin-project-internal/src/project.rs index 8bb07d54..81cd3527 100644 --- a/pin-project-internal/src/project.rs +++ b/pin-project-internal/src/project.rs @@ -13,17 +13,33 @@ pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability) .unwrap_or_else(|e| e.to_compile_error()) } -fn parse(mut stmt: Stmt, mutability: Mutability) -> Result { - match &mut stmt { +fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result { + match stmt { Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - Context::new(mutability).replace_expr_match(expr) + Context::new(mutability).replace_expr_match(expr); + return Ok(true) + } + Stmt::Expr(Expr::If(expr_if)) => { + if let Expr::Let(ref mut expr) = &mut *expr_if.cond { + Context::new(mutability).replace_expr_let(expr); + return Ok(true); + } } Stmt::Local(local) => Context::new(mutability).replace_local(local)?, - Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?, - Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability), - Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?, _ => {} } + Ok(false) +} + +fn parse(mut stmt: Stmt, mutability: Mutability) -> Result { + if !replace_stmt(&mut stmt, mutability)? { + match &mut stmt { + Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?, + Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability), + Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?, + _ => {} + } + } Ok(stmt.into_token_stream()) } @@ -73,6 +89,10 @@ impl Context { Ok(()) } + fn replace_expr_let(&mut self, expr: &mut ExprLet) { + self.replace_pat(&mut expr.pat, true) + } + fn replace_expr_match(&mut self, expr: &mut ExprMatch) { expr.arms.iter_mut().for_each(|Arm { pat, .. }| self.replace_pat(pat, true)) } @@ -195,17 +215,18 @@ impl FnVisitor { expr.attrs.find_remove(self.name())? } Stmt::Local(local) => local.attrs.find_remove(self.name())?, + Stmt::Expr(Expr::If(expr_if)) => { + if let Expr::Let(_) = &*expr_if.cond { + expr_if.attrs.find_remove(self.name())? + } else { + None + } + } _ => return Ok(()), }; if let Some(attr) = attr { parse_as_empty(&attr.tokens)?; - match node { - Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - Context::new(self.mutability).replace_expr_match(expr) - } - Stmt::Local(local) => Context::new(self.mutability).replace_local(local)?, - _ => unreachable!(), - } + replace_stmt(node, self.mutability)?; } Ok(()) }