[red-knot] Infer the members of a protocol class (#17556)

This commit is contained in:
Alex Waygood
2025-04-23 22:36:12 +01:00
committed by GitHub
parent 7b6222700b
commit 00e73dc331
4 changed files with 205 additions and 28 deletions

View File

@@ -437,6 +437,15 @@ impl<'db> UseDefMap<'db> {
.map(|symbol_id| (symbol_id, self.public_declarations(symbol_id)))
}
pub(crate) fn all_public_bindings<'map>(
&'map self,
) -> impl Iterator<Item = (ScopedSymbolId, BindingWithConstraintsIterator<'map, 'db>)> + 'map
{
(0..self.public_symbols.len())
.map(ScopedSymbolId::from_usize)
.map(|symbol_id| (symbol_id, self.public_bindings(symbol_id)))
}
/// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`.
pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool {
!self

View File

@@ -20,8 +20,8 @@ use crate::types::generics::{Specialization, SpecializationBuilder};
use crate::types::signatures::{Parameter, ParameterForm};
use crate::types::{
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass,
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType,
WrapperDescriptorKind,
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, TupleType,
UnionType, WrapperDescriptorKind,
};
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
use ruff_python_ast as ast;
@@ -561,6 +561,22 @@ impl<'db> Bindings<'db> {
}
}
Some(KnownFunction::GetProtocolMembers) => {
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
if let Some(protocol_class) = class.into_protocol_class(db) {
// TODO: actually a frozenset at runtime (requires support for legacy generic classes)
overload.set_return_type(Type::Tuple(TupleType::new(
db,
protocol_class
.protocol_members(db)
.iter()
.map(|member| Type::string_literal(db, member))
.collect::<Box<[Type<'db>]>>(),
)));
}
}
}
Some(KnownFunction::Overload) => {
// TODO: This can be removed once we understand legacy generics because the
// typeshed definition for `typing.overload` is an identity function.

View File

@@ -1,4 +1,5 @@
use std::hash::BuildHasherDefault;
use std::ops::Deref;
use std::sync::{LazyLock, Mutex};
use super::{
@@ -13,6 +14,7 @@ use crate::types::signatures::{Parameter, Parameters};
use crate::types::{
CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature,
};
use crate::FxOrderSet;
use crate::{
module_resolver::file_to_module,
semantic_index::{
@@ -1710,6 +1712,11 @@ impl<'db> ClassLiteralType<'db> {
Some(InheritanceCycle::Inherited)
}
}
/// Returns `Some` if this is a protocol class, `None` otherwise.
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
}
}
impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
@@ -1721,6 +1728,125 @@ impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
}
}
/// Representation of a single `Protocol` class definition.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteralType<'db>);
impl<'db> ProtocolClassLiteral<'db> {
/// Returns the protocol members of this class.
///
/// A protocol's members define the interface declared by the protocol.
/// They therefore determine how the protocol should behave with regards to
/// assignability and subtyping.
///
/// The list of members consists of all bindings and declarations that take place
/// in the protocol's class body, except for a list of excluded attributes which should
/// not be taken into account. (This list includes `__init__` and `__new__`, which can
/// legally be defined on protocol classes but do not constitute protocol members.)
///
/// It is illegal for a protocol class to have any instance attributes that are not declared
/// in the protocol's class body. If any are assigned to, they are not taken into account in
/// the protocol's list of members.
pub(super) fn protocol_members(self, db: &'db dyn Db) -> &'db ordermap::set::Slice<Name> {
/// The list of excluded members is subject to change between Python versions,
/// especially for dunders, but it probably doesn't matter *too* much if this
/// list goes out of date. It's up to date as of Python commit 87b1ea016b1454b1e83b9113fa9435849b7743aa
/// (<https://github.com/python/cpython/blob/87b1ea016b1454b1e83b9113fa9435849b7743aa/Lib/typing.py#L1776-L1791>)
fn excluded_from_proto_members(member: &str) -> bool {
matches!(
member,
"_is_protocol"
| "__non_callable_proto_members__"
| "__static_attributes__"
| "__orig_class__"
| "__match_args__"
| "__weakref__"
| "__doc__"
| "__parameters__"
| "__module__"
| "_MutableMapping__marker"
| "__slots__"
| "__dict__"
| "__new__"
| "__protocol_attrs__"
| "__init__"
| "__class_getitem__"
| "__firstlineno__"
| "__abstractmethods__"
| "__orig_bases__"
| "_is_runtime_protocol"
| "__subclasshook__"
| "__type_params__"
| "__annotations__"
| "__annotate__"
| "__annotate_func__"
| "__annotations_cache__"
)
}
#[salsa::tracked(return_ref)]
fn cached_protocol_members<'db>(
db: &'db dyn Db,
class: ClassLiteralType<'db>,
) -> Box<ordermap::set::Slice<Name>> {
let mut members = FxOrderSet::default();
for parent_protocol in class
.iter_mro(db, None)
.filter_map(ClassBase::into_class)
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
{
let parent_scope = parent_protocol.body_scope(db);
let use_def_map = use_def_map(db, parent_scope);
let symbol_table = symbol_table(db, parent_scope);
members.extend(
use_def_map
.all_public_declarations()
.flat_map(|(symbol_id, declarations)| {
symbol_from_declarations(db, declarations)
.map(|symbol| (symbol_id, symbol))
})
.filter_map(|(symbol_id, symbol)| {
symbol.symbol.ignore_possibly_unbound().map(|_| symbol_id)
})
// Bindings in the class body that are not declared in the class body
// are not valid protocol members, and we plan to emit diagnostics for them
// elsewhere. Invalid or not, however, it's important that we still consider
// them to be protocol members. The implementation of `issubclass()` and
// `isinstance()` for runtime-checkable protocols considers them to be protocol
// members at runtime, and it's important that we accurately understand
// type narrowing that uses `isinstance()` or `issubclass()` with
// runtime-checkable protocols.
.chain(use_def_map.all_public_bindings().filter_map(
|(symbol_id, bindings)| {
symbol_from_bindings(db, bindings)
.ignore_possibly_unbound()
.map(|_| symbol_id)
},
))
.map(|symbol_id| symbol_table.symbol(symbol_id).name())
.filter(|name| !excluded_from_proto_members(name))
.cloned(),
);
}
members.sort();
members.into_boxed_slice()
}
cached_protocol_members(db, *self)
}
}
impl<'db> Deref for ProtocolClassLiteral<'db> {
type Target = ClassLiteralType<'db>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub(super) enum InheritanceCycle {
/// The class is cyclically defined and is a participant in the cycle.