Skip to content

Commit

Permalink
Fix PYI030 bug with non-literal unions
Browse files Browse the repository at this point in the history
  • Loading branch information
diceroll123 committed Oct 14, 2023
1 parent 8061894 commit ecaa860
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 12 deletions.
51 changes: 51 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI030.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,54 @@ def func2() -> Literal[1] | Literal[2]: # Error

# Should emit for union in generic parent type.
field11: dict[Literal[1] | Literal[2], str] # Error

# Should emit for unions with more than two cases
field12: Literal[1] | Literal[2] | Literal[3] # Error
field13: Literal[1] | Literal[2] | Literal[3] | Literal[4] # Error

# Should emit for unions with more than two cases, even if not directly adjacent
field14: Literal[1] | Literal[2] | str | Literal[3] # Error

# Should emit for unions with mixed literal internal types
field15: Literal[1] | Literal["foo"] | Literal[True] # Error

# Shouldn't emit for duplicate field types with same value; covered by Y016
field16: Literal[1] | Literal[1] # OK

# Shouldn't emit if in new parent type
field17: Literal[1] | dict[Literal[2], str] # OK

# Shouldn't emit if not in a union parent
field18: dict[Literal[1], Literal[2]] # OK

# Should respect name of literal type used
field19: typing.Literal[1] | typing.Literal[2] # Error

# Should emit in cases with newlines
field20: typing.Union[
Literal[
1 # test
],
Literal[2],
] # Error, newline and comment will not be emitted in message

# Should handle multiple unions with multiple members
field21: Literal[1, 2] | Literal[3, 4] # Error

# Should emit in cases with `typing.Union` instead of `|`
field22: typing.Union[Literal[1], Literal[2]] # Error

# Should emit in cases with `typing_extensions.Literal`
field23: typing_extensions.Literal[1] | typing_extensions.Literal[2] # Error

# Should emit in cases with nested `typing.Union`
field24: typing.Union[Literal[1], typing.Union[Literal[2], str]] # Error

# Should emit in cases with mixed `typing.Union` and `|`
field25: typing.Union[Literal[1], Literal[2] | str] # Error

# Should emit only once in cases with multiple nested `typing.Union`
field24: typing.Union[Literal[1], typing.Union[Literal[2], typing.Union[Literal[3], Literal[4]]]] # Error

# Should use the first literal subscript attribute when fixing
field25: typing.Union[typing_extensions.Literal[1], typing.Union[Literal[2], typing.Union[Literal[3], Literal[4]]], str] # Error
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,6 @@ field25: typing.Union[Literal[1], Literal[2] | str] # Error

# Should emit only once in cases with multiple nested `typing.Union`
field24: typing.Union[Literal[1], typing.Union[Literal[2], typing.Union[Literal[3], Literal[4]]]] # Error

# Should use the first literal subscript attribute when fixing
field25: typing.Union[typing_extensions.Literal[1], typing.Union[Literal[2], typing.Union[Literal[3], Literal[4]]], str] # Error
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use ruff_diagnostics::{Diagnostic, Violation};
use ast::{ExprSubscript, Operator};
use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::{self as ast, Expr};
use ruff_text_size::Ranged;
use ruff_text_size::{Ranged, TextRange};

use crate::checkers::ast::Checker;

use crate::rules::flake8_pyi::helpers::traverse_union;

/// ## What it does
Expand Down Expand Up @@ -31,44 +33,156 @@ pub struct UnnecessaryLiteralUnion {
members: Vec<String>,
}

