[ty] Detect invalid @total_ordering applications in non-decorator contexts (#22486)

## Summary

E.g., `ValidOrderedClass = total_ordering(HasOrderingMethod)`.
This commit is contained in:
Charlie Marsh
2026-01-09 19:37:58 -05:00
committed by GitHub
parent c88e1a0663
commit 11cc324449
5 changed files with 113 additions and 37 deletions

View File

@@ -244,3 +244,36 @@ reveal_type(n1 > n2) # revealed: bool
n1 <= n2 # error: [unsupported-operator]
n1 >= n2 # error: [unsupported-operator]
```
## Function call form
When `total_ordering` is called as a function (not as a decorator), the same validation is
performed:
```py
from functools import total_ordering
class NoOrderingMethod:
def __eq__(self, other: object) -> bool:
return True
# error: [invalid-total-ordering]
InvalidOrderedClass = total_ordering(NoOrderingMethod)
```
When the class does define an ordering method, no error is emitted:
```py
from functools import total_ordering
class HasOrderingMethod:
def __eq__(self, other: object) -> bool:
return True
def __lt__(self, other: "HasOrderingMethod") -> bool:
return True
# No error (class defines `__lt__`).
ValidOrderedClass = total_ordering(HasOrderingMethod)
reveal_type(ValidOrderedClass) # revealed: type[HasOrderingMethod]
```

View File

@@ -604,6 +604,13 @@ impl<'db> ClassType<'db> {
class_literal.is_final(db)
}
/// Returns `true` if any class in this class's MRO (excluding `object`) defines an ordering
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation.
pub(super) fn has_ordering_method_in_mro(self, db: &'db dyn Db) -> bool {
let (class_literal, specialization) = self.class_literal(db);
class_literal.has_ordering_method_in_mro(db, specialization)
}
/// Return `true` if `other` is present in this class's MRO.
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
self.when_subclass_of(db, other, InferableTypeVars::None)
@@ -1559,6 +1566,20 @@ impl<'db> ClassLiteral<'db> {
.any(|method| !class_member(db, body_scope, method).is_undefined())
}
/// Returns `true` if any class in this class's MRO (excluding `object`) defines an ordering
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation and
/// for synthesizing comparison methods.
pub(super) fn has_ordering_method_in_mro(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> bool {
self.iter_mro(db, specialization)
.filter_map(ClassBase::into_class)
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
.any(|class| class.class_literal(db).0.has_own_ordering_method(db))
}
pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
// Several typeshed definitions examine `sys.version_info`. To break cycles, we hard-code
// the knowledge that this class is not generic.
@@ -2409,15 +2430,7 @@ impl<'db> ClassLiteral<'db> {
// __le__, __gt__, or __ge__ to be defined (either in this class or
// inherited from a superclass, excluding `object`).
if self.total_ordering(db) && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
// Check if any class in the MRO (excluding object) defines at least one
// ordering method in its own body (not synthesized).
let has_ordering_method = self
.iter_mro(db, specialization)
.filter_map(super::class_base::ClassBase::into_class)
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
.any(|class| class.class_literal(db).0.has_own_ordering_method(db));
if has_ordering_method {
if self.has_ordering_method_in_mro(db, specialization) {
let instance_ty =
Type::instance(db, self.apply_optional_specialization(db, specialization));

View File

@@ -4684,6 +4684,29 @@ pub(super) fn report_invalid_total_ordering(
diagnostic.info("The decorator will raise `ValueError` at runtime");
}
/// Reports an invalid `total_ordering(cls)` function call where the class
/// does not define any ordering method.
pub(super) fn report_invalid_total_ordering_call(
context: &InferContext<'_, '_>,
class: ClassLiteral<'_>,
call_expression: &ast::ExprCall,
) {
let db = context.db();
let Some(builder) = context.report_lint(&INVALID_TOTAL_ORDERING, call_expression) else {
return;
};
let mut diagnostic = builder.into_diagnostic(
"`@functools.total_ordering` requires at least one ordering method (`__lt__`, `__le__`, `__gt__`, or `__ge__`) to be defined",
);
diagnostic.set_primary_message(format_args!(
"`{}` does not define `__lt__`, `__le__`, `__gt__`, or `__ge__`",
class.name(db)
));
diagnostic.info("The function will raise `ValueError` at runtime");
}
/// This function receives an unresolved `from foo import bar` import,
/// where `foo` can be resolved to a module but that module does not
/// have a `bar` member or submodule.

View File

@@ -70,6 +70,7 @@ use crate::types::context::InferContext;
use crate::types::diagnostic::{
INVALID_ARGUMENT_TYPE, REDUNDANT_CAST, STATIC_ASSERT_ERROR, TYPE_ASSERTION_FAILURE,
report_bad_argument_to_get_protocol_members, report_bad_argument_to_protocol_interface,
report_invalid_total_ordering_call,
report_runtime_check_against_non_runtime_checkable_protocol,
};
use crate::types::display::DisplaySettings;
@@ -2005,6 +2006,29 @@ impl KnownFunction {
overload.set_return_type(Type::module_literal(db, file, module));
}
KnownFunction::TotalOrdering => {
// When `total_ordering(cls)` is called as a function (not as a decorator),
// check that the class defines at least one ordering method.
let [Some(class_type)] = parameter_types else {
return;
};
let class = match class_type {
Type::ClassLiteral(class) => ClassType::NonGeneric(*class),
Type::GenericAlias(generic) => ClassType::Generic(*generic),
_ => return,
};
if !class.has_ordering_method_in_mro(db) {
report_invalid_total_ordering_call(
context,
class.class_literal(db).0,
call_expression,
);
}
}
_ => {}
}
}

View File

@@ -854,34 +854,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
// (5) Check that @total_ordering has a valid ordering method in the MRO
if class.total_ordering(self.db()) {
let has_ordering_method = class
.iter_mro(self.db(), None)
.filter_map(super::super::class_base::ClassBase::into_class)
.filter(|base_class| {
!base_class
.class_literal(self.db())
.0
.is_known(self.db(), KnownClass::Object)
})
.any(|base_class| {
base_class
.class_literal(self.db())
.0
.has_own_ordering_method(self.db())
});
if !has_ordering_method {
// Find the @total_ordering decorator to report the diagnostic at its location
if let Some(decorator) = class_node.decorator_list.iter().find(|decorator| {
self.expression_type(&decorator.expression)
.as_function_literal()
.is_some_and(|function| {
function.is_known(self.db(), KnownFunction::TotalOrdering)
})
}) {
report_invalid_total_ordering(&self.context, class, decorator);
}
if class.total_ordering(self.db()) && !class.has_ordering_method_in_mro(self.db(), None)
{
// Find the @total_ordering decorator to report the diagnostic at its location
if let Some(decorator) = class_node.decorator_list.iter().find(|decorator| {
self.expression_type(&decorator.expression)
.as_function_literal()
.is_some_and(|function| {
function.is_known(self.db(), KnownFunction::TotalOrdering)
})
}) {
report_invalid_total_ordering(&self.context, class, decorator);
}
}