Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Fix more edge cases for intersection simplification with LiteralString and AlwaysTruthy/AlwaysFalsy #15496

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ simplified, due to the fact that a `LiteralString` inhabitant is known to have `
exactly `str` (and not a subclass of `str`):

```py
from knot_extensions import Intersection, Not, AlwaysTruthy, AlwaysFalsy
from knot_extensions import Intersection, Not, AlwaysTruthy, AlwaysFalsy, Unknown
from typing_extensions import LiteralString

def f(
Expand All @@ -690,13 +690,21 @@ def f(
d: Intersection[LiteralString, Not[AlwaysFalsy]],
e: Intersection[AlwaysFalsy, LiteralString],
f: Intersection[Not[AlwaysTruthy], LiteralString],
g: Intersection[AlwaysTruthy, LiteralString],
h: Intersection[Not[AlwaysFalsy], LiteralString],
i: Intersection[Unknown, LiteralString, AlwaysFalsy],
j: Intersection[Not[AlwaysTruthy], Unknown, LiteralString],
):
reveal_type(a) # revealed: LiteralString & ~Literal[""]
reveal_type(b) # revealed: Literal[""]
reveal_type(c) # revealed: Literal[""]
reveal_type(d) # revealed: LiteralString & ~Literal[""]
reveal_type(e) # revealed: Literal[""]
reveal_type(f) # revealed: Literal[""]
reveal_type(g) # revealed: LiteralString & ~Literal[""]
reveal_type(h) # revealed: LiteralString & ~Literal[""]
reveal_type(i) # revealed: Unknown & Literal[""]
reveal_type(j) # revealed: Unknown & Literal[""]
```

## Addition of a type to an intersection with many non-disjoint types
Expand Down
227 changes: 120 additions & 107 deletions crates/red_knot_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,131 +247,144 @@ struct InnerIntersectionBuilder<'db> {
impl<'db> InnerIntersectionBuilder<'db> {
/// Adds a positive type to this intersection.
fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) {
if new_positive == Type::AlwaysTruthy && self.positive.contains(&Type::LiteralString) {
self.add_negative(db, Type::string_literal(db, ""));
return;
}

if let Type::Intersection(other) = new_positive {
for pos in other.positive(db) {
self.add_positive(db, *pos);
match new_positive {
// `LiteralString & AlwaysTruthy` -> `LiteralString & ~Literal[""]`
Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => {
self.add_negative(db, Type::string_literal(db, ""));
}
for neg in other.negative(db) {
self.add_negative(db, *neg);
// `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::AlwaysFalsy if self.positive.swap_remove(&Type::LiteralString) => {
self.add_positive(db, Type::string_literal(db, ""));
}
} else {
let addition_is_bool_instance = new_positive
.into_instance()
.and_then(|instance| instance.class.known(db))
.is_some_and(KnownClass::is_bool);

for (index, existing_positive) in self.positive.iter().enumerate() {
match existing_positive {
// `AlwaysTruthy & bool` -> `Literal[True]`
Type::AlwaysTruthy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(true);
}
// `AlwaysFalsy & bool` -> `Literal[False]`
Type::AlwaysFalsy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(false);
}
// `AlwaysFalsy & LiteralString` -> `Literal[""]`
Type::AlwaysFalsy if new_positive.is_literal_string() => {
new_positive = Type::string_literal(db, "");
}
Type::Instance(InstanceType { class })
if class.is_known(db, KnownClass::Bool) =>
{
match new_positive {
// `bool & AlwaysTruthy` -> `Literal[True]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(true);
}
// `bool & AlwaysFalsy` -> `Literal[False]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(false);
}
_ => continue,
}
}
// `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::LiteralString if new_positive == Type::AlwaysFalsy => {
new_positive = Type::string_literal(db, "");
}
_ => continue,
// `AlwaysTruthy & LiteralString` -> `LiteralString & ~Literal[""]`
Type::LiteralString if self.positive.swap_remove(&Type::AlwaysTruthy) => {
self.add_positive(db, Type::LiteralString);
self.add_negative(db, Type::string_literal(db, ""));
}
// `AlwaysFalsy & LiteralString` -> `Literal[""]`
Type::LiteralString if self.positive.swap_remove(&Type::AlwaysFalsy) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `LiteralString & ~AlwaysTruthy` -> `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::LiteralString if self.negative.swap_remove(&Type::AlwaysTruthy) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]`
Type::LiteralString if self.negative.swap_remove(&Type::AlwaysFalsy) => {
self.add_positive(db, Type::LiteralString);
self.add_negative(db, Type::string_literal(db, ""));
}
Comment on lines +259 to +276
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's tempting to think that we could avoid some indirection here, e.g.