impl Violation for UnnecessaryLiteralUnion {
impl AlwaysFixableViolation for UnnecessaryLiteralUnion {
#[derive_message_formats]
fn message(&self) -> String {
format!(
"Multiple literal members in a union. Use a single literal, e.g. `Literal[{}]`",
self.members.join(", ")
)
}

fn fix_title(&self) -> String {
format!("Replace with a single `Literal`")
}
}

fn concatenate_bin_ors(exprs: Vec<&Expr>) -> Expr {
let mut exprs = exprs.into_iter();
let first = exprs.next().unwrap();
exprs.fold((*first).clone(), |acc, expr| {
Expr::BinOp(ast::ExprBinOp {
left: Box::new(acc),
op: Operator::BitOr,
right: Box::new((*expr).clone()),
range: TextRange::default(),
})
})
}

fn make_union(subscript: &ExprSubscript, exprs: Vec<&Expr>) -> Expr {
Expr::Subscript(ast::ExprSubscript {
value: subscript.value.clone(),
slice: Box::new(Expr::Tuple(ast::ExprTuple {
elts: exprs.into_iter().map(|expr| (*expr).clone()).collect(),
range: TextRange::default(),
ctx: ast::ExprContext::Load,
})),
range: TextRange::default(),
ctx: ast::ExprContext::Load,
})
}

fn make_literal_expr(subscript: Option<Expr>, exprs: Vec<&Expr>) -> Expr {
let use_subscript = if let subscript @ Some(_) = subscript {
subscript.unwrap().clone()
} else {
Expr::Name(ast::ExprName {
id: "Literal".to_string(),
range: TextRange::default(),
ctx: ast::ExprContext::Load,
})
};
Expr::Subscript(ast::ExprSubscript {
value: Box::new(use_subscript),
slice: Box::new(Expr::Tuple(ast::ExprTuple {
elts: exprs.into_iter().map(|expr| (*expr).clone()).collect(),
range: TextRange::default(),
ctx: ast::ExprContext::Load,
})),
range: TextRange::default(),
ctx: ast::ExprContext::Load,
})
}

/// PYI030
pub(crate) fn unnecessary_literal_union<'a>(checker: &mut Checker, expr: &'a Expr) {
let mut literal_exprs = Vec::new();
let mut other_exprs = Vec::new();

// Adds a member to `literal_exprs` if it is a `Literal` annotation
// for the sake of consistency and correctness, we'll use the first Literal subscript attribute
let mut literal_subscript = None;

// Adds a member to `literal_exprs` if it is a `Literal` annotation.
let mut collect_literal_expr = |expr: &'a Expr, _| {
if let Expr::Subscript(ast::ExprSubscript { value, slice, .. }) = expr {
if checker.semantic().match_typing_expr(value, "Literal") {
literal_exprs.push(slice);
// flatten already-unioned literals to later union again
if let Expr::Tuple(ast::ExprTuple {
elts,
range: _,
ctx: _,
}) = slice.as_ref()
{
for expr in elts {
if literal_subscript.is_none() {
literal_subscript = Some(*value.clone());
}
literal_exprs.push(expr);
}
} else {
if literal_subscript.is_none() {
literal_subscript = Some(*value.clone());
}
literal_exprs.push(slice.as_ref());
}
}
} else {
other_exprs.push(expr);
}
};

// Traverse the union, collect all literal members
// Traverse the union, collect all members, split out the literals from the rest.
traverse_union(&mut collect_literal_expr, checker.semantic(), expr, None);

// Raise a violation if more than one
let union_subscript = expr.as_subscript_expr();
if union_subscript.is_some_and(|subscript| {
!checker
.semantic()
.match_typing_expr(&subscript.value, "Union")
}) {
return;
}

// Raise a violation if more than one.
if literal_exprs.len() > 1 {
let diagnostic = Diagnostic::new(
let literal_members: Vec<String> = literal_exprs
.clone()
.into_iter()
.map(|expr| checker.locator().slice(expr).to_string())
.collect();

let mut diagnostic = Diagnostic::new(
UnnecessaryLiteralUnion {
members: literal_exprs
.into_iter()
.map(|expr| checker.locator().slice(expr.as_ref()).to_string())
.collect(),
members: literal_members.clone(),
},
expr.range(),
);

let literals = make_literal_expr(literal_subscript, literal_exprs.into_iter().collect());

if other_exprs.is_empty() {
// if the union is only literals, we just replace the whole thing with a single literal
diagnostic.set_fix(Fix::safe_edit(Edit::range_replacement(
checker.generator().expr(&literals),
expr.range(),
)));
} else {
let mut expr_vec: Vec<&Expr> = other_exprs.clone().into_iter().collect();
expr_vec.insert(0, &literals);

let content = if let Some(subscript) = union_subscript {
checker.generator().expr(&make_union(subscript, expr_vec))
} else {
checker.generator().expr(&concatenate_bin_ors(expr_vec))
};

diagnostic.set_fix(Fix::safe_edit(Edit::range_replacement(
content,
expr.range(),
)));
}

checker.diagnostics.push(diagnostic);
}
}

0 comments on commit ecaa860

Please sign in to comment.