[ty] Detect invalid @total_ordering applications in non-decorator contexts (#22486)
## Summary E.g., `ValidOrderedClass = total_ordering(HasOrderingMethod)`.
This commit is contained in:
@@ -244,3 +244,36 @@ reveal_type(n1 > n2) # revealed: bool
|
||||
n1 <= n2 # error: [unsupported-operator]
|
||||
n1 >= n2 # error: [unsupported-operator]
|
||||
```
|
||||
|
||||
## Function call form
|
||||
|
||||
When `total_ordering` is called as a function (not as a decorator), the same validation is
|
||||
performed:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
class NoOrderingMethod:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return True
|
||||
|
||||
# error: [invalid-total-ordering]
|
||||
InvalidOrderedClass = total_ordering(NoOrderingMethod)
|
||||
```
|
||||
|
||||
When the class does define an ordering method, no error is emitted:
|
||||
|
||||
```py
|
||||
from functools import total_ordering
|
||||
|
||||
class HasOrderingMethod:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return True
|
||||
|
||||
def __lt__(self, other: "HasOrderingMethod") -> bool:
|
||||
return True
|
||||
|
||||
# No error (class defines `__lt__`).
|
||||
ValidOrderedClass = total_ordering(HasOrderingMethod)
|
||||
reveal_type(ValidOrderedClass) # revealed: type[HasOrderingMethod]
|
||||
```
|
||||
|
||||
@@ -604,6 +604,13 @@ impl<'db> ClassType<'db> {
|
||||
class_literal.is_final(db)
|
||||
}
|
||||
|
||||
/// Returns `true` if any class in this class's MRO (excluding `object`) defines an ordering
|
||||
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation.
|
||||
pub(super) fn has_ordering_method_in_mro(self, db: &'db dyn Db) -> bool {
|
||||
let (class_literal, specialization) = self.class_literal(db);
|
||||
class_literal.has_ordering_method_in_mro(db, specialization)
|
||||
}
|
||||
|
||||
/// Return `true` if `other` is present in this class's MRO.
|
||||
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
|
||||
self.when_subclass_of(db, other, InferableTypeVars::None)
|
||||
@@ -1559,6 +1566,20 @@ impl<'db> ClassLiteral<'db> {
|
||||
.any(|method| !class_member(db, body_scope, method).is_undefined())
|
||||
}
|
||||
|
||||
/// Returns `true` if any class in this class's MRO (excluding `object`) defines an ordering
|
||||
/// method (`__lt__`, `__le__`, `__gt__`, `__ge__`). Used by `@total_ordering` validation and
|
||||
/// for synthesizing comparison methods.
|
||||
pub(super) fn has_ordering_method_in_mro(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
specialization: Option<Specialization<'db>>,
|
||||
) -> bool {
|
||||
self.iter_mro(db, specialization)
|
||||
.filter_map(ClassBase::into_class)
|
||||
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
|
||||
.any(|class| class.class_literal(db).0.has_own_ordering_method(db))
|
||||
}
|
||||
|
||||
pub(crate) fn generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
|
||||
// Several typeshed definitions examine `sys.version_info`. To break cycles, we hard-code
|
||||
// the knowledge that this class is not generic.
|
||||
@@ -2409,15 +2430,7 @@ impl<'db> ClassLiteral<'db> {
|
||||
// __le__, __gt__, or __ge__ to be defined (either in this class or
|
||||
// inherited from a superclass, excluding `object`).
|
||||
if self.total_ordering(db) && matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
|
||||
// Check if any class in the MRO (excluding object) defines at least one
|
||||
// ordering method in its own body (not synthesized).
|
||||
let has_ordering_method = self
|
||||
.iter_mro(db, specialization)
|
||||
.filter_map(super::class_base::ClassBase::into_class)
|
||||
.filter(|class| !class.class_literal(db).0.is_known(db, KnownClass::Object))
|
||||
.any(|class| class.class_literal(db).0.has_own_ordering_method(db));
|
||||
|
||||
if has_ordering_method {
|
||||
if self.has_ordering_method_in_mro(db, specialization) {
|
||||
let instance_ty =
|
||||
Type::instance(db, self.apply_optional_specialization(db, specialization));
|
||||
|
||||
|
||||
@@ -4684,6 +4684,29 @@ pub(super) fn report_invalid_total_ordering(
|
||||
diagnostic.info("The decorator will raise `ValueError` at runtime");
|
||||
}
|
||||
|
||||
/// Reports an invalid `total_ordering(cls)` function call where the class
|
||||
/// does not define any ordering method.
|
||||
pub(super) fn report_invalid_total_ordering_call(
|
||||
context: &InferContext<'_, '_>,
|
||||
class: ClassLiteral<'_>,
|
||||
call_expression: &ast::ExprCall,
|
||||
) {
|
||||
let db = context.db();
|
||||
|
||||
let Some(builder) = context.report_lint(&INVALID_TOTAL_ORDERING, call_expression) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut diagnostic = builder.into_diagnostic(
|
||||
"`@functools.total_ordering` requires at least one ordering method (`__lt__`, `__le__`, `__gt__`, or `__ge__`) to be defined",
|
||||
);
|
||||
diagnostic.set_primary_message(format_args!(
|
||||
"`{}` does not define `__lt__`, `__le__`, `__gt__`, or `__ge__`",
|
||||
class.name(db)
|
||||
));
|
||||
diagnostic.info("The function will raise `ValueError` at runtime");
|
||||
}
|
||||
|
||||
/// This function receives an unresolved `from foo import bar` import,
|
||||
/// where `foo` can be resolved to a module but that module does not
|
||||
/// have a `bar` member or submodule.
|
||||
|
||||
@@ -70,6 +70,7 @@ use crate::types::context::InferContext;
|
||||
use crate::types::diagnostic::{
|
||||
INVALID_ARGUMENT_TYPE, REDUNDANT_CAST, STATIC_ASSERT_ERROR, TYPE_ASSERTION_FAILURE,
|
||||
report_bad_argument_to_get_protocol_members, report_bad_argument_to_protocol_interface,
|
||||
report_invalid_total_ordering_call,
|
||||
report_runtime_check_against_non_runtime_checkable_protocol,
|
||||
};
|
||||
use crate::types::display::DisplaySettings;
|
||||
@@ -2005,6 +2006,29 @@ impl KnownFunction {
|
||||
|
||||
overload.set_return_type(Type::module_literal(db, file, module));
|
||||
}
|
||||
|
||||
KnownFunction::TotalOrdering => {
|
||||
// When `total_ordering(cls)` is called as a function (not as a decorator),
|
||||
// check that the class defines at least one ordering method.
|
||||
let [Some(class_type)] = parameter_types else {
|
||||
return;
|
||||
};
|
||||
|
||||
let class = match class_type {
|
||||
Type::ClassLiteral(class) => ClassType::NonGeneric(*class),
|
||||
Type::GenericAlias(generic) => ClassType::Generic(*generic),
|
||||
_ => return,
|
||||
};
|
||||
|
||||
if !class.has_ordering_method_in_mro(db) {
|
||||
report_invalid_total_ordering_call(
|
||||
context,
|
||||
class.class_literal(db).0,
|
||||
call_expression,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -854,34 +854,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
}
|
||||
|
||||
// (5) Check that @total_ordering has a valid ordering method in the MRO
|
||||
if class.total_ordering(self.db()) {
|
||||
let has_ordering_method = class
|
||||
.iter_mro(self.db(), None)
|
||||
.filter_map(super::super::class_base::ClassBase::into_class)
|
||||
.filter(|base_class| {
|
||||
!base_class
|
||||
.class_literal(self.db())
|
||||
.0
|
||||
.is_known(self.db(), KnownClass::Object)
|
||||
})
|
||||
.any(|base_class| {
|
||||
base_class
|
||||
.class_literal(self.db())
|
||||
.0
|
||||
.has_own_ordering_method(self.db())
|
||||
});
|
||||
|
||||
if !has_ordering_method {
|
||||
// Find the @total_ordering decorator to report the diagnostic at its location
|
||||
if let Some(decorator) = class_node.decorator_list.iter().find(|decorator| {
|
||||
self.expression_type(&decorator.expression)
|
||||
.as_function_literal()
|
||||
.is_some_and(|function| {
|
||||
function.is_known(self.db(), KnownFunction::TotalOrdering)
|
||||
})
|
||||
}) {
|
||||
report_invalid_total_ordering(&self.context, class, decorator);
|
||||
}
|
||||
if class.total_ordering(self.db()) && !class.has_ordering_method_in_mro(self.db(), None)
|
||||
{
|
||||
// Find the @total_ordering decorator to report the diagnostic at its location
|
||||
if let Some(decorator) = class_node.decorator_list.iter().find(|decorator| {
|
||||
self.expression_type(&decorator.expression)
|
||||
.as_function_literal()
|
||||
.is_some_and(|function| {
|
||||
function.is_known(self.db(), KnownFunction::TotalOrdering)
|
||||
})
|
||||
}) {
|
||||
report_invalid_total_ordering(&self.context, class, decorator);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user