Compare commits

...

3 Commits

Author SHA1 Message Date
Douglas Creager
7c976dc570 Merge branch 'main' into dcreager/function-enum
* main:
  Update pre-commit dependencies (#17506)
  [red-knot] Simplify visibility constraint handling for `*`-import definitions (#17486)
  [red-knot] Detect (some) invalid protocols (#17488)
  [red-knot] Correctly identify protocol classes (#17487)
  Update dependency ruff to v0.11.6 (#17516)
  Update Rust crate shellexpand to v3.1.1 (#17512)
  Update Rust crate proc-macro2 to v1.0.95 (#17510)
  Update Rust crate rand to v0.9.1 (#17511)
  Update Rust crate libc to v0.2.172 (#17509)
  Update Rust crate jiff to v0.2.9 (#17508)
  Update Rust crate clap to v4.5.37 (#17507)
  Update astral-sh/setup-uv action to v5.4.2 (#17504)
  Update taiki-e/install-action digest to 09dc018 (#17503)
  [red-knot] infer attribute assignments bound in comprehensions (#17396)
  [red-knot] simplify gradually-equivalent types out of unions and intersections (#17467)
  [red-knot] pull primer projects to run from file (#17473)
2025-04-21 13:18:36 -04:00
Douglas Creager
b44fb47f25 create generic context lazily 2025-04-21 13:15:09 -04:00
Douglas Creager
0a4dec0323 start pulling out enum 2025-04-18 17:04:26 -04:00
8 changed files with 276 additions and 161 deletions

View File

@@ -4605,8 +4605,8 @@ impl<'db> Type<'db> {
match self {
Type::TypeVar(typevar) => specialization.get(db, typevar).unwrap_or(self),
Type::FunctionLiteral(function) => {
Type::FunctionLiteral(function.apply_specialization(db, specialization))
Type::FunctionLiteral(function) =>{
Type::FunctionLiteral(FunctionType::Specialized(SpecializedFunction::new(db, function, specialization)))
}
// Note that we don't need to apply the specialization to `self_instance`, since it
@@ -4619,19 +4619,19 @@ impl<'db> Type<'db> {
// specialized.)
Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new(
db,
method.function(db).apply_specialization(db, specialization),
FunctionType::Specialized(SpecializedFunction::new(db, method.function(db), specialization)),
method.self_instance(db),
)),
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(
function.apply_specialization(db, specialization),
FunctionType::Specialized(SpecializedFunction::new(db, function, specialization))
))
}
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderCall(function)) => {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderCall(
function.apply_specialization(db, specialization),
FunctionType::Specialized(SpecializedFunction::new(db, function, specialization))
))
}
@@ -5834,6 +5834,34 @@ impl<'db> FunctionSignature<'db> {
pub(crate) fn iter(&self) -> Iter<Signature<'db>> {
self.as_slice().iter()
}
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
match self {
Self::Single(signature) => signature.apply_specialization(db, specialization),
Self::Overloaded(signatures, implementation) => {
signatures
.iter_mut()
.for_each(|signature| signature.apply_specialization(db, specialization));
implementation
.as_mut()
.map(|signature| signature.apply_specialization(db, specialization));
}
}
}
fn set_generic_context(&mut self, generic_context: GenericContext<'db>) {
match self {
Self::Single(signature) => signature.set_generic_context(generic_context),
Self::Overloaded(signatures, implementation) => {
signatures
.iter_mut()
.for_each(|signature| signature.set_generic_context(generic_context));
implementation
.as_mut()
.map(|signature| signature.set_generic_context(generic_context));
}
}
}
}
impl<'db> IntoIterator for &'db FunctionSignature<'db> {
@@ -5845,34 +5873,40 @@ impl<'db> IntoIterator for &'db FunctionSignature<'db> {
}
}
#[salsa::interned(debug)]
pub struct FunctionType<'db> {
/// Name of the function at definition.
#[return_ref]
pub name: ast::name::Name,
/// A callable type that represents a single Python function.
#[derive(
Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, salsa::Supertype, salsa::Update,
)]
pub enum FunctionType<'db> {
/// A function literal in the Python AST
FunctionLiteral(FunctionLiteral<'db>),
/// Is this a function that we special-case somehow? If so, which one?
known: Option<KnownFunction>,
/// A function that has a specialization applied to its signature.
///
/// (This does not necessarily mean that the function itself is generic — the methods of a
/// generic class, for instance, will have the class's specialization applied so that we
/// correctly substitute any class typevars that appear in the signature.)
Specialized(SpecializedFunction<'db>),
/// The scope that's created by the function, in which the function body is evaluated.
body_scope: ScopeId<'db>,
/// A set of special decorators that were applied to this function
decorators: FunctionDecorators,
/// The generic context of a generic function.
generic_context: Option<GenericContext<'db>>,
/// A specialization that should be applied to the function's parameter and return types,
/// either because the function is itself generic, or because it appears in the body of a
/// generic class.
specialization: Option<Specialization<'db>>,
/// A function that we treat as generic because it inherits a containing generic context.
///
/// This is currently only used for the `__new__` and `__init__` methods of a generic class.
/// That lets us pretend those methods are generic, so that we can infer a class specialization
/// from the arguments to its constructor.
InheritedGenericContext(FunctionWithInheritedGenericContext<'db>),
}
#[salsa::tracked]
impl<'db> FunctionType<'db> {
fn function_literal(self, db: &'db dyn Db) -> FunctionLiteral<'db> {
match self {
FunctionType::FunctionLiteral(literal) => literal,
FunctionType::Specialized(specialized) => specialized.function(db).function_literal(db),
FunctionType::InheritedGenericContext(inherited) => inherited.function(db),
}
}
pub(crate) fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool {
self.decorators(db).contains(decorator)
self.function_literal(db).decorators(db).contains(decorator)
}
/// Convert the `FunctionType` into a [`Type::Callable`].
@@ -5885,20 +5919,80 @@ impl<'db> FunctionType<'db> {
/// Returns the [`FileRange`] of the function's name.
pub fn focus_range(self, db: &dyn Db) -> FileRange {
let body_scope = self.function_literal(db).body_scope(db);
FileRange::new(
self.body_scope(db).file(db),
self.body_scope(db).node(db).expect_function().name.range,
body_scope.file(db),
body_scope.node(db).expect_function().name.range,
)
}
pub fn full_range(self, db: &dyn Db) -> FileRange {
let body_scope = self.function_literal(db).body_scope(db);
FileRange::new(
self.body_scope(db).file(db),
self.body_scope(db).node(db).expect_function().range,
body_scope.file(db),
body_scope.node(db).expect_function().range,
)
}
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
self.function_literal(db).definition(db)
}
/// Typed externally-visible signature for this function.
///
/// This is the signature as seen by external callers, possibly modified by decorators and/or
/// overloaded.
///
/// ## Why is this a salsa query?
///
/// This is a salsa query to short-circuit the invalidation
/// when the function's AST node changes.
///
/// Were this not a salsa query, then the calling query
/// would depend on the function's AST and rerun for every change in that file.
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
match self {
FunctionType::FunctionLiteral(literal) => literal.signature(db),
FunctionType::Specialized(specialized) => specialized.signature(db),
FunctionType::InheritedGenericContext(inherited) => inherited.signature(db),
}
}
pub(crate) fn known(self, db: &'db dyn Db) -> Option<KnownFunction> {
self.function_literal(db).known(db)
}
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
self.known(db) == Some(known_function)
}
}
#[salsa::interned(debug)]
pub struct FunctionLiteral<'db> {
/// Name of the function at definition.
#[return_ref]
pub name: ast::name::Name,
/// Is this a function that we special-case somehow? If so, which one?
known: Option<KnownFunction>,
/// The scope that's created by the function, in which the function body is evaluated.
body_scope: ScopeId<'db>,
/// The scope containing the PEP 695 type parameters in the function definition, if any.
type_params_scope: Option<ScopeId<'db>>,
/// A set of special decorators that were applied to this function
decorators: FunctionDecorators,
}
#[salsa::tracked]
impl<'db> FunctionLiteral<'db> {
fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool {
self.decorators(db).contains(decorator)
}
fn definition(self, db: &'db dyn Db) -> Definition<'db> {
let body_scope = self.body_scope(db);
let index = semantic_index(db, body_scope.file(db));
index.expect_single_definition(body_scope.node(db).expect_function())
@@ -5916,13 +6010,9 @@ impl<'db> FunctionType<'db> {
///
/// Were this not a salsa query, then the calling query
/// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked(return_ref)]
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let mut internal_signature = self.internal_signature(db);
if let Some(specialization) = self.specialization(db) {
internal_signature = internal_signature.apply_specialization(db, specialization);
}
#[salsa::tracked]
fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let internal_signature = self.internal_signature(db);
// The semantic model records a use for each function on the name node. This is used here
// to get the previous function definition with the same name.
@@ -5982,39 +6072,61 @@ impl<'db> FunctionType<'db> {
let scope = self.body_scope(db);
let function_stmt_node = scope.node(db).expect_function();
let definition = self.definition(db);
Signature::from_function(db, self.generic_context(db), definition, function_stmt_node)
let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| {
let index = semantic_index(db, scope.file(db));
GenericContext::from_type_params(db, index, type_params)
});
Signature::from_function(db, generic_context, definition, function_stmt_node)
}
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
self.known(db) == Some(known_function)
}
fn with_generic_context(self, db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self {
Self::new(
fn with_generic_context(
self,
db: &'db dyn Db,
generic_context: GenericContext<'db>,
) -> FunctionType<'db> {
FunctionType::InheritedGenericContext(FunctionWithInheritedGenericContext::new(
db,
self.name(db).clone(),
self.known(db),
self.body_scope(db),
self.decorators(db),
Some(generic_context),
self.specialization(db),
)
self,
generic_context,
))
}
}
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
let specialization = match self.specialization(db) {
Some(existing) => existing.apply_specialization(db, specialization),
None => specialization,
};
Self::new(
db,
self.name(db).clone(),
self.known(db),
self.body_scope(db),
self.decorators(db),
self.generic_context(db),
Some(specialization),
)
impl<'db> From<FunctionLiteral<'db>> for Type<'db> {
fn from(literal: FunctionLiteral<'db>) -> Type<'db> {
Type::FunctionLiteral(FunctionType::FunctionLiteral(literal))
}
}
#[salsa::interned(debug)]
pub struct SpecializedFunction<'db> {
function: FunctionType<'db>,
specialization: Specialization<'db>,
}
#[salsa::tracked]
impl<'db> SpecializedFunction<'db> {
#[salsa::tracked]
fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let mut signature = self.function(db).signature(db);
signature.apply_specialization(db, self.specialization(db));
signature
}
}
#[salsa::interned(debug)]
pub struct FunctionWithInheritedGenericContext<'db> {
function: FunctionLiteral<'db>,
generic_context: GenericContext<'db>,
}
#[salsa::tracked]
impl<'db> FunctionWithInheritedGenericContext<'db> {
#[salsa::tracked]
fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let mut signature = self.function(db).signature(db);
signature.set_generic_context(self.generic_context(db));
signature
}
}
@@ -6230,9 +6342,11 @@ impl<'db> CallableType<'db> {
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
CallableType::from_overloads(
db,
self.signatures(db)
.iter()
.map(|signature| signature.apply_specialization(db, specialization)),
self.signatures(db).iter().map(|signature| {
let mut signature = signature.clone();
signature.apply_specialization(db, specialization);
signature
}),
)
}

View File

@@ -219,7 +219,8 @@ impl<'db> Bindings<'db> {
match binding_type {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) {
let function_literal = function.function_literal(db);
if function_literal.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) {
match overload.parameter_types() {
[_, Some(owner)] => {
overload.set_return_type(Type::BoundMethod(BoundMethodType::new(
@@ -250,7 +251,9 @@ impl<'db> Bindings<'db> {
if let [Some(function_ty @ Type::FunctionLiteral(function)), ..] =
overload.parameter_types()
{
if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) {
let function_literal = function.function_literal(db);
if function_literal.has_known_decorator(db, FunctionDecorators::CLASSMETHOD)
{
match overload.parameter_types() {
[_, _, Some(owner)] => {
overload.set_return_type(Type::BoundMethod(
@@ -298,7 +301,7 @@ impl<'db> Bindings<'db> {
if property.getter(db).is_some_and(|getter| {
getter
.into_function_literal()
.is_some_and(|f| f.name(db) == "__name__")
.is_some_and(|f| f.function_literal(db).name(db) == "__name__")
}) =>
{
overload.set_return_type(Type::string_literal(db, type_alias.name(db)));
@@ -307,7 +310,7 @@ impl<'db> Bindings<'db> {
if property.getter(db).is_some_and(|getter| {
getter
.into_function_literal()
.is_some_and(|f| f.name(db) == "__name__")
.is_some_and(|f| f.function_literal(db).name(db) == "__name__")
}) =>
{
overload.set_return_type(Type::string_literal(db, type_var.name(db)));
@@ -416,7 +419,12 @@ impl<'db> Bindings<'db> {
Type::BoundMethod(bound_method)
if bound_method.self_instance(db).is_property_instance() =>
{
match bound_method.function(db).name(db).as_str() {
match bound_method
.function(db)
.function_literal(db)
.name(db)
.as_str()
{
"setter" => {
if let [Some(_), Some(setter)] = overload.parameter_types() {
let mut ty_property = bound_method.self_instance(db);
@@ -456,7 +464,10 @@ impl<'db> Bindings<'db> {
}
}
Type::FunctionLiteral(function_type) => match function_type.known(db) {
Type::FunctionLiteral(function_type) => match function_type
.function_literal(db)
.known(db)
{
Some(KnownFunction::IsEquivalentTo) => {
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
overload.set_return_type(Type::BooleanLiteral(
@@ -1175,7 +1186,7 @@ impl<'db> CallableDescription<'db> {
match callable_type {
Type::FunctionLiteral(function) => Some(CallableDescription {
kind: "function",
name: function.name(db),
name: function.function_literal(db).name(db),
}),
Type::ClassLiteral(class_type) => Some(CallableDescription {
kind: "class",
@@ -1183,12 +1194,12 @@ impl<'db> CallableDescription<'db> {
}),
Type::BoundMethod(bound_method) => Some(CallableDescription {
kind: "bound method",
name: bound_method.function(db).name(db),
name: bound_method.function(db).function_literal(db).name(db),
}),
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
Some(CallableDescription {
kind: "method wrapper `__get__` of function",
name: function.name(db),
name: function.function_literal(db).name(db),
})
}
Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(_)) => {
@@ -1313,7 +1324,7 @@ impl<'db> BindingError<'db> {
) -> Option<(Span, Span)> {
match callable_ty {
Type::FunctionLiteral(function) => {
let function_scope = function.body_scope(db);
let function_scope = function.function_literal(db).body_scope(db);
let span = Span::from(function_scope.file(db));
let node = function_scope.node(db);
if let Some(func_def) = node.as_function() {

View File

@@ -962,7 +962,9 @@ impl<'db> ClassLiteralType<'db> {
Some(_),
"__new__" | "__init__",
) => Type::FunctionLiteral(
function.with_generic_context(db, origin.generic_context(db)),
function
.function_literal(db)
.with_generic_context(db, origin.generic_context(db)),
),
_ => ty,
}

View File

@@ -169,7 +169,9 @@ impl<'db> InferContext<'db> {
// Iterate over all functions and test if any is decorated with `@no_type_check`.
function_scope_tys.any(|function_ty| {
function_ty.has_known_decorator(self.db, FunctionDecorators::NO_TYPE_CHECK)
function_ty
.function_literal(self.db)
.has_known_decorator(self.db, FunctionDecorators::NO_TYPE_CHECK)
})
}
InNoTypeCheck::Yes => true,

View File

@@ -1125,7 +1125,7 @@ fn report_invalid_assignment_with_message(
Type::FunctionLiteral(function) => {
context.report_lint_old(&INVALID_ASSIGNMENT, node, format_args!(
"Implicit shadowing of function `{}`; annotate to make it explicit if this is intentional",
function.name(context.db())));
function.function_literal(context.db()).name(context.db())));
}
_ => {
context.report_lint_old(&INVALID_ASSIGNMENT, node, message);

View File

@@ -10,7 +10,7 @@ use crate::types::class::{ClassType, GenericAlias, GenericClass};
use crate::types::generics::{GenericContext, Specialization};
use crate::types::signatures::{Parameter, Parameters, Signature};
use crate::types::{
FunctionSignature, InstanceType, IntersectionType, KnownClass, MethodWrapperKind,
FunctionSignature, FunctionType, InstanceType, IntersectionType, KnownClass, MethodWrapperKind,
StringLiteralType, SubclassOfInner, Type, TypeVarBoundOrConstraints, TypeVarInstance,
UnionType, WrapperDescriptorKind,
};
@@ -108,7 +108,7 @@ impl Display for DisplayRepresentation<'_> {
f,
// "def {name}{specialization}{signature}",
"def {name}{signature}",
name = function.name(self.db),
name = function.function_literal(self.db).name(self.db),
signature = signature.display(self.db)
)
}
@@ -135,7 +135,7 @@ impl Display for DisplayRepresentation<'_> {
write!(
f,
"bound method {instance}.{method}{signature}",
method = function.name(self.db),
method = function.function_literal(self.db).name(self.db),
instance = bound_method.self_instance(self.db).display(self.db),
signature = signature.bind_self().display(self.db)
)
@@ -155,10 +155,12 @@ impl Display for DisplayRepresentation<'_> {
write!(
f,
"<method-wrapper `__get__` of `{function}{specialization}`>",
function = function.name(self.db),
specialization = if let Some(specialization) = function.specialization(self.db)
{
specialization.display_short(self.db).to_string()
function = function.function_literal(self.db).name(self.db),
specialization = if let FunctionType::Specialized(specialized) = function {
specialized
.specialization(self.db)
.display_short(self.db)
.to_string()
} else {
String::new()
},
@@ -168,10 +170,12 @@ impl Display for DisplayRepresentation<'_> {
write!(
f,
"<method-wrapper `__call__` of `{function}{specialization}`>",
function = function.name(self.db),
specialization = if let Some(specialization) = function.specialization(self.db)
{
specialization.display_short(self.db).to_string()
function = function.function_literal(self.db).name(self.db),
specialization = if let FunctionType::Specialized(specialized) = function {
specialized
.specialization(self.db)
.display_short(self.db)
.to_string()
} else {
String::new()
},

View File

@@ -82,12 +82,12 @@ use crate::types::mro::MroErrorKind;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
binding_type, todo_type, CallDunderError, CallableSignature, CallableType, Class,
ClassLiteralType, ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType,
GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter,
ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType,
SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType,
TypeAndQualifiers, TypeArrayDisplay, TypeQualifiers, TypeVarBoundOrConstraints,
ClassLiteralType, ClassType, DataclassMetadata, DynamicType, FunctionDecorators,
FunctionLiteral, GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass,
KnownFunction, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass,
Parameter, ParameterForm, Parameters, Signature, Signatures, SliceLiteralType,
StringLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type,
TypeAliasType, TypeAndQualifiers, TypeArrayDisplay, TypeQualifiers, TypeVarBoundOrConstraints,
TypeVarInstance, UnionBuilder, UnionType,
};
use crate::unpack::{Unpack, UnpackPosition};
@@ -1503,10 +1503,6 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
let generic_context = type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(self.db(), self.index, type_params)
});
let function_kind =
KnownFunction::try_from_definition_and_name(self.db(), definition, name);
@@ -1515,16 +1511,19 @@ impl<'db> TypeInferenceBuilder<'db> {
.node_scope(NodeWithScopeRef::Function(function))
.to_scope_id(self.db(), self.file());
let specialization = None;
let type_params_scope = type_params.as_ref().map(|_| {
self.index
.node_scope(NodeWithScopeRef::FunctionTypeParameters(function))
.to_scope_id(self.db(), self.file())
});
let mut inferred_ty = Type::FunctionLiteral(FunctionType::new(
let mut inferred_ty = Type::from(FunctionLiteral::new(
self.db(),
&name.id,
function_kind,
body_scope,
type_params_scope,
function_decorators,
generic_context,
specialization,
));
for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() {

View File

@@ -289,17 +289,18 @@ impl<'db> Signature<'db> {
}
pub(crate) fn apply_specialization(
&self,
&mut self,
db: &'db dyn Db,
specialization: Specialization<'db>,
) -> Self {
Self {
generic_context: self.generic_context,
parameters: self.parameters.apply_specialization(db, specialization),
return_ty: self
.return_ty
.map(|ty| ty.apply_specialization(db, specialization)),
}
) {
self.parameters.apply_specialization(db, specialization);
self.return_ty
.as_mut()
.map(|ty| *ty = ty.apply_specialization(db, specialization));
}
pub(crate) fn set_generic_context(&mut self, generic_context: GenericContext<'db>) {
self.generic_context = Some(generic_context);
}
/// Return the parameters in this signature.
@@ -1000,15 +1001,10 @@ impl<'db> Parameters<'db> {
)
}
fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
Self {
value: self
.value
.iter()
.map(|param| param.apply_specialization(db, specialization))
.collect(),
is_gradual: self.is_gradual,
}
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
self.value
.iter_mut()
.for_each(|param| param.apply_specialization(db, specialization));
}
pub(crate) fn len(&self) -> usize {
@@ -1172,14 +1168,11 @@ impl<'db> Parameter<'db> {
self
}
fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
Self {
annotated_type: self
.annotated_type
.map(|ty| ty.apply_specialization(db, specialization)),
kind: self.kind.apply_specialization(db, specialization),
form: self.form,
}
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
self.annotated_type
.as_mut()
.map(|ty| *ty = ty.apply_specialization(db, specialization));
self.kind.apply_specialization(db, specialization);
}
/// Strip information from the parameter so that two equivalent parameters compare equal.
@@ -1369,27 +1362,16 @@ pub(crate) enum ParameterKind<'db> {
}
impl<'db> ParameterKind<'db> {
fn apply_specialization(&self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
match self {
Self::PositionalOnly { default_type, name } => Self::PositionalOnly {
default_type: default_type
.as_ref()
.map(|ty| ty.apply_specialization(db, specialization)),
name: name.clone(),
},
Self::PositionalOrKeyword { default_type, name } => Self::PositionalOrKeyword {
default_type: default_type
.as_ref()
.map(|ty| ty.apply_specialization(db, specialization)),
name: name.clone(),
},
Self::KeywordOnly { default_type, name } => Self::KeywordOnly {
default_type: default_type
.as_ref()
.map(|ty| ty.apply_specialization(db, specialization)),
name: name.clone(),
},
Self::Variadic { .. } | Self::KeywordVariadic { .. } => self.clone(),
Self::PositionalOnly { default_type, .. }
| Self::PositionalOrKeyword { default_type, .. }
| Self::KeywordOnly { default_type, .. } => {
default_type
.as_mut()
.map(|ty| *ty = ty.apply_specialization(db, specialization));
}
Self::Variadic { .. } | Self::KeywordVariadic { .. } => {}
}
}
}
@@ -1406,16 +1388,20 @@ mod tests {
use super::*;
use crate::db::tests::{setup_db, TestDb};
use crate::symbol::global_symbol;
use crate::types::{FunctionSignature, FunctionType, KnownClass};
use crate::types::{FunctionLiteral, FunctionSignature, FunctionType, KnownClass};
use ruff_db::system::DbWithWritableSystem as _;
#[track_caller]
fn get_function_f<'db>(db: &'db TestDb, file: &'static str) -> FunctionType<'db> {
fn get_function_f<'db>(db: &'db TestDb, file: &'static str) -> FunctionLiteral<'db> {
let module = ruff_db::files::system_path_to_file(db, file).unwrap();
global_symbol(db, module, "f")
let function = global_symbol(db, module, "f")
.symbol
.expect_type()
.expect_function_literal()
.expect_function_literal();
let FunctionType::FunctionLiteral(literal) = function else {
panic!("function should be a function literal");
};
literal
}
#[track_caller]
@@ -1653,9 +1639,6 @@ mod tests {
let expected_sig = func.internal_signature(&db);
// With no decorators, internal and external signature are the same
assert_eq!(
func.signature(&db),
&FunctionSignature::Single(expected_sig)
);
assert_eq!(func.signature(&db), FunctionSignature::Single(expected_sig));
}
}