From d2a238dfad2bd5d3572a7a8a5b11fdf91d22b8f8 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Thu, 1 May 2025 01:33:21 +0530 Subject: [PATCH] [red-knot] Update call binding to return all matching overloads (#17618) ## Summary This PR updates the existing overload matching methods to return an iterator of all the matched overloads instead. This would be useful once the overload call evaluation algorithm is implemented which should provide an accurate picture of all the matched overloads. The return type would then be picked from either the only matched overload or the first overload from the ones that are matched. In an earlier version of this PR, it tried to check if using an intersection of return types from the matched overload would help reduce the false positives but that's not enough. [This comment](https://github.com/astral-sh/ruff/pull/17618#issuecomment-2842891696) keep the ecosystem analysis for that change for prosperity. > [!NOTE] > > The best way to review this PR is by hiding the whitespace changes because there are two instances where a large match expression is indented to be inside a loop over matching overlods > > Screenshot 2025-04-28 at 15 12 16 ## Test Plan Make sure existing test cases are unaffected and no ecosystem changes. --- crates/red_knot_python_semantic/src/types.rs | 42 +- .../src/types/call/bind.rs | 1126 +++++++++-------- .../src/types/infer.rs | 593 ++++----- 3 files changed, 903 insertions(+), 858 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index f4d12a06f5..a942a4acd9 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -4527,27 +4527,43 @@ impl<'db> Type<'db> { new_call_outcome @ (None | Some(Ok(_))), init_call_outcome @ (None | Some(Ok(_))), ) => { + fn combine_specializations<'db>( + db: &'db dyn Db, + s1: Option>, + s2: Option>, + ) -> Option> { + match (s1, s2) { + (None, None) => None, + (Some(s), None) | (None, Some(s)) => Some(s), + (Some(s1), Some(s2)) => Some(s1.combine(db, s2)), + } + } + + fn combine_binding_specialization<'db>( + db: &'db dyn Db, + binding: &CallableBinding<'db>, + ) -> Option> { + binding + .matching_overloads() + .map(|(_, binding)| binding.inherited_specialization()) + .reduce(|acc, specialization| { + combine_specializations(db, acc, specialization) + }) + .flatten() + } + let new_specialization = new_call_outcome .and_then(Result::ok) .as_ref() .and_then(Bindings::single_element) - .and_then(CallableBinding::matching_overload) - .and_then(|(_, binding)| binding.inherited_specialization()); + .and_then(|binding| combine_binding_specialization(db, binding)); let init_specialization = init_call_outcome .and_then(Result::ok) .as_ref() .and_then(Bindings::single_element) - .and_then(CallableBinding::matching_overload) - .and_then(|(_, binding)| binding.inherited_specialization()); - let specialization = match (new_specialization, init_specialization) { - (None, None) => None, - (Some(specialization), None) | (None, Some(specialization)) => { - Some(specialization) - } - (Some(new_specialization), Some(init_specialization)) => { - Some(new_specialization.combine(db, init_specialization)) - } - }; + .and_then(|binding| combine_binding_specialization(db, binding)); + let specialization = + combine_specializations(db, new_specialization, init_specialization); let specialized = specialization .map(|specialization| { Type::instance( diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 09bf3b0d34..3bac167755 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -221,587 +221,605 @@ impl<'db> Bindings<'db> { // Each special case listed here should have a corresponding clause in `Type::signatures`. for (binding, callable_signature) in self.elements.iter_mut().zip(self.signatures.iter()) { let binding_type = binding.callable_type; - let Some((overload_index, overload)) = binding.matching_overload_mut() else { - continue; - }; - - match binding_type { - Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { - if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { - match overload.parameter_types() { - [_, Some(owner)] => { - overload.set_return_type(Type::BoundMethod(BoundMethodType::new( - db, function, *owner, - ))); - } - [Some(instance), None] => { - overload.set_return_type(Type::BoundMethod(BoundMethodType::new( - db, - function, - instance.to_meta_type(db), - ))); - } - _ => {} - } - } else if let [Some(first), _] = overload.parameter_types() { - if first.is_none(db) { - overload.set_return_type(Type::FunctionLiteral(function)); - } else { - overload.set_return_type(Type::BoundMethod(BoundMethodType::new( - db, function, *first, - ))); - } - } - } - - Type::WrapperDescriptor(WrapperDescriptorKind::FunctionTypeDunderGet) => { - if let [Some(function_ty @ Type::FunctionLiteral(function)), ..] = - overload.parameter_types() - { + for (overload_index, overload) in binding.matching_overloads_mut() { + match binding_type { + Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { match overload.parameter_types() { - [_, _, Some(owner)] => { + [_, Some(owner)] => { overload.set_return_type(Type::BoundMethod( - BoundMethodType::new(db, *function, *owner), + BoundMethodType::new(db, function, *owner), )); } - - [_, Some(instance), None] => { + [Some(instance), None] => { overload.set_return_type(Type::BoundMethod( BoundMethodType::new( db, - *function, + function, instance.to_meta_type(db), ), )); } - _ => {} } - } else { - match overload.parameter_types() { - [_, Some(instance), _] if instance.is_none(db) => { - overload.set_return_type(*function_ty); - } - [_, Some(instance), _] => { - overload.set_return_type(Type::BoundMethod( - BoundMethodType::new(db, *function, *instance), - )); - } - - _ => {} - } - } - } - } - - Type::WrapperDescriptor(WrapperDescriptorKind::PropertyDunderGet) => match overload - .parameter_types() - { - [Some(property @ Type::PropertyInstance(_)), Some(instance), ..] - if instance.is_none(db) => - { - overload.set_return_type(*property); - } - [Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeAliasType(type_alias))), ..] - if property.getter(db).is_some_and(|getter| { - getter - .into_function_literal() - .is_some_and(|f| f.name(db) == "__name__") - }) => - { - overload.set_return_type(Type::string_literal(db, type_alias.name(db))); - } - [Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeVar(typevar))), ..] => { - match property - .getter(db) - .and_then(Type::into_function_literal) - .map(|f| f.name(db).as_str()) - { - Some("__name__") => { - overload - .set_return_type(Type::string_literal(db, typevar.name(db))); - } - Some("__bound__") => { - overload.set_return_type( - typevar.upper_bound(db).unwrap_or_else(|| Type::none(db)), - ); - } - Some("__constraints__") => { - overload.set_return_type(TupleType::from_elements( - db, - typevar.constraints(db).into_iter().flatten(), - )); - } - Some("__default__") => { - overload.set_return_type( - typevar.default_ty(db).unwrap_or_else(|| { - KnownClass::NoDefaultType.to_instance(db) - }), - ); - } - _ => {} - } - } - [Some(Type::PropertyInstance(property)), Some(instance), ..] => { - if let Some(getter) = property.getter(db) { - if let Ok(return_ty) = getter - .try_call(db, &mut CallArgumentTypes::positional([*instance])) - .map(|binding| binding.return_type(db)) - { - overload.set_return_type(return_ty); + } else if let [Some(first), _] = overload.parameter_types() { + if first.is_none(db) { + overload.set_return_type(Type::FunctionLiteral(function)); } else { - overload.errors.push(BindingError::InternalCallError( - "calling the getter failed", - )); - overload.set_return_type(Type::unknown()); - } - } else { - overload - .errors - .push(BindingError::InternalCallError("property has no getter")); - overload.set_return_type(Type::Never); - } - } - _ => {} - }, - - Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(property)) => { - match overload.parameter_types() { - [Some(instance), ..] if instance.is_none(db) => { - overload.set_return_type(Type::PropertyInstance(property)); - } - [Some(instance), ..] => { - if let Some(getter) = property.getter(db) { - if let Ok(return_ty) = getter - .try_call(db, &mut CallArgumentTypes::positional([*instance])) - .map(|binding| binding.return_type(db)) - { - overload.set_return_type(return_ty); - } else { - overload.errors.push(BindingError::InternalCallError( - "calling the getter failed", - )); - overload.set_return_type(Type::unknown()); - } - } else { - overload.set_return_type(Type::Never); - overload.errors.push(BindingError::InternalCallError( - "property has no getter", - )); - } - } - _ => {} - } - } - - Type::WrapperDescriptor(WrapperDescriptorKind::PropertyDunderSet) => { - if let [Some(Type::PropertyInstance(property)), Some(instance), Some(value), ..] = - overload.parameter_types() - { - if let Some(setter) = property.setter(db) { - if let Err(_call_error) = setter.try_call( - db, - &mut CallArgumentTypes::positional([*instance, *value]), - ) { - overload.errors.push(BindingError::InternalCallError( - "calling the setter failed", - )); - } - } else { - overload - .errors - .push(BindingError::InternalCallError("property has no setter")); - } - } - } - - Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(property)) => { - if let [Some(instance), Some(value), ..] = overload.parameter_types() { - if let Some(setter) = property.setter(db) { - if let Err(_call_error) = setter.try_call( - db, - &mut CallArgumentTypes::positional([*instance, *value]), - ) { - overload.errors.push(BindingError::InternalCallError( - "calling the setter failed", - )); - } - } else { - overload - .errors - .push(BindingError::InternalCallError("property has no setter")); - } - } - } - - Type::MethodWrapper(MethodWrapperKind::StrStartswith(literal)) => { - if let [Some(Type::StringLiteral(prefix)), None, None] = - overload.parameter_types() - { - overload.set_return_type(Type::BooleanLiteral( - literal.value(db).starts_with(&**prefix.value(db)), - )); - } - } - - Type::DataclassTransformer(params) => { - if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() { - overload.set_return_type(Type::FunctionLiteral( - function.with_dataclass_transformer_params(db, params), - )); - } - } - - Type::BoundMethod(bound_method) - if bound_method.self_instance(db).is_property_instance() => - { - match bound_method.function(db).name(db).as_str() { - "setter" => { - if let [Some(_), Some(setter)] = overload.parameter_types() { - let mut ty_property = bound_method.self_instance(db); - if let Type::PropertyInstance(property) = ty_property { - ty_property = - Type::PropertyInstance(PropertyInstanceType::new( - db, - property.getter(db), - Some(*setter), - )); - } - overload.set_return_type(ty_property); - } - } - "getter" => { - if let [Some(_), Some(getter)] = overload.parameter_types() { - let mut ty_property = bound_method.self_instance(db); - if let Type::PropertyInstance(property) = ty_property { - ty_property = - Type::PropertyInstance(PropertyInstanceType::new( - db, - Some(*getter), - property.setter(db), - )); - } - overload.set_return_type(ty_property); - } - } - "deleter" => { - // TODO: we do not store deleters yet - let ty_property = bound_method.self_instance(db); - overload.set_return_type(ty_property); - } - _ => { - // Fall back to typeshed stubs for all other methods - } - } - } - - Type::FunctionLiteral(function_type) => match function_type.known(db) { - Some(KnownFunction::IsEquivalentTo) => { - if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral( - ty_a.is_equivalent_to(db, *ty_b), - )); - } - } - - Some(KnownFunction::IsSubtypeOf) => { - if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral( - ty_a.is_subtype_of(db, *ty_b), - )); - } - } - - Some(KnownFunction::IsAssignableTo) => { - if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral( - ty_a.is_assignable_to(db, *ty_b), - )); - } - } - - Some(KnownFunction::IsDisjointFrom) => { - if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral( - ty_a.is_disjoint_from(db, *ty_b), - )); - } - } - - Some(KnownFunction::IsGradualEquivalentTo) => { - if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral( - ty_a.is_gradual_equivalent_to(db, *ty_b), - )); - } - } - - Some(KnownFunction::IsFullyStatic) => { - if let [Some(ty)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral(ty.is_fully_static(db))); - } - } - - Some(KnownFunction::IsSingleton) => { - if let [Some(ty)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral(ty.is_singleton(db))); - } - } - - Some(KnownFunction::IsSingleValued) => { - if let [Some(ty)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral(ty.is_single_valued(db))); - } - } - - Some(KnownFunction::Len) => { - if let [Some(first_arg)] = overload.parameter_types() { - if let Some(len_ty) = first_arg.len(db) { - overload.set_return_type(len_ty); - } - } - } - - Some(KnownFunction::Repr) => { - if let [Some(first_arg)] = overload.parameter_types() { - overload.set_return_type(first_arg.repr(db)); - } - } - - Some(KnownFunction::Cast) => { - if let [Some(casted_ty), Some(_)] = overload.parameter_types() { - overload.set_return_type(*casted_ty); - } - } - - Some(KnownFunction::IsProtocol) => { - if let [Some(ty)] = overload.parameter_types() { - overload.set_return_type(Type::BooleanLiteral( - ty.into_class_literal() - .is_some_and(|class| class.is_protocol(db)), - )); - } - } - - Some(KnownFunction::GetProtocolMembers) => { - if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() { - if let Some(protocol_class) = class.into_protocol_class(db) { - // TODO: actually a frozenset at runtime (requires support for legacy generic classes) - overload.set_return_type(Type::Tuple(TupleType::new( - db, - protocol_class - .protocol_members(db) - .iter() - .map(|member| Type::string_literal(db, member)) - .collect::]>>(), + overload.set_return_type(Type::BoundMethod(BoundMethodType::new( + db, function, *first, ))); } } } - Some(KnownFunction::Overload) => { - // TODO: This can be removed once we understand legacy generics because the - // typeshed definition for `typing.overload` is an identity function. - if let [Some(ty)] = overload.parameter_types() { - overload.set_return_type(*ty); - } - } - - Some(KnownFunction::Override) => { - // TODO: This can be removed once we understand legacy generics because the - // typeshed definition for `typing.overload` is an identity function. - if let [Some(ty)] = overload.parameter_types() { - overload.set_return_type(*ty); - } - } - - Some(KnownFunction::AbstractMethod) => { - // TODO: This can be removed once we understand legacy generics because the - // typeshed definition for `abc.abstractmethod` is an identity function. - if let [Some(ty)] = overload.parameter_types() { - overload.set_return_type(*ty); - } - } - - Some(KnownFunction::Final) => { - // TODO: This can be removed once we understand legacy generics because the - // typeshed definition for `abc.abstractmethod` is an identity function. - if let [Some(ty)] = overload.parameter_types() { - overload.set_return_type(*ty); - } - } - - Some(KnownFunction::GetattrStatic) => { - let [Some(instance_ty), Some(attr_name), default] = + Type::WrapperDescriptor(WrapperDescriptorKind::FunctionTypeDunderGet) => { + if let [Some(function_ty @ Type::FunctionLiteral(function)), ..] = overload.parameter_types() - else { - continue; - }; - - let Some(attr_name) = attr_name.into_string_literal() else { - continue; - }; - - let default = if let Some(default) = default { - *default - } else { - Type::Never - }; - - let union_with_default = |ty| UnionType::from_elements(db, [ty, default]); - - // TODO: we could emit a diagnostic here (if default is not set) - overload.set_return_type( - match instance_ty.static_member(db, attr_name.value(db)) { - Symbol::Type(ty, Boundness::Bound) => { - if instance_ty.is_fully_static(db) { - ty - } else { - // Here, we attempt to model the fact that an attribute lookup on - // a non-fully static type could fail. This is an approximation, - // as there are gradual types like `tuple[Any]`, on which a lookup - // of (e.g. of the `index` method) would always succeed. - - union_with_default(ty) + { + if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { + match overload.parameter_types() { + [_, _, Some(owner)] => { + overload.set_return_type(Type::BoundMethod( + BoundMethodType::new(db, *function, *owner), + )); } + + [_, Some(instance), None] => { + overload.set_return_type(Type::BoundMethod( + BoundMethodType::new( + db, + *function, + instance.to_meta_type(db), + ), + )); + } + + _ => {} } - Symbol::Type(ty, Boundness::PossiblyUnbound) => { - union_with_default(ty) + } else { + match overload.parameter_types() { + [_, Some(instance), _] if instance.is_none(db) => { + overload.set_return_type(*function_ty); + } + [_, Some(instance), _] => { + overload.set_return_type(Type::BoundMethod( + BoundMethodType::new(db, *function, *instance), + )); + } + + _ => {} } - Symbol::Unbound => default, - }, - ); - } - - Some(KnownFunction::Dataclass) => { - if let [init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot] = - overload.parameter_types() - { - let mut params = DataclassParams::empty(); - - if to_bool(init, true) { - params |= DataclassParams::INIT; } - if to_bool(repr, true) { - params |= DataclassParams::REPR; - } - if to_bool(eq, true) { - params |= DataclassParams::EQ; - } - if to_bool(order, false) { - params |= DataclassParams::ORDER; - } - if to_bool(unsafe_hash, false) { - params |= DataclassParams::UNSAFE_HASH; - } - if to_bool(frozen, false) { - params |= DataclassParams::FROZEN; - } - if to_bool(match_args, true) { - params |= DataclassParams::MATCH_ARGS; - } - if to_bool(kw_only, false) { - params |= DataclassParams::KW_ONLY; - } - if to_bool(slots, false) { - params |= DataclassParams::SLOTS; - } - if to_bool(weakref_slot, false) { - params |= DataclassParams::WEAKREF_SLOT; - } - - overload.set_return_type(Type::DataclassDecorator(params)); } } - Some(KnownFunction::DataclassTransform) => { - if let [eq_default, order_default, kw_only_default, frozen_default, _field_specifiers, _kwargs] = - overload.parameter_types() - { - let mut params = DataclassTransformerParams::empty(); - - if to_bool(eq_default, true) { - params |= DataclassTransformerParams::EQ_DEFAULT; - } - if to_bool(order_default, false) { - params |= DataclassTransformerParams::ORDER_DEFAULT; - } - if to_bool(kw_only_default, false) { - params |= DataclassTransformerParams::KW_ONLY_DEFAULT; - } - if to_bool(frozen_default, false) { - params |= DataclassTransformerParams::FROZEN_DEFAULT; - } - - overload.set_return_type(Type::DataclassTransformer(params)); - } - } - - _ => { - if let Some(params) = function_type.dataclass_transformer_params(db) { - // This is a call to a custom function that was decorated with `@dataclass_transformer`. - // If this function was called with a keyword argument like `order=False`, we extract - // the argument type and overwrite the corresponding flag in `dataclass_params` after - // constructing them from the `dataclass_transformer`-parameter defaults. - - let mut dataclass_params = DataclassParams::from(params); - - if let Some(Some(Type::BooleanLiteral(order))) = callable_signature - .iter() - .nth(overload_index) - .and_then(|signature| { - let (idx, _) = - signature.parameters().keyword_by_name("order")?; - overload.parameter_types().get(idx) - }) - { - dataclass_params.set(DataclassParams::ORDER, *order); - } - - overload.set_return_type(Type::DataclassDecorator(dataclass_params)); - } - } - }, - - Type::ClassLiteral(class) => match class.known(db) { - Some(KnownClass::Bool) => match overload.parameter_types() { - [Some(arg)] => overload.set_return_type(arg.bool(db).into_type(db)), - [None] => overload.set_return_type(Type::BooleanLiteral(false)), - _ => {} - }, - - Some(KnownClass::Str) if overload_index == 0 => { + Type::WrapperDescriptor(WrapperDescriptorKind::PropertyDunderGet) => { match overload.parameter_types() { - [Some(arg)] => overload.set_return_type(arg.str(db)), - [None] => overload.set_return_type(Type::string_literal(db, "")), + [Some(property @ Type::PropertyInstance(_)), Some(instance), ..] + if instance.is_none(db) => + { + overload.set_return_type(*property); + } + [Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeAliasType( + type_alias, + ))), ..] + if property.getter(db).is_some_and(|getter| { + getter + .into_function_literal() + .is_some_and(|f| f.name(db) == "__name__") + }) => + { + overload + .set_return_type(Type::string_literal(db, type_alias.name(db))); + } + [Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeVar(typevar))), ..] => { + match property + .getter(db) + .and_then(Type::into_function_literal) + .map(|f| f.name(db).as_str()) + { + Some("__name__") => { + overload.set_return_type(Type::string_literal( + db, + typevar.name(db), + )); + } + Some("__bound__") => { + overload.set_return_type( + typevar + .upper_bound(db) + .unwrap_or_else(|| Type::none(db)), + ); + } + Some("__constraints__") => { + overload.set_return_type(TupleType::from_elements( + db, + typevar.constraints(db).into_iter().flatten(), + )); + } + Some("__default__") => { + overload.set_return_type( + typevar.default_ty(db).unwrap_or_else(|| { + KnownClass::NoDefaultType.to_instance(db) + }), + ); + } + _ => {} + } + } + [Some(Type::PropertyInstance(property)), Some(instance), ..] => { + if let Some(getter) = property.getter(db) { + if let Ok(return_ty) = getter + .try_call( + db, + &mut CallArgumentTypes::positional([*instance]), + ) + .map(|binding| binding.return_type(db)) + { + overload.set_return_type(return_ty); + } else { + overload.errors.push(BindingError::InternalCallError( + "calling the getter failed", + )); + overload.set_return_type(Type::unknown()); + } + } else { + overload.errors.push(BindingError::InternalCallError( + "property has no getter", + )); + overload.set_return_type(Type::Never); + } + } _ => {} } } - Some(KnownClass::Type) if overload_index == 0 => { - if let [Some(arg)] = overload.parameter_types() { - overload.set_return_type(arg.to_meta_type(db)); + Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(property)) => { + match overload.parameter_types() { + [Some(instance), ..] if instance.is_none(db) => { + overload.set_return_type(Type::PropertyInstance(property)); + } + [Some(instance), ..] => { + if let Some(getter) = property.getter(db) { + if let Ok(return_ty) = getter + .try_call( + db, + &mut CallArgumentTypes::positional([*instance]), + ) + .map(|binding| binding.return_type(db)) + { + overload.set_return_type(return_ty); + } else { + overload.errors.push(BindingError::InternalCallError( + "calling the getter failed", + )); + overload.set_return_type(Type::unknown()); + } + } else { + overload.set_return_type(Type::Never); + overload.errors.push(BindingError::InternalCallError( + "property has no getter", + )); + } + } + _ => {} } } - Some(KnownClass::Property) => { - if let [getter, setter, ..] = overload.parameter_types() { - overload.set_return_type(Type::PropertyInstance( - PropertyInstanceType::new(db, *getter, *setter), + Type::WrapperDescriptor(WrapperDescriptorKind::PropertyDunderSet) => { + if let [Some(Type::PropertyInstance(property)), Some(instance), Some(value), ..] = + overload.parameter_types() + { + if let Some(setter) = property.setter(db) { + if let Err(_call_error) = setter.try_call( + db, + &mut CallArgumentTypes::positional([*instance, *value]), + ) { + overload.errors.push(BindingError::InternalCallError( + "calling the setter failed", + )); + } + } else { + overload.errors.push(BindingError::InternalCallError( + "property has no setter", + )); + } + } + } + + Type::MethodWrapper(MethodWrapperKind::PropertyDunderSet(property)) => { + if let [Some(instance), Some(value), ..] = overload.parameter_types() { + if let Some(setter) = property.setter(db) { + if let Err(_call_error) = setter.try_call( + db, + &mut CallArgumentTypes::positional([*instance, *value]), + ) { + overload.errors.push(BindingError::InternalCallError( + "calling the setter failed", + )); + } + } else { + overload.errors.push(BindingError::InternalCallError( + "property has no setter", + )); + } + } + } + + Type::MethodWrapper(MethodWrapperKind::StrStartswith(literal)) => { + if let [Some(Type::StringLiteral(prefix)), None, None] = + overload.parameter_types() + { + overload.set_return_type(Type::BooleanLiteral( + literal.value(db).starts_with(&**prefix.value(db)), )); } } + Type::DataclassTransformer(params) => { + if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() + { + overload.set_return_type(Type::FunctionLiteral( + function.with_dataclass_transformer_params(db, params), + )); + } + } + + Type::BoundMethod(bound_method) + if bound_method.self_instance(db).is_property_instance() => + { + match bound_method.function(db).name(db).as_str() { + "setter" => { + if let [Some(_), Some(setter)] = overload.parameter_types() { + let mut ty_property = bound_method.self_instance(db); + if let Type::PropertyInstance(property) = ty_property { + ty_property = + Type::PropertyInstance(PropertyInstanceType::new( + db, + property.getter(db), + Some(*setter), + )); + } + overload.set_return_type(ty_property); + } + } + "getter" => { + if let [Some(_), Some(getter)] = overload.parameter_types() { + let mut ty_property = bound_method.self_instance(db); + if let Type::PropertyInstance(property) = ty_property { + ty_property = + Type::PropertyInstance(PropertyInstanceType::new( + db, + Some(*getter), + property.setter(db), + )); + } + overload.set_return_type(ty_property); + } + } + "deleter" => { + // TODO: we do not store deleters yet + let ty_property = bound_method.self_instance(db); + overload.set_return_type(ty_property); + } + _ => { + // Fall back to typeshed stubs for all other methods + } + } + } + + Type::FunctionLiteral(function_type) => match function_type.known(db) { + Some(KnownFunction::IsEquivalentTo) => { + if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { + overload.set_return_type(Type::BooleanLiteral( + ty_a.is_equivalent_to(db, *ty_b), + )); + } + } + + Some(KnownFunction::IsSubtypeOf) => { + if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { + overload.set_return_type(Type::BooleanLiteral( + ty_a.is_subtype_of(db, *ty_b), + )); + } + } + + Some(KnownFunction::IsAssignableTo) => { + if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { + overload.set_return_type(Type::BooleanLiteral( + ty_a.is_assignable_to(db, *ty_b), + )); + } + } + + Some(KnownFunction::IsDisjointFrom) => { + if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { + overload.set_return_type(Type::BooleanLiteral( + ty_a.is_disjoint_from(db, *ty_b), + )); + } + } + + Some(KnownFunction::IsGradualEquivalentTo) => { + if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() { + overload.set_return_type(Type::BooleanLiteral( + ty_a.is_gradual_equivalent_to(db, *ty_b), + )); + } + } + + Some(KnownFunction::IsFullyStatic) => { + if let [Some(ty)] = overload.parameter_types() { + overload + .set_return_type(Type::BooleanLiteral(ty.is_fully_static(db))); + } + } + + Some(KnownFunction::IsSingleton) => { + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(Type::BooleanLiteral(ty.is_singleton(db))); + } + } + + Some(KnownFunction::IsSingleValued) => { + if let [Some(ty)] = overload.parameter_types() { + overload + .set_return_type(Type::BooleanLiteral(ty.is_single_valued(db))); + } + } + + Some(KnownFunction::Len) => { + if let [Some(first_arg)] = overload.parameter_types() { + if let Some(len_ty) = first_arg.len(db) { + overload.set_return_type(len_ty); + } + } + } + + Some(KnownFunction::Repr) => { + if let [Some(first_arg)] = overload.parameter_types() { + overload.set_return_type(first_arg.repr(db)); + } + } + + Some(KnownFunction::Cast) => { + if let [Some(casted_ty), Some(_)] = overload.parameter_types() { + overload.set_return_type(*casted_ty); + } + } + + Some(KnownFunction::IsProtocol) => { + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(Type::BooleanLiteral( + ty.into_class_literal() + .is_some_and(|class| class.is_protocol(db)), + )); + } + } + + Some(KnownFunction::GetProtocolMembers) => { + if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() { + if let Some(protocol_class) = class.into_protocol_class(db) { + // TODO: actually a frozenset at runtime (requires support for legacy generic classes) + overload.set_return_type(Type::Tuple(TupleType::new( + db, + protocol_class + .protocol_members(db) + .iter() + .map(|member| Type::string_literal(db, member)) + .collect::]>>(), + ))); + } + } + } + + Some(KnownFunction::Overload) => { + // TODO: This can be removed once we understand legacy generics because the + // typeshed definition for `typing.overload` is an identity function. + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(*ty); + } + } + + Some(KnownFunction::Override) => { + // TODO: This can be removed once we understand legacy generics because the + // typeshed definition for `typing.overload` is an identity function. + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(*ty); + } + } + + Some(KnownFunction::AbstractMethod) => { + // TODO: This can be removed once we understand legacy generics because the + // typeshed definition for `abc.abstractmethod` is an identity function. + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(*ty); + } + } + + Some(KnownFunction::Final) => { + // TODO: This can be removed once we understand legacy generics because the + // typeshed definition for `abc.abstractmethod` is an identity function. + if let [Some(ty)] = overload.parameter_types() { + overload.set_return_type(*ty); + } + } + + Some(KnownFunction::GetattrStatic) => { + let [Some(instance_ty), Some(attr_name), default] = + overload.parameter_types() + else { + continue; + }; + + let Some(attr_name) = attr_name.into_string_literal() else { + continue; + }; + + let default = if let Some(default) = default { + *default + } else { + Type::Never + }; + + let union_with_default = + |ty| UnionType::from_elements(db, [ty, default]); + + // TODO: we could emit a diagnostic here (if default is not set) + overload.set_return_type( + match instance_ty.static_member(db, attr_name.value(db)) { + Symbol::Type(ty, Boundness::Bound) => { + if instance_ty.is_fully_static(db) { + ty + } else { + // Here, we attempt to model the fact that an attribute lookup on + // a non-fully static type could fail. This is an approximation, + // as there are gradual types like `tuple[Any]`, on which a lookup + // of (e.g. of the `index` method) would always succeed. + + union_with_default(ty) + } + } + Symbol::Type(ty, Boundness::PossiblyUnbound) => { + union_with_default(ty) + } + Symbol::Unbound => default, + }, + ); + } + + Some(KnownFunction::Dataclass) => { + if let [init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot] = + overload.parameter_types() + { + let mut params = DataclassParams::empty(); + + if to_bool(init, true) { + params |= DataclassParams::INIT; + } + if to_bool(repr, true) { + params |= DataclassParams::REPR; + } + if to_bool(eq, true) { + params |= DataclassParams::EQ; + } + if to_bool(order, false) { + params |= DataclassParams::ORDER; + } + if to_bool(unsafe_hash, false) { + params |= DataclassParams::UNSAFE_HASH; + } + if to_bool(frozen, false) { + params |= DataclassParams::FROZEN; + } + if to_bool(match_args, true) { + params |= DataclassParams::MATCH_ARGS; + } + if to_bool(kw_only, false) { + params |= DataclassParams::KW_ONLY; + } + if to_bool(slots, false) { + params |= DataclassParams::SLOTS; + } + if to_bool(weakref_slot, false) { + params |= DataclassParams::WEAKREF_SLOT; + } + + overload.set_return_type(Type::DataclassDecorator(params)); + } + } + + Some(KnownFunction::DataclassTransform) => { + if let [eq_default, order_default, kw_only_default, frozen_default, _field_specifiers, _kwargs] = + overload.parameter_types() + { + let mut params = DataclassTransformerParams::empty(); + + if to_bool(eq_default, true) { + params |= DataclassTransformerParams::EQ_DEFAULT; + } + if to_bool(order_default, false) { + params |= DataclassTransformerParams::ORDER_DEFAULT; + } + if to_bool(kw_only_default, false) { + params |= DataclassTransformerParams::KW_ONLY_DEFAULT; + } + if to_bool(frozen_default, false) { + params |= DataclassTransformerParams::FROZEN_DEFAULT; + } + + overload.set_return_type(Type::DataclassTransformer(params)); + } + } + + _ => { + if let Some(params) = function_type.dataclass_transformer_params(db) { + // This is a call to a custom function that was decorated with `@dataclass_transformer`. + // If this function was called with a keyword argument like `order=False`, we extract + // the argument type and overwrite the corresponding flag in `dataclass_params` after + // constructing them from the `dataclass_transformer`-parameter defaults. + + let mut dataclass_params = DataclassParams::from(params); + + if let Some(Some(Type::BooleanLiteral(order))) = callable_signature + .iter() + .nth(overload_index) + .and_then(|signature| { + let (idx, _) = + signature.parameters().keyword_by_name("order")?; + overload.parameter_types().get(idx) + }) + { + dataclass_params.set(DataclassParams::ORDER, *order); + } + + overload + .set_return_type(Type::DataclassDecorator(dataclass_params)); + } + } + }, + + Type::ClassLiteral(class) => match class.known(db) { + Some(KnownClass::Bool) => match overload.parameter_types() { + [Some(arg)] => overload.set_return_type(arg.bool(db).into_type(db)), + [None] => overload.set_return_type(Type::BooleanLiteral(false)), + _ => {} + }, + + Some(KnownClass::Str) if overload_index == 0 => { + match overload.parameter_types() { + [Some(arg)] => overload.set_return_type(arg.str(db)), + [None] => overload.set_return_type(Type::string_literal(db, "")), + _ => {} + } + } + + Some(KnownClass::Type) if overload_index == 0 => { + if let [Some(arg)] = overload.parameter_types() { + overload.set_return_type(arg.to_meta_type(db)); + } + } + + Some(KnownClass::Property) => { + if let [getter, setter, ..] = overload.parameter_types() { + overload.set_return_type(Type::PropertyInstance( + PropertyInstanceType::new(db, *getter, *setter), + )); + } + } + + _ => {} + }, + + Type::KnownInstance(KnownInstanceType::TypedDict) => { + overload.set_return_type(todo_type!("TypedDict")); + } + + // Not a special case _ => {} - }, - - Type::KnownInstance(KnownInstanceType::TypedDict) => { - overload.set_return_type(todo_type!("TypedDict")); } - - // Not a special case - _ => {} } } } @@ -868,7 +886,7 @@ impl<'db> CallableBinding<'db> { // the matching overloads. Make sure to implement that as part of separating call binding into // two phases. // - // [1] https://github.com/python/typing/pull/1839 + // [1] https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation let overloads = signature .into_iter() .map(|signature| { @@ -928,35 +946,39 @@ impl<'db> CallableBinding<'db> { /// Returns whether there were any errors binding this call site. If the callable has multiple /// overloads, they must _all_ have errors. pub(crate) fn has_binding_errors(&self) -> bool { - self.matching_overload().is_none() + self.matching_overloads().next().is_none() } - /// Returns the overload that matched for this call binding. Returns `None` if none of the - /// overloads matched. - pub(crate) fn matching_overload(&self) -> Option<(usize, &Binding<'db>)> { + /// Returns an iterator over all the overloads that matched for this call binding. + pub(crate) fn matching_overloads(&self) -> impl Iterator)> { self.overloads .iter() .enumerate() - .find(|(_, overload)| overload.as_result().is_ok()) + .filter(|(_, overload)| overload.as_result().is_ok()) } - /// Returns the overload that matched for this call binding. Returns `None` if none of the - /// overloads matched. - pub(crate) fn matching_overload_mut(&mut self) -> Option<(usize, &mut Binding<'db>)> { + /// Returns an iterator over all the mutable overloads that matched for this call binding. + pub(crate) fn matching_overloads_mut( + &mut self, + ) -> impl Iterator)> { self.overloads .iter_mut() .enumerate() - .find(|(_, overload)| overload.as_result().is_ok()) + .filter(|(_, overload)| overload.as_result().is_ok()) } /// Returns the return type of this call. For a valid call, this is the return type of the - /// overload that the arguments matched against. For an invalid call to a non-overloaded + /// first overload that the arguments matched against. For an invalid call to a non-overloaded /// function, this is the return type of the function. For an invalid call to an overloaded /// function, we return `Type::unknown`, since we cannot make any useful conclusions about /// which overload was intended to be called. pub(crate) fn return_type(&self) -> Type<'db> { - if let Some((_, overload)) = self.matching_overload() { - return overload.return_type(); + // TODO: Implement the overload call evaluation algorithm as mentioned in the spec [1] to + // get the matching overload and use that to get the return type. + // + // [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation + if let Some((_, first_overload)) = self.matching_overloads().next() { + return first_overload.return_type(); } if let [overload] = self.overloads.as_slice() { return overload.return_type(); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 22b821bdd3..d39b5dc34c 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -4627,307 +4627,310 @@ impl<'db> TypeInferenceBuilder<'db> { Ok(mut bindings) => { for binding in &mut bindings { let binding_type = binding.callable_type; - let Some((_, overload)) = binding.matching_overload_mut() else { - continue; - }; + for (_, overload) in binding.matching_overloads_mut() { + match binding_type { + Type::FunctionLiteral(function_literal) => { + let Some(known_function) = function_literal.known(self.db()) else { + continue; + }; - match binding_type { - Type::FunctionLiteral(function_literal) => { - let Some(known_function) = function_literal.known(self.db()) else { - continue; - }; - - match known_function { - KnownFunction::RevealType => { - if let [Some(revealed_type)] = overload.parameter_types() { - if let Some(builder) = self.context.report_diagnostic( - DiagnosticId::RevealedType, - Severity::Info, - ) { - let mut diag = builder.into_diagnostic("Revealed type"); - let span = self.context.span(call_expression); - diag.annotate(Annotation::primary(span).message( - format_args!( - "`{}`", - revealed_type.display(self.db()) - ), - )); + match known_function { + KnownFunction::RevealType => { + if let [Some(revealed_type)] = overload.parameter_types() { + if let Some(builder) = self.context.report_diagnostic( + DiagnosticId::RevealedType, + Severity::Info, + ) { + let mut diag = + builder.into_diagnostic("Revealed type"); + let span = self.context.span(call_expression); + diag.annotate(Annotation::primary(span).message( + format_args!( + "`{}`", + revealed_type.display(self.db()) + ), + )); + } } } - } - KnownFunction::AssertType => { - if let [Some(actual_ty), Some(asserted_ty)] = - overload.parameter_types() - { - if !actual_ty - .is_gradual_equivalent_to(self.db(), *asserted_ty) + KnownFunction::AssertType => { + if let [Some(actual_ty), Some(asserted_ty)] = + overload.parameter_types() { - if let Some(builder) = self.context.report_lint( - &TYPE_ASSERTION_FAILURE, - call_expression, - ) { - builder.into_diagnostic(format_args!( - "Actual type `{}` is not the same \ + if !actual_ty + .is_gradual_equivalent_to(self.db(), *asserted_ty) + { + if let Some(builder) = self.context.report_lint( + &TYPE_ASSERTION_FAILURE, + call_expression, + ) { + builder.into_diagnostic(format_args!( + "Actual type `{}` is not the same \ as asserted type `{}`", - actual_ty.display(self.db()), - asserted_ty.display(self.db()), - )); + actual_ty.display(self.db()), + asserted_ty.display(self.db()), + )); + } } } } - } - KnownFunction::AssertNever => { - if let [Some(actual_ty)] = overload.parameter_types() { - if !actual_ty.is_equivalent_to(self.db(), Type::Never) { - if let Some(builder) = self.context.report_lint( - &TYPE_ASSERTION_FAILURE, - call_expression, - ) { - builder.into_diagnostic(format_args!( - "Expected type `Never`, got `{}` instead", - actual_ty.display(self.db()), - )); + KnownFunction::AssertNever => { + if let [Some(actual_ty)] = overload.parameter_types() { + if !actual_ty.is_equivalent_to(self.db(), Type::Never) { + if let Some(builder) = self.context.report_lint( + &TYPE_ASSERTION_FAILURE, + call_expression, + ) { + builder.into_diagnostic(format_args!( + "Expected type `Never`, got `{}` instead", + actual_ty.display(self.db()), + )); + } } } } - } - KnownFunction::StaticAssert => { - if let [Some(parameter_ty), message] = - overload.parameter_types() - { - let truthiness = match parameter_ty.try_bool(self.db()) { - Ok(truthiness) => truthiness, - Err(err) => { - let condition = arguments - .find_argument("condition", 0) - .map(|argument| match argument { + KnownFunction::StaticAssert => { + if let [Some(parameter_ty), message] = + overload.parameter_types() + { + let truthiness = match parameter_ty.try_bool(self.db()) + { + Ok(truthiness) => truthiness, + Err(err) => { + let condition = arguments + .find_argument("condition", 0) + .map(|argument| { + match argument { ruff_python_ast::ArgOrKeyword::Arg( expr, ) => ast::AnyNodeRef::from(expr), ruff_python_ast::ArgOrKeyword::Keyword( keyword, ) => ast::AnyNodeRef::from(keyword), - }) - .unwrap_or(ast::AnyNodeRef::from( - call_expression, - )); + } + }) + .unwrap_or(ast::AnyNodeRef::from( + call_expression, + )); - err.report_diagnostic(&self.context, condition); + err.report_diagnostic(&self.context, condition); - continue; - } - }; + continue; + } + }; - if let Some(builder) = self - .context - .report_lint(&STATIC_ASSERT_ERROR, call_expression) - { - if !truthiness.is_always_true() { - if let Some(message) = message - .and_then(Type::into_string_literal) - .map(|s| &**s.value(self.db())) - { - builder.into_diagnostic(format_args!( - "Static assertion error: {message}" - )); - } else if *parameter_ty - == Type::BooleanLiteral(false) - { - builder.into_diagnostic( - "Static assertion error: \ + if let Some(builder) = self + .context + .report_lint(&STATIC_ASSERT_ERROR, call_expression) + { + if !truthiness.is_always_true() { + if let Some(message) = message + .and_then(Type::into_string_literal) + .map(|s| &**s.value(self.db())) + { + builder.into_diagnostic(format_args!( + "Static assertion error: {message}" + )); + } else if *parameter_ty + == Type::BooleanLiteral(false) + { + builder.into_diagnostic( + "Static assertion error: \ argument evaluates to `False`", - ); - } else if truthiness.is_always_false() { - builder.into_diagnostic(format_args!( - "Static assertion error: \ + ); + } else if truthiness.is_always_false() { + builder.into_diagnostic(format_args!( + "Static assertion error: \ argument of type `{parameter_ty}` \ is statically known to be falsy", - parameter_ty = - parameter_ty.display(self.db()) - )); - } else { - builder.into_diagnostic(format_args!( - "Static assertion error: \ + parameter_ty = + parameter_ty.display(self.db()) + )); + } else { + builder.into_diagnostic(format_args!( + "Static assertion error: \ argument of type `{parameter_ty}` \ has an ambiguous static truthiness", - parameter_ty = - parameter_ty.display(self.db()) + parameter_ty = + parameter_ty.display(self.db()) + )); + } + } + } + } + } + KnownFunction::Cast => { + if let [Some(casted_type), Some(source_type)] = + overload.parameter_types() + { + let db = self.db(); + if (source_type.is_equivalent_to(db, *casted_type) + || source_type.normalized(db) + == casted_type.normalized(db)) + && !source_type.contains_todo(db) + { + if let Some(builder) = self + .context + .report_lint(&REDUNDANT_CAST, call_expression) + { + builder.into_diagnostic(format_args!( + "Value is already of type `{}`", + casted_type.display(db), )); } } } } - } - KnownFunction::Cast => { - if let [Some(casted_type), Some(source_type)] = - overload.parameter_types() - { - let db = self.db(); - if (source_type.is_equivalent_to(db, *casted_type) - || source_type.normalized(db) - == casted_type.normalized(db)) - && !source_type.contains_todo(db) + KnownFunction::GetProtocolMembers => { + if let [Some(Type::ClassLiteral(class))] = + overload.parameter_types() { - if let Some(builder) = self - .context - .report_lint(&REDUNDANT_CAST, call_expression) - { - builder.into_diagnostic(format_args!( - "Value is already of type `{}`", - casted_type.display(db), - )); + if !class.is_protocol(self.db()) { + report_bad_argument_to_get_protocol_members( + &self.context, + call_expression, + *class, + ); } } } - } - KnownFunction::GetProtocolMembers => { - if let [Some(Type::ClassLiteral(class))] = - overload.parameter_types() - { - if !class.is_protocol(self.db()) { - report_bad_argument_to_get_protocol_members( - &self.context, - call_expression, - *class, - ); - } - } - } - KnownFunction::IsInstance | KnownFunction::IsSubclass => { - if let [_, Some(Type::ClassLiteral(class))] = - overload.parameter_types() - { - if let Some(protocol_class) = - class.into_protocol_class(self.db()) + KnownFunction::IsInstance | KnownFunction::IsSubclass => { + if let [_, Some(Type::ClassLiteral(class))] = + overload.parameter_types() { - if !protocol_class.is_runtime_checkable(self.db()) { - report_runtime_check_against_non_runtime_checkable_protocol( + if let Some(protocol_class) = + class.into_protocol_class(self.db()) + { + if !protocol_class.is_runtime_checkable(self.db()) { + report_runtime_check_against_non_runtime_checkable_protocol( &self.context, call_expression, protocol_class, known_function ); + } } } } + _ => {} } - _ => {} } - } - Type::ClassLiteral(class) => { - let Some(known_class) = class.known(self.db()) else { - continue; - }; + Type::ClassLiteral(class) => { + let Some(known_class) = class.known(self.db()) else { + continue; + }; - match known_class { - KnownClass::Super => { - // Handle the case where `super()` is called with no arguments. - // In this case, we need to infer the two arguments: - // 1. The nearest enclosing class - // 2. The first parameter of the current function (typically `self` or `cls`) - match overload.parameter_types() { - [] => { - let scope = self.scope(); + match known_class { + KnownClass::Super => { + // Handle the case where `super()` is called with no arguments. + // In this case, we need to infer the two arguments: + // 1. The nearest enclosing class + // 2. The first parameter of the current function (typically `self` or `cls`) + match overload.parameter_types() { + [] => { + let scope = self.scope(); - let Some(enclosing_class) = - self.enclosing_class_symbol(scope) - else { - overload.set_return_type(Type::unknown()); - BoundSuperError::UnavailableImplicitArguments - .report_diagnostic( + let Some(enclosing_class) = + self.enclosing_class_symbol(scope) + else { + overload.set_return_type(Type::unknown()); + BoundSuperError::UnavailableImplicitArguments + .report_diagnostic( + &self.context, + call_expression.into(), + ); + continue; + }; + + let Some(first_param) = + self.first_param_type_in_scope(scope) + else { + overload.set_return_type(Type::unknown()); + BoundSuperError::UnavailableImplicitArguments + .report_diagnostic( + &self.context, + call_expression.into(), + ); + continue; + }; + + let bound_super = BoundSuperType::build( + self.db(), + enclosing_class, + first_param, + ) + .unwrap_or_else(|err| { + err.report_diagnostic( &self.context, call_expression.into(), ); - continue; - }; + Type::unknown() + }); - let Some(first_param) = - self.first_param_type_in_scope(scope) - else { - overload.set_return_type(Type::unknown()); - BoundSuperError::UnavailableImplicitArguments - .report_diagnostic( - &self.context, - call_expression.into(), - ); - continue; - }; - - let bound_super = BoundSuperType::build( - self.db(), - enclosing_class, - first_param, - ) - .unwrap_or_else(|err| { - err.report_diagnostic( - &self.context, - call_expression.into(), - ); - Type::unknown() - }); - - overload.set_return_type(bound_super); - } - [Some(pivot_class_type), Some(owner_type)] => { - let bound_super = BoundSuperType::build( - self.db(), - *pivot_class_type, - *owner_type, - ) - .unwrap_or_else(|err| { - err.report_diagnostic( - &self.context, - call_expression.into(), - ); - Type::unknown() - }); - - overload.set_return_type(bound_super); - } - _ => (), - } - } - - KnownClass::TypeVar => { - let assigned_to = (self.index) - .try_expression(call_expression_node) - .and_then(|expr| expr.assigned_to(self.db())); - - let Some(target) = - assigned_to.as_ref().and_then(|assigned_to| { - match assigned_to.node().targets.as_slice() { - [ast::Expr::Name(target)] => Some(target), - _ => None, + overload.set_return_type(bound_super); } - }) - else { - if let Some(builder) = self.context.report_lint( - &INVALID_LEGACY_TYPE_VARIABLE, - call_expression, - ) { - builder.into_diagnostic(format_args!( + [Some(pivot_class_type), Some(owner_type)] => { + let bound_super = BoundSuperType::build( + self.db(), + *pivot_class_type, + *owner_type, + ) + .unwrap_or_else(|err| { + err.report_diagnostic( + &self.context, + call_expression.into(), + ); + Type::unknown() + }); + + overload.set_return_type(bound_super); + } + _ => (), + } + } + + KnownClass::TypeVar => { + let assigned_to = (self.index) + .try_expression(call_expression_node) + .and_then(|expr| expr.assigned_to(self.db())); + + let Some(target) = + assigned_to.as_ref().and_then(|assigned_to| { + match assigned_to.node().targets.as_slice() { + [ast::Expr::Name(target)] => Some(target), + _ => None, + } + }) + else { + if let Some(builder) = self.context.report_lint( + &INVALID_LEGACY_TYPE_VARIABLE, + call_expression, + ) { + builder.into_diagnostic(format_args!( "A legacy `typing.TypeVar` must be immediately assigned to a variable", )); - } - continue; - }; + } + continue; + }; - let [Some(name_param), constraints, bound, default, _contravariant, _covariant, _infer_variance] = - overload.parameter_types() - else { - continue; - }; + let [Some(name_param), constraints, bound, default, _contravariant, _covariant, _infer_variance] = + overload.parameter_types() + else { + continue; + }; - let name_param = name_param - .into_string_literal() - .map(|name| name.value(self.db()).as_ref()); - if name_param.is_none_or(|name_param| name_param != target.id) { - if let Some(builder) = self.context.report_lint( - &INVALID_LEGACY_TYPE_VARIABLE, - call_expression, - ) { - builder.into_diagnostic(format_args!( + let name_param = name_param + .into_string_literal() + .map(|name| name.value(self.db()).as_ref()); + if name_param + .is_none_or(|name_param| name_param != target.id) + { + if let Some(builder) = self.context.report_lint( + &INVALID_LEGACY_TYPE_VARIABLE, + call_expression, + ) { + builder.into_diagnostic(format_args!( "The name of a legacy `typing.TypeVar`{} must match \ the name of the variable it is assigned to (`{}`)", if let Some(name_param) = name_param { @@ -4937,60 +4940,63 @@ impl<'db> TypeInferenceBuilder<'db> { }, target.id, )); + } + continue; } - continue; + + let bound_or_constraint = match (bound, constraints) { + (Some(bound), None) => { + Some(TypeVarBoundOrConstraints::UpperBound(*bound)) + } + + (None, Some(_constraints)) => { + // We don't use UnionType::from_elements or UnionBuilder here, + // because we don't want to simplify the list of constraints like + // we do with the elements of an actual union type. + // TODO: Consider using a new `OneOfType` connective here instead, + // since that more accurately represents the actual semantics of + // typevar constraints. + let elements = UnionType::new( + self.db(), + overload + .arguments_for_parameter( + &call_argument_types, + 1, + ) + .map(|(_, ty)| ty) + .collect::>(), + ); + Some(TypeVarBoundOrConstraints::Constraints( + elements, + )) + } + + // TODO: Emit a diagnostic that TypeVar cannot be both bounded and + // constrained + (Some(_), Some(_)) => continue, + + (None, None) => None, + }; + + let containing_assignment = + self.index.expect_single_definition(target); + overload.set_return_type(Type::KnownInstance( + KnownInstanceType::TypeVar(TypeVarInstance::new( + self.db(), + target.id.clone(), + containing_assignment, + bound_or_constraint, + *default, + TypeVarKind::Legacy, + )), + )); } - let bound_or_constraint = match (bound, constraints) { - (Some(bound), None) => { - Some(TypeVarBoundOrConstraints::UpperBound(*bound)) - } - - (None, Some(_constraints)) => { - // We don't use UnionType::from_elements or UnionBuilder here, - // because we don't want to simplify the list of constraints like - // we do with the elements of an actual union type. - // TODO: Consider using a new `OneOfType` connective here instead, - // since that more accurately represents the actual semantics of - // typevar constraints. - let elements = UnionType::new( - self.db(), - overload - .arguments_for_parameter( - &call_argument_types, - 1, - ) - .map(|(_, ty)| ty) - .collect::>(), - ); - Some(TypeVarBoundOrConstraints::Constraints(elements)) - } - - // TODO: Emit a diagnostic that TypeVar cannot be both bounded and - // constrained - (Some(_), Some(_)) => continue, - - (None, None) => None, - }; - - let containing_assignment = - self.index.expect_single_definition(target); - overload.set_return_type(Type::KnownInstance( - KnownInstanceType::TypeVar(TypeVarInstance::new( - self.db(), - target.id.clone(), - containing_assignment, - bound_or_constraint, - *default, - TypeVarKind::Legacy, - )), - )); + _ => (), } - - _ => (), } + _ => (), } - _ => (), } } bindings.return_type(self.db()) @@ -6637,7 +6643,8 @@ impl<'db> TypeInferenceBuilder<'db> { .next() .expect("valid bindings should have one callable"); let (_, overload) = callable - .matching_overload() + .matching_overloads() + .next() .expect("valid bindings should have matching overload"); let specialization = generic_context.specialize( self.db(),