Compare commits

...

2 Commits

Author SHA1 Message Date
Dhruv Manilawala
3420cdda09 [ty] Add truncation policy for displaying overloads on single line 2025-12-12 20:24:59 +05:30
Dhruv Manilawala
76ffc56e85 [ty] Add basic support for overloads in ParamSpec 2025-12-12 16:23:48 +05:30
3 changed files with 99 additions and 19 deletions

View File

@@ -664,9 +664,15 @@ reveal_type(change_return_type(int_str)) # revealed: Overload[(x: int) -> str,
# error: [invalid-argument-type]
reveal_type(change_return_type(str_str)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
# TODO: Both of these shouldn't raise an error
# error: [invalid-argument-type]
# TODO: This should reveal the matching overload instead
reveal_type(with_parameters(int_int, 1)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
# error: [invalid-argument-type]
reveal_type(with_parameters(int_int, "a")) # revealed: Overload[(x: int) -> str, (x: str) -> str]
# error: [invalid-argument-type] "Argument to function `with_parameters` is incorrect: Expected `int`, found `None`"
reveal_type(with_parameters(int_int, None)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
def foo(int_or_str: int | str):
# Argument type expansion leads to matching both overloads.
# TODO: Should this be an error instead?
reveal_type(with_parameters(int_int, int_or_str)) # revealed: Overload[(x: int) -> str, (x: str) -> str]
```

View File

@@ -3313,8 +3313,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
/// are passed.
///
/// This method returns `false` if the specialization does not contain a mapping for the given
/// `paramspec`, contains an invalid mapping (i.e., not a `Callable` of kind `ParamSpecValue`)
/// or if the value is an overloaded callable.
/// `paramspec` or contains an invalid mapping (i.e., not a `Callable` of kind `ParamSpecValue`).
///
/// For more details, refer to [`Self::try_paramspec_evaluation_at`].
fn evaluate_paramspec_sub_call(
@@ -3333,10 +3332,10 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
return false;
}
// TODO: Support overloads?
let [signature] = callable.signatures(self.db).overloads.as_slice() else {
let signatures = &callable.signatures(self.db).overloads;
if signatures.is_empty() {
return false;
};
}
let sub_arguments = if let Some(argument_index) = argument_index {
self.arguments.start_from(argument_index)
@@ -3344,21 +3343,61 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
CallArguments::none()
};
// TODO: What should be the `signature_type` here?
let bindings = match Bindings::from(Binding::single(self.signature_type, signature.clone()))
// Create Bindings with all overloads and perform full overload resolution
let callable_binding =
CallableBinding::from_overloads(self.signature_type, signatures.iter().cloned());
let bindings = match Bindings::from(callable_binding)
.match_parameters(self.db, &sub_arguments)
.check_types(self.db, &sub_arguments, self.call_expression_tcx, &[])
{
Ok(bindings) => Box::new(bindings),
Err(CallError(_, bindings)) => bindings,
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => *bindings,
};
// SAFETY: `bindings` was created from a single binding above.
let [binding] = bindings.single_element().unwrap().overloads.as_slice() else {
unreachable!("ParamSpec sub-call should only contain a single binding");
// SAFETY: `bindings` was created from a single `CallableBinding` above.
let Some(callable_binding) = bindings.single_element() else {
unreachable!("ParamSpec sub-call should only contain a single CallableBinding");
};
self.errors.extend(binding.errors.iter().cloned());
match callable_binding.matching_overload_index() {
MatchingOverloadIndex::None => {
if let [binding] = callable_binding.overloads() {
// This is not an overloaded function, so we can propagate its errors
// to the outer bindings.
self.errors.extend(binding.errors.iter().cloned());
} else {
let index = callable_binding
.matching_overload_before_type_checking
.unwrap_or(0);
// TODO: We should also update the specialization for the `ParamSpec` to reflect
// the matching overload here.
self.errors
.extend(callable_binding.overloads()[index].errors.iter().cloned());
}
}
MatchingOverloadIndex::Single(index) => {
// TODO: We should also update the specialization for the `ParamSpec` to reflect the
// matching overload here.
self.errors
.extend(callable_binding.overloads()[index].errors.iter().cloned());
}
MatchingOverloadIndex::Multiple(_) => {
if !matches!(
callable_binding.overload_call_return_type,
Some(OverloadCallReturnType::ArgumentTypeExpansion(_))
) {
self.errors.extend(
callable_binding
.overloads()
.first()
.unwrap()
.errors
.iter()
.cloned(),
);
}
}
}
true
}

View File

@@ -746,13 +746,23 @@ impl<'db> FmtDetailed<'db> for DisplayRepresentation<'db> {
}
let separator = if self.settings.multiline { "\n" } else { ", " };
let mut join = f.join(separator);
for signature in signatures {
let display_limit = OVERLOAD_POLICY
.display_limit(signatures.len(), self.settings.multiline);
for signature in signatures.iter().take(display_limit) {
join.entry(
&signature
.bind_self(self.db, Some(typing_self_ty))
.display_with(self.db, self.settings.clone()),
);
}
if !self.settings.multiline {
let omitted = signatures.len().saturating_sub(display_limit);
join.entry(&DisplayOmitted {
count: omitted,
singular: "overload",
plural: "overloads",
});
}
join.finish()?;
if !self.settings.multiline {
f.write_str("]")?;
@@ -1152,9 +1162,19 @@ impl<'db> FmtDetailed<'db> for DisplayFunctionType<'db> {
}
let separator = if self.settings.multiline { "\n" } else { ", " };
let mut join = f.join(separator);
for signature in signatures {
let display_limit =
OVERLOAD_POLICY.display_limit(signatures.len(), self.settings.multiline);
for signature in signatures.iter().take(display_limit) {
join.entry(&signature.display_with(self.db, self.settings.clone()));
}
if !self.settings.multiline {
let omitted = signatures.len().saturating_sub(display_limit);
join.entry(&DisplayOmitted {
count: omitted,
singular: "overload",
plural: "overloads",
});
}
join.finish()?;
if !self.settings.multiline {
f.write_str("]")?;
@@ -1470,6 +1490,11 @@ impl TupleSpecialization {
}
}
const OVERLOAD_POLICY: TruncationPolicy = TruncationPolicy {
max: 3,
max_when_elided: 2,
};
impl<'db> CallableType<'db> {
pub(crate) fn display<'a>(&'a self, db: &'db dyn Db) -> DisplayCallableType<'a, 'db> {
Self::display_with(self, db, DisplaySettings::default())
@@ -1521,9 +1546,19 @@ impl<'db> FmtDetailed<'db> for DisplayCallableType<'_, 'db> {
}
let separator = if self.settings.multiline { "\n" } else { ", " };
let mut join = f.join(separator);
for signature in signatures {
let display_limit =
OVERLOAD_POLICY.display_limit(signatures.len(), self.settings.multiline);
for signature in signatures.iter().take(display_limit) {
join.entry(&signature.display_with(self.db, self.settings.clone()));
}
if !self.settings.multiline {
let omitted = signatures.len().saturating_sub(display_limit);
join.entry(&DisplayOmitted {
count: omitted,
singular: "overload",
plural: "overloads",
});
}
join.finish()?;
if !self.settings.multiline {
f.write_char(']')?;