Suggested change
// `AlwaysTruthy & LiteralString` -> `LiteralString & ~Literal[""]`
Type::LiteralString if self.positive.swap_remove(&Type::AlwaysTruthy) => {
self.add_positive(db, Type::LiteralString);
self.add_negative(db, Type::string_literal(db, ""));
}
// `AlwaysFalsy & LiteralString` -> `Literal[""]`
Type::LiteralString if self.positive.swap_remove(&Type::AlwaysFalsy) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `LiteralString & ~AlwaysTruthy` -> `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::LiteralString if self.negative.swap_remove(&Type::AlwaysTruthy) => {
self.add_positive(db, Type::string_literal(db, ""));
}
// `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]`
Type::LiteralString if self.negative.swap_remove(&Type::AlwaysFalsy) => {
self.add_positive(db, Type::LiteralString);
self.add_negative(db, Type::string_literal(db, ""));
}
// `AlwaysTruthy & LiteralString` -> `LiteralString & ~Literal[""]`
Type::LiteralString if self.positive.swap_remove(&Type::AlwaysTruthy) => {
self.positive.insert(Type::LiteralString);
self.negative.insert(Type::string_literal(db, ""));
}
// `AlwaysFalsy & LiteralString` -> `Literal[""]`
Type::LiteralString if self.positive.swap_remove(&Type::AlwaysFalsy) => {
self.positive.insert(Type::string_literal(db, ""));
}
// `LiteralString & ~AlwaysTruthy` -> `LiteralString & AlwaysFalsy` -> `Literal[""]`
Type::LiteralString if self.negative.swap_remove(&Type::AlwaysTruthy) => {
self.positive.insert(Type::string_literal(db, ""));
}
// `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]`
Type::LiteralString if self.negative.swap_remove(&Type::AlwaysFalsy) => {
self.positive.insert(Type::LiteralString);
self.negative.insert(Type::string_literal(db, ""));
}

but this causes many property test failures, because we skip the simplifications done in the fallback branch of this function where redundant supertypes are removed if subtypes exist in the intersection, and where the entire intersection simplifies to Never if it contains a pair of disjoint types.

// `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
Type::Intersection(other) => {
for pos in other.positive(db) {
self.add_positive(db, *pos);
}
for neg in other.negative(db) {
self.add_negative(db, *neg);
}
self.positive.swap_remove_index(index);
break;
}

