[red-knot] Support re-export conventions for stub files (#16073)

This is an alternative implementation to #15848.

## Summary

This PR adds support for re-export conventions for imports for stub
files.

**How does this work?**
* Add a new flag on the `Import` and `ImportFrom` definitions to
indicate whether they're being exported or not
* Add a new enum to indicate whether the symbol lookup is happening
within the same file or is being queried from another file (e.g., an
import statement)
* When a `Symbol` is being queried, we'll skip the definitions that are
(a) coming from a stub file (b) external lookup and (c) check the
re-export flag on the definition

This implementation does not yet support `__all__` and `*` imports as
both are features that needs to be implemented independently.

closes: #14099
closes: #15476 

## Test Plan

Add test cases, update existing ones if required.
This commit is contained in:
Dhruv Manilawala
2025-02-14 15:17:51 +05:30
committed by GitHub
parent 3d0a58eb60
commit 60b3ef2c98
7 changed files with 594 additions and 101 deletions

View File

@@ -33,8 +33,8 @@ use crate::Db;
use super::constraint::{Constraint, ConstraintNode, PatternConstraint};
use super::definition::{
DefinitionCategory, ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef,
WithItemDefinitionNodeRef,
DefinitionCategory, ExceptHandlerDefinitionNodeRef, ImportDefinitionNodeRef,
MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef,
};
mod except_handlers;
@@ -886,22 +886,28 @@ where
self.imported_modules.extend(module_name.ancestors());
}
let symbol_name = if let Some(asname) = &alias.asname {
asname.id.clone()
let (symbol_name, is_reexported) = if let Some(asname) = &alias.asname {
(asname.id.clone(), asname.id == alias.name.id)
} else {
Name::new(alias.name.id.split('.').next().unwrap())
(Name::new(alias.name.id.split('.').next().unwrap()), false)
};
let symbol = self.add_symbol(symbol_name);
self.add_definition(symbol, alias);
self.add_definition(
symbol,
ImportDefinitionNodeRef {
alias,
is_reexported,
},
);
}
}
ast::Stmt::ImportFrom(node) => {
for (alias_index, alias) in node.names.iter().enumerate() {
let symbol_name = if let Some(asname) = &alias.asname {
&asname.id
let (symbol_name, is_reexported) = if let Some(asname) = &alias.asname {
(&asname.id, asname.id == alias.name.id)
} else {
&alias.name.id
(&alias.name.id, false)
};
// Look for imports `from __future__ import annotations`, ignore `as ...`
@@ -914,7 +920,14 @@ where
let symbol = self.add_symbol(symbol_name.clone());
self.add_definition(symbol, ImportFromDefinitionNodeRef { node, alias_index });
self.add_definition(
symbol,
ImportFromDefinitionNodeRef {
node,
alias_index,
is_reexported,
},
);
}
}
ast::Stmt::Assign(node) => {

View File

@@ -50,6 +50,10 @@ impl<'db> Definition<'db> {
self.kind(db).category()
}
pub(crate) fn in_stub(self, db: &'db dyn Db) -> bool {
self.file(db).is_stub(db.upcast())
}
pub(crate) fn is_declaration(self, db: &'db dyn Db) -> bool {
self.kind(db).category().is_declaration()
}
@@ -57,11 +61,15 @@ impl<'db> Definition<'db> {
pub(crate) fn is_binding(self, db: &'db dyn Db) -> bool {
self.kind(db).category().is_binding()
}
pub(crate) fn is_reexported(self, db: &'db dyn Db) -> bool {
self.kind(db).is_reexported()
}
}
#[derive(Copy, Clone, Debug)]
pub(crate) enum DefinitionNodeRef<'a> {
Import(&'a ast::Alias),
Import(ImportDefinitionNodeRef<'a>),
ImportFrom(ImportFromDefinitionNodeRef<'a>),
For(ForStmtDefinitionNodeRef<'a>),
Function(&'a ast::StmtFunctionDef),
@@ -119,12 +127,6 @@ impl<'a> From<&'a ast::StmtAugAssign> for DefinitionNodeRef<'a> {
}
}
impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> {
fn from(node_ref: &'a ast::Alias) -> Self {
Self::Import(node_ref)
}
}
impl<'a> From<&'a ast::TypeParamTypeVar> for DefinitionNodeRef<'a> {
fn from(value: &'a ast::TypeParamTypeVar) -> Self {
Self::TypeVar(value)
@@ -143,6 +145,12 @@ impl<'a> From<&'a ast::TypeParamTypeVarTuple> for DefinitionNodeRef<'a> {
}
}
impl<'a> From<ImportDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: ImportDefinitionNodeRef<'a>) -> Self {
Self::Import(node_ref)
}
}
impl<'a> From<ImportFromDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: ImportFromDefinitionNodeRef<'a>) -> Self {
Self::ImportFrom(node_ref)
@@ -185,10 +193,17 @@ impl<'a> From<MatchPatternDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ImportDefinitionNodeRef<'a> {
pub(crate) alias: &'a ast::Alias,
pub(crate) is_reexported: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImportFrom,
pub(crate) alias_index: usize,
pub(crate) is_reexported: bool,
}
#[derive(Copy, Clone, Debug)]
@@ -244,15 +259,22 @@ impl<'db> DefinitionNodeRef<'db> {
#[allow(unsafe_code)]
pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind<'db> {
match self {
DefinitionNodeRef::Import(alias) => {
DefinitionKind::Import(AstNodeRef::new(parsed, alias))
}
DefinitionNodeRef::ImportFrom(ImportFromDefinitionNodeRef { node, alias_index }) => {
DefinitionKind::ImportFrom(ImportFromDefinitionKind {
node: AstNodeRef::new(parsed, node),
alias_index,
})
}
DefinitionNodeRef::Import(ImportDefinitionNodeRef {
alias,
is_reexported,
}) => DefinitionKind::Import(ImportDefinitionKind {
alias: AstNodeRef::new(parsed, alias),
is_reexported,
}),
DefinitionNodeRef::ImportFrom(ImportFromDefinitionNodeRef {
node,
alias_index,
is_reexported,
}) => DefinitionKind::ImportFrom(ImportFromDefinitionKind {
node: AstNodeRef::new(parsed, node),
alias_index,
is_reexported,
}),
DefinitionNodeRef::Function(function) => {
DefinitionKind::Function(AstNodeRef::new(parsed, function))
}
@@ -354,10 +376,15 @@ impl<'db> DefinitionNodeRef<'db> {
pub(super) fn key(self) -> DefinitionNodeKey {
match self {
Self::Import(node) => node.into(),
Self::ImportFrom(ImportFromDefinitionNodeRef { node, alias_index }) => {
(&node.names[alias_index]).into()
}
Self::Import(ImportDefinitionNodeRef {
alias,
is_reexported: _,
}) => alias.into(),
Self::ImportFrom(ImportFromDefinitionNodeRef {
node,
alias_index,
is_reexported: _,
}) => (&node.names[alias_index]).into(),
Self::Function(node) => node.into(),
Self::Class(node) => node.into(),
Self::TypeAlias(node) => node.into(),
@@ -441,7 +468,7 @@ impl DefinitionCategory {
/// for an in-depth explanation of why this is necessary.
#[derive(Clone, Debug)]
pub enum DefinitionKind<'db> {
Import(AstNodeRef<ast::Alias>),
Import(ImportDefinitionKind),
ImportFrom(ImportFromDefinitionKind),
Function(AstNodeRef<ast::StmtFunctionDef>),
Class(AstNodeRef<ast::StmtClassDef>),
@@ -464,6 +491,14 @@ pub enum DefinitionKind<'db> {
}
impl DefinitionKind<'_> {
pub(crate) fn is_reexported(&self) -> bool {
match self {
DefinitionKind::Import(import) => import.is_reexported(),
DefinitionKind::ImportFrom(import) => import.is_reexported(),
_ => true,
}
}
/// Returns the [`TextRange`] of the definition target.
///
/// A definition target would mainly be the node representing the symbol being defined i.e.,
@@ -472,7 +507,7 @@ impl DefinitionKind<'_> {
/// This is mainly used for logging and debugging purposes.
pub(crate) fn target_range(&self) -> TextRange {
match self {
DefinitionKind::Import(alias) => alias.range(),
DefinitionKind::Import(import) => import.alias().range(),
DefinitionKind::ImportFrom(import) => import.alias().range(),
DefinitionKind::Function(function) => function.name.range(),
DefinitionKind::Class(class) => class.name.range(),
@@ -603,10 +638,27 @@ impl ComprehensionDefinitionKind {
}
}
#[derive(Clone, Debug)]
pub struct ImportDefinitionKind {
alias: AstNodeRef<ast::Alias>,
is_reexported: bool,
}
impl ImportDefinitionKind {
pub(crate) fn alias(&self) -> &ast::Alias {
self.alias.node()
}
pub(crate) fn is_reexported(&self) -> bool {
self.is_reexported
}
}
#[derive(Clone, Debug)]
pub struct ImportFromDefinitionKind {
node: AstNodeRef<ast::StmtImportFrom>,
alias_index: usize,
is_reexported: bool,
}
impl ImportFromDefinitionKind {
@@ -617,6 +669,10 @@ impl ImportFromDefinitionKind {
pub(crate) fn alias(&self) -> &ast::Alias {
&self.node.node().names[self.alias_index]
}
pub(crate) fn is_reexported(&self) -> bool {
self.is_reexported
}
}
#[derive(Clone, Debug)]

View File

@@ -2,7 +2,7 @@ use crate::module_resolver::{resolve_module, KnownModule};
use crate::semantic_index::global_scope;
use crate::semantic_index::symbol::ScopeId;
use crate::symbol::Symbol;
use crate::types::global_symbol;
use crate::types::{global_symbol, SymbolLookup};
use crate::Db;
/// Lookup the type of `symbol` in a given known module
@@ -14,7 +14,7 @@ pub(crate) fn known_module_symbol<'db>(
symbol: &str,
) -> Symbol<'db> {
resolve_module(db, &known_module.name())
.map(|module| global_symbol(db, module.file(), symbol))
.map(|module| global_symbol(db, SymbolLookup::External, module.file(), symbol))
.unwrap_or(Symbol::Unbound)
}

View File

@@ -106,11 +106,31 @@ fn widen_type_for_undeclared_public_symbol<'db>(
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) enum SymbolLookup {
/// Look up the symbol as seen from within the same module.
Internal,
/// Look up the symbol as seen from outside the module.
External,
}
impl SymbolLookup {
const fn is_external(self) -> bool {
matches!(self, Self::External)
}
}
/// Infer the public type of a symbol (its type as seen from outside its scope).
fn symbol<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Symbol<'db> {
fn symbol<'db>(
db: &'db dyn Db,
lookup: SymbolLookup,
scope: ScopeId<'db>,
name: &str,
) -> Symbol<'db> {
#[salsa::tracked]
fn symbol_by_id<'db>(
db: &'db dyn Db,
lookup: SymbolLookup,
scope: ScopeId<'db>,
symbol_id: ScopedSymbolId,
) -> Symbol<'db> {
@@ -120,7 +140,7 @@ fn symbol<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Symbol<'db>
// on inference from bindings.
let declarations = use_def.public_declarations(symbol_id);
let declared = symbol_from_declarations(db, declarations);
let declared = symbol_from_declarations(db, lookup, declarations);
let is_final = declared.as_ref().is_ok_and(SymbolAndQualifiers::is_final);
let declared = declared.map(|SymbolAndQualifiers(symbol, _)| symbol);
@@ -130,7 +150,7 @@ fn symbol<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Symbol<'db>
// Symbol is possibly declared
Ok(Symbol::Type(declared_ty, Boundness::PossiblyUnbound)) => {
let bindings = use_def.public_bindings(symbol_id);
let inferred = symbol_from_bindings(db, bindings);
let inferred = symbol_from_bindings(db, lookup, bindings);
match inferred {
// Symbol is possibly undeclared and definitely unbound
@@ -150,7 +170,7 @@ fn symbol<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Symbol<'db>
// Symbol is undeclared, return the union of `Unknown` with the inferred type
Ok(Symbol::Unbound) => {
let bindings = use_def.public_bindings(symbol_id);
let inferred = symbol_from_bindings(db, bindings);
let inferred = symbol_from_bindings(db, lookup, bindings);
// `__slots__` is a symbol with special behavior in Python's runtime. It can be
// modified externally, but those changes do not take effect. We therefore issue
@@ -212,7 +232,7 @@ fn symbol<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Symbol<'db>
symbol_table(db, scope)
.symbol_id_by_name(name)
.map(|symbol_id| symbol_by_id(db, scope, symbol_id))
.map(|symbol| symbol_by_id(db, lookup, scope, symbol))
.unwrap_or(Symbol::Unbound)
}
@@ -251,12 +271,16 @@ fn module_type_symbols<'db>(db: &'db dyn Db) -> smallvec::SmallVec<[ast::name::N
.collect()
}
/// Looks up a module-global symbol by name in a file.
pub(crate) fn global_symbol<'db>(db: &'db dyn Db, file: File, name: &str) -> Symbol<'db> {
pub(crate) fn global_symbol<'db>(
db: &'db dyn Db,
lookup: SymbolLookup,
file: File,
name: &str,
) -> Symbol<'db> {
// Not defined explicitly in the global scope?
// All modules are instances of `types.ModuleType`;
// look it up there (with a few very special exceptions)
symbol(db, global_scope(db, file), name).or_fall_back_to(db, || {
symbol(db, lookup, global_scope(db, file), name).or_fall_back_to(db, || {
if module_type_symbols(db)
.iter()
.any(|module_type_member| &**module_type_member == name)
@@ -316,20 +340,25 @@ fn definition_expression_type<'db>(
/// The type will be a union if there are multiple bindings with different types.
fn symbol_from_bindings<'db>(
db: &'db dyn Db,
lookup: SymbolLookup,
bindings_with_constraints: BindingWithConstraintsIterator<'_, 'db>,
) -> Symbol<'db> {
let visibility_constraints = bindings_with_constraints.visibility_constraints;
let mut bindings_with_constraints = bindings_with_constraints.peekable();
let unbound_visibility = if let Some(BindingWithConstraints {
binding: None,
constraints: _,
visibility_constraint,
}) = bindings_with_constraints.peek()
{
visibility_constraints.evaluate(db, *visibility_constraint)
} else {
Truthiness::AlwaysFalse
let is_non_exported = |binding: Definition<'db>| {
lookup.is_external() && !binding.is_reexported(db) && binding.in_stub(db)
};
let unbound_visibility = match bindings_with_constraints.peek() {
Some(BindingWithConstraints {
binding,
visibility_constraint,
constraints: _,
}) if binding.map_or(true, is_non_exported) => {
visibility_constraints.evaluate(db, *visibility_constraint)
}
_ => Truthiness::AlwaysFalse,
};
let mut types = bindings_with_constraints.filter_map(
@@ -339,6 +368,11 @@ fn symbol_from_bindings<'db>(
visibility_constraint,
}| {
let binding = binding?;
if is_non_exported(binding) {
return None;
}
let static_visibility = visibility_constraints.evaluate(db, visibility_constraint);
if static_visibility.is_always_false() {
@@ -437,19 +471,24 @@ type SymbolFromDeclarationsResult<'db> =
/// [`TypeQualifiers`] that have been specified on the declaration(s).
fn symbol_from_declarations<'db>(
db: &'db dyn Db,
lookup: SymbolLookup,
declarations: DeclarationsIterator<'_, 'db>,
) -> SymbolFromDeclarationsResult<'db> {
let visibility_constraints = declarations.visibility_constraints;
let mut declarations = declarations.peekable();
let undeclared_visibility = if let Some(DeclarationWithConstraint {
declaration: None,
visibility_constraint,
}) = declarations.peek()
{
visibility_constraints.evaluate(db, *visibility_constraint)
} else {
Truthiness::AlwaysFalse
let is_non_exported = |declaration: Definition<'db>| {
lookup.is_external() && !declaration.is_reexported(db) && declaration.in_stub(db)
};
let undeclared_visibility = match declarations.peek() {
Some(DeclarationWithConstraint {
declaration,
visibility_constraint,
}) if declaration.map_or(true, is_non_exported) => {
visibility_constraints.evaluate(db, *visibility_constraint)
}
_ => Truthiness::AlwaysFalse,
};
let mut types = declarations.filter_map(
@@ -458,6 +497,11 @@ fn symbol_from_declarations<'db>(
visibility_constraint,
}| {
let declaration = declaration?;
if is_non_exported(declaration) {
return None;
}
let static_visibility = visibility_constraints.evaluate(db, visibility_constraint);
if static_visibility.is_always_false() {
@@ -3810,13 +3854,16 @@ impl<'db> ModuleLiteralType<'db> {
// ignore `__getattr__`. Typeshed has a fake `__getattr__` on `types.ModuleType`
// to help out with dynamic imports; we shouldn't use it for `ModuleLiteral` types
// where we know exactly which module we're dealing with.
symbol(db, global_scope(db, self.module(db).file()), name).or_fall_back_to(db, || {
if name == "__getattr__" {
Symbol::Unbound
} else {
KnownClass::ModuleType.to_instance(db).member(db, name)
}
})
global_symbol(db, SymbolLookup::External, self.module(db).file(), name).or_fall_back_to(
db,
|| {
if name == "__getattr__" {
Symbol::Unbound
} else {
KnownClass::ModuleType.to_instance(db).member(db, name)
}
},
)
}
}
@@ -4151,7 +4198,7 @@ impl<'db> Class<'db> {
/// traverse through the MRO until it finds the member.
pub(crate) fn own_class_member(self, db: &'db dyn Db, name: &str) -> Symbol<'db> {
let scope = self.body_scope(db);
symbol(db, scope, name)
symbol(db, SymbolLookup::Internal, scope, name)
}
/// Returns the `name` attribute of an instance of this class.
@@ -4293,7 +4340,7 @@ impl<'db> Class<'db> {
let declarations = use_def.public_declarations(symbol_id);
match symbol_from_declarations(db, declarations) {
match symbol_from_declarations(db, SymbolLookup::Internal, declarations) {
Ok(SymbolAndQualifiers(Symbol::Type(declared_ty, _), qualifiers)) => {
// The attribute is declared in the class body.
@@ -4315,7 +4362,7 @@ impl<'db> Class<'db> {
// in a method, and it could also be *bound* in the class body (and/or in a method).
let bindings = use_def.public_bindings(symbol_id);
let inferred = symbol_from_bindings(db, bindings);
let inferred = symbol_from_bindings(db, SymbolLookup::Internal, bindings);
let inferred_ty = inferred.ignore_possibly_unbound();
Self::implicit_instance_attribute(db, body_scope, name, inferred_ty).into()
@@ -4933,7 +4980,7 @@ pub(crate) mod tests {
)?;
let bar = system_path_to_file(&db, "src/bar.py")?;
let a = global_symbol(&db, bar, "a");
let a = global_symbol(&db, SymbolLookup::Internal, bar, "a");
assert_eq!(
a.expect_type(),
@@ -4952,7 +4999,7 @@ pub(crate) mod tests {
)?;
db.clear_salsa_events();
let a = global_symbol(&db, bar, "a");
let a = global_symbol(&db, SymbolLookup::Internal, bar, "a");
assert_eq!(
a.expect_type(),

View File

@@ -63,11 +63,11 @@ use crate::types::diagnostic::{
use crate::types::mro::MroErrorKind;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
builtins_symbol, global_symbol, symbol, symbol_from_bindings, symbol_from_declarations,
todo_type, typing_extensions_symbol, Boundness, CallDunderResult, Class, ClassLiteralType,
DynamicType, FunctionType, InstanceType, IntersectionBuilder, IntersectionType,
IterationOutcome, KnownClass, KnownFunction, KnownInstanceType, MetaclassCandidate,
MetaclassErrorKind, SliceLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness,
builtins_symbol, symbol, symbol_from_bindings, symbol_from_declarations, todo_type,
typing_extensions_symbol, Boundness, CallDunderResult, Class, ClassLiteralType, DynamicType,
FunctionType, InstanceType, IntersectionBuilder, IntersectionType, IterationOutcome,
KnownClass, KnownFunction, KnownInstanceType, MetaclassCandidate, MetaclassErrorKind,
SliceLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers, SymbolLookup, Truthiness,
TupleType, Type, TypeAliasType, TypeAndQualifiers, TypeArrayDisplay, TypeQualifiers,
TypeVarBoundOrConstraints, TypeVarInstance, UnionBuilder, UnionType,
};
@@ -86,7 +86,7 @@ use super::slots::check_class_slots;
use super::string_annotation::{
parse_string_annotation, BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION,
};
use super::{ParameterExpectation, ParameterExpectations};
use super::{global_symbol, ParameterExpectation, ParameterExpectations};
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
@@ -735,7 +735,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_type_alias_definition(type_alias.node(), definition);
}
DefinitionKind::Import(import) => {
self.infer_import_definition(import.node(), definition);
self.infer_import_definition(import.alias(), definition);
}
DefinitionKind::ImportFrom(import_from) => {
self.infer_import_from_definition(
@@ -871,7 +871,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let use_def = self.index.use_def_map(binding.file_scope(self.db()));
let declarations = use_def.declarations_at_binding(binding);
let mut bound_ty = ty;
let declared_ty = symbol_from_declarations(self.db(), declarations)
let declared_ty = symbol_from_declarations(self.db(), SymbolLookup::Internal, declarations)
.map(|SymbolAndQualifiers(s, _)| s.ignore_possibly_unbound().unwrap_or(Type::unknown()))
.unwrap_or_else(|(ty, conflicting)| {
// TODO point out the conflicting declarations in the diagnostic?
@@ -906,7 +906,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let use_def = self.index.use_def_map(declaration.file_scope(self.db()));
let prior_bindings = use_def.bindings_at_declaration(declaration);
// unbound_ty is Never because for this check we don't care about unbound
let inferred_ty = symbol_from_bindings(self.db(), prior_bindings)
let inferred_ty = symbol_from_bindings(self.db(), SymbolLookup::Internal, prior_bindings)
.ignore_possibly_unbound()
.unwrap_or(Type::Never);
let ty = if inferred_ty.is_assignable_to(self.db(), ty.inner_type()) {
@@ -3307,7 +3307,11 @@ impl<'db> TypeInferenceBuilder<'db> {
// If we're inferring types of deferred expressions, always treat them as public symbols
let local_scope_symbol = if self.is_deferred() {
if let Some(symbol_id) = symbol_table.symbol_id_by_name(symbol_name) {
symbol_from_bindings(db, use_def.public_bindings(symbol_id))
symbol_from_bindings(
db,
SymbolLookup::Internal,
use_def.public_bindings(symbol_id),
)
} else {
assert!(
self.deferred_state.in_string_annotation(),
@@ -3317,7 +3321,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
} else {
let use_id = name_node.scoped_use_id(db, scope);
symbol_from_bindings(db, use_def.bindings_at_use(use_id))
symbol_from_bindings(db, SymbolLookup::Internal, use_def.bindings_at_use(use_id))
};
let symbol = local_scope_symbol.or_fall_back_to(db, || {
@@ -3368,7 +3372,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// runtime, it is the scope that creates the cell for our closure.) If the name
// isn't bound in that scope, we should get an unbound name, not continue
// falling back to other scopes / globals / builtins.
return symbol(db, enclosing_scope_id, symbol_name);
return symbol(db, SymbolLookup::Internal, enclosing_scope_id, symbol_name);
}
}
@@ -3379,7 +3383,7 @@ impl<'db> TypeInferenceBuilder<'db> {
if file_scope_id.is_global() {
Symbol::Unbound
} else {
global_symbol(db, self.file(), symbol_name)
global_symbol(db, SymbolLookup::Internal, self.file(), symbol_name)
}
})
// Not found in globals? Fallback to builtins
@@ -6051,7 +6055,7 @@ mod tests {
assert_eq!(scope.name(db), *expected_scope_name);
}
symbol(db, scope, symbol_name)
symbol(db, SymbolLookup::Internal, scope, symbol_name)
}
#[track_caller]
@@ -6076,7 +6080,7 @@ mod tests {
let mut db = setup_db();
let content = format!(
r#"
from typing_extensions import assert_type
from typing_extensions import Literal, assert_type
assert_type(not "{y}", bool)
assert_type(not 10*"{y}", bool)
@@ -6098,7 +6102,7 @@ mod tests {
let mut db = setup_db();
let content = format!(
r#"
from typing_extensions import assert_type
from typing_extensions import Literal, LiteralString, assert_type
assert_type(2 * "hello", Literal["hellohello"])
assert_type("goodbye" * 3, Literal["goodbyegoodbyegoodbye"])
@@ -6123,7 +6127,7 @@ mod tests {
let mut db = setup_db();
let content = format!(
r#"
from typing_extensions import assert_type
from typing_extensions import Literal, LiteralString, assert_type
assert_type("{y}", LiteralString)
assert_type(10*"{y}", LiteralString)
@@ -6145,7 +6149,7 @@ mod tests {
let mut db = setup_db();
let content = format!(
r#"
from typing_extensions import assert_type
from typing_extensions import LiteralString, assert_type
assert_type("{y}", LiteralString)
assert_type("a" + "{z}", LiteralString)
@@ -6165,7 +6169,7 @@ mod tests {
let mut db = setup_db();
let content = format!(
r#"
from typing_extensions import assert_type
from typing_extensions import LiteralString, assert_type
assert_type("{y}", LiteralString)
assert_type("{y}" + "a", LiteralString)
@@ -6267,7 +6271,7 @@ mod tests {
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = global_symbol(&db, a, "x").expect_type();
let x_ty = global_symbol(&db, SymbolLookup::Internal, a, "x").expect_type();
assert_eq!(x_ty.display(&db).to_string(), "int");
@@ -6276,7 +6280,7 @@ mod tests {
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty_2 = global_symbol(&db, a, "x").expect_type();
let x_ty_2 = global_symbol(&db, SymbolLookup::Internal, a, "x").expect_type();
assert_eq!(x_ty_2.display(&db).to_string(), "bool");
@@ -6293,7 +6297,7 @@ mod tests {
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = global_symbol(&db, a, "x").expect_type();
let x_ty = global_symbol(&db, SymbolLookup::Internal, a, "x").expect_type();
assert_eq!(x_ty.display(&db).to_string(), "int");
@@ -6303,7 +6307,7 @@ mod tests {
db.clear_salsa_events();
let x_ty_2 = global_symbol(&db, a, "x").expect_type();
let x_ty_2 = global_symbol(&db, SymbolLookup::Internal, a, "x").expect_type();
assert_eq!(x_ty_2.display(&db).to_string(), "int");
@@ -6329,7 +6333,7 @@ mod tests {
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = global_symbol(&db, a, "x").expect_type();
let x_ty = global_symbol(&db, SymbolLookup::Internal, a, "x").expect_type();
assert_eq!(x_ty.display(&db).to_string(), "int");
@@ -6339,7 +6343,7 @@ mod tests {
db.clear_salsa_events();
let x_ty_2 = global_symbol(&db, a, "x").expect_type();
let x_ty_2 = global_symbol(&db, SymbolLookup::Internal, a, "x").expect_type();
assert_eq!(x_ty_2.display(&db).to_string(), "int");
@@ -6386,7 +6390,7 @@ mod tests {
)?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
let attr_ty = global_symbol(&db, SymbolLookup::Internal, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | int | None");
// Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred
@@ -6401,7 +6405,7 @@ mod tests {
let events = {
db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
let attr_ty = global_symbol(&db, SymbolLookup::Internal, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events()
};
@@ -6420,7 +6424,7 @@ mod tests {
let events = {
db.clear_salsa_events();
let attr_ty = global_symbol(&db, file_main, "x").expect_type();
let attr_ty = global_symbol(&db, SymbolLookup::Internal, file_main, "x").expect_type();
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events()
};

View File

@@ -322,13 +322,13 @@ pub(crate) enum ParameterKind<'db> {
mod tests {
use super::*;
use crate::db::tests::{setup_db, TestDb};
use crate::types::{global_symbol, FunctionType, KnownClass};
use crate::types::{global_symbol, FunctionType, KnownClass, SymbolLookup};
use ruff_db::system::DbWithTestSystem;
#[track_caller]
fn get_function_f<'db>(db: &'db TestDb, file: &'static str) -> FunctionType<'db> {
let module = ruff_db::files::system_path_to_file(db, file).unwrap();
global_symbol(db, module, "f")
global_symbol(db, SymbolLookup::Internal, module, "f")
.expect_type()
.expect_function_literal()
}
@@ -357,6 +357,8 @@ mod tests {
db.write_dedented(
"/src/a.py",
"
from typing import Literal
def f(a, b: int, c = 1, d: int = 2, /,
e = 3, f: Literal[4] = 4, *args: object,
g = 5, h: Literal[6] = 6, **kwargs: str) -> bytes: ...