From 45b56d1152d0b68f7f8faf6b3d04df65f8dcb74b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mikko=20Lepp=C3=A4nen?= <mleppan23@gmail.com>
Date: Mon, 29 Jan 2024 19:29:05 +0200
Subject: [PATCH] [flake8-return] Consider exception suppress for unnecessary
 assignment (#9673)

## Summary

This review contains a fix for
[RET504](https://docs.astral.sh/ruff/rules/unnecessary-assign/)
(unnecessary-assign)

The problem is that Ruff suggests combining a return statement inside
contextlib.suppress. Even though it is an unsafe fix it might lead to an
invalid code that is not equivalent to the original one.

See: https://github.com/astral-sh/ruff/issues/5909

## Test Plan

```bash
cargo test
```
---
 .../test/fixtures/flake8_return/RET504.py     | 43 +++++++++
 .../src/rules/flake8_return/rules/function.rs |  2 +-
 ...lake8_return__tests__RET504_RET504.py.snap | 23 +++++
 .../src/rules/flake8_return/visitor.rs        | 96 +++++++++++++++----
 4 files changed, 147 insertions(+), 17 deletions(-)

diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_return/RET504.py b/crates/ruff_linter/resources/test/fixtures/flake8_return/RET504.py
index 9899fa15ed6bc3..8480aafaa1c013 100644
--- a/crates/ruff_linter/resources/test/fixtures/flake8_return/RET504.py
+++ b/crates/ruff_linter/resources/test/fixtures/flake8_return/RET504.py
@@ -363,3 +363,46 @@ def foo():
 def mavko_debari(P_kbar):
     D=0.4853881 + 3.6006116*P - 0.0117368*(P-1.3822)**2
     return D
+
+
+# contextlib suppress in with statement
+import contextlib
+
+
+def foo():
+    x = 2
+    with contextlib.suppress(Exception):
+        x = x + 1
+    return x
+
+
+def foo(data):
+    with open("in.txt") as file_out, contextlib.suppress(IOError):
+        file_out.write(data)
+        data = 10
+    return data
+
+
+def foo(data):
+    with open("in.txt") as file_out:
+        file_out.write(data)
+        with contextlib.suppress(IOError):
+            data = 10
+    return data
+
+
+def foo():
+    y = 1
+    x = 2
+    with contextlib.suppress(Exception):
+        x = 1
+    y = y + 2
+    return y  # RET504
+
+
+def foo():
+    y = 1
+    if y > 0:
+        with contextlib.suppress(Exception):
+            y = 2
+        return y
diff --git a/crates/ruff_linter/src/rules/flake8_return/rules/function.rs b/crates/ruff_linter/src/rules/flake8_return/rules/function.rs
index 7e016ecba70c77..4604663b3ece92 100644
--- a/crates/ruff_linter/src/rules/flake8_return/rules/function.rs
+++ b/crates/ruff_linter/src/rules/flake8_return/rules/function.rs
@@ -737,7 +737,7 @@ pub(crate) fn function(checker: &mut Checker, body: &[Stmt], returns: Option<&Ex
 
     // Traverse the function body, to collect the stack.
     let stack = {
-        let mut visitor = ReturnVisitor::default();
+        let mut visitor = ReturnVisitor::new(checker.semantic());
         for stmt in body {
             visitor.visit_stmt(stmt);
         }
diff --git a/crates/ruff_linter/src/rules/flake8_return/snapshots/ruff_linter__rules__flake8_return__tests__RET504_RET504.py.snap b/crates/ruff_linter/src/rules/flake8_return/snapshots/ruff_linter__rules__flake8_return__tests__RET504_RET504.py.snap
index 9427cd7a79d662..3656e71afbd2d7 100644
--- a/crates/ruff_linter/src/rules/flake8_return/snapshots/ruff_linter__rules__flake8_return__tests__RET504_RET504.py.snap
+++ b/crates/ruff_linter/src/rules/flake8_return/snapshots/ruff_linter__rules__flake8_return__tests__RET504_RET504.py.snap
@@ -217,5 +217,28 @@ RET504.py:365:12: RET504 [*] Unnecessary assignment to `D` before `return` state
 364     |-    D=0.4853881 + 3.6006116*P - 0.0117368*(P-1.3822)**2
 365     |-    return D
     364 |+    return 0.4853881 + 3.6006116*P - 0.0117368*(P-1.3822)**2
+366 365 | 
+367 366 | 
+368 367 | # contextlib suppress in with statement
+
+RET504.py:400:12: RET504 [*] Unnecessary assignment to `y` before `return` statement
+    |
+398 |         x = 1
+399 |     y = y + 2
+400 |     return y  # RET504
+    |            ^ RET504
+    |
+    = help: Remove unnecessary assignment
+
+ℹ Unsafe fix
+396 396 |     x = 2
+397 397 |     with contextlib.suppress(Exception):
+398 398 |         x = 1
+399     |-    y = y + 2
+400     |-    return y  # RET504
+    399 |+    return y + 2
+401 400 | 
+402 401 | 
+403 402 | def foo():
 
 
diff --git a/crates/ruff_linter/src/rules/flake8_return/visitor.rs b/crates/ruff_linter/src/rules/flake8_return/visitor.rs
index 775c3356e5f72a..653364ff1ab782 100644
--- a/crates/ruff_linter/src/rules/flake8_return/visitor.rs
+++ b/crates/ruff_linter/src/rules/flake8_return/visitor.rs
@@ -3,34 +3,48 @@ use rustc_hash::FxHashSet;
 
 use ruff_python_ast::visitor;
 use ruff_python_ast::visitor::Visitor;
+use ruff_python_semantic::SemanticModel;
 
 #[derive(Default)]
-pub(super) struct Stack<'a> {
+pub(super) struct Stack<'data> {
     /// The `return` statements in the current function.
-    pub(super) returns: Vec<&'a ast::StmtReturn>,
+    pub(super) returns: Vec<&'data ast::StmtReturn>,
     /// The `elif` or `else` statements in the current function.
-    pub(super) elifs_elses: Vec<(&'a [Stmt], &'a ElifElseClause)>,
+    pub(super) elifs_elses: Vec<(&'data [Stmt], &'data ElifElseClause)>,
     /// The non-local variables in the current function.
-    pub(super) non_locals: FxHashSet<&'a str>,
+    pub(super) non_locals: FxHashSet<&'data str>,
     /// Whether the current function is a generator.
     pub(super) is_generator: bool,
     /// The `assignment`-to-`return` statement pairs in the current function.
     /// TODO(charlie): Remove the extra [`Stmt`] here, which is necessary to support statement
     /// removal for the `return` statement.
-    pub(super) assignment_return: Vec<(&'a ast::StmtAssign, &'a ast::StmtReturn, &'a Stmt)>,
+    pub(super) assignment_return:
+        Vec<(&'data ast::StmtAssign, &'data ast::StmtReturn, &'data Stmt)>,
 }
 
-#[derive(Default)]
-pub(super) struct ReturnVisitor<'a> {
+pub(super) struct ReturnVisitor<'semantic, 'data> {
+    /// The semantic model of the current file.
+    semantic: &'semantic SemanticModel<'data>,
     /// The current stack of nodes.
-    pub(super) stack: Stack<'a>,
+    pub(super) stack: Stack<'data>,
     /// The preceding sibling of the current node.
-    sibling: Option<&'a Stmt>,
+    sibling: Option<&'data Stmt>,
     /// The parent nodes of the current node.
-    parents: Vec<&'a Stmt>,
+    parents: Vec<&'data Stmt>,
+}
+
+impl<'semantic, 'data> ReturnVisitor<'semantic, 'data> {
+    pub(super) fn new(semantic: &'semantic SemanticModel<'data>) -> Self {
+        Self {
+            semantic,
+            stack: Stack::default(),
+            sibling: None,
+            parents: Vec::new(),
+        }
+    }
 }
 
-impl<'a> Visitor<'a> for ReturnVisitor<'a> {
+impl<'semantic, 'a> Visitor<'a> for ReturnVisitor<'semantic, 'a> {
     fn visit_stmt(&mut self, stmt: &'a Stmt) {
         match stmt {
             Stmt::ClassDef(ast::StmtClassDef { decorator_list, .. }) => {
@@ -95,11 +109,17 @@ impl<'a> Visitor<'a> for ReturnVisitor<'a> {
                         //         x = f.read()
                         //     return x
                         // ```
-                        Stmt::With(ast::StmtWith { body, .. }) => {
-                            if let Some(stmt_assign) = body.last().and_then(Stmt::as_assign_stmt) {
-                                self.stack
-                                    .assignment_return
-                                    .push((stmt_assign, stmt_return, stmt));
+                        Stmt::With(with) => {
+                            if let Some(stmt_assign) =
+                                with.body.last().and_then(Stmt::as_assign_stmt)
+                            {
+                                if !has_conditional_body(with, self.semantic) {
+                                    self.stack.assignment_return.push((
+                                        stmt_assign,
+                                        stmt_return,
+                                        stmt,
+                                    ));
+                                }
                             }
                         }
                         _ => {}
@@ -142,3 +162,47 @@ impl<'a> Visitor<'a> for ReturnVisitor<'a> {
         self.sibling = sibling;
     }
 }
+
+/// RET504
+/// If the last statement is a `return` statement, and the second-to-last statement is a
+/// `with` statement that suppresses an exception, then we should not analyze the `return`
+/// statement for unnecessary assignments. Otherwise we will suggest removing the assignment
+/// and the `with` statement, which would change the behavior of the code.
+///
+/// Example:
+/// ```python
+/// def foo(data):
+///    with suppress(JSONDecoderError):
+///       data = data.decode()
+///   return data
+
+/// Returns `true` if the [`With`] statement is known to have a conditional body. In other words:
+/// if the [`With`] statement's body may or may not run.
+///
+/// For example, in the following, it's unsafe to inline the `return` into the `with`, since if
+/// `data.decode()` fails, the behavior of the program will differ. (As-is, the function will return
+/// the input `data`; if we inline the `return`, the function will return `None`.)
+///
+/// ```python
+/// def func(data):
+///     with suppress(JSONDecoderError):
+///         data = data.decode()
+///     return data
+/// ```
+fn has_conditional_body(with: &ast::StmtWith, semantic: &SemanticModel) -> bool {
+    with.items.iter().any(|item| {
+        let ast::WithItem {
+            context_expr: Expr::Call(ast::ExprCall { func, .. }),
+            ..
+        } = item
+        else {
+            return false;
+        };
+        if let Some(call_path) = semantic.resolve_call_path(func) {
+            if call_path.as_slice() == ["contextlib", "suppress"] {
+                return true;
+            }
+        }
+        false
+    })
+}