diff --git a/crates/ty_python_semantic/resources/mdtest/call/function.md b/crates/ty_python_semantic/resources/mdtest/call/function.md index 852623a4f4..624c0976a6 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/function.md +++ b/crates/ty_python_semantic/resources/mdtest/call/function.md @@ -613,56 +613,6 @@ def _(args: str) -> None: takes_at_least_two_positional_only(*args) ``` -### Argument expansion regression - -This is a regression that was highlighted by the ecosystem check, which shows that we might need to -rethink how we perform argument expansion during overload resolution. In particular, we might need -to retry both `match_parameters` *and* `check_types` for each expansion. Currently we only retry -`check_types`. - -The issue is that argument expansion might produce a splatted value with a different arity than what -we originally inferred for the unexpanded value, and that in turn can affect which parameters the -splatted value is matched with. - -The first example correctly produces an error. The `tuple[int, str]` union element has a precise -arity of two, and so parameter matching chooses the first overload. The second element of the tuple -does not match the second parameter type, which yielding an `invalid-argument-type` error. - -The third example should produce the same error. However, because we have a union, we do not see the -precise arity of each union element during parameter matching. Instead, we infer an arity of "zero -or more" for the union as a whole, and use that less precise arity when matching parameters. We -therefore consider the second overload to still be a potential candidate for the `tuple[int, str]` -union element. During type checking, we have to force the arity of each union element to match the -inferred arity of the union as a whole (turning `tuple[int, str]` into `tuple[int | str, ...]`). -That less precise tuple type-checks successfully against the second overload, making us incorrectly -think that `tuple[int, str]` is a valid splatted call. - -If we update argument expansion to retry parameter matching with the precise arity of each union -element, we will correctly rule out the second overload for `tuple[int, str]`, just like we do when -splatting that tuple directly (instead of as part of a union). - -```py -from typing import overload - -@overload -def f(x: int, y: int) -> None: ... -@overload -def f(x: int, y: str, z: int) -> None: ... -def f(*args): ... - -# Test all of the above with a number of different splatted argument types - -def _(t: tuple[int, str]) -> None: - f(*t) # error: [invalid-argument-type] - -def _(t: tuple[int, str, int]) -> None: - f(*t) - -def _(t: tuple[int, str] | tuple[int, str, int]) -> None: - # TODO: error: [invalid-argument-type] - f(*t) -``` - ## Wrong argument type ### Positional argument, positional-or-keyword parameter diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index a59de5d27b..c769f904d6 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -889,6 +889,48 @@ def _(a: int | None): ) ``` +### Retry from parameter matching + +As per the spec, the argument type expansion should retry evaluating the expanded argument list from +the type checking step. However, that creates an issue when variadic arguments are involved because +if a variadic argument is a union type, it could be expanded to have different arities. So, ty +retries it from the start which includes parameter matching as well. + +`overloaded.pyi`: + +```pyi +from typing import overload + +@overload +def f(x: int, y: int) -> None: ... +@overload +def f(x: int, y: str, z: int) -> None: ... +``` + +```py +from overloaded import f + +# Test all of the above with a number of different splatted argument types + +def _(t: tuple[int, str]) -> None: + # This correctly produces an error because the first element of the union has a precise arity of + # 2, which matches the first overload, but the second element of the tuple doesn't match the + # second parameter type, yielding an `invalid-argument-type` error. + f(*t) # error: [invalid-argument-type] + +def _(t: tuple[int, str, int]) -> None: + # This correctly produces no error because the first element of the union has a precise arity of + # 3, which matches the second overload. + f(*t) + +def _(t: tuple[int, str] | tuple[int, str, int]) -> None: + # This produces an error because the expansion produces two argument lists: `[*tuple[int, str]]` + # and `[*tuple[int, str, int]]`. The first list produces produces a type checking error as + # described in the first example, while the second list matches the second overload. And, + # because not all of the expanded argument list evaluates successfully, we produce an error. + f(*t) # error: [no-matching-overload] +``` + ## Filtering based on `Any` / `Unknown` This is the step 5 of the overload call evaluation algorithm which specifies that: diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index 3bf78c6a75..2854f5a9e5 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -202,10 +202,25 @@ impl<'a, 'db> CallArguments<'a, 'db> { for subtype in &expanded_types { let mut new_expanded_types = pre_expanded_types.to_vec(); new_expanded_types[index] = Some(*subtype); - expanded_arguments.push(CallArguments::new( - self.arguments.clone(), - new_expanded_types, - )); + + // Update the arguments list to handle variadic argument expansion + let mut new_arguments = self.arguments.clone(); + if let Argument::Variadic(_) = self.arguments[index] { + // If the argument corresponding to this type is variadic, we need to + // update the tuple length because expanding could change the length. + // For example, in `tuple[int] | tuple[int, int]`, the length of the + // first type is 1, while the length of the second type is 2. + if let Some(expanded_type) = new_expanded_types[index] { + let length = expanded_type + .try_iterate(db) + .map(|tuple| tuple.len()) + .unwrap_or(TupleLength::unknown()); + new_arguments[index] = Argument::Variadic(length); + } + } + + expanded_arguments + .push(CallArguments::new(new_arguments, new_expanded_types)); } } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 6444e33b03..9866695233 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3,7 +3,6 @@ //! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a //! union of types, each of which might contain multiple overloads. -use std::borrow::Cow; use std::collections::HashSet; use std::fmt; @@ -28,7 +27,7 @@ use crate::types::function::{ }; use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError}; use crate::types::signatures::{Parameter, ParameterForm, Parameters}; -use crate::types::tuple::{Tuple, TupleLength, TupleType}; +use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, @@ -51,9 +50,7 @@ pub(crate) struct Bindings<'db> { elements: SmallVec<[CallableBinding<'db>; 1]>, /// Whether each argument will be used as a value and/or a type form in this call. - pub(crate) argument_forms: Box<[Option]>, - - conflicting_forms: Box<[bool]>, + argument_forms: ArgumentForms, } impl<'db> Bindings<'db> { @@ -71,8 +68,7 @@ impl<'db> Bindings<'db> { Self { callable_type, elements, - argument_forms: Box::from([]), - conflicting_forms: Box::from([]), + argument_forms: ArgumentForms::new(0), } } @@ -91,6 +87,10 @@ impl<'db> Bindings<'db> { } } + pub(crate) fn argument_forms(&self) -> &[Option] { + &self.argument_forms.values + } + /// Match the arguments of a call site against the parameters of a collection of possibly /// unioned, possibly overloaded signatures. /// @@ -105,13 +105,12 @@ impl<'db> Bindings<'db> { db: &'db dyn Db, arguments: &CallArguments<'_, 'db>, ) -> Self { - let mut argument_forms = vec![None; arguments.len()]; - let mut conflicting_forms = vec![false; arguments.len()]; + let mut argument_forms = ArgumentForms::new(arguments.len()); for binding in &mut self.elements { - binding.match_parameters(db, arguments, &mut argument_forms, &mut conflicting_forms); + binding.match_parameters(db, arguments, &mut argument_forms); } - self.argument_forms = argument_forms.into(); - self.conflicting_forms = conflicting_forms.into(); + argument_forms.shrink_to_fit(); + self.argument_forms = argument_forms; self } @@ -130,7 +129,12 @@ impl<'db> Bindings<'db> { argument_types: &CallArguments<'_, 'db>, ) -> Result> { for element in &mut self.elements { - element.check_types(db, argument_types); + if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) { + // If this element returned a new set of argument forms (indicating successful + // argument type expansion), update the `Bindings` with these forms. + updated_argument_forms.shrink_to_fit(); + self.argument_forms = updated_argument_forms; + } } self.evaluate_known_cases(db); @@ -153,7 +157,7 @@ impl<'db> Bindings<'db> { let mut all_ok = true; let mut any_binding_error = false; let mut all_not_callable = true; - if self.conflicting_forms.contains(&true) { + if self.argument_forms.conflicting.contains(&true) { all_ok = false; any_binding_error = true; all_not_callable = false; @@ -226,7 +230,7 @@ impl<'db> Bindings<'db> { return; } - for (index, conflicting_form) in self.conflicting_forms.iter().enumerate() { + for (index, conflicting_form) in self.argument_forms.conflicting.iter().enumerate() { if *conflicting_form { let node = BindingError::get_node(node, Some(index)); if let Some(builder) = context.report_lint(&CONFLICTING_ARGUMENT_FORMS, node) { @@ -1118,8 +1122,7 @@ impl<'db> From> for Bindings<'db> { Bindings { callable_type: from.callable_type, elements: smallvec_inline![from], - argument_forms: Box::from([]), - conflicting_forms: Box::from([]), + argument_forms: ArgumentForms::new(0), } } } @@ -1140,8 +1143,7 @@ impl<'db> From> for Bindings<'db> { Bindings { callable_type, elements: smallvec_inline![callable_binding], - argument_forms: Box::from([]), - conflicting_forms: Box::from([]), + argument_forms: ArgumentForms::new(0), } } } @@ -1262,19 +1264,22 @@ impl<'db> CallableBinding<'db> { &mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>, - argument_forms: &mut [Option], - conflicting_forms: &mut [bool], + argument_forms: &mut ArgumentForms, ) { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. let arguments = arguments.with_self(self.bound_type); for overload in &mut self.overloads { - overload.match_parameters(db, arguments.as_ref(), argument_forms, conflicting_forms); + overload.match_parameters(db, arguments.as_ref(), argument_forms); } } - fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>) { + fn check_types( + &mut self, + db: &'db dyn Db, + argument_types: &CallArguments<'_, 'db>, + ) -> Option { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. let argument_types = argument_types.with_self(self.bound_type); @@ -1288,14 +1293,14 @@ impl<'db> CallableBinding<'db> { if let [overload] = self.overloads.as_mut_slice() { overload.check_types(db, argument_types.as_ref()); } - return; + return None; } MatchingOverloadIndex::Single(index) => { // If only one candidate overload remains, it is the winning match. Evaluate it as // a regular (non-overloaded) call. self.matching_overload_index = Some(index); self.overloads[index].check_types(db, argument_types.as_ref()); - return; + return None; } MatchingOverloadIndex::Multiple(indexes) => { // If two or more candidate overloads remain, proceed to step 2. @@ -1303,12 +1308,6 @@ impl<'db> CallableBinding<'db> { } }; - let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes); - - // State of the bindings _before_ evaluating (type checking) the matching overloads using - // the non-expanded argument types. - let pre_evaluation_snapshot = snapshotter.take(self); - // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine // whether it is compatible with the supplied argument list. for (_, overload) in self.matching_overloads_mut() { @@ -1321,7 +1320,7 @@ impl<'db> CallableBinding<'db> { } MatchingOverloadIndex::Single(_) => { // If only one overload evaluates without error, it is the winning match. - return; + return None; } MatchingOverloadIndex::Multiple(indexes) => { // If two or more candidate overloads remain, proceed to step 4. @@ -1330,8 +1329,8 @@ impl<'db> CallableBinding<'db> { // Step 5 self.filter_overloads_using_any_or_unknown(db, argument_types.as_ref(), &indexes); - // We're returning here because this shouldn't lead to argument type expansion. - return; + // This shouldn't lead to argument type expansion. + return None; } } @@ -1339,27 +1338,14 @@ impl<'db> CallableBinding<'db> { // https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion let mut expansions = argument_types.expand(db).peekable(); - if expansions.peek().is_none() { - // Return early if there are no argument types to expand. - return; - } - - // State of the bindings _after_ evaluating (type checking) the matching overloads using - // the non-expanded argument types. - let post_evaluation_snapshot = snapshotter.take(self); - - // Restore the bindings state to the one prior to the type checking step in preparation - // for evaluating the expanded argument lists. - snapshotter.restore(self, pre_evaluation_snapshot); + // Return early if there are no argument types to expand. + expansions.peek()?; // At this point, there's at least one argument that can be expanded. // // This heuristic tries to detect if there's any need to perform argument type expansion or // not by checking whether there are any non-expandable argument type that cannot be - // assigned to any of the remaining overloads. - // - // This heuristic needs to be applied after restoring the bindings state to the one before - // type checking as argument type expansion would evaluate it from that point on. + // assigned to any of the overloads. for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() { // TODO: Remove `Keywords` once `**kwargs` support is added if matches!(argument, Argument::Synthetic | Argument::Keywords) { @@ -1372,7 +1358,7 @@ impl<'db> CallableBinding<'db> { continue; } let mut is_argument_assignable_to_any_overload = false; - 'overload: for (_, overload) in self.matching_overloads() { + 'overload: for overload in &self.overloads { for parameter_index in &overload.argument_matches[argument_index].parameters { let parameter_type = overload.signature.parameters()[*parameter_index] .annotated_type() @@ -1389,11 +1375,16 @@ impl<'db> CallableBinding<'db> { remaining overloads, skipping argument type expansion", argument_type.display(db) ); - snapshotter.restore(self, post_evaluation_snapshot); - return; + return None; } } + let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes); + + // State of the bindings _after_ evaluating (type checking) the matching overloads using + // the non-expanded argument types. + let post_evaluation_snapshot = snapshotter.take(self); + for expansion in expansions { let expanded_argument_lists = match expansion { Expansion::LimitReached(index) => { @@ -1401,7 +1392,7 @@ impl<'db> CallableBinding<'db> { self.overload_call_return_type = Some( OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index), ); - return; + return None; } Expansion::Expanded(argument_lists) => argument_lists, }; @@ -1411,13 +1402,33 @@ impl<'db> CallableBinding<'db> { // the expanded argument lists evaluated successfully. let mut merged_evaluation_state: Option> = None; + // Merged argument forms after evaluating all the argument lists in this expansion. + let mut merged_argument_forms = ArgumentForms::default(); + + // The return types of each of the expanded argument lists that evaluated successfully. let mut return_types = Vec::new(); - for expanded_argument_types in &expanded_argument_lists { - let pre_evaluation_snapshot = snapshotter.take(self); + for expanded_arguments in &expanded_argument_lists { + let mut argument_forms = ArgumentForms::new(expanded_arguments.len()); + + // The spec mentions that each expanded argument list should be re-evaluated from + // step 2 but we need to re-evaluate from step 1 because our step 1 does more than + // what the spec mentions. Step 1 of the spec means only "eliminate impossible + // overloads due to arity mismatch" while our step 1 (`match_parameters`) also + // includes "match arguments to the parameters". This is important because it + // allows us to correctly handle cases involving a variadic argument that could + // expand into different number of arguments with each expansion. Refer to + // https://github.com/astral-sh/ty/issues/735 for more details. + for overload in &mut self.overloads { + // Clear the state of all overloads before re-evaluating from step 1 + overload.reset(); + overload.match_parameters(db, expanded_arguments, &mut argument_forms); + } + + merged_argument_forms.merge(&argument_forms); for (_, overload) in self.matching_overloads_mut() { - overload.check_types(db, expanded_argument_types); + overload.check_types(db, expanded_arguments); } let return_type = match self.matching_overload_index() { @@ -1430,7 +1441,7 @@ impl<'db> CallableBinding<'db> { self.filter_overloads_using_any_or_unknown( db, - expanded_argument_types, + expanded_arguments, &matching_overload_indexes, ); @@ -1451,9 +1462,6 @@ impl<'db> CallableBinding<'db> { merged_evaluation_state = Some(snapshotter.take(self)); } - // Restore the bindings state before evaluating the next argument list. - snapshotter.restore(self, pre_evaluation_snapshot); - if let Some(return_type) = return_type { return_types.push(return_type); } else { @@ -1481,7 +1489,7 @@ impl<'db> CallableBinding<'db> { UnionType::from_elements(db, return_types), )); - return; + return Some(merged_argument_forms); } } @@ -1490,6 +1498,8 @@ impl<'db> CallableBinding<'db> { // argument types. This is necessary because we restore the state to the pre-evaluation // snapshot when processing the expanded argument lists. snapshotter.restore(self, post_evaluation_snapshot); + + None } /// Filter overloads based on [`Any`] or [`Unknown`] argument types. @@ -1915,10 +1925,59 @@ enum MatchingOverloadIndex { Multiple(Vec), } +#[derive(Default, Debug)] +struct ArgumentForms { + values: Vec>, + conflicting: Vec, +} + +impl ArgumentForms { + /// Create a new argument forms initialized to the given length and the default values. + fn new(len: usize) -> Self { + Self { + values: vec![None; len], + conflicting: vec![false; len], + } + } + + fn merge(&mut self, other: &ArgumentForms) { + if self.values.len() < other.values.len() { + self.values.resize(other.values.len(), None); + self.conflicting.resize(other.conflicting.len(), false); + } + + for (index, (other_form, other_conflict)) in other + .values + .iter() + .zip(other.conflicting.iter()) + .enumerate() + { + if let Some(self_form) = &mut self.values[index] { + if let Some(other_form) = other_form { + if *self_form != *other_form { + // Different parameter forms, mark as conflicting + self.conflicting[index] = true; + *self_form = *other_form; // Use the new form + } + } + } else { + self.values[index] = *other_form; + } + + // Update the conflicting form (true takes precedence) + self.conflicting[index] |= *other_conflict; + } + } + + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.conflicting.shrink_to_fit(); + } +} + struct ArgumentMatcher<'a, 'db> { parameters: &'a Parameters<'db>, - argument_forms: &'a mut [Option], - conflicting_forms: &'a mut [bool], + argument_forms: &'a mut ArgumentForms, errors: &'a mut Vec>, argument_matches: Vec>, @@ -1932,14 +1991,12 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { fn new( arguments: &CallArguments, parameters: &'a Parameters<'db>, - argument_forms: &'a mut [Option], - conflicting_forms: &'a mut [bool], + argument_forms: &'a mut ArgumentForms, errors: &'a mut Vec>, ) -> Self { Self { parameters, argument_forms, - conflicting_forms, errors, argument_matches: vec![MatchedArgument::default(); arguments.len()], parameter_matched: vec![false; parameters.len()], @@ -1971,11 +2028,13 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { positional: bool, ) { if !matches!(argument, Argument::Synthetic) { - if let Some(existing) = self.argument_forms[argument_index - self.num_synthetic_args] - .replace(parameter.form) + let adjusted_argument_index = argument_index - self.num_synthetic_args; + if let Some(existing) = + self.argument_forms.values[adjusted_argument_index].replace(parameter.form) { if existing != parameter.form { - self.conflicting_forms[argument_index - self.num_synthetic_args] = true; + self.argument_forms.conflicting[argument_index - self.num_synthetic_args] = + true; } } } @@ -2295,22 +2354,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // how many elements the iterator will produce. let argument_types = argument_type.iterate(self.db); - // TODO: When we perform argument expansion during overload resolution, we might need - // to retry both `match_parameters` _and_ `check_types` for each expansion. Currently - // we only retry `check_types`. The issue is that argument expansion might produce a - // splatted value with a different arity than what we originally inferred for the - // unexpanded value, and that in turn can affect which parameters the splatted value is - // matched with. As a workaround, make sure that the splatted tuple contains an - // arbitrary number of `Unknown`s at the end, so that if the expanded value has a - // smaller arity than the unexpanded value, we still have enough values to assign to - // the already matched parameters. - let argument_types = match argument_types.as_ref() { - Tuple::Fixed(_) => { - Cow::Owned(argument_types.concat(self.db, &Tuple::homogeneous(Type::unknown()))) - } - Tuple::Variable(_) => argument_types, - }; - // Resize the tuple of argument types to line up with the number of parameters this // argument was matched against. If parameter matching succeeded, then we can (TODO: // should be able to, see above) guarantee that all of the required elements of the @@ -2441,21 +2484,15 @@ impl<'db> Binding<'db> { } } - pub(crate) fn match_parameters( + fn match_parameters( &mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>, - argument_forms: &mut [Option], - conflicting_forms: &mut [bool], + argument_forms: &mut ArgumentForms, ) { let parameters = self.signature.parameters(); - let mut matcher = ArgumentMatcher::new( - arguments, - parameters, - argument_forms, - conflicting_forms, - &mut self.errors, - ); + let mut matcher = + ArgumentMatcher::new(arguments, parameters, argument_forms, &mut self.errors); for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() { match argument { Argument::Positional | Argument::Synthetic => { @@ -2610,6 +2647,16 @@ impl<'db> Binding<'db> { pub(crate) fn errors(&self) -> &[BindingError<'db>] { &self.errors } + + /// Resets the state of this binding to its initial state. + fn reset(&mut self) { + self.return_ty = Type::unknown(); + self.specialization = None; + self.inherited_specialization = None; + self.argument_matches = Box::from([]); + self.parameter_tys = Box::from([]); + self.errors.clear(); + } } #[derive(Clone, Debug)] diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index ab51e9de07..0dc9492162 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -5833,7 +5833,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let bindings = callable_type .bindings(self.db()) .match_parameters(self.db(), &call_arguments); - self.infer_argument_types(arguments, &mut call_arguments, &bindings.argument_forms); + self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms()); // Validate `TypedDict` constructor calls after argument type inference if let Some(class_literal) = callable_type.into_class_literal() { diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index e85fd9902f..5b76defefd 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -1015,11 +1015,6 @@ impl<'db> Tuple> { UnionType::from_elements(db, self.all_elements()) } - /// Concatenates another tuple to the end of this tuple, returning a new tuple. - pub(crate) fn concat(&self, db: &'db dyn Db, other: &Self) -> Self { - TupleSpecBuilder::from(self).concat(db, other).build() - } - /// Resizes this tuple to a different length, if possible. If this tuple cannot satisfy the /// desired minimum or maximum length, we return an error. If we return an `Ok` result, the /// [`len`][Self::len] of the resulting tuple is guaranteed to be equal to `new_length`.