diff --git a/crates/red_knot_python_semantic/resources/mdtest/directives/cast.md b/crates/red_knot_python_semantic/resources/mdtest/directives/cast.md new file mode 100644 index 0000000000..600a3316cb --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/directives/cast.md @@ -0,0 +1,27 @@ +# `cast` + +`cast()` takes two arguments, one type and one value, and returns a value of the given type. + +The (inferred) type of the value and the given type do not need to have any correlation. + +```py +from typing import Literal, cast + +reveal_type(True) # revealed: Literal[True] +reveal_type(cast(str, True)) # revealed: str +reveal_type(cast("str", True)) # revealed: str + +reveal_type(cast(int | str, 1)) # revealed: int | str + +# error: [invalid-type-form] +reveal_type(cast(Literal, True)) # revealed: Unknown + +# TODO: These should be errors +cast(1) +cast(str) +cast(str, b"ar", "foo") + +# TODO: Either support keyword arguments properly, +# or give a comprehensible error message saying they're unsupported +cast(val="foo", typ=int) # error: [unresolved-reference] "Name `foo` used when not defined" +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 6606cb3640..d7d83c697d 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -2014,6 +2014,20 @@ impl<'db> Type<'db> { CallOutcome::asserted(binding, asserted_ty) } + Some(KnownFunction::Cast) => { + // TODO: Use `.two_parameter_tys()` exclusively + // when overloads are supported. + if binding.two_parameter_tys().is_none() { + return CallOutcome::callable(binding); + }; + + let Some(casted_ty) = arguments.first_argument() else { + return CallOutcome::callable(binding); + }; + + CallOutcome::casted(binding, casted_ty) + } + _ => CallOutcome::callable(binding), } } @@ -3353,6 +3367,8 @@ pub enum KnownFunction { /// `typing(_extensions).assert_type` AssertType, + /// `typing(_extensions).cast` + Cast, /// `knot_extensions.static_assert` StaticAssert, @@ -3399,6 +3415,7 @@ impl KnownFunction { Some(KnownFunction::NoTypeCheck) } "assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType), + "cast" if definition.is_typing_definition(db) => Some(KnownFunction::Cast), "static_assert" if definition.is_knot_extensions_definition(db) => { Some(KnownFunction::StaticAssert) } @@ -3441,6 +3458,7 @@ impl KnownFunction { | Self::IsDisjointFrom => ParameterExpectations::TwoTypeExpressions, Self::AssertType => ParameterExpectations::ValueExpressionAndTypeExpression, + Self::Cast => ParameterExpectations::TypeExpressionAndValueExpression, Self::ConstraintFunction(_) | Self::Len @@ -3468,6 +3486,9 @@ enum ParameterExpectations { /// The first parameter in the function expects a value expression, /// and the second expects a type expression ValueExpressionAndTypeExpression, + /// The first parameter in the function expects a type expression, + /// and the second expects a value expression + TypeExpressionAndValueExpression, } impl ParameterExpectations { @@ -3475,7 +3496,7 @@ impl ParameterExpectations { fn expectation_at_index(self, parameter_index: usize) -> ParameterExpectation { match self { Self::AllValueExpressions => ParameterExpectation::ValueExpression, - Self::SingleTypeExpression => { + Self::SingleTypeExpression | Self::TypeExpressionAndValueExpression => { if parameter_index == 0 { ParameterExpectation::TypeExpression } else { @@ -3833,6 +3854,7 @@ impl<'db> Class<'db> { | CallOutcome::RevealType { binding, .. } | CallOutcome::StaticAssertionError { binding, .. } | CallOutcome::AssertType { binding, .. } => Ok(binding.return_ty()), + CallOutcome::Cast { casted_ty, .. } => Ok(casted_ty), }; return return_ty_result.map(|ty| ty.to_meta_type(db)); diff --git a/crates/red_knot_python_semantic/src/types/call.rs b/crates/red_knot_python_semantic/src/types/call.rs index 5b8558131f..b31ac891eb 100644 --- a/crates/red_knot_python_semantic/src/types/call.rs +++ b/crates/red_knot_python_semantic/src/types/call.rs @@ -48,6 +48,10 @@ pub(super) enum CallOutcome<'db> { binding: CallBinding<'db>, asserted_ty: Type<'db>, }, + Cast { + binding: CallBinding<'db>, + casted_ty: Type<'db>, + }, } impl<'db> CallOutcome<'db> { @@ -80,7 +84,7 @@ impl<'db> CallOutcome<'db> { } } - /// Create a new `CallOutcome::AssertType` with given revealed and return types. + /// Create a new `CallOutcome::AssertType` with given asserted and return types. pub(super) fn asserted(binding: CallBinding<'db>, asserted_ty: Type<'db>) -> CallOutcome<'db> { CallOutcome::AssertType { binding, @@ -88,6 +92,11 @@ impl<'db> CallOutcome<'db> { } } + /// Create a new `CallOutcome::Casted` with given casted and return types. + pub(super) fn casted(binding: CallBinding<'db>, casted_ty: Type<'db>) -> CallOutcome<'db> { + CallOutcome::Cast { binding, casted_ty } + } + /// Get the return type of the call, or `None` if not callable. pub(super) fn return_ty(&self, db: &'db dyn Db) -> Option> { match self { @@ -119,6 +128,10 @@ impl<'db> CallOutcome<'db> { binding, asserted_ty: _, } => Some(binding.return_ty()), + Self::Cast { + binding: _, + casted_ty, + } => Some(*casted_ty), } } @@ -280,7 +293,7 @@ impl<'db> CallOutcome<'db> { }), } } - CallOutcome::StaticAssertionError { + Self::StaticAssertionError { binding, error_kind, } => { @@ -325,7 +338,7 @@ impl<'db> CallOutcome<'db> { Ok(Type::unknown()) } - CallOutcome::AssertType { + Self::AssertType { binding, asserted_ty, } => { @@ -347,6 +360,10 @@ impl<'db> CallOutcome<'db> { Ok(binding.return_ty()) } + Self::Cast { + binding: _, + casted_ty, + } => Ok(*casted_ty), } } }