if addition_is_bool_instance {
for (index, existing_negative) in self.negative.iter().enumerate() {
match existing_negative {
// `bool & ~Literal[False]` -> `Literal[True]`
// `bool & ~Literal[True]` -> `Literal[False]`
Type::BooleanLiteral(bool_value) => {
new_positive = Type::BooleanLiteral(!bool_value);
_ => {
let addition_is_bool_instance = new_positive
.into_instance()
.and_then(|instance| instance.class.known(db))
.is_some_and(KnownClass::is_bool);

for (index, existing_positive) in self.positive.iter().enumerate() {
match existing_positive {
// `AlwaysTruthy & bool` -> `Literal[True]`
Type::AlwaysTruthy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(true);
}
// `bool & ~AlwaysTruthy` -> `Literal[False]`
Type::AlwaysTruthy => {
// `AlwaysFalsy & bool` -> `Literal[False]`
Type::AlwaysFalsy if addition_is_bool_instance => {
new_positive = Type::BooleanLiteral(false);
}
// `bool & ~AlwaysFalsy` -> `Literal[True]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(true);
Type::Instance(InstanceType { class })
if class.is_known(db, KnownClass::Bool) =>
{
match new_positive {
// `bool & AlwaysTruthy` -> `Literal[True]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(true);
}
// `bool & AlwaysFalsy` -> `Literal[False]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(false);
}
_ => continue,
}
}
_ => continue,
}
self.negative.swap_remove_index(index);
self.positive.swap_remove_index(index);
break;
}
} else if new_positive.is_literal_string() {
if self.negative.swap_remove(&Type::AlwaysTruthy) {
new_positive = Type::string_literal(db, "");
}
}

let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_positive) in self.positive.iter().enumerate() {
// S & T = S if S <: T
if existing_positive.is_subtype_of(db, new_positive)
|| existing_positive.is_same_gradual_form(new_positive)
{
return;
if addition_is_bool_instance {
for (index, existing_negative) in self.negative.iter().enumerate() {
match existing_negative {
// `bool & ~Literal[False]` -> `Literal[True]`
// `bool & ~Literal[True]` -> `Literal[False]`
Type::BooleanLiteral(bool_value) => {
new_positive = Type::BooleanLiteral(!bool_value);
}
// `bool & ~AlwaysTruthy` -> `Literal[False]`
Type::AlwaysTruthy => {
new_positive = Type::BooleanLiteral(false);
}
// `bool & ~AlwaysFalsy` -> `Literal[True]`
Type::AlwaysFalsy => {
new_positive = Type::BooleanLiteral(true);
}
_ => continue,
}
self.negative.swap_remove_index(index);
break;
}
}
// same rule, reverse order
if new_positive.is_subtype_of(db, *existing_positive) {
to_remove.push(index);

let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_positive) in self.positive.iter().enumerate() {
// S & T = S if S <: T
if existing_positive.is_subtype_of(db, new_positive)
|| existing_positive.is_same_gradual_form(new_positive)
{
return;
}
// same rule, reverse order
if new_positive.is_subtype_of(db, *existing_positive) {
to_remove.push(index);
}
// A & B = Never if A and B are disjoint
if new_positive.is_disjoint_from(db, *existing_positive) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
}
}
// A & B = Never if A and B are disjoint
if new_positive.is_disjoint_from(db, *existing_positive) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
for index in to_remove.into_iter().rev() {
self.positive.swap_remove_index(index);
}
}
for index in to_remove.into_iter().rev() {
self.positive.swap_remove_index(index);
}

let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_negative) in self.negative.iter().enumerate() {
// S & ~T = Never if S <: T
if new_positive.is_subtype_of(db, *existing_negative) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
let mut to_remove = SmallVec::<[usize; 1]>::new();
for (index, existing_negative) in self.negative.iter().enumerate() {
// S & ~T = Never if S <: T
if new_positive.is_subtype_of(db, *existing_negative) {
*self = Self::default();
self.positive.insert(Type::Never);
return;
}
// A & ~B = A if A and B are disjoint
if existing_negative.is_disjoint_from(db, new_positive) {
to_remove.push(index);
}
}
// A & ~B = A if A and B are disjoint
if existing_negative.is_disjoint_from(db, new_positive) {
to_remove.push(index);
for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(index);
}
}
for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(index);
}

self.positive.insert(new_positive);
self.positive.insert(new_positive);
}
}
}

Expand Down Expand Up @@ -438,8 +451,8 @@ impl<'db> InnerIntersectionBuilder<'db> {
return;
}
}
for index in to_remove.iter().rev() {
self.negative.swap_remove_index(*index);
for index in to_remove.into_iter().rev() {
self.negative.swap_remove_index(index);
}

for existing_positive in &self.positive {
Expand Down
Loading