diff --git a/crates/ruff_python_formatter/src/expression/mod.rs b/crates/ruff_python_formatter/src/expression/mod.rs index 97c43cf75f9641..cfe701d35a9e83 100644 --- a/crates/ruff_python_formatter/src/expression/mod.rs +++ b/crates/ruff_python_formatter/src/expression/mod.rs @@ -198,37 +198,30 @@ impl<'ast> IntoFormat> for Expr { /// /// This mimics Black's [`_maybe_split_omitting_optional_parens`](https://github.com/psf/black/blob/d1248ca9beaf0ba526d265f4108836d89cf551b7/src/black/linegen.py#L746-L820) fn can_omit_optional_parentheses(expr: &Expr, context: &PyFormatContext) -> bool { - let mut visitor = MaxOperatorPriorityVisitor::new(context.source()); - + let mut visitor = CanOmitOptionalParenthesesVisitor::new(context.source()); visitor.visit_subexpression(expr); - - let (max_operator_priority, operation_count, any_parenthesized_expression) = visitor.finish(); - - if operation_count > 1 { - false - } else if max_operator_priority == OperatorPriority::Attribute { - true - } else { - // Only use the more complex IR when there is any expression that we can possibly split by - any_parenthesized_expression - } + visitor.can_omit() } #[derive(Clone, Debug)] -struct MaxOperatorPriorityVisitor<'input> { +struct CanOmitOptionalParenthesesVisitor<'input> { max_priority: OperatorPriority, max_priority_count: u32, any_parenthesized_expressions: bool, + last: Option<&'input Expr>, + first: Option<&'input Expr>, source: &'input str, } -impl<'input> MaxOperatorPriorityVisitor<'input> { +impl<'input> CanOmitOptionalParenthesesVisitor<'input> { fn new(source: &'input str) -> Self { Self { source, max_priority: OperatorPriority::None, max_priority_count: 0, any_parenthesized_expressions: false, + last: None, + first: None, } } @@ -305,6 +298,7 @@ impl<'input> MaxOperatorPriorityVisitor<'input> { self.any_parenthesized_expressions = true; // Only walk the function, the arguments are always parenthesized self.visit_expr(func); + self.last = Some(expr); return; } Expr::Subscript(_) => { @@ -351,23 +345,41 @@ impl<'input> MaxOperatorPriorityVisitor<'input> { walk_expr(self, expr); } - fn finish(self) -> (OperatorPriority, u32, bool) { - ( - self.max_priority, - self.max_priority_count, - self.any_parenthesized_expressions, - ) + fn can_omit(self) -> bool { + if self.max_priority_count > 1 { + false + } else if self.max_priority == OperatorPriority::Attribute { + true + } else if !self.any_parenthesized_expressions { + // Only use the more complex IR when there is any expression that we can possibly split by + false + } else { + // Only use the layout if the first or last expression has parentheses of some sort. + let first_parenthesized = self + .first + .map_or(false, |first| has_parentheses(first, self.source)); + let last_parenthesized = self + .last + .map_or(false, |last| has_parentheses(last, self.source)); + first_parenthesized || last_parenthesized + } } } -impl<'input> PreorderVisitor<'input> for MaxOperatorPriorityVisitor<'input> { +impl<'input> PreorderVisitor<'input> for CanOmitOptionalParenthesesVisitor<'input> { fn visit_expr(&mut self, expr: &'input Expr) { + self.last = Some(expr); + // Rule only applies for non-parenthesized expressions. if is_expression_parenthesized(AnyNodeRef::from(expr), self.source) { self.any_parenthesized_expressions = true; } else { self.visit_subexpression(expr); } + + if self.first.is_none() { + self.first = Some(expr); + } } } diff --git a/crates/ruff_python_formatter/src/lib.rs b/crates/ruff_python_formatter/src/lib.rs index 79bc92d667f279..44ab2e8de4e7af 100644 --- a/crates/ruff_python_formatter/src/lib.rs +++ b/crates/ruff_python_formatter/src/lib.rs @@ -280,8 +280,15 @@ if True: #[test] fn quick_test() { let src = r#" -def foo() -> tuple[int, int, int,]: - return 2 +if a * [ + bbbbbbbbbbbbbbbbbbbbbb, + cccccccccccccccccccccccccccccdddddddddddddddddddddddddd, +] + a * e * [ + ffff, + gggg, + hhhhhhhhhhhhhh, +] * c: + pass "#; // Tokenize once diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__trailing_comma_optional_parens1.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__trailing_comma_optional_parens1.py.snap deleted file mode 100644 index 7fb41572f30e81..00000000000000 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@simple_cases__trailing_comma_optional_parens1.py.snap +++ /dev/null @@ -1,135 +0,0 @@ ---- -source: crates/ruff_python_formatter/tests/fixtures.rs -input_file: crates/ruff_python_formatter/resources/test/fixtures/black/simple_cases/trailing_comma_optional_parens1.py ---- -## Input - -```py -if e1234123412341234.winerror not in (_winapi.ERROR_SEM_TIMEOUT, - _winapi.ERROR_PIPE_BUSY) or _check_timeout(t): - pass - -if x: - if y: - new_id = max(Vegetable.objects.order_by('-id')[0].id, - Mineral.objects.order_by('-id')[0].id) + 1 - -class X: - def get_help_text(self): - return ngettext( - "Your password must contain at least %(min_length)d character.", - "Your password must contain at least %(min_length)d characters.", - self.min_length, - ) % {'min_length': self.min_length} - -class A: - def b(self): - if self.connection.mysql_is_mariadb and ( - 10, - 4, - 3, - ) < self.connection.mysql_version < (10, 5, 2): - pass -``` - -## Black Differences - -```diff ---- Black -+++ Ruff -@@ -6,13 +6,10 @@ - - if x: - if y: -- new_id = ( -- max( -- Vegetable.objects.order_by("-id")[0].id, -- Mineral.objects.order_by("-id")[0].id, -- ) -- + 1 -- ) -+ new_id = max( -+ Vegetable.objects.order_by("-id")[0].id, -+ Mineral.objects.order_by("-id")[0].id, -+ ) + 1 - - - class X: -``` - -## Ruff Output - -```py -if e1234123412341234.winerror not in ( - _winapi.ERROR_SEM_TIMEOUT, - _winapi.ERROR_PIPE_BUSY, -) or _check_timeout(t): - pass - -if x: - if y: - new_id = max( - Vegetable.objects.order_by("-id")[0].id, - Mineral.objects.order_by("-id")[0].id, - ) + 1 - - -class X: - def get_help_text(self): - return ngettext( - "Your password must contain at least %(min_length)d character.", - "Your password must contain at least %(min_length)d characters.", - self.min_length, - ) % {"min_length": self.min_length} - - -class A: - def b(self): - if self.connection.mysql_is_mariadb and ( - 10, - 4, - 3, - ) < self.connection.mysql_version < (10, 5, 2): - pass -``` - -## Black Output - -```py -if e1234123412341234.winerror not in ( - _winapi.ERROR_SEM_TIMEOUT, - _winapi.ERROR_PIPE_BUSY, -) or _check_timeout(t): - pass - -if x: - if y: - new_id = ( - max( - Vegetable.objects.order_by("-id")[0].id, - Mineral.objects.order_by("-id")[0].id, - ) - + 1 - ) - - -class X: - def get_help_text(self): - return ngettext( - "Your password must contain at least %(min_length)d character.", - "Your password must contain at least %(min_length)d characters.", - self.min_length, - ) % {"min_length": self.min_length} - - -class A: - def b(self): - if self.connection.mysql_is_mariadb and ( - 10, - 4, - 3, - ) < self.connection.mysql_version < (10, 5, 2): - pass -``` - -