diff --git a/crates/ruff_linter/resources/test/fixtures/pycodestyle/E721.py b/crates/ruff_linter/resources/test/fixtures/pycodestyle/E721.py index 872fa7042aded1..9ce183b6adad82 100644 --- a/crates/ruff_linter/resources/test/fixtures/pycodestyle/E721.py +++ b/crates/ruff_linter/resources/test/fixtures/pycodestyle/E721.py @@ -126,3 +126,15 @@ def type(): # Okay if type(value) is str: ... + + +import numpy as np + +#: Okay +x.dtype == float + +#: Okay +np.dtype(int) == float + +#: E721 +dtype == float diff --git a/crates/ruff_linter/src/rules/pycodestyle/rules/type_comparison.rs b/crates/ruff_linter/src/rules/pycodestyle/rules/type_comparison.rs index 598b9a9c119722..57de972c231771 100644 --- a/crates/ruff_linter/src/rules/pycodestyle/rules/type_comparison.rs +++ b/crates/ruff_linter/src/rules/pycodestyle/rules/type_comparison.rs @@ -162,7 +162,14 @@ pub(crate) fn preview_type_comparison(checker: &mut Checker, compare: &ast::Expr .filter(|(_, op)| matches!(op, CmpOp::Eq | CmpOp::NotEq)) .map(|((left, right), _)| (left, right)) { + // If either expression is a type... if is_type(left, checker.semantic()) || is_type(right, checker.semantic()) { + // And neither is a `dtype`... + if is_dtype(left, checker.semantic()) || is_dtype(right, checker.semantic()) { + continue; + } + + // Disallow the comparison. checker.diagnostics.push(Diagnostic::new( TypeComparison { preview: PreviewMode::Enabled, @@ -295,3 +302,23 @@ fn is_type(expr: &Expr, semantic: &SemanticModel) -> bool { _ => false, } } + +/// Returns `true` if the [`Expr`] appears to be a reference to a NumPy dtype, since: +/// > `dtype` are a bit of a strange beast, but definitely best thought of as instances, not +/// > classes, and they are meant to be comparable not just to their own class, but also to the +/// corresponding scalar types (e.g., `x.dtype == np.float32`) and strings (e.g., +/// `x.dtype == ['i1,i4']`; basically, __eq__ always tries to do `dtype(other)`). +fn is_dtype(expr: &Expr, semantic: &SemanticModel) -> bool { + match expr { + // Ex) `np.dtype(obj)` + Expr::Call(ast::ExprCall { func, .. }) => semantic + .resolve_call_path(func) + .is_some_and(|call_path| matches!(call_path.as_slice(), ["numpy", "dtype"])), + // Ex) `obj.dtype` + Expr::Attribute(ast::ExprAttribute { attr, .. }) => { + // Ex) `obj.dtype` + attr.as_str() == "dtype" + } + _ => false, + } +} diff --git a/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__preview__E721_E721.py.snap b/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__preview__E721_E721.py.snap index 8971d3f3ccf065..fc1d13521fdae6 100644 --- a/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__preview__E721_E721.py.snap +++ b/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__preview__E721_E721.py.snap @@ -129,4 +129,11 @@ E721.py:59:4: E721 Use `is` and `is not` for type comparisons, or `isinstance()` 61 | #: Okay | +E721.py:140:1: E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks + | +139 | #: E721 +140 | dtype == float + | ^^^^^^^^^^^^^^ E721 + | +