From ecaa860a2ce65a454544c5fba5ff92a515d7b9a6 Mon Sep 17 00:00:00 2001 From: Steve C Date: Fri, 13 Oct 2023 01:10:07 -0400 Subject: [PATCH] Fix PYI030 bug with non-literal unions --- .../test/fixtures/flake8_pyi/PYI030.py | 51 +++++++ .../test/fixtures/flake8_pyi/PYI030.pyi | 3 + .../rules/unnecessary_literal_union.rs | 138 ++++++++++++++++-- 3 files changed, 180 insertions(+), 12 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.py b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.py index 3c9d0ac15c9c37..c6cd2b453707a2 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.py @@ -36,3 +36,54 @@ def func2() -> Literal[1] | Literal[2]: # Error # Should emit for union in generic parent type. field11: dict[Literal[1] | Literal[2], str] # Error + +# Should emit for unions with more than two cases +field12: Literal[1] | Literal[2] | Literal[3] # Error +field13: Literal[1] | Literal[2] | Literal[3] | Literal[4] # Error + +# Should emit for unions with more than two cases, even if not directly adjacent +field14: Literal[1] | Literal[2] | str | Literal[3] # Error + +# Should emit for unions with mixed literal internal types +field15: Literal[1] | Literal["foo"] | Literal[True] # Error + +# Shouldn't emit for duplicate field types with same value; covered by Y016 +field16: Literal[1] | Literal[1] # OK + +# Shouldn't emit if in new parent type +field17: Literal[1] | dict[Literal[2], str] # OK + +# Shouldn't emit if not in a union parent +field18: dict[Literal[1], Literal[2]] # OK + +# Should respect name of literal type used +field19: typing.Literal[1] | typing.Literal[2] # Error + +# Should emit in cases with newlines +field20: typing.Union[ + Literal[ + 1 # test + ], + Literal[2], +] # Error, newline and comment will not be emitted in message + +# Should handle multiple unions with multiple members +field21: Literal[1, 2] | Literal[3, 4] # Error + +# Should emit in cases with `typing.Union` instead of `|` +field22: typing.Union[Literal[1], Literal[2]] # Error + +# Should emit in cases with `typing_extensions.Literal` +field23: typing_extensions.Literal[1] | typing_extensions.Literal[2] # Error + +# Should emit in cases with nested `typing.Union` +field24: typing.Union[Literal[1], typing.Union[Literal[2], str]] # Error + +# Should emit in cases with mixed `typing.Union` and `|` +field25: typing.Union[Literal[1], Literal[2] | str] # Error + +# Should emit only once in cases with multiple nested `typing.Union` +field24: typing.Union[Literal[1], typing.Union[Literal[2], typing.Union[Literal[3], Literal[4]]]] # Error + +# Should use the first literal subscript attribute when fixing +field25: typing.Union[typing_extensions.Literal[1], typing.Union[Literal[2], typing.Union[Literal[3], Literal[4]]], str] # Error diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.pyi b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.pyi index e92af925df67d7..c6cd2b453707a2 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.pyi +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.pyi @@ -84,3 +84,6 @@ field25: typing.Union[Literal[1], Literal[2] | str] # Error # Should emit only once in cases with multiple nested `typing.Union` field24: typing.Union[Literal[1], typing.Union[Literal[2], typing.Union[Literal[3], Literal[4]]]] # Error + +# Should use the first literal subscript attribute when fixing +field25: typing.Union[typing_extensions.Literal[1], typing.Union[Literal[2], typing.Union[Literal[3], Literal[4]]], str] # Error diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/unnecessary_literal_union.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/unnecessary_literal_union.rs index 187837905fe60a..b46e5c76a2f07c 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/rules/unnecessary_literal_union.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/unnecessary_literal_union.rs @@ -1,9 +1,11 @@ -use ruff_diagnostics::{Diagnostic, Violation}; +use ast::{ExprSubscript, Operator}; +use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::{self as ast, Expr}; -use ruff_text_size::Ranged; +use ruff_text_size::{Ranged, TextRange}; use crate::checkers::ast::Checker; + use crate::rules::flake8_pyi::helpers::traverse_union; /// ## What it does @@ -31,7 +33,7 @@ pub struct UnnecessaryLiteralUnion { members: Vec, } -impl Violation for UnnecessaryLiteralUnion { +impl AlwaysFixableViolation for UnnecessaryLiteralUnion { #[derive_message_formats] fn message(&self) -> String { format!( @@ -39,36 +41,148 @@ impl Violation for UnnecessaryLiteralUnion { self.members.join(", ") ) } + + fn fix_title(&self) -> String { + format!("Replace with a single `Literal`") + } +} + +fn concatenate_bin_ors(exprs: Vec<&Expr>) -> Expr { + let mut exprs = exprs.into_iter(); + let first = exprs.next().unwrap(); + exprs.fold((*first).clone(), |acc, expr| { + Expr::BinOp(ast::ExprBinOp { + left: Box::new(acc), + op: Operator::BitOr, + right: Box::new((*expr).clone()), + range: TextRange::default(), + }) + }) +} + +fn make_union(subscript: &ExprSubscript, exprs: Vec<&Expr>) -> Expr { + Expr::Subscript(ast::ExprSubscript { + value: subscript.value.clone(), + slice: Box::new(Expr::Tuple(ast::ExprTuple { + elts: exprs.into_iter().map(|expr| (*expr).clone()).collect(), + range: TextRange::default(), + ctx: ast::ExprContext::Load, + })), + range: TextRange::default(), + ctx: ast::ExprContext::Load, + }) +} + +fn make_literal_expr(subscript: Option, exprs: Vec<&Expr>) -> Expr { + let use_subscript = if let subscript @ Some(_) = subscript { + subscript.unwrap().clone() + } else { + Expr::Name(ast::ExprName { + id: "Literal".to_string(), + range: TextRange::default(), + ctx: ast::ExprContext::Load, + }) + }; + Expr::Subscript(ast::ExprSubscript { + value: Box::new(use_subscript), + slice: Box::new(Expr::Tuple(ast::ExprTuple { + elts: exprs.into_iter().map(|expr| (*expr).clone()).collect(), + range: TextRange::default(), + ctx: ast::ExprContext::Load, + })), + range: TextRange::default(), + ctx: ast::ExprContext::Load, + }) } /// PYI030 pub(crate) fn unnecessary_literal_union<'a>(checker: &mut Checker, expr: &'a Expr) { let mut literal_exprs = Vec::new(); + let mut other_exprs = Vec::new(); - // Adds a member to `literal_exprs` if it is a `Literal` annotation + // for the sake of consistency and correctness, we'll use the first Literal subscript attribute + let mut literal_subscript = None; + + // Adds a member to `literal_exprs` if it is a `Literal` annotation. let mut collect_literal_expr = |expr: &'a Expr, _| { if let Expr::Subscript(ast::ExprSubscript { value, slice, .. }) = expr { if checker.semantic().match_typing_expr(value, "Literal") { - literal_exprs.push(slice); + // flatten already-unioned literals to later union again + if let Expr::Tuple(ast::ExprTuple { + elts, + range: _, + ctx: _, + }) = slice.as_ref() + { + for expr in elts { + if literal_subscript.is_none() { + literal_subscript = Some(*value.clone()); + } + literal_exprs.push(expr); + } + } else { + if literal_subscript.is_none() { + literal_subscript = Some(*value.clone()); + } + literal_exprs.push(slice.as_ref()); + } } + } else { + other_exprs.push(expr); } }; - // Traverse the union, collect all literal members + // Traverse the union, collect all members, split out the literals from the rest. traverse_union(&mut collect_literal_expr, checker.semantic(), expr, None); - // Raise a violation if more than one + let union_subscript = expr.as_subscript_expr(); + if union_subscript.is_some_and(|subscript| { + !checker + .semantic() + .match_typing_expr(&subscript.value, "Union") + }) { + return; + } + + // Raise a violation if more than one. if literal_exprs.len() > 1 { - let diagnostic = Diagnostic::new( + let literal_members: Vec = literal_exprs + .clone() + .into_iter() + .map(|expr| checker.locator().slice(expr).to_string()) + .collect(); + + let mut diagnostic = Diagnostic::new( UnnecessaryLiteralUnion { - members: literal_exprs - .into_iter() - .map(|expr| checker.locator().slice(expr.as_ref()).to_string()) - .collect(), + members: literal_members.clone(), }, expr.range(), ); + let literals = make_literal_expr(literal_subscript, literal_exprs.into_iter().collect()); + + if other_exprs.is_empty() { + // if the union is only literals, we just replace the whole thing with a single literal + diagnostic.set_fix(Fix::safe_edit(Edit::range_replacement( + checker.generator().expr(&literals), + expr.range(), + ))); + } else { + let mut expr_vec: Vec<&Expr> = other_exprs.clone().into_iter().collect(); + expr_vec.insert(0, &literals); + + let content = if let Some(subscript) = union_subscript { + checker.generator().expr(&make_union(subscript, expr_vec)) + } else { + checker.generator().expr(&concatenate_bin_ors(expr_vec)) + }; + + diagnostic.set_fix(Fix::safe_edit(Edit::range_replacement( + content, + expr.range(), + ))); + } + checker.diagnostics.push(diagnostic); } }