From 1408876abed0c71406a9705e7c6bf2404fedbe92 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 15 Jan 2025 12:12:43 +0000 Subject: [PATCH] fix edge cases and move `LiteralString` logic into the `match` at the top --- .../resources/mdtest/intersection_types.md | 10 +- .../src/types/builder.rs | 227 +++++++++--------- 2 files changed, 129 insertions(+), 108 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/intersection_types.md b/crates/red_knot_python_semantic/resources/mdtest/intersection_types.md index 659f8110b57a0..a807c10a5939c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/intersection_types.md +++ b/crates/red_knot_python_semantic/resources/mdtest/intersection_types.md @@ -680,7 +680,7 @@ simplified, due to the fact that a `LiteralString` inhabitant is known to have ` exactly `str` (and not a subclass of `str`): ```py -from knot_extensions import Intersection, Not, AlwaysTruthy, AlwaysFalsy +from knot_extensions import Intersection, Not, AlwaysTruthy, AlwaysFalsy, Unknown from typing_extensions import LiteralString def f( @@ -690,6 +690,10 @@ def f( d: Intersection[LiteralString, Not[AlwaysFalsy]], e: Intersection[AlwaysFalsy, LiteralString], f: Intersection[Not[AlwaysTruthy], LiteralString], + g: Intersection[AlwaysTruthy, LiteralString], + h: Intersection[Not[AlwaysFalsy], LiteralString], + i: Intersection[Unknown, LiteralString, AlwaysFalsy], + j: Intersection[Not[AlwaysTruthy], Unknown, LiteralString], ): reveal_type(a) # revealed: LiteralString & ~Literal[""] reveal_type(b) # revealed: Literal[""] @@ -697,6 +701,10 @@ def f( reveal_type(d) # revealed: LiteralString & ~Literal[""] reveal_type(e) # revealed: Literal[""] reveal_type(f) # revealed: Literal[""] + reveal_type(g) # revealed: LiteralString & ~Literal[""] + reveal_type(h) # revealed: LiteralString & ~Literal[""] + reveal_type(i) # revealed: Unknown & Literal[""] + reveal_type(j) # revealed: Unknown & Literal[""] ``` ## Addition of a type to an intersection with many non-disjoint types diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 4be528787d2b9..67d5891b4a338 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -247,131 +247,144 @@ struct InnerIntersectionBuilder<'db> { impl<'db> InnerIntersectionBuilder<'db> { /// Adds a positive type to this intersection. fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) { - if new_positive == Type::AlwaysTruthy && self.positive.contains(&Type::LiteralString) { - self.add_negative(db, Type::string_literal(db, "")); - return; - } - - if let Type::Intersection(other) = new_positive { - for pos in other.positive(db) { - self.add_positive(db, *pos); + match new_positive { + // `LiteralString & AlwaysTruthy` -> `LiteralString & ~Literal[""]` + Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => { + self.add_negative(db, Type::string_literal(db, "")); } - for neg in other.negative(db) { - self.add_negative(db, *neg); + // `LiteralString & AlwaysFalsy` -> `Literal[""]` + Type::AlwaysFalsy if self.positive.swap_remove(&Type::LiteralString) => { + self.add_positive(db, Type::string_literal(db, "")); } - } else { - let addition_is_bool_instance = new_positive - .into_instance() - .and_then(|instance| instance.class.known(db)) - .is_some_and(KnownClass::is_bool); - - for (index, existing_positive) in self.positive.iter().enumerate() { - match existing_positive { - // `AlwaysTruthy & bool` -> `Literal[True]` - Type::AlwaysTruthy if addition_is_bool_instance => { - new_positive = Type::BooleanLiteral(true); - } - // `AlwaysFalsy & bool` -> `Literal[False]` - Type::AlwaysFalsy if addition_is_bool_instance => { - new_positive = Type::BooleanLiteral(false); - } - // `AlwaysFalsy & LiteralString` -> `Literal[""]` - Type::AlwaysFalsy if new_positive.is_literal_string() => { - new_positive = Type::string_literal(db, ""); - } - Type::Instance(InstanceType { class }) - if class.is_known(db, KnownClass::Bool) => - { - match new_positive { - // `bool & AlwaysTruthy` -> `Literal[True]` - Type::AlwaysTruthy => { - new_positive = Type::BooleanLiteral(true); - } - // `bool & AlwaysFalsy` -> `Literal[False]` - Type::AlwaysFalsy => { - new_positive = Type::BooleanLiteral(false); - } - _ => continue, - } - } - // `LiteralString & AlwaysFalsy` -> `Literal[""]` - Type::LiteralString if new_positive == Type::AlwaysFalsy => { - new_positive = Type::string_literal(db, ""); - } - _ => continue, + // `AlwaysTruthy & LiteralString` -> `LiteralString & ~Literal[""]` + Type::LiteralString if self.positive.swap_remove(&Type::AlwaysTruthy) => { + self.add_positive(db, Type::LiteralString); + self.add_negative(db, Type::string_literal(db, "")); + } + // `AlwaysFalsy & LiteralString` -> `Literal[""]` + Type::LiteralString if self.positive.swap_remove(&Type::AlwaysFalsy) => { + self.add_positive(db, Type::string_literal(db, "")); + } + // `LiteralString & ~AlwaysTruthy` -> `LiteralString & AlwaysFalsy` -> `Literal[""]` + Type::LiteralString if self.negative.swap_remove(&Type::AlwaysTruthy) => { + self.add_positive(db, Type::string_literal(db, "")); + } + // `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]` + Type::LiteralString if self.negative.swap_remove(&Type::AlwaysFalsy) => { + self.add_positive(db, Type::LiteralString); + self.add_negative(db, Type::string_literal(db, "")); + } + // `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F` + Type::Intersection(other) => { + for pos in other.positive(db) { + self.add_positive(db, *pos); + } + for neg in other.negative(db) { + self.add_negative(db, *neg); } - self.positive.swap_remove_index(index); - break; } - - if addition_is_bool_instance { - for (index, existing_negative) in self.negative.iter().enumerate() { - match existing_negative { - // `bool & ~Literal[False]` -> `Literal[True]` - // `bool & ~Literal[True]` -> `Literal[False]` - Type::BooleanLiteral(bool_value) => { - new_positive = Type::BooleanLiteral(!bool_value); + _ => { + let addition_is_bool_instance = new_positive + .into_instance() + .and_then(|instance| instance.class.known(db)) + .is_some_and(KnownClass::is_bool); + + for (index, existing_positive) in self.positive.iter().enumerate() { + match existing_positive { + // `AlwaysTruthy & bool` -> `Literal[True]` + Type::AlwaysTruthy if addition_is_bool_instance => { + new_positive = Type::BooleanLiteral(true); } - // `bool & ~AlwaysTruthy` -> `Literal[False]` - Type::AlwaysTruthy => { + // `AlwaysFalsy & bool` -> `Literal[False]` + Type::AlwaysFalsy if addition_is_bool_instance => { new_positive = Type::BooleanLiteral(false); } - // `bool & ~AlwaysFalsy` -> `Literal[True]` - Type::AlwaysFalsy => { - new_positive = Type::BooleanLiteral(true); + Type::Instance(InstanceType { class }) + if class.is_known(db, KnownClass::Bool) => + { + match new_positive { + // `bool & AlwaysTruthy` -> `Literal[True]` + Type::AlwaysTruthy => { + new_positive = Type::BooleanLiteral(true); + } + // `bool & AlwaysFalsy` -> `Literal[False]` + Type::AlwaysFalsy => { + new_positive = Type::BooleanLiteral(false); + } + _ => continue, + } } _ => continue, } - self.negative.swap_remove_index(index); + self.positive.swap_remove_index(index); break; } - } else if new_positive.is_literal_string() { - if self.negative.swap_remove(&Type::AlwaysTruthy) { - new_positive = Type::string_literal(db, ""); - } - } - let mut to_remove = SmallVec::<[usize; 1]>::new(); - for (index, existing_positive) in self.positive.iter().enumerate() { - // S & T = S if S <: T - if existing_positive.is_subtype_of(db, new_positive) - || existing_positive.is_same_gradual_form(new_positive) - { - return; + if addition_is_bool_instance { + for (index, existing_negative) in self.negative.iter().enumerate() { + match existing_negative { + // `bool & ~Literal[False]` -> `Literal[True]` + // `bool & ~Literal[True]` -> `Literal[False]` + Type::BooleanLiteral(bool_value) => { + new_positive = Type::BooleanLiteral(!bool_value); + } + // `bool & ~AlwaysTruthy` -> `Literal[False]` + Type::AlwaysTruthy => { + new_positive = Type::BooleanLiteral(false); + } + // `bool & ~AlwaysFalsy` -> `Literal[True]` + Type::AlwaysFalsy => { + new_positive = Type::BooleanLiteral(true); + } + _ => continue, + } + self.negative.swap_remove_index(index); + break; + } } - // same rule, reverse order - if new_positive.is_subtype_of(db, *existing_positive) { - to_remove.push(index); + + let mut to_remove = SmallVec::<[usize; 1]>::new(); + for (index, existing_positive) in self.positive.iter().enumerate() { + // S & T = S if S <: T + if existing_positive.is_subtype_of(db, new_positive) + || existing_positive.is_same_gradual_form(new_positive) + { + return; + } + // same rule, reverse order + if new_positive.is_subtype_of(db, *existing_positive) { + to_remove.push(index); + } + // A & B = Never if A and B are disjoint + if new_positive.is_disjoint_from(db, *existing_positive) { + *self = Self::default(); + self.positive.insert(Type::Never); + return; + } } - // A & B = Never if A and B are disjoint - if new_positive.is_disjoint_from(db, *existing_positive) { - *self = Self::default(); - self.positive.insert(Type::Never); - return; + for index in to_remove.into_iter().rev() { + self.positive.swap_remove_index(index); } - } - for index in to_remove.into_iter().rev() { - self.positive.swap_remove_index(index); - } - let mut to_remove = SmallVec::<[usize; 1]>::new(); - for (index, existing_negative) in self.negative.iter().enumerate() { - // S & ~T = Never if S <: T - if new_positive.is_subtype_of(db, *existing_negative) { - *self = Self::default(); - self.positive.insert(Type::Never); - return; + let mut to_remove = SmallVec::<[usize; 1]>::new(); + for (index, existing_negative) in self.negative.iter().enumerate() { + // S & ~T = Never if S <: T + if new_positive.is_subtype_of(db, *existing_negative) { + *self = Self::default(); + self.positive.insert(Type::Never); + return; + } + // A & ~B = A if A and B are disjoint + if existing_negative.is_disjoint_from(db, new_positive) { + to_remove.push(index); + } } - // A & ~B = A if A and B are disjoint - if existing_negative.is_disjoint_from(db, new_positive) { - to_remove.push(index); + for index in to_remove.into_iter().rev() { + self.negative.swap_remove_index(index); } - } - for index in to_remove.into_iter().rev() { - self.negative.swap_remove_index(index); - } - self.positive.insert(new_positive); + self.positive.insert(new_positive); + } } } @@ -438,8 +451,8 @@ impl<'db> InnerIntersectionBuilder<'db> { return; } } - for index in to_remove.iter().rev() { - self.negative.swap_remove_index(*index); + for index in to_remove.into_iter().rev() { + self.negative.swap_remove_index(index); } for existing_positive in &self.positive {