[red-knot] Support cast (#15413)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
InSync
2025-01-12 20:05:45 +07:00
committed by GitHub
parent 60d7a464fb
commit 6ae3e8f8d7
3 changed files with 70 additions and 4 deletions

View File

@@ -0,0 +1,27 @@
# `cast`
`cast()` takes two arguments, one type and one value, and returns a value of the given type.
The (inferred) type of the value and the given type do not need to have any correlation.
```py
from typing import Literal, cast
reveal_type(True) # revealed: Literal[True]
reveal_type(cast(str, True)) # revealed: str
reveal_type(cast("str", True)) # revealed: str
reveal_type(cast(int | str, 1)) # revealed: int | str
# error: [invalid-type-form]
reveal_type(cast(Literal, True)) # revealed: Unknown
# TODO: These should be errors
cast(1)
cast(str)
cast(str, b"ar", "foo")
# TODO: Either support keyword arguments properly,
# or give a comprehensible error message saying they're unsupported
cast(val="foo", typ=int) # error: [unresolved-reference] "Name `foo` used when not defined"
```

View File

@@ -2014,6 +2014,20 @@ impl<'db> Type<'db> {
CallOutcome::asserted(binding, asserted_ty)
}
Some(KnownFunction::Cast) => {
// TODO: Use `.two_parameter_tys()` exclusively
// when overloads are supported.
if binding.two_parameter_tys().is_none() {
return CallOutcome::callable(binding);
};
let Some(casted_ty) = arguments.first_argument() else {
return CallOutcome::callable(binding);
};
CallOutcome::casted(binding, casted_ty)
}
_ => CallOutcome::callable(binding),
}
}
@@ -3353,6 +3367,8 @@ pub enum KnownFunction {
/// `typing(_extensions).assert_type`
AssertType,
/// `typing(_extensions).cast`
Cast,
/// `knot_extensions.static_assert`
StaticAssert,
@@ -3399,6 +3415,7 @@ impl KnownFunction {
Some(KnownFunction::NoTypeCheck)
}
"assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType),
"cast" if definition.is_typing_definition(db) => Some(KnownFunction::Cast),
"static_assert" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::StaticAssert)
}
@@ -3441,6 +3458,7 @@ impl KnownFunction {
| Self::IsDisjointFrom => ParameterExpectations::TwoTypeExpressions,
Self::AssertType => ParameterExpectations::ValueExpressionAndTypeExpression,
Self::Cast => ParameterExpectations::TypeExpressionAndValueExpression,
Self::ConstraintFunction(_)
| Self::Len
@@ -3468,6 +3486,9 @@ enum ParameterExpectations {
/// The first parameter in the function expects a value expression,
/// and the second expects a type expression
ValueExpressionAndTypeExpression,
/// The first parameter in the function expects a type expression,
/// and the second expects a value expression
TypeExpressionAndValueExpression,
}
impl ParameterExpectations {
@@ -3475,7 +3496,7 @@ impl ParameterExpectations {
fn expectation_at_index(self, parameter_index: usize) -> ParameterExpectation {
match self {
Self::AllValueExpressions => ParameterExpectation::ValueExpression,
Self::SingleTypeExpression => {
Self::SingleTypeExpression | Self::TypeExpressionAndValueExpression => {
if parameter_index == 0 {
ParameterExpectation::TypeExpression
} else {
@@ -3833,6 +3854,7 @@ impl<'db> Class<'db> {
| CallOutcome::RevealType { binding, .. }
| CallOutcome::StaticAssertionError { binding, .. }
| CallOutcome::AssertType { binding, .. } => Ok(binding.return_ty()),
CallOutcome::Cast { casted_ty, .. } => Ok(casted_ty),
};
return return_ty_result.map(|ty| ty.to_meta_type(db));

View File

@@ -48,6 +48,10 @@ pub(super) enum CallOutcome<'db> {
binding: CallBinding<'db>,
asserted_ty: Type<'db>,
},
Cast {
binding: CallBinding<'db>,
casted_ty: Type<'db>,
},
}
impl<'db> CallOutcome<'db> {
@@ -80,7 +84,7 @@ impl<'db> CallOutcome<'db> {
}
}
/// Create a new `CallOutcome::AssertType` with given revealed and return types.
/// Create a new `CallOutcome::AssertType` with given asserted and return types.
pub(super) fn asserted(binding: CallBinding<'db>, asserted_ty: Type<'db>) -> CallOutcome<'db> {
CallOutcome::AssertType {
binding,
@@ -88,6 +92,11 @@ impl<'db> CallOutcome<'db> {
}
}
/// Create a new `CallOutcome::Casted` with given casted and return types.
pub(super) fn casted(binding: CallBinding<'db>, casted_ty: Type<'db>) -> CallOutcome<'db> {
CallOutcome::Cast { binding, casted_ty }
}
/// Get the return type of the call, or `None` if not callable.
pub(super) fn return_ty(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
@@ -119,6 +128,10 @@ impl<'db> CallOutcome<'db> {
binding,
asserted_ty: _,
} => Some(binding.return_ty()),
Self::Cast {
binding: _,
casted_ty,
} => Some(*casted_ty),
}
}
@@ -280,7 +293,7 @@ impl<'db> CallOutcome<'db> {
}),
}
}
CallOutcome::StaticAssertionError {
Self::StaticAssertionError {
binding,
error_kind,
} => {
@@ -325,7 +338,7 @@ impl<'db> CallOutcome<'db> {
Ok(Type::unknown())
}
CallOutcome::AssertType {
Self::AssertType {
binding,
asserted_ty,
} => {
@@ -347,6 +360,10 @@ impl<'db> CallOutcome<'db> {
Ok(binding.return_ty())
}
Self::Cast {
binding: _,
casted_ty,
} => Ok(*casted_ty),
}
}
}