diff --git a/crates/red_knot_python_semantic/resources/mdtest/mro.md b/crates/red_knot_python_semantic/resources/mdtest/mro.md index 4ca16760ca..dfeb30b5eb 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/mro.md +++ b/crates/red_knot_python_semantic/resources/mdtest/mro.md @@ -396,11 +396,10 @@ class Foo: ... class BarCycle(FooCycle): ... # error: [cyclic-class-definition] class Bar(Foo): ... -# TODO: can we avoid emitting the errors for these? -# The classes have cyclic superclasses, +# Avoid emitting the errors for these. The classes have cyclic superclasses, # but are not themselves cyclic... -class Baz(Bar, BarCycle): ... # error: [cyclic-class-definition] -class Spam(Baz): ... # error: [cyclic-class-definition] +class Baz(Bar, BarCycle): ... +class Spam(Baz): ... reveal_type(FooCycle.__mro__) # revealed: tuple[Literal[FooCycle], Unknown, Literal[object]] reveal_type(BarCycle.__mro__) # revealed: tuple[Literal[BarCycle], Unknown, Literal[object]] diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 49747b1056..94ff773413 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -3729,6 +3729,22 @@ pub struct Class<'db> { known: Option, } +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +enum InheritanceCycle { + /// The class is cyclically defined and is a participant in the cycle. + /// i.e., it inherits either directly or indirectly from itself. + Participant, + /// The class inherits from a class that is a `Participant` in an inheritance cycle, + /// but is not itself a participant. + Inherited, +} + +impl InheritanceCycle { + const fn is_participant(self) -> bool { + matches!(self, InheritanceCycle::Participant) + } +} + #[salsa::tracked] impl<'db> Class<'db> { /// Return `true` if this class represents `known_class` @@ -3886,7 +3902,7 @@ impl<'db> Class<'db> { // Identify the class's own metaclass (or take the first base class's metaclass). let mut base_classes = self.fully_static_explicit_bases(db).peekable(); - if base_classes.peek().is_some() && self.is_cyclically_defined(db) { + if base_classes.peek().is_some() && self.inheritance_cycle(db).is_some() { // We emit diagnostics for cyclic class definitions elsewhere. // Avoid attempting to infer the metaclass if the class is cyclically defined: // it would be easy to enter an infinite loop. @@ -4122,37 +4138,50 @@ impl<'db> Class<'db> { } } - /// Return `true` if this class appears to be a cyclic definition, - /// i.e., it inherits either directly or indirectly from itself. + /// Return this class' involvement in an inheritance cycle, if any. /// /// A class definition like this will fail at runtime, /// but we must be resilient to it or we could panic. #[salsa::tracked] - fn is_cyclically_defined(self, db: &'db dyn Db) -> bool { + fn inheritance_cycle(self, db: &'db dyn Db) -> Option { + /// Return `true` if the class is cyclically defined. + /// + /// Also, populates `visited_classes` with all base classes of `self`. fn is_cyclically_defined_recursive<'db>( db: &'db dyn Db, class: Class<'db>, - classes_to_watch: &mut IndexSet>, + classes_on_stack: &mut IndexSet>, + visited_classes: &mut IndexSet>, ) -> bool { - if !classes_to_watch.insert(class) { - return true; - } + let mut result = false; for explicit_base_class in class.fully_static_explicit_bases(db) { - // Each base must be considered in isolation. - // This is due to the fact that if a class uses multiple inheritance, - // there could easily be a situation where two bases have the same class in their MROs; - // that isn't enough to constitute the class being cyclically defined. - let classes_to_watch_len = classes_to_watch.len(); - if is_cyclically_defined_recursive(db, explicit_base_class, classes_to_watch) { + if !classes_on_stack.insert(explicit_base_class) { return true; } - classes_to_watch.truncate(classes_to_watch_len); + + if visited_classes.insert(explicit_base_class) { + // If we find a cycle, keep searching to check if we can reach the starting class. + result |= is_cyclically_defined_recursive( + db, + explicit_base_class, + classes_on_stack, + visited_classes, + ); + } + + classes_on_stack.pop(); } - false + result } - self.fully_static_explicit_bases(db) - .any(|base_class| is_cyclically_defined_recursive(db, base_class, &mut IndexSet::new())) + let visited_classes = &mut IndexSet::new(); + if !is_cyclically_defined_recursive(db, self, &mut IndexSet::new(), visited_classes) { + None + } else if visited_classes.contains(&self) { + Some(InheritanceCycle::Participant) + } else { + Some(InheritanceCycle::Inherited) + } } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index e34790f111..576cde3832 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -576,16 +576,17 @@ impl<'db> TypeInferenceBuilder<'db> { // Iterate through all class definitions in this scope. for (class, class_node) in class_definitions { // (1) Check that the class does not have a cyclic definition - if class.is_cyclically_defined(self.db()) { - self.context.report_lint( - &CYCLIC_CLASS_DEFINITION, - class_node.into(), - format_args!( - "Cyclic definition of `{}` or bases of `{}` (class cannot inherit from itself)", - class.name(self.db()), - class.name(self.db()) - ), - ); + if let Some(inheritance_cycle) = class.inheritance_cycle(self.db()) { + if inheritance_cycle.is_participant() { + self.context.report_lint( + &CYCLIC_CLASS_DEFINITION, + class_node.into(), + format_args!( + "Cyclic definition of `{}` (class cannot inherit from itself)", + class.name(self.db()) + ), + ); + } // Attempting to determine the MRO of a class or if the class has a metaclass conflict // is impossible if the class is cyclically defined; there's nothing more to do here. continue; diff --git a/crates/red_knot_python_semantic/src/types/mro.rs b/crates/red_knot_python_semantic/src/types/mro.rs index a20c591656..9469da76f8 100644 --- a/crates/red_knot_python_semantic/src/types/mro.rs +++ b/crates/red_knot_python_semantic/src/types/mro.rs @@ -42,7 +42,7 @@ impl<'db> Mro<'db> { fn of_class_impl(db: &'db dyn Db, class: Class<'db>) -> Result> { let class_bases = class.explicit_bases(db); - if !class_bases.is_empty() && class.is_cyclically_defined(db) { + if !class_bases.is_empty() && class.inheritance_cycle(db).is_some() { // We emit errors for cyclically defined classes elsewhere. // It's important that we don't even try to infer the MRO for a cyclically defined class, // or we'll end up in an infinite loop.