From 8531f4b3ca23b1bc877ddd3a4204832673851468 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Thu, 5 Jun 2025 11:43:18 -0400 Subject: [PATCH] [ty] Add infrastructure for AST garbage collection (#18445) ## Summary https://github.com/astral-sh/ty/issues/214 will require a couple invasive changes that I would like to get merged even before garbage collection is fully implemented (to avoid rebasing): - `ParsedModule` can no longer be dereferenced directly. Instead you need to load a `ParsedModuleRef` to access the AST, which requires a reference to the salsa database (as it may require re-parsing the AST if it was collected). - `AstNodeRef` can only be dereferenced with the `node` method, which takes a reference to the `ParsedModuleRef`. This allows us to encode the fact that ASTs do not live as long as the database and may be collected as soon a given instance of a `ParsedModuleRef` is dropped. There are a number of places where we currently merge the `'db` and `'ast` lifetimes, so this requires giving some types/functions two separate lifetime parameters. --- crates/ruff_db/src/parsed.rs | 58 ++- crates/ruff_python_formatter/src/lib.rs | 4 +- crates/ty/src/lib.rs | 3 +- crates/ty_ide/src/completion.rs | 10 +- crates/ty_ide/src/goto.rs | 15 +- crates/ty_ide/src/hover.rs | 6 +- crates/ty_ide/src/inlay_hints.rs | 2 +- crates/ty_project/src/lib.rs | 6 +- crates/ty_project/tests/check.rs | 2 +- crates/ty_python_semantic/src/ast_node_ref.rs | 143 +------ crates/ty_python_semantic/src/dunder_all.rs | 2 +- .../ty_python_semantic/src/semantic_index.rs | 100 +++-- .../src/semantic_index/builder.rs | 125 +++--- .../src/semantic_index/definition.rs | 393 ++++++++++-------- .../src/semantic_index/expression.rs | 13 +- .../src/semantic_index/place.rs | 45 +- .../src/semantic_index/re_exports.rs | 2 +- .../ty_python_semantic/src/semantic_model.rs | 6 +- crates/ty_python_semantic/src/suppression.rs | 2 +- crates/ty_python_semantic/src/types.rs | 22 +- .../ty_python_semantic/src/types/call/bind.rs | 18 +- crates/ty_python_semantic/src/types/class.rs | 98 +++-- .../ty_python_semantic/src/types/context.rs | 30 +- .../src/types/definition.rs | 11 +- .../src/types/diagnostic.rs | 14 +- .../ty_python_semantic/src/types/function.rs | 42 +- crates/ty_python_semantic/src/types/infer.rs | 283 ++++++++----- crates/ty_python_semantic/src/types/narrow.rs | 43 +- .../ty_python_semantic/src/types/unpacker.rs | 41 +- crates/ty_python_semantic/src/unpack.rs | 28 +- crates/ty_test/src/assertion.rs | 2 +- crates/ty_test/src/lib.rs | 2 +- crates/ty_wasm/src/lib.rs | 4 +- 33 files changed, 886 insertions(+), 689 deletions(-) diff --git a/crates/ruff_db/src/parsed.rs b/crates/ruff_db/src/parsed.rs index d1f6f6a125..a479713924 100644 --- a/crates/ruff_db/src/parsed.rs +++ b/crates/ruff_db/src/parsed.rs @@ -1,5 +1,4 @@ use std::fmt::Formatter; -use std::ops::Deref; use std::sync::Arc; use ruff_python_ast::ModModule; @@ -18,7 +17,7 @@ use crate::source::source_text; /// The query is only cached when the [`source_text()`] hasn't changed. This is because /// comparing two ASTs is a non-trivial operation and every offset change is directly /// reflected in the changed AST offsets. -/// The other reason is that Ruff's AST doesn't implement `Eq` which Sala requires +/// The other reason is that Ruff's AST doesn't implement `Eq` which Salsa requires /// for determining if a query result is unchanged. #[salsa::tracked(returns(ref), no_eq)] pub fn parsed_module(db: &dyn Db, file: File) -> ParsedModule { @@ -36,7 +35,10 @@ pub fn parsed_module(db: &dyn Db, file: File) -> ParsedModule { ParsedModule::new(parsed) } -/// Cheap cloneable wrapper around the parsed module. +/// A wrapper around a parsed module. +/// +/// This type manages instances of the module AST. A particular instance of the AST +/// is represented with the [`ParsedModuleRef`] type. #[derive(Clone)] pub struct ParsedModule { inner: Arc>, @@ -49,17 +51,11 @@ impl ParsedModule { } } - /// Consumes `self` and returns the Arc storing the parsed module. - pub fn into_arc(self) -> Arc> { - self.inner - } -} - -impl Deref for ParsedModule { - type Target = Parsed; - - fn deref(&self) -> &Self::Target { - &self.inner + /// Loads a reference to the parsed module. + pub fn load(&self, _db: &dyn Db) -> ParsedModuleRef { + ParsedModuleRef { + module_ref: self.inner.clone(), + } } } @@ -77,6 +73,30 @@ impl PartialEq for ParsedModule { impl Eq for ParsedModule {} +/// Cheap cloneable wrapper around an instance of a module AST. +#[derive(Clone)] +pub struct ParsedModuleRef { + module_ref: Arc>, +} + +impl ParsedModuleRef { + pub fn as_arc(&self) -> &Arc> { + &self.module_ref + } + + pub fn into_arc(self) -> Arc> { + self.module_ref + } +} + +impl std::ops::Deref for ParsedModuleRef { + type Target = Parsed; + + fn deref(&self) -> &Self::Target { + &self.module_ref + } +} + #[cfg(test)] mod tests { use crate::Db; @@ -98,7 +118,7 @@ mod tests { let file = system_path_to_file(&db, path).unwrap(); - let parsed = parsed_module(&db, file); + let parsed = parsed_module(&db, file).load(&db); assert!(parsed.has_valid_syntax()); @@ -114,7 +134,7 @@ mod tests { let file = system_path_to_file(&db, path).unwrap(); - let parsed = parsed_module(&db, file); + let parsed = parsed_module(&db, file).load(&db); assert!(parsed.has_valid_syntax()); @@ -130,7 +150,7 @@ mod tests { let virtual_file = db.files().virtual_file(&db, path); - let parsed = parsed_module(&db, virtual_file.file()); + let parsed = parsed_module(&db, virtual_file.file()).load(&db); assert!(parsed.has_valid_syntax()); @@ -146,7 +166,7 @@ mod tests { let virtual_file = db.files().virtual_file(&db, path); - let parsed = parsed_module(&db, virtual_file.file()); + let parsed = parsed_module(&db, virtual_file.file()).load(&db); assert!(parsed.has_valid_syntax()); @@ -177,7 +197,7 @@ else: let file = vendored_path_to_file(&db, VendoredPath::new("path.pyi")).unwrap(); - let parsed = parsed_module(&db, file); + let parsed = parsed_module(&db, file).load(&db); assert!(parsed.has_valid_syntax()); } diff --git a/crates/ruff_python_formatter/src/lib.rs b/crates/ruff_python_formatter/src/lib.rs index c9c1c35bf8..18e1ef4c62 100644 --- a/crates/ruff_python_formatter/src/lib.rs +++ b/crates/ruff_python_formatter/src/lib.rs @@ -165,7 +165,7 @@ where pub fn formatted_file(db: &dyn Db, file: File) -> Result, FormatModuleError> { let options = db.format_options(file); - let parsed = parsed_module(db.upcast(), file); + let parsed = parsed_module(db.upcast(), file).load(db.upcast()); if let Some(first) = parsed.errors().first() { return Err(FormatModuleError::ParseError(first.clone())); @@ -174,7 +174,7 @@ pub fn formatted_file(db: &dyn Db, file: File) -> Result, FormatM let comment_ranges = CommentRanges::from(parsed.tokens()); let source = source_text(db.upcast(), file); - let formatted = format_node(parsed, &comment_ranges, &source, options)?; + let formatted = format_node(&parsed, &comment_ranges, &source, options)?; let printed = formatted.print()?; if printed.as_code() == &*source { diff --git a/crates/ty/src/lib.rs b/crates/ty/src/lib.rs index 91efddde60..aa00d2e374 100644 --- a/crates/ty/src/lib.rs +++ b/crates/ty/src/lib.rs @@ -18,10 +18,9 @@ use clap::{CommandFactory, Parser}; use colored::Colorize; use crossbeam::channel as crossbeam_channel; use rayon::ThreadPoolBuilder; -use ruff_db::Upcast; use ruff_db::diagnostic::{Diagnostic, DisplayDiagnosticConfig, Severity}; -use ruff_db::max_parallelism; use ruff_db::system::{OsSystem, SystemPath, SystemPathBuf}; +use ruff_db::{Upcast, max_parallelism}; use salsa::plumbing::ZalsaDatabase; use ty_project::metadata::options::ProjectOptionsOverrides; use ty_project::watch::ProjectWatcher; diff --git a/crates/ty_ide/src/completion.rs b/crates/ty_ide/src/completion.rs index 6584b7339b..10bc2971d1 100644 --- a/crates/ty_ide/src/completion.rs +++ b/crates/ty_ide/src/completion.rs @@ -1,7 +1,7 @@ use std::cmp::Ordering; use ruff_db::files::File; -use ruff_db::parsed::{ParsedModule, parsed_module}; +use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_ast as ast; use ruff_python_parser::{Token, TokenAt, TokenKind}; use ruff_text_size::{Ranged, TextRange, TextSize}; @@ -15,9 +15,9 @@ pub struct Completion { } pub fn completion(db: &dyn Db, file: File, offset: TextSize) -> Vec { - let parsed = parsed_module(db.upcast(), file); + let parsed = parsed_module(db.upcast(), file).load(db.upcast()); - let Some(target) = CompletionTargetTokens::find(parsed, offset).ast(parsed) else { + let Some(target) = CompletionTargetTokens::find(&parsed, offset).ast(&parsed) else { return vec![]; }; @@ -63,7 +63,7 @@ enum CompletionTargetTokens<'t> { impl<'t> CompletionTargetTokens<'t> { /// Look for the best matching token pattern at the given offset. - fn find(parsed: &ParsedModule, offset: TextSize) -> CompletionTargetTokens<'_> { + fn find(parsed: &ParsedModuleRef, offset: TextSize) -> CompletionTargetTokens<'_> { static OBJECT_DOT_EMPTY: [TokenKind; 2] = [TokenKind::Name, TokenKind::Dot]; static OBJECT_DOT_NON_EMPTY: [TokenKind; 3] = [TokenKind::Name, TokenKind::Dot, TokenKind::Name]; @@ -97,7 +97,7 @@ impl<'t> CompletionTargetTokens<'t> { /// Returns a corresponding AST node for these tokens. /// /// If no plausible AST node could be found, then `None` is returned. - fn ast(&self, parsed: &'t ParsedModule) -> Option> { + fn ast(&self, parsed: &'t ParsedModuleRef) -> Option> { match *self { CompletionTargetTokens::ObjectDot { object, .. } => { let covering_node = covering_node(parsed.syntax().into(), object.range()) diff --git a/crates/ty_ide/src/goto.rs b/crates/ty_ide/src/goto.rs index b7fb38c314..89861b73ff 100644 --- a/crates/ty_ide/src/goto.rs +++ b/crates/ty_ide/src/goto.rs @@ -1,7 +1,7 @@ use crate::find_node::covering_node; use crate::{Db, HasNavigationTargets, NavigationTargets, RangedValue}; use ruff_db::files::{File, FileRange}; -use ruff_db::parsed::{ParsedModule, parsed_module}; +use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_ast::{self as ast, AnyNodeRef}; use ruff_python_parser::TokenKind; use ruff_text_size::{Ranged, TextRange, TextSize}; @@ -13,8 +13,8 @@ pub fn goto_type_definition( file: File, offset: TextSize, ) -> Option> { - let parsed = parsed_module(db.upcast(), file); - let goto_target = find_goto_target(parsed, offset)?; + let module = parsed_module(db.upcast(), file).load(db.upcast()); + let goto_target = find_goto_target(&module, offset)?; let model = SemanticModel::new(db.upcast(), file); let ty = goto_target.inferred_type(&model)?; @@ -128,8 +128,8 @@ pub(crate) enum GotoTarget<'a> { }, } -impl<'db> GotoTarget<'db> { - pub(crate) fn inferred_type(self, model: &SemanticModel<'db>) -> Option> { +impl GotoTarget<'_> { + pub(crate) fn inferred_type<'db>(self, model: &SemanticModel<'db>) -> Option> { let ty = match self { GotoTarget::Expression(expression) => expression.inferred_type(model), GotoTarget::FunctionDef(function) => function.inferred_type(model), @@ -183,7 +183,10 @@ impl Ranged for GotoTarget<'_> { } } -pub(crate) fn find_goto_target(parsed: &ParsedModule, offset: TextSize) -> Option { +pub(crate) fn find_goto_target( + parsed: &ParsedModuleRef, + offset: TextSize, +) -> Option> { let token = parsed .tokens() .at_offset(offset) diff --git a/crates/ty_ide/src/hover.rs b/crates/ty_ide/src/hover.rs index 6509b4bf76..34ce394b9f 100644 --- a/crates/ty_ide/src/hover.rs +++ b/crates/ty_ide/src/hover.rs @@ -8,9 +8,9 @@ use std::fmt::Formatter; use ty_python_semantic::SemanticModel; use ty_python_semantic::types::Type; -pub fn hover(db: &dyn Db, file: File, offset: TextSize) -> Option> { - let parsed = parsed_module(db.upcast(), file); - let goto_target = find_goto_target(parsed, offset)?; +pub fn hover(db: &dyn Db, file: File, offset: TextSize) -> Option>> { + let parsed = parsed_module(db.upcast(), file).load(db.upcast()); + let goto_target = find_goto_target(&parsed, offset)?; if let GotoTarget::Expression(expr) = goto_target { if expr.is_literal_expr() { diff --git a/crates/ty_ide/src/inlay_hints.rs b/crates/ty_ide/src/inlay_hints.rs index 0ef71ddc05..9c1c835024 100644 --- a/crates/ty_ide/src/inlay_hints.rs +++ b/crates/ty_ide/src/inlay_hints.rs @@ -54,7 +54,7 @@ impl fmt::Display for DisplayInlayHint<'_, '_> { pub fn inlay_hints(db: &dyn Db, file: File, range: TextRange) -> Vec> { let mut visitor = InlayHintVisitor::new(db, file, range); - let ast = parsed_module(db.upcast(), file); + let ast = parsed_module(db.upcast(), file).load(db.upcast()); visitor.visit_body(ast.suite()); diff --git a/crates/ty_project/src/lib.rs b/crates/ty_project/src/lib.rs index ecd9501a47..b57608d5f5 100644 --- a/crates/ty_project/src/lib.rs +++ b/crates/ty_project/src/lib.rs @@ -453,14 +453,16 @@ fn check_file_impl(db: &dyn Db, file: File) -> Vec { } let parsed = parsed_module(db.upcast(), file); + + let parsed_ref = parsed.load(db.upcast()); diagnostics.extend( - parsed + parsed_ref .errors() .iter() .map(|error| create_parse_diagnostic(file, error)), ); - diagnostics.extend(parsed.unsupported_syntax_errors().iter().map(|error| { + diagnostics.extend(parsed_ref.unsupported_syntax_errors().iter().map(|error| { let mut error = create_unsupported_syntax_diagnostic(file, error); add_inferred_python_version_hint_to_diagnostic(db.upcast(), &mut error, "parsing syntax"); error diff --git a/crates/ty_project/tests/check.rs b/crates/ty_project/tests/check.rs index 36f8b16f1b..9d4d461ff5 100644 --- a/crates/ty_project/tests/check.rs +++ b/crates/ty_project/tests/check.rs @@ -172,7 +172,7 @@ fn run_corpus_tests(pattern: &str) -> anyhow::Result<()> { fn pull_types(db: &ProjectDatabase, file: File) { let mut visitor = PullTypesVisitor::new(db, file); - let ast = parsed_module(db, file); + let ast = parsed_module(db, file).load(db); visitor.visit_body(ast.suite()); } diff --git a/crates/ty_python_semantic/src/ast_node_ref.rs b/crates/ty_python_semantic/src/ast_node_ref.rs index 7d460fed19..fd3e1726dc 100644 --- a/crates/ty_python_semantic/src/ast_node_ref.rs +++ b/crates/ty_python_semantic/src/ast_node_ref.rs @@ -1,15 +1,14 @@ -use std::hash::Hash; -use std::ops::Deref; +use std::sync::Arc; -use ruff_db::parsed::ParsedModule; +use ruff_db::parsed::ParsedModuleRef; /// Ref-counted owned reference to an AST node. /// -/// The type holds an owned reference to the node's ref-counted [`ParsedModule`]. -/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the +/// The type holds an owned reference to the node's ref-counted [`ParsedModuleRef`]. +/// Holding on to the node's [`ParsedModuleRef`] guarantees that the reference to the /// node must still be valid. /// -/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released. +/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModuleRef`] from being released. /// /// ## Equality /// Two `AstNodeRef` are considered equal if their pointer addresses are equal. @@ -33,11 +32,11 @@ use ruff_db::parsed::ParsedModule; /// run on every AST change. All other queries only run when the expression's identity changes. #[derive(Clone)] pub struct AstNodeRef { - /// Owned reference to the node's [`ParsedModule`]. + /// Owned reference to the node's [`ParsedModuleRef`]. /// /// The node's reference is guaranteed to remain valid as long as it's enclosing - /// [`ParsedModule`] is alive. - parsed: ParsedModule, + /// [`ParsedModuleRef`] is alive. + parsed: ParsedModuleRef, /// Pointer to the referenced node. node: std::ptr::NonNull, @@ -45,15 +44,15 @@ pub struct AstNodeRef { #[expect(unsafe_code)] impl AstNodeRef { - /// Creates a new `AstNodeRef` that references `node`. The `parsed` is the [`ParsedModule`] to + /// Creates a new `AstNodeRef` that references `node`. The `parsed` is the [`ParsedModuleRef`] to /// which the `AstNodeRef` belongs. /// /// ## Safety /// /// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the - /// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that + /// [`ParsedModuleRef`] to which `node` belongs. It's the caller's responsibility to ensure that /// the invariant `node belongs to parsed` is upheld. - pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self { + pub(super) unsafe fn new(parsed: ParsedModuleRef, node: &T) -> Self { Self { parsed, node: std::ptr::NonNull::from(node), @@ -61,54 +60,26 @@ impl AstNodeRef { } /// Returns a reference to the wrapped node. - pub const fn node(&self) -> &T { + /// + /// Note that this method will panic if the provided module is from a different file or Salsa revision + /// than the module this node was created with. + pub fn node<'ast>(&self, parsed: &'ast ParsedModuleRef) -> &'ast T { + debug_assert!(Arc::ptr_eq(self.parsed.as_arc(), parsed.as_arc())); + // SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still // alive and not moved. unsafe { self.node.as_ref() } } } -impl Deref for AstNodeRef { - type Target = T; - - fn deref(&self) -> &Self::Target { - self.node() - } -} - impl std::fmt::Debug for AstNodeRef where T: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("AstNodeRef").field(&self.node()).finish() - } -} - -impl PartialEq for AstNodeRef -where - T: PartialEq, -{ - fn eq(&self, other: &Self) -> bool { - if self.parsed == other.parsed { - // Comparing the pointer addresses is sufficient to determine equality - // if the parsed are the same. - self.node.eq(&other.node) - } else { - // Otherwise perform a deep comparison. - self.node().eq(other.node()) - } - } -} - -impl Eq for AstNodeRef where T: Eq {} - -impl Hash for AstNodeRef -where - T: Hash, -{ - fn hash(&self, state: &mut H) { - self.node().hash(state); + f.debug_tuple("AstNodeRef") + .field(self.node(&self.parsed)) + .finish() } } @@ -117,7 +88,9 @@ unsafe impl salsa::Update for AstNodeRef { unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { let old_ref = unsafe { &mut (*old_pointer) }; - if old_ref.parsed == new_value.parsed && old_ref.node.eq(&new_value.node) { + if Arc::ptr_eq(old_ref.parsed.as_arc(), new_value.parsed.as_arc()) + && old_ref.node.eq(&new_value.node) + { false } else { *old_ref = new_value; @@ -130,73 +103,3 @@ unsafe impl salsa::Update for AstNodeRef { unsafe impl Send for AstNodeRef where T: Send {} #[expect(unsafe_code)] unsafe impl Sync for AstNodeRef where T: Sync {} - -#[cfg(test)] -mod tests { - use crate::ast_node_ref::AstNodeRef; - use ruff_db::parsed::ParsedModule; - use ruff_python_ast::PySourceType; - use ruff_python_parser::parse_unchecked_source; - - #[test] - #[expect(unsafe_code)] - fn equality() { - let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python); - let parsed = ParsedModule::new(parsed_raw.clone()); - - let stmt = &parsed.syntax().body[0]; - - let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) }; - let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) }; - - assert_eq!(node1, node2); - - // Compare from different trees - let cloned = ParsedModule::new(parsed_raw); - let stmt_cloned = &cloned.syntax().body[0]; - let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) }; - - assert_eq!(node1, cloned_node); - - let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python); - let other = ParsedModule::new(other_raw); - - let other_stmt = &other.syntax().body[0]; - let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) }; - - assert_ne!(node1, other_node); - } - - #[expect(unsafe_code)] - #[test] - fn inequality() { - let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python); - let parsed = ParsedModule::new(parsed_raw); - - let stmt = &parsed.syntax().body[0]; - let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) }; - - let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python); - let other = ParsedModule::new(other_raw); - - let other_stmt = &other.syntax().body[0]; - let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) }; - - assert_ne!(node, other_node); - } - - #[test] - #[expect(unsafe_code)] - fn debug() { - let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python); - let parsed = ParsedModule::new(parsed_raw); - - let stmt = &parsed.syntax().body[0]; - - let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) }; - - let debug = format!("{stmt_node:?}"); - - assert_eq!(debug, format!("AstNodeRef({stmt:?})")); - } -} diff --git a/crates/ty_python_semantic/src/dunder_all.rs b/crates/ty_python_semantic/src/dunder_all.rs index 6976e35e22..78ceb1250d 100644 --- a/crates/ty_python_semantic/src/dunder_all.rs +++ b/crates/ty_python_semantic/src/dunder_all.rs @@ -32,7 +32,7 @@ fn dunder_all_names_cycle_initial(_db: &dyn Db, _file: File) -> Option Option> { let _span = tracing::trace_span!("dunder_all_names", file=?file.path(db)).entered(); - let module = parsed_module(db.upcast(), file); + let module = parsed_module(db.upcast(), file).load(db.upcast()); let index = semantic_index(db, file); let mut collector = DunderAllNamesCollector::new(db, file, index); collector.visit_body(module.suite()); diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index c117b7f737..0b71130d1e 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -50,9 +50,9 @@ type PlaceSet = hashbrown::HashMap; pub(crate) fn semantic_index(db: &dyn Db, file: File) -> SemanticIndex<'_> { let _span = tracing::trace_span!("semantic_index", ?file).entered(); - let parsed = parsed_module(db.upcast(), file); + let module = parsed_module(db.upcast(), file).load(db.upcast()); - SemanticIndexBuilder::new(db, file, parsed).build() + SemanticIndexBuilder::new(db, file, &module).build() } /// Returns the place table for a specific `scope`. @@ -129,10 +129,11 @@ pub(crate) fn attribute_scopes<'db, 's>( class_body_scope: ScopeId<'db>, ) -> impl Iterator + use<'s, 'db> { let file = class_body_scope.file(db); + let module = parsed_module(db.upcast(), file).load(db.upcast()); let index = semantic_index(db, file); let class_scope_id = class_body_scope.file_scope_id(db); - ChildrenIter::new(index, class_scope_id).filter_map(|(child_scope_id, scope)| { + ChildrenIter::new(index, class_scope_id).filter_map(move |(child_scope_id, scope)| { let (function_scope_id, function_scope) = if scope.node().scope_kind() == ScopeKind::Annotation { // This could be a generic method with a type-params scope. @@ -144,7 +145,7 @@ pub(crate) fn attribute_scopes<'db, 's>( (child_scope_id, scope) }; - function_scope.node().as_function()?; + function_scope.node().as_function(&module)?; Some(function_scope_id) }) } @@ -559,7 +560,7 @@ impl FusedIterator for ChildrenIter<'_> {} #[cfg(test)] mod tests { use ruff_db::files::{File, system_path_to_file}; - use ruff_db::parsed::parsed_module; + use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_ast::{self as ast}; use ruff_text_size::{Ranged, TextRange}; @@ -742,6 +743,7 @@ y = 2 assert_eq!(names(global_table), vec!["C", "y"]); + let module = parsed_module(&db, file).load(&db); let index = semantic_index(&db, file); let [(class_scope_id, class_scope)] = index @@ -751,7 +753,10 @@ y = 2 panic!("expected one child scope") }; assert_eq!(class_scope.kind(), ScopeKind::Class); - assert_eq!(class_scope_id.to_scope_id(&db, file).name(&db), "C"); + assert_eq!( + class_scope_id.to_scope_id(&db, file).name(&db, &module), + "C" + ); let class_table = index.place_table(class_scope_id); assert_eq!(names(&class_table), vec!["x"]); @@ -772,6 +777,7 @@ def func(): y = 2 ", ); + let module = parsed_module(&db, file).load(&db); let index = semantic_index(&db, file); let global_table = index.place_table(FileScopeId::global()); @@ -784,7 +790,10 @@ y = 2 panic!("expected one child scope") }; assert_eq!(function_scope.kind(), ScopeKind::Function); - assert_eq!(function_scope_id.to_scope_id(&db, file).name(&db), "func"); + assert_eq!( + function_scope_id.to_scope_id(&db, file).name(&db, &module), + "func" + ); let function_table = index.place_table(function_scope_id); assert_eq!(names(&function_table), vec!["x"]); @@ -921,6 +930,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): ", ); + let module = parsed_module(&db, file).load(&db); let index = semantic_index(&db, file); let global_table = index.place_table(FileScopeId::global()); @@ -935,7 +945,9 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): assert_eq!(comprehension_scope.kind(), ScopeKind::Comprehension); assert_eq!( - comprehension_scope_id.to_scope_id(&db, file).name(&db), + comprehension_scope_id + .to_scope_id(&db, file) + .name(&db, &module), "" ); @@ -979,8 +991,9 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): let use_def = index.use_def_map(comprehension_scope_id); - let module = parsed_module(&db, file).syntax(); - let element = module.body[0] + let module = parsed_module(&db, file).load(&db); + let syntax = module.syntax(); + let element = syntax.body[0] .as_expr_stmt() .unwrap() .value @@ -996,7 +1009,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): let DefinitionKind::Comprehension(comprehension) = binding.kind(&db) else { panic!("expected generator definition") }; - let target = comprehension.target(); + let target = comprehension.target(&module); let name = target.as_name_expr().unwrap().id().as_str(); assert_eq!(name, "x"); @@ -1014,6 +1027,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): ", ); + let module = parsed_module(&db, file).load(&db); let index = semantic_index(&db, file); let global_table = index.place_table(FileScopeId::global()); @@ -1028,7 +1042,9 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): assert_eq!(comprehension_scope.kind(), ScopeKind::Comprehension); assert_eq!( - comprehension_scope_id.to_scope_id(&db, file).name(&db), + comprehension_scope_id + .to_scope_id(&db, file) + .name(&db, &module), "" ); @@ -1047,7 +1063,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): assert_eq!( inner_comprehension_scope_id .to_scope_id(&db, file) - .name(&db), + .name(&db, &module), "" ); @@ -1112,6 +1128,7 @@ def func(): y = 2 ", ); + let module = parsed_module(&db, file).load(&db); let index = semantic_index(&db, file); let global_table = index.place_table(FileScopeId::global()); @@ -1128,9 +1145,15 @@ def func(): assert_eq!(func_scope_1.kind(), ScopeKind::Function); - assert_eq!(func_scope1_id.to_scope_id(&db, file).name(&db), "func"); + assert_eq!( + func_scope1_id.to_scope_id(&db, file).name(&db, &module), + "func" + ); assert_eq!(func_scope_2.kind(), ScopeKind::Function); - assert_eq!(func_scope2_id.to_scope_id(&db, file).name(&db), "func"); + assert_eq!( + func_scope2_id.to_scope_id(&db, file).name(&db, &module), + "func" + ); let func1_table = index.place_table(func_scope1_id); let func2_table = index.place_table(func_scope2_id); @@ -1157,6 +1180,7 @@ def func[T](): ", ); + let module = parsed_module(&db, file).load(&db); let index = semantic_index(&db, file); let global_table = index.place_table(FileScopeId::global()); @@ -1170,7 +1194,10 @@ def func[T](): }; assert_eq!(ann_scope.kind(), ScopeKind::Annotation); - assert_eq!(ann_scope_id.to_scope_id(&db, file).name(&db), "func"); + assert_eq!( + ann_scope_id.to_scope_id(&db, file).name(&db, &module), + "func" + ); let ann_table = index.place_table(ann_scope_id); assert_eq!(names(&ann_table), vec!["T"]); @@ -1180,7 +1207,10 @@ def func[T](): panic!("expected one child scope"); }; assert_eq!(func_scope.kind(), ScopeKind::Function); - assert_eq!(func_scope_id.to_scope_id(&db, file).name(&db), "func"); + assert_eq!( + func_scope_id.to_scope_id(&db, file).name(&db, &module), + "func" + ); let func_table = index.place_table(func_scope_id); assert_eq!(names(&func_table), vec!["x"]); } @@ -1194,6 +1224,7 @@ class C[T]: ", ); + let module = parsed_module(&db, file).load(&db); let index = semantic_index(&db, file); let global_table = index.place_table(FileScopeId::global()); @@ -1207,7 +1238,7 @@ class C[T]: }; assert_eq!(ann_scope.kind(), ScopeKind::Annotation); - assert_eq!(ann_scope_id.to_scope_id(&db, file).name(&db), "C"); + assert_eq!(ann_scope_id.to_scope_id(&db, file).name(&db, &module), "C"); let ann_table = index.place_table(ann_scope_id); assert_eq!(names(&ann_table), vec!["T"]); assert!( @@ -1224,16 +1255,19 @@ class C[T]: }; assert_eq!(class_scope.kind(), ScopeKind::Class); - assert_eq!(class_scope_id.to_scope_id(&db, file).name(&db), "C"); + assert_eq!( + class_scope_id.to_scope_id(&db, file).name(&db, &module), + "C" + ); assert_eq!(names(&index.place_table(class_scope_id)), vec!["x"]); } #[test] fn reachability_trivial() { let TestCase { db, file } = test_case("x = 1; x"); - let parsed = parsed_module(&db, file); + let module = parsed_module(&db, file).load(&db); let scope = global_scope(&db, file); - let ast = parsed.syntax(); + let ast = module.syntax(); let ast::Stmt::Expr(ast::StmtExpr { value: x_use_expr, .. }) = &ast.body[1] @@ -1252,7 +1286,7 @@ class C[T]: let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { value: ast::Number::Int(num), .. - }) = assignment.value() + }) = assignment.value(&module) else { panic!("should be a number literal") }; @@ -1264,8 +1298,8 @@ class C[T]: let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4"); let index = semantic_index(&db, file); - let parsed = parsed_module(&db, file); - let ast = parsed.syntax(); + let module = parsed_module(&db, file).load(&db); + let ast = module.syntax(); let x_stmt = ast.body[0].as_assign_stmt().unwrap(); let x = &x_stmt.targets[0]; @@ -1282,14 +1316,15 @@ class C[T]: #[test] fn scope_iterators() { - fn scope_names<'a>( - scopes: impl Iterator, - db: &'a dyn Db, + fn scope_names<'a, 'db>( + scopes: impl Iterator, + db: &'db dyn Db, file: File, + module: &'a ParsedModuleRef, ) -> Vec<&'a str> { scopes .into_iter() - .map(|(scope_id, _)| scope_id.to_scope_id(db, file).name(db)) + .map(|(scope_id, _)| scope_id.to_scope_id(db, file).name(db, module)) .collect() } @@ -1306,21 +1341,22 @@ def x(): pass", ); + let module = parsed_module(&db, file).load(&db); let index = semantic_index(&db, file); let descendants = index.descendent_scopes(FileScopeId::global()); assert_eq!( - scope_names(descendants, &db, file), + scope_names(descendants, &db, file, &module), vec!["Test", "foo", "bar", "baz", "x"] ); let children = index.child_scopes(FileScopeId::global()); - assert_eq!(scope_names(children, &db, file), vec!["Test", "x"]); + assert_eq!(scope_names(children, &db, file, &module), vec!["Test", "x"]); let test_class = index.child_scopes(FileScopeId::global()).next().unwrap().0; let test_child_scopes = index.child_scopes(test_class); assert_eq!( - scope_names(test_child_scopes, &db, file), + scope_names(test_child_scopes, &db, file, &module), vec!["foo", "baz"] ); @@ -1332,7 +1368,7 @@ def x(): let ancestors = index.ancestor_scopes(bar_scope); assert_eq!( - scope_names(ancestors, &db, file), + scope_names(ancestors, &db, file, &module), vec!["bar", "foo", "Test", ""] ); } diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index db4cf3da50..bb5f40d3be 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -5,7 +5,7 @@ use except_handlers::TryNodeContextStackManager; use rustc_hash::{FxHashMap, FxHashSet}; use ruff_db::files::File; -use ruff_db::parsed::ParsedModule; +use ruff_db::parsed::ParsedModuleRef; use ruff_db::source::{SourceText, source_text}; use ruff_index::IndexVec; use ruff_python_ast::name::Name; @@ -69,20 +69,20 @@ struct ScopeInfo { current_loop: Option, } -pub(super) struct SemanticIndexBuilder<'db> { +pub(super) struct SemanticIndexBuilder<'db, 'ast> { // Builder state db: &'db dyn Db, file: File, source_type: PySourceType, - module: &'db ParsedModule, + module: &'ast ParsedModuleRef, scope_stack: Vec, /// The assignments we're currently visiting, with /// the most recent visit at the end of the Vec - current_assignments: Vec>, + current_assignments: Vec>, /// The match case we're currently visiting. - current_match_case: Option>, + current_match_case: Option>, /// The name of the first function parameter of the innermost function that we're currently visiting. - current_first_parameter_name: Option<&'db str>, + current_first_parameter_name: Option<&'ast str>, /// Per-scope contexts regarding nested `try`/`except` statements try_node_context_stack_manager: TryNodeContextStackManager, @@ -116,13 +116,13 @@ pub(super) struct SemanticIndexBuilder<'db> { semantic_syntax_errors: RefCell>, } -impl<'db> SemanticIndexBuilder<'db> { - pub(super) fn new(db: &'db dyn Db, file: File, parsed: &'db ParsedModule) -> Self { +impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { + pub(super) fn new(db: &'db dyn Db, file: File, module_ref: &'ast ParsedModuleRef) -> Self { let mut builder = Self { db, file, source_type: file.source_type(db.upcast()), - module: parsed, + module: module_ref, scope_stack: Vec::new(), current_assignments: vec![], current_match_case: None, @@ -423,7 +423,7 @@ impl<'db> SemanticIndexBuilder<'db> { fn add_definition( &mut self, place: ScopedPlaceId, - definition_node: impl Into> + std::fmt::Debug + Copy, + definition_node: impl Into> + std::fmt::Debug + Copy, ) -> Definition<'db> { let (definition, num_definitions) = self.push_additional_definition(place, definition_node); debug_assert_eq!( @@ -463,16 +463,18 @@ impl<'db> SemanticIndexBuilder<'db> { fn push_additional_definition( &mut self, place: ScopedPlaceId, - definition_node: impl Into>, + definition_node: impl Into>, ) -> (Definition<'db>, usize) { - let definition_node: DefinitionNodeRef<'_> = definition_node.into(); + let definition_node: DefinitionNodeRef<'ast, 'db> = definition_node.into(); + #[expect(unsafe_code)] // SAFETY: `definition_node` is guaranteed to be a child of `self.module` let kind = unsafe { definition_node.into_owned(self.module.clone()) }; - let category = kind.category(self.source_type.is_stub()); + + let category = kind.category(self.source_type.is_stub(), self.module); let is_reexported = kind.is_reexported(); - let definition = Definition::new( + let definition: Definition<'db> = Definition::new( self.db, self.file, self.current_scope(), @@ -658,7 +660,7 @@ impl<'db> SemanticIndexBuilder<'db> { .record_reachability_constraint(negated_constraint); } - fn push_assignment(&mut self, assignment: CurrentAssignment<'db>) { + fn push_assignment(&mut self, assignment: CurrentAssignment<'ast, 'db>) { self.current_assignments.push(assignment); } @@ -667,11 +669,11 @@ impl<'db> SemanticIndexBuilder<'db> { debug_assert!(popped_assignment.is_some()); } - fn current_assignment(&self) -> Option> { + fn current_assignment(&self) -> Option> { self.current_assignments.last().copied() } - fn current_assignment_mut(&mut self) -> Option<&mut CurrentAssignment<'db>> { + fn current_assignment_mut(&mut self) -> Option<&mut CurrentAssignment<'ast, 'db>> { self.current_assignments.last_mut() } @@ -792,7 +794,7 @@ impl<'db> SemanticIndexBuilder<'db> { fn with_type_params( &mut self, with_scope: NodeWithScopeRef, - type_params: Option<&'db ast::TypeParams>, + type_params: Option<&'ast ast::TypeParams>, nested: impl FnOnce(&mut Self) -> FileScopeId, ) -> FileScopeId { if let Some(type_params) = type_params { @@ -858,7 +860,7 @@ impl<'db> SemanticIndexBuilder<'db> { fn with_generators_scope( &mut self, scope: NodeWithScopeRef, - generators: &'db [ast::Comprehension], + generators: &'ast [ast::Comprehension], visit_outer_elt: impl FnOnce(&mut Self), ) { let mut generators_iter = generators.iter(); @@ -908,7 +910,7 @@ impl<'db> SemanticIndexBuilder<'db> { self.pop_scope(); } - fn declare_parameters(&mut self, parameters: &'db ast::Parameters) { + fn declare_parameters(&mut self, parameters: &'ast ast::Parameters) { for parameter in parameters.iter_non_variadic_params() { self.declare_parameter(parameter); } @@ -925,7 +927,7 @@ impl<'db> SemanticIndexBuilder<'db> { } } - fn declare_parameter(&mut self, parameter: &'db ast::ParameterWithDefault) { + fn declare_parameter(&mut self, parameter: &'ast ast::ParameterWithDefault) { let symbol = self.add_symbol(parameter.name().id().clone()); let definition = self.add_definition(symbol, parameter); @@ -946,8 +948,8 @@ impl<'db> SemanticIndexBuilder<'db> { /// for statements, etc. fn add_unpackable_assignment( &mut self, - unpackable: &Unpackable<'db>, - target: &'db ast::Expr, + unpackable: &Unpackable<'ast>, + target: &'ast ast::Expr, value: Expression<'db>, ) { // We only handle assignments to names and unpackings here, other targets like @@ -1010,8 +1012,7 @@ impl<'db> SemanticIndexBuilder<'db> { } pub(super) fn build(mut self) -> SemanticIndex<'db> { - let module = self.module; - self.visit_body(module.suite()); + self.visit_body(self.module.suite()); // Pop the root scope self.pop_scope(); @@ -1081,10 +1082,7 @@ impl<'db> SemanticIndexBuilder<'db> { } } -impl<'db, 'ast> Visitor<'ast> for SemanticIndexBuilder<'db> -where - 'ast: 'db, -{ +impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) { self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context)); @@ -2299,7 +2297,7 @@ where } } -impl SemanticSyntaxContext for SemanticIndexBuilder<'_> { +impl SemanticSyntaxContext for SemanticIndexBuilder<'_, '_> { fn future_annotations_or_stub(&self) -> bool { self.has_future_annotations } @@ -2324,7 +2322,7 @@ impl SemanticSyntaxContext for SemanticIndexBuilder<'_> { match scope.kind() { ScopeKind::Class | ScopeKind::Lambda => return false, ScopeKind::Function => { - return scope.node().expect_function().is_async; + return scope.node().expect_function(self.module).is_async; } ScopeKind::Comprehension | ScopeKind::Module @@ -2366,9 +2364,9 @@ impl SemanticSyntaxContext for SemanticIndexBuilder<'_> { for scope_info in self.scope_stack.iter().rev() { let scope = &self.scopes[scope_info.file_scope_id]; let generators = match scope.node() { - NodeWithScopeKind::ListComprehension(node) => &node.generators, - NodeWithScopeKind::SetComprehension(node) => &node.generators, - NodeWithScopeKind::DictComprehension(node) => &node.generators, + NodeWithScopeKind::ListComprehension(node) => &node.node(self.module).generators, + NodeWithScopeKind::SetComprehension(node) => &node.node(self.module).generators, + NodeWithScopeKind::DictComprehension(node) => &node.node(self.module).generators, _ => continue, }; if generators @@ -2409,31 +2407,31 @@ impl SemanticSyntaxContext for SemanticIndexBuilder<'_> { } #[derive(Copy, Clone, Debug, PartialEq)] -enum CurrentAssignment<'a> { +enum CurrentAssignment<'ast, 'db> { Assign { - node: &'a ast::StmtAssign, - unpack: Option<(UnpackPosition, Unpack<'a>)>, + node: &'ast ast::StmtAssign, + unpack: Option<(UnpackPosition, Unpack<'db>)>, }, - AnnAssign(&'a ast::StmtAnnAssign), - AugAssign(&'a ast::StmtAugAssign), + AnnAssign(&'ast ast::StmtAnnAssign), + AugAssign(&'ast ast::StmtAugAssign), For { - node: &'a ast::StmtFor, - unpack: Option<(UnpackPosition, Unpack<'a>)>, + node: &'ast ast::StmtFor, + unpack: Option<(UnpackPosition, Unpack<'db>)>, }, - Named(&'a ast::ExprNamed), + Named(&'ast ast::ExprNamed), Comprehension { - node: &'a ast::Comprehension, + node: &'ast ast::Comprehension, first: bool, - unpack: Option<(UnpackPosition, Unpack<'a>)>, + unpack: Option<(UnpackPosition, Unpack<'db>)>, }, WithItem { - item: &'a ast::WithItem, + item: &'ast ast::WithItem, is_async: bool, - unpack: Option<(UnpackPosition, Unpack<'a>)>, + unpack: Option<(UnpackPosition, Unpack<'db>)>, }, } -impl CurrentAssignment<'_> { +impl CurrentAssignment<'_, '_> { fn unpack_position_mut(&mut self) -> Option<&mut UnpackPosition> { match self { Self::Assign { unpack, .. } @@ -2445,28 +2443,28 @@ impl CurrentAssignment<'_> { } } -impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> { - fn from(value: &'a ast::StmtAnnAssign) -> Self { +impl<'ast> From<&'ast ast::StmtAnnAssign> for CurrentAssignment<'ast, '_> { + fn from(value: &'ast ast::StmtAnnAssign) -> Self { Self::AnnAssign(value) } } -impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> { - fn from(value: &'a ast::StmtAugAssign) -> Self { +impl<'ast> From<&'ast ast::StmtAugAssign> for CurrentAssignment<'ast, '_> { + fn from(value: &'ast ast::StmtAugAssign) -> Self { Self::AugAssign(value) } } -impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> { - fn from(value: &'a ast::ExprNamed) -> Self { +impl<'ast> From<&'ast ast::ExprNamed> for CurrentAssignment<'ast, '_> { + fn from(value: &'ast ast::ExprNamed) -> Self { Self::Named(value) } } #[derive(Debug, PartialEq)] -struct CurrentMatchCase<'a> { +struct CurrentMatchCase<'ast> { /// The pattern that's part of the current match case. - pattern: &'a ast::Pattern, + pattern: &'ast ast::Pattern, /// The index of the sub-pattern that's being currently visited within the pattern. /// @@ -2488,20 +2486,20 @@ impl<'a> CurrentMatchCase<'a> { } } -enum Unpackable<'a> { - Assign(&'a ast::StmtAssign), - For(&'a ast::StmtFor), +enum Unpackable<'ast> { + Assign(&'ast ast::StmtAssign), + For(&'ast ast::StmtFor), WithItem { - item: &'a ast::WithItem, + item: &'ast ast::WithItem, is_async: bool, }, Comprehension { first: bool, - node: &'a ast::Comprehension, + node: &'ast ast::Comprehension, }, } -impl<'a> Unpackable<'a> { +impl<'ast> Unpackable<'ast> { const fn kind(&self) -> UnpackKind { match self { Unpackable::Assign(_) => UnpackKind::Assign, @@ -2510,7 +2508,10 @@ impl<'a> Unpackable<'a> { } } - fn as_current_assignment(&self, unpack: Option>) -> CurrentAssignment<'a> { + fn as_current_assignment<'db>( + &self, + unpack: Option>, + ) -> CurrentAssignment<'ast, 'db> { let unpack = unpack.map(|unpack| (UnpackPosition::First, unpack)); match self { Unpackable::Assign(stmt) => CurrentAssignment::Assign { node: stmt, unpack }, diff --git a/crates/ty_python_semantic/src/semantic_index/definition.rs b/crates/ty_python_semantic/src/semantic_index/definition.rs index 3adbeb68c6..391bbf87eb 100644 --- a/crates/ty_python_semantic/src/semantic_index/definition.rs +++ b/crates/ty_python_semantic/src/semantic_index/definition.rs @@ -1,7 +1,7 @@ use std::ops::Deref; use ruff_db::files::{File, FileRange}; -use ruff_db::parsed::ParsedModule; +use ruff_db::parsed::ParsedModuleRef; use ruff_python_ast as ast; use ruff_text_size::{Ranged, TextRange}; @@ -49,12 +49,12 @@ impl<'db> Definition<'db> { self.file_scope(db).to_scope_id(db, self.file(db)) } - pub fn full_range(self, db: &'db dyn Db) -> FileRange { - FileRange::new(self.file(db), self.kind(db).full_range()) + pub fn full_range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> FileRange { + FileRange::new(self.file(db), self.kind(db).full_range(module)) } - pub fn focus_range(self, db: &'db dyn Db) -> FileRange { - FileRange::new(self.file(db), self.kind(db).target_range()) + pub fn focus_range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> FileRange { + FileRange::new(self.file(db), self.kind(db).target_range(module)) } } @@ -123,218 +123,218 @@ impl<'db> DefinitionState<'db> { } #[derive(Copy, Clone, Debug)] -pub(crate) enum DefinitionNodeRef<'a> { - Import(ImportDefinitionNodeRef<'a>), - ImportFrom(ImportFromDefinitionNodeRef<'a>), - ImportStar(StarImportDefinitionNodeRef<'a>), - For(ForStmtDefinitionNodeRef<'a>), - Function(&'a ast::StmtFunctionDef), - Class(&'a ast::StmtClassDef), - TypeAlias(&'a ast::StmtTypeAlias), - NamedExpression(&'a ast::ExprNamed), - Assignment(AssignmentDefinitionNodeRef<'a>), - AnnotatedAssignment(AnnotatedAssignmentDefinitionNodeRef<'a>), - AugmentedAssignment(&'a ast::StmtAugAssign), - Comprehension(ComprehensionDefinitionNodeRef<'a>), - VariadicPositionalParameter(&'a ast::Parameter), - VariadicKeywordParameter(&'a ast::Parameter), - Parameter(&'a ast::ParameterWithDefault), - WithItem(WithItemDefinitionNodeRef<'a>), - MatchPattern(MatchPatternDefinitionNodeRef<'a>), - ExceptHandler(ExceptHandlerDefinitionNodeRef<'a>), - TypeVar(&'a ast::TypeParamTypeVar), - ParamSpec(&'a ast::TypeParamParamSpec), - TypeVarTuple(&'a ast::TypeParamTypeVarTuple), +pub(crate) enum DefinitionNodeRef<'ast, 'db> { + Import(ImportDefinitionNodeRef<'ast>), + ImportFrom(ImportFromDefinitionNodeRef<'ast>), + ImportStar(StarImportDefinitionNodeRef<'ast>), + For(ForStmtDefinitionNodeRef<'ast, 'db>), + Function(&'ast ast::StmtFunctionDef), + Class(&'ast ast::StmtClassDef), + TypeAlias(&'ast ast::StmtTypeAlias), + NamedExpression(&'ast ast::ExprNamed), + Assignment(AssignmentDefinitionNodeRef<'ast, 'db>), + AnnotatedAssignment(AnnotatedAssignmentDefinitionNodeRef<'ast>), + AugmentedAssignment(&'ast ast::StmtAugAssign), + Comprehension(ComprehensionDefinitionNodeRef<'ast, 'db>), + VariadicPositionalParameter(&'ast ast::Parameter), + VariadicKeywordParameter(&'ast ast::Parameter), + Parameter(&'ast ast::ParameterWithDefault), + WithItem(WithItemDefinitionNodeRef<'ast, 'db>), + MatchPattern(MatchPatternDefinitionNodeRef<'ast>), + ExceptHandler(ExceptHandlerDefinitionNodeRef<'ast>), + TypeVar(&'ast ast::TypeParamTypeVar), + ParamSpec(&'ast ast::TypeParamParamSpec), + TypeVarTuple(&'ast ast::TypeParamTypeVarTuple), } -impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> { - fn from(node: &'a ast::StmtFunctionDef) -> Self { +impl<'ast> From<&'ast ast::StmtFunctionDef> for DefinitionNodeRef<'ast, '_> { + fn from(node: &'ast ast::StmtFunctionDef) -> Self { Self::Function(node) } } -impl<'a> From<&'a ast::StmtClassDef> for DefinitionNodeRef<'a> { - fn from(node: &'a ast::StmtClassDef) -> Self { +impl<'ast> From<&'ast ast::StmtClassDef> for DefinitionNodeRef<'ast, '_> { + fn from(node: &'ast ast::StmtClassDef) -> Self { Self::Class(node) } } -impl<'a> From<&'a ast::StmtTypeAlias> for DefinitionNodeRef<'a> { - fn from(node: &'a ast::StmtTypeAlias) -> Self { +impl<'ast> From<&'ast ast::StmtTypeAlias> for DefinitionNodeRef<'ast, '_> { + fn from(node: &'ast ast::StmtTypeAlias) -> Self { Self::TypeAlias(node) } } -impl<'a> From<&'a ast::ExprNamed> for DefinitionNodeRef<'a> { - fn from(node: &'a ast::ExprNamed) -> Self { +impl<'ast> From<&'ast ast::ExprNamed> for DefinitionNodeRef<'ast, '_> { + fn from(node: &'ast ast::ExprNamed) -> Self { Self::NamedExpression(node) } } -impl<'a> From<&'a ast::StmtAugAssign> for DefinitionNodeRef<'a> { - fn from(node: &'a ast::StmtAugAssign) -> Self { +impl<'ast> From<&'ast ast::StmtAugAssign> for DefinitionNodeRef<'ast, '_> { + fn from(node: &'ast ast::StmtAugAssign) -> Self { Self::AugmentedAssignment(node) } } -impl<'a> From<&'a ast::TypeParamTypeVar> for DefinitionNodeRef<'a> { - fn from(value: &'a ast::TypeParamTypeVar) -> Self { +impl<'ast> From<&'ast ast::TypeParamTypeVar> for DefinitionNodeRef<'ast, '_> { + fn from(value: &'ast ast::TypeParamTypeVar) -> Self { Self::TypeVar(value) } } -impl<'a> From<&'a ast::TypeParamParamSpec> for DefinitionNodeRef<'a> { - fn from(value: &'a ast::TypeParamParamSpec) -> Self { +impl<'ast> From<&'ast ast::TypeParamParamSpec> for DefinitionNodeRef<'ast, '_> { + fn from(value: &'ast ast::TypeParamParamSpec) -> Self { Self::ParamSpec(value) } } -impl<'a> From<&'a ast::TypeParamTypeVarTuple> for DefinitionNodeRef<'a> { - fn from(value: &'a ast::TypeParamTypeVarTuple) -> Self { +impl<'ast> From<&'ast ast::TypeParamTypeVarTuple> for DefinitionNodeRef<'ast, '_> { + fn from(value: &'ast ast::TypeParamTypeVarTuple) -> Self { Self::TypeVarTuple(value) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(node_ref: ImportDefinitionNodeRef<'a>) -> Self { +impl<'ast> From> for DefinitionNodeRef<'ast, '_> { + fn from(node_ref: ImportDefinitionNodeRef<'ast>) -> Self { Self::Import(node_ref) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(node_ref: ImportFromDefinitionNodeRef<'a>) -> Self { +impl<'ast> From> for DefinitionNodeRef<'ast, '_> { + fn from(node_ref: ImportFromDefinitionNodeRef<'ast>) -> Self { Self::ImportFrom(node_ref) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(value: ForStmtDefinitionNodeRef<'a>) -> Self { +impl<'ast, 'db> From> for DefinitionNodeRef<'ast, 'db> { + fn from(value: ForStmtDefinitionNodeRef<'ast, 'db>) -> Self { Self::For(value) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(node_ref: AssignmentDefinitionNodeRef<'a>) -> Self { +impl<'ast, 'db> From> for DefinitionNodeRef<'ast, 'db> { + fn from(node_ref: AssignmentDefinitionNodeRef<'ast, 'db>) -> Self { Self::Assignment(node_ref) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(node_ref: AnnotatedAssignmentDefinitionNodeRef<'a>) -> Self { +impl<'ast> From> for DefinitionNodeRef<'ast, '_> { + fn from(node_ref: AnnotatedAssignmentDefinitionNodeRef<'ast>) -> Self { Self::AnnotatedAssignment(node_ref) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(node_ref: WithItemDefinitionNodeRef<'a>) -> Self { +impl<'ast, 'db> From> for DefinitionNodeRef<'ast, 'db> { + fn from(node_ref: WithItemDefinitionNodeRef<'ast, 'db>) -> Self { Self::WithItem(node_ref) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(node: ComprehensionDefinitionNodeRef<'a>) -> Self { +impl<'ast, 'db> From> for DefinitionNodeRef<'ast, 'db> { + fn from(node: ComprehensionDefinitionNodeRef<'ast, 'db>) -> Self { Self::Comprehension(node) } } -impl<'a> From<&'a ast::ParameterWithDefault> for DefinitionNodeRef<'a> { - fn from(node: &'a ast::ParameterWithDefault) -> Self { +impl<'ast> From<&'ast ast::ParameterWithDefault> for DefinitionNodeRef<'ast, '_> { + fn from(node: &'ast ast::ParameterWithDefault) -> Self { Self::Parameter(node) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(node: MatchPatternDefinitionNodeRef<'a>) -> Self { +impl<'ast> From> for DefinitionNodeRef<'ast, '_> { + fn from(node: MatchPatternDefinitionNodeRef<'ast>) -> Self { Self::MatchPattern(node) } } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(node: StarImportDefinitionNodeRef<'a>) -> Self { +impl<'ast> From> for DefinitionNodeRef<'ast, '_> { + fn from(node: StarImportDefinitionNodeRef<'ast>) -> Self { Self::ImportStar(node) } } #[derive(Copy, Clone, Debug)] -pub(crate) struct ImportDefinitionNodeRef<'a> { - pub(crate) node: &'a ast::StmtImport, +pub(crate) struct ImportDefinitionNodeRef<'ast> { + pub(crate) node: &'ast ast::StmtImport, pub(crate) alias_index: usize, pub(crate) is_reexported: bool, } #[derive(Copy, Clone, Debug)] -pub(crate) struct StarImportDefinitionNodeRef<'a> { - pub(crate) node: &'a ast::StmtImportFrom, +pub(crate) struct StarImportDefinitionNodeRef<'ast> { + pub(crate) node: &'ast ast::StmtImportFrom, pub(crate) place_id: ScopedPlaceId, } #[derive(Copy, Clone, Debug)] -pub(crate) struct ImportFromDefinitionNodeRef<'a> { - pub(crate) node: &'a ast::StmtImportFrom, +pub(crate) struct ImportFromDefinitionNodeRef<'ast> { + pub(crate) node: &'ast ast::StmtImportFrom, pub(crate) alias_index: usize, pub(crate) is_reexported: bool, } #[derive(Copy, Clone, Debug)] -pub(crate) struct AssignmentDefinitionNodeRef<'a> { - pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>, - pub(crate) value: &'a ast::Expr, - pub(crate) target: &'a ast::Expr, +pub(crate) struct AssignmentDefinitionNodeRef<'ast, 'db> { + pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>, + pub(crate) value: &'ast ast::Expr, + pub(crate) target: &'ast ast::Expr, } #[derive(Copy, Clone, Debug)] -pub(crate) struct AnnotatedAssignmentDefinitionNodeRef<'a> { - pub(crate) node: &'a ast::StmtAnnAssign, - pub(crate) annotation: &'a ast::Expr, - pub(crate) value: Option<&'a ast::Expr>, - pub(crate) target: &'a ast::Expr, +pub(crate) struct AnnotatedAssignmentDefinitionNodeRef<'ast> { + pub(crate) node: &'ast ast::StmtAnnAssign, + pub(crate) annotation: &'ast ast::Expr, + pub(crate) value: Option<&'ast ast::Expr>, + pub(crate) target: &'ast ast::Expr, } #[derive(Copy, Clone, Debug)] -pub(crate) struct WithItemDefinitionNodeRef<'a> { - pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>, - pub(crate) context_expr: &'a ast::Expr, - pub(crate) target: &'a ast::Expr, +pub(crate) struct WithItemDefinitionNodeRef<'ast, 'db> { + pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>, + pub(crate) context_expr: &'ast ast::Expr, + pub(crate) target: &'ast ast::Expr, pub(crate) is_async: bool, } #[derive(Copy, Clone, Debug)] -pub(crate) struct ForStmtDefinitionNodeRef<'a> { - pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>, - pub(crate) iterable: &'a ast::Expr, - pub(crate) target: &'a ast::Expr, +pub(crate) struct ForStmtDefinitionNodeRef<'ast, 'db> { + pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>, + pub(crate) iterable: &'ast ast::Expr, + pub(crate) target: &'ast ast::Expr, pub(crate) is_async: bool, } #[derive(Copy, Clone, Debug)] -pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> { - pub(crate) handler: &'a ast::ExceptHandlerExceptHandler, +pub(crate) struct ExceptHandlerDefinitionNodeRef<'ast> { + pub(crate) handler: &'ast ast::ExceptHandlerExceptHandler, pub(crate) is_star: bool, } #[derive(Copy, Clone, Debug)] -pub(crate) struct ComprehensionDefinitionNodeRef<'a> { - pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>, - pub(crate) iterable: &'a ast::Expr, - pub(crate) target: &'a ast::Expr, +pub(crate) struct ComprehensionDefinitionNodeRef<'ast, 'db> { + pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>, + pub(crate) iterable: &'ast ast::Expr, + pub(crate) target: &'ast ast::Expr, pub(crate) first: bool, pub(crate) is_async: bool, } #[derive(Copy, Clone, Debug)] -pub(crate) struct MatchPatternDefinitionNodeRef<'a> { +pub(crate) struct MatchPatternDefinitionNodeRef<'ast> { /// The outermost pattern node in which the identifier being defined occurs. - pub(crate) pattern: &'a ast::Pattern, + pub(crate) pattern: &'ast ast::Pattern, /// The identifier being defined. - pub(crate) identifier: &'a ast::Identifier, + pub(crate) identifier: &'ast ast::Identifier, /// The index of the identifier in the pattern when visiting the `pattern` node in evaluation /// order. pub(crate) index: u32, } -impl<'db> DefinitionNodeRef<'db> { +impl<'db> DefinitionNodeRef<'_, 'db> { #[expect(unsafe_code)] - pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind<'db> { + pub(super) unsafe fn into_owned(self, parsed: ParsedModuleRef) -> DefinitionKind<'db> { match self { DefinitionNodeRef::Import(ImportDefinitionNodeRef { node, @@ -626,60 +626,74 @@ impl DefinitionKind<'_> { /// /// A definition target would mainly be the node representing the place being defined i.e., /// [`ast::ExprName`], [`ast::Identifier`], [`ast::ExprAttribute`] or [`ast::ExprSubscript`] but could also be other nodes. - pub(crate) fn target_range(&self) -> TextRange { + pub(crate) fn target_range(&self, module: &ParsedModuleRef) -> TextRange { match self { - DefinitionKind::Import(import) => import.alias().range(), - DefinitionKind::ImportFrom(import) => import.alias().range(), - DefinitionKind::StarImport(import) => import.alias().range(), - DefinitionKind::Function(function) => function.name.range(), - DefinitionKind::Class(class) => class.name.range(), - DefinitionKind::TypeAlias(type_alias) => type_alias.name.range(), - DefinitionKind::NamedExpression(named) => named.target.range(), - DefinitionKind::Assignment(assignment) => assignment.target.range(), - DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(), - DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.target.range(), - DefinitionKind::For(for_stmt) => for_stmt.target.range(), - DefinitionKind::Comprehension(comp) => comp.target().range(), - DefinitionKind::VariadicPositionalParameter(parameter) => parameter.name.range(), - DefinitionKind::VariadicKeywordParameter(parameter) => parameter.name.range(), - DefinitionKind::Parameter(parameter) => parameter.parameter.name.range(), - DefinitionKind::WithItem(with_item) => with_item.target.range(), - DefinitionKind::MatchPattern(match_pattern) => match_pattern.identifier.range(), - DefinitionKind::ExceptHandler(handler) => handler.node().range(), - DefinitionKind::TypeVar(type_var) => type_var.name.range(), - DefinitionKind::ParamSpec(param_spec) => param_spec.name.range(), - DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.name.range(), + DefinitionKind::Import(import) => import.alias(module).range(), + DefinitionKind::ImportFrom(import) => import.alias(module).range(), + DefinitionKind::StarImport(import) => import.alias(module).range(), + DefinitionKind::Function(function) => function.node(module).name.range(), + DefinitionKind::Class(class) => class.node(module).name.range(), + DefinitionKind::TypeAlias(type_alias) => type_alias.node(module).name.range(), + DefinitionKind::NamedExpression(named) => named.node(module).target.range(), + DefinitionKind::Assignment(assignment) => assignment.target.node(module).range(), + DefinitionKind::AnnotatedAssignment(assign) => assign.target.node(module).range(), + DefinitionKind::AugmentedAssignment(aug_assign) => { + aug_assign.node(module).target.range() + } + DefinitionKind::For(for_stmt) => for_stmt.target.node(module).range(), + DefinitionKind::Comprehension(comp) => comp.target(module).range(), + DefinitionKind::VariadicPositionalParameter(parameter) => { + parameter.node(module).name.range() + } + DefinitionKind::VariadicKeywordParameter(parameter) => { + parameter.node(module).name.range() + } + DefinitionKind::Parameter(parameter) => parameter.node(module).parameter.name.range(), + DefinitionKind::WithItem(with_item) => with_item.target.node(module).range(), + DefinitionKind::MatchPattern(match_pattern) => { + match_pattern.identifier.node(module).range() + } + DefinitionKind::ExceptHandler(handler) => handler.node(module).range(), + DefinitionKind::TypeVar(type_var) => type_var.node(module).name.range(), + DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).name.range(), + DefinitionKind::TypeVarTuple(type_var_tuple) => { + type_var_tuple.node(module).name.range() + } } } /// Returns the [`TextRange`] of the entire definition. - pub(crate) fn full_range(&self) -> TextRange { + pub(crate) fn full_range(&self, module: &ParsedModuleRef) -> TextRange { match self { - DefinitionKind::Import(import) => import.alias().range(), - DefinitionKind::ImportFrom(import) => import.alias().range(), - DefinitionKind::StarImport(import) => import.import().range(), - DefinitionKind::Function(function) => function.range(), - DefinitionKind::Class(class) => class.range(), - DefinitionKind::TypeAlias(type_alias) => type_alias.range(), - DefinitionKind::NamedExpression(named) => named.range(), - DefinitionKind::Assignment(assignment) => assignment.target.range(), - DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(), - DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.range(), - DefinitionKind::For(for_stmt) => for_stmt.target.range(), - DefinitionKind::Comprehension(comp) => comp.target().range(), - DefinitionKind::VariadicPositionalParameter(parameter) => parameter.range(), - DefinitionKind::VariadicKeywordParameter(parameter) => parameter.range(), - DefinitionKind::Parameter(parameter) => parameter.parameter.range(), - DefinitionKind::WithItem(with_item) => with_item.target.range(), - DefinitionKind::MatchPattern(match_pattern) => match_pattern.identifier.range(), - DefinitionKind::ExceptHandler(handler) => handler.node().range(), - DefinitionKind::TypeVar(type_var) => type_var.range(), - DefinitionKind::ParamSpec(param_spec) => param_spec.range(), - DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.range(), + DefinitionKind::Import(import) => import.alias(module).range(), + DefinitionKind::ImportFrom(import) => import.alias(module).range(), + DefinitionKind::StarImport(import) => import.import(module).range(), + DefinitionKind::Function(function) => function.node(module).range(), + DefinitionKind::Class(class) => class.node(module).range(), + DefinitionKind::TypeAlias(type_alias) => type_alias.node(module).range(), + DefinitionKind::NamedExpression(named) => named.node(module).range(), + DefinitionKind::Assignment(assignment) => assignment.target.node(module).range(), + DefinitionKind::AnnotatedAssignment(assign) => assign.target.node(module).range(), + DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.node(module).range(), + DefinitionKind::For(for_stmt) => for_stmt.target.node(module).range(), + DefinitionKind::Comprehension(comp) => comp.target(module).range(), + DefinitionKind::VariadicPositionalParameter(parameter) => { + parameter.node(module).range() + } + DefinitionKind::VariadicKeywordParameter(parameter) => parameter.node(module).range(), + DefinitionKind::Parameter(parameter) => parameter.node(module).parameter.range(), + DefinitionKind::WithItem(with_item) => with_item.target.node(module).range(), + DefinitionKind::MatchPattern(match_pattern) => { + match_pattern.identifier.node(module).range() + } + DefinitionKind::ExceptHandler(handler) => handler.node(module).range(), + DefinitionKind::TypeVar(type_var) => type_var.node(module).range(), + DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).range(), + DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.node(module).range(), } } - pub(crate) fn category(&self, in_stub: bool) -> DefinitionCategory { + pub(crate) fn category(&self, in_stub: bool, module: &ParsedModuleRef) -> DefinitionCategory { match self { // functions, classes, and imports always bind, and we consider them declarations DefinitionKind::Function(_) @@ -694,7 +708,7 @@ impl DefinitionKind<'_> { // a parameter always binds a value, but is only a declaration if annotated DefinitionKind::VariadicPositionalParameter(parameter) | DefinitionKind::VariadicKeywordParameter(parameter) => { - if parameter.annotation.is_some() { + if parameter.node(module).annotation.is_some() { DefinitionCategory::DeclarationAndBinding } else { DefinitionCategory::Binding @@ -702,7 +716,12 @@ impl DefinitionKind<'_> { } // presence of a default is irrelevant, same logic as for a no-default parameter DefinitionKind::Parameter(parameter_with_default) => { - if parameter_with_default.parameter.annotation.is_some() { + if parameter_with_default + .node(module) + .parameter + .annotation + .is_some() + { DefinitionCategory::DeclarationAndBinding } else { DefinitionCategory::Binding @@ -753,15 +772,15 @@ pub struct StarImportDefinitionKind { } impl StarImportDefinitionKind { - pub(crate) fn import(&self) -> &ast::StmtImportFrom { - self.node.node() + pub(crate) fn import<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtImportFrom { + self.node.node(module) } - pub(crate) fn alias(&self) -> &ast::Alias { + pub(crate) fn alias<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Alias { // INVARIANT: for an invalid-syntax statement such as `from foo import *, bar, *`, // we only create a `StarImportDefinitionKind` for the *first* `*` alias in the names list. self.node - .node() + .node(module) .names .iter() .find(|alias| &alias.name == "*") @@ -784,8 +803,8 @@ pub struct MatchPatternDefinitionKind { } impl MatchPatternDefinitionKind { - pub(crate) fn pattern(&self) -> &ast::Pattern { - self.pattern.node() + pub(crate) fn pattern<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Pattern { + self.pattern.node(module) } pub(crate) fn index(&self) -> u32 { @@ -808,16 +827,16 @@ pub struct ComprehensionDefinitionKind<'db> { } impl<'db> ComprehensionDefinitionKind<'db> { - pub(crate) fn iterable(&self) -> &ast::Expr { - self.iterable.node() + pub(crate) fn iterable<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.iterable.node(module) } pub(crate) fn target_kind(&self) -> TargetKind<'db> { self.target_kind } - pub(crate) fn target(&self) -> &ast::Expr { - self.target.node() + pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.target.node(module) } pub(crate) fn is_first(&self) -> bool { @@ -837,12 +856,12 @@ pub struct ImportDefinitionKind { } impl ImportDefinitionKind { - pub(crate) fn import(&self) -> &ast::StmtImport { - self.node.node() + pub(crate) fn import<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtImport { + self.node.node(module) } - pub(crate) fn alias(&self) -> &ast::Alias { - &self.node.node().names[self.alias_index] + pub(crate) fn alias<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Alias { + &self.node.node(module).names[self.alias_index] } pub(crate) fn is_reexported(&self) -> bool { @@ -858,12 +877,12 @@ pub struct ImportFromDefinitionKind { } impl ImportFromDefinitionKind { - pub(crate) fn import(&self) -> &ast::StmtImportFrom { - self.node.node() + pub(crate) fn import<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtImportFrom { + self.node.node(module) } - pub(crate) fn alias(&self) -> &ast::Alias { - &self.node.node().names[self.alias_index] + pub(crate) fn alias<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Alias { + &self.node.node(module).names[self.alias_index] } pub(crate) fn is_reexported(&self) -> bool { @@ -883,12 +902,12 @@ impl<'db> AssignmentDefinitionKind<'db> { self.target_kind } - pub(crate) fn value(&self) -> &ast::Expr { - self.value.node() + pub(crate) fn value<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.value.node(module) } - pub(crate) fn target(&self) -> &ast::Expr { - self.target.node() + pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.target.node(module) } } @@ -900,16 +919,16 @@ pub struct AnnotatedAssignmentDefinitionKind { } impl AnnotatedAssignmentDefinitionKind { - pub(crate) fn value(&self) -> Option<&ast::Expr> { - self.value.as_deref() + pub(crate) fn value<'ast>(&self, module: &'ast ParsedModuleRef) -> Option<&'ast ast::Expr> { + self.value.as_ref().map(|value| value.node(module)) } - pub(crate) fn annotation(&self) -> &ast::Expr { - self.annotation.node() + pub(crate) fn annotation<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.annotation.node(module) } - pub(crate) fn target(&self) -> &ast::Expr { - self.target.node() + pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.target.node(module) } } @@ -922,16 +941,16 @@ pub struct WithItemDefinitionKind<'db> { } impl<'db> WithItemDefinitionKind<'db> { - pub(crate) fn context_expr(&self) -> &ast::Expr { - self.context_expr.node() + pub(crate) fn context_expr<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.context_expr.node(module) } pub(crate) fn target_kind(&self) -> TargetKind<'db> { self.target_kind } - pub(crate) fn target(&self) -> &ast::Expr { - self.target.node() + pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.target.node(module) } pub(crate) const fn is_async(&self) -> bool { @@ -948,16 +967,16 @@ pub struct ForStmtDefinitionKind<'db> { } impl<'db> ForStmtDefinitionKind<'db> { - pub(crate) fn iterable(&self) -> &ast::Expr { - self.iterable.node() + pub(crate) fn iterable<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.iterable.node(module) } pub(crate) fn target_kind(&self) -> TargetKind<'db> { self.target_kind } - pub(crate) fn target(&self) -> &ast::Expr { - self.target.node() + pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr { + self.target.node(module) } pub(crate) const fn is_async(&self) -> bool { @@ -972,12 +991,18 @@ pub struct ExceptHandlerDefinitionKind { } impl ExceptHandlerDefinitionKind { - pub(crate) fn node(&self) -> &ast::ExceptHandlerExceptHandler { - self.handler.node() + pub(crate) fn node<'ast>( + &self, + module: &'ast ParsedModuleRef, + ) -> &'ast ast::ExceptHandlerExceptHandler { + self.handler.node(module) } - pub(crate) fn handled_exceptions(&self) -> Option<&ast::Expr> { - self.node().type_.as_deref() + pub(crate) fn handled_exceptions<'ast>( + &self, + module: &'ast ParsedModuleRef, + ) -> Option<&'ast ast::Expr> { + self.node(module).type_.as_deref() } pub(crate) fn is_star(&self) -> bool { diff --git a/crates/ty_python_semantic/src/semantic_index/expression.rs b/crates/ty_python_semantic/src/semantic_index/expression.rs index 1c5178b244..18a64b54e7 100644 --- a/crates/ty_python_semantic/src/semantic_index/expression.rs +++ b/crates/ty_python_semantic/src/semantic_index/expression.rs @@ -2,6 +2,7 @@ use crate::ast_node_ref::AstNodeRef; use crate::db::Db; use crate::semantic_index::place::{FileScopeId, ScopeId}; use ruff_db::files::File; +use ruff_db::parsed::ParsedModuleRef; use ruff_python_ast as ast; use salsa; @@ -41,8 +42,8 @@ pub(crate) struct Expression<'db> { /// The expression node. #[no_eq] #[tracked] - #[returns(deref)] - pub(crate) node_ref: AstNodeRef, + #[returns(ref)] + pub(crate) _node_ref: AstNodeRef, /// An assignment statement, if this expression is immediately used as the rhs of that /// assignment. @@ -62,6 +63,14 @@ pub(crate) struct Expression<'db> { } impl<'db> Expression<'db> { + pub(crate) fn node_ref<'ast>( + self, + db: &'db dyn Db, + parsed: &'ast ParsedModuleRef, + ) -> &'ast ast::Expr { + self._node_ref(db).node(parsed) + } + pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { self.file_scope(db).to_scope_id(db, self.file(db)) } diff --git a/crates/ty_python_semantic/src/semantic_index/place.rs b/crates/ty_python_semantic/src/semantic_index/place.rs index 4862c61b2a..d8295ef50e 100644 --- a/crates/ty_python_semantic/src/semantic_index/place.rs +++ b/crates/ty_python_semantic/src/semantic_index/place.rs @@ -5,7 +5,7 @@ use std::ops::Range; use bitflags::bitflags; use hashbrown::hash_map::RawEntryMut; use ruff_db::files::File; -use ruff_db::parsed::ParsedModule; +use ruff_db::parsed::ParsedModuleRef; use ruff_index::{IndexVec, newtype_index}; use ruff_python_ast as ast; use ruff_python_ast::name::Name; @@ -381,16 +381,19 @@ impl<'db> ScopeId<'db> { } #[cfg(test)] - pub(crate) fn name(self, db: &'db dyn Db) -> &'db str { + pub(crate) fn name<'ast>(self, db: &'db dyn Db, module: &'ast ParsedModuleRef) -> &'ast str { match self.node(db) { NodeWithScopeKind::Module => "", NodeWithScopeKind::Class(class) | NodeWithScopeKind::ClassTypeParameters(class) => { - class.name.as_str() + class.node(module).name.as_str() } NodeWithScopeKind::Function(function) - | NodeWithScopeKind::FunctionTypeParameters(function) => function.name.as_str(), + | NodeWithScopeKind::FunctionTypeParameters(function) => { + function.node(module).name.as_str() + } NodeWithScopeKind::TypeAlias(type_alias) | NodeWithScopeKind::TypeAliasTypeParameters(type_alias) => type_alias + .node(module) .name .as_name_expr() .map(|name| name.id.as_str()) @@ -778,7 +781,7 @@ impl NodeWithScopeRef<'_> { /// # Safety /// The node wrapped by `self` must be a child of `module`. #[expect(unsafe_code)] - pub(super) unsafe fn to_kind(self, module: ParsedModule) -> NodeWithScopeKind { + pub(super) unsafe fn to_kind(self, module: ParsedModuleRef) -> NodeWithScopeKind { unsafe { match self { NodeWithScopeRef::Module => NodeWithScopeKind::Module, @@ -892,34 +895,46 @@ impl NodeWithScopeKind { } } - pub fn expect_class(&self) -> &ast::StmtClassDef { + pub fn expect_class<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtClassDef { match self { - Self::Class(class) => class.node(), + Self::Class(class) => class.node(module), _ => panic!("expected class"), } } - pub(crate) const fn as_class(&self) -> Option<&ast::StmtClassDef> { + pub(crate) fn as_class<'ast>( + &self, + module: &'ast ParsedModuleRef, + ) -> Option<&'ast ast::StmtClassDef> { match self { - Self::Class(class) => Some(class.node()), + Self::Class(class) => Some(class.node(module)), _ => None, } } - pub fn expect_function(&self) -> &ast::StmtFunctionDef { - self.as_function().expect("expected function") + pub fn expect_function<'ast>( + &self, + module: &'ast ParsedModuleRef, + ) -> &'ast ast::StmtFunctionDef { + self.as_function(module).expect("expected function") } - pub fn expect_type_alias(&self) -> &ast::StmtTypeAlias { + pub fn expect_type_alias<'ast>( + &self, + module: &'ast ParsedModuleRef, + ) -> &'ast ast::StmtTypeAlias { match self { - Self::TypeAlias(type_alias) => type_alias.node(), + Self::TypeAlias(type_alias) => type_alias.node(module), _ => panic!("expected type alias"), } } - pub const fn as_function(&self) -> Option<&ast::StmtFunctionDef> { + pub fn as_function<'ast>( + &self, + module: &'ast ParsedModuleRef, + ) -> Option<&'ast ast::StmtFunctionDef> { match self { - Self::Function(function) => Some(function.node()), + Self::Function(function) => Some(function.node(module)), _ => None, } } diff --git a/crates/ty_python_semantic/src/semantic_index/re_exports.rs b/crates/ty_python_semantic/src/semantic_index/re_exports.rs index 049e751bf0..1f31e05f7d 100644 --- a/crates/ty_python_semantic/src/semantic_index/re_exports.rs +++ b/crates/ty_python_semantic/src/semantic_index/re_exports.rs @@ -45,7 +45,7 @@ fn exports_cycle_initial(_db: &dyn Db, _file: File) -> Box<[Name]> { #[salsa::tracked(returns(deref), cycle_fn=exports_cycle_recover, cycle_initial=exports_cycle_initial)] pub(super) fn exported_names(db: &dyn Db, file: File) -> Box<[Name]> { - let module = parsed_module(db.upcast(), file); + let module = parsed_module(db.upcast(), file).load(db.upcast()); let mut finder = ExportFinder::new(db, file); finder.visit_body(module.suite()); finder.resolve_exports() diff --git a/crates/ty_python_semantic/src/semantic_model.rs b/crates/ty_python_semantic/src/semantic_model.rs index b850d9d6c3..9237e75ee2 100644 --- a/crates/ty_python_semantic/src/semantic_model.rs +++ b/crates/ty_python_semantic/src/semantic_model.rs @@ -232,7 +232,7 @@ mod tests { let foo = system_path_to_file(&db, "/src/foo.py").unwrap(); - let ast = parsed_module(&db, foo); + let ast = parsed_module(&db, foo).load(&db); let function = ast.suite()[0].as_function_def_stmt().unwrap(); let model = SemanticModel::new(&db, foo); @@ -251,7 +251,7 @@ mod tests { let foo = system_path_to_file(&db, "/src/foo.py").unwrap(); - let ast = parsed_module(&db, foo); + let ast = parsed_module(&db, foo).load(&db); let class = ast.suite()[0].as_class_def_stmt().unwrap(); let model = SemanticModel::new(&db, foo); @@ -271,7 +271,7 @@ mod tests { let bar = system_path_to_file(&db, "/src/bar.py").unwrap(); - let ast = parsed_module(&db, bar); + let ast = parsed_module(&db, bar).load(&db); let import = ast.suite()[0].as_import_from_stmt().unwrap(); let alias = &import.names[0]; diff --git a/crates/ty_python_semantic/src/suppression.rs b/crates/ty_python_semantic/src/suppression.rs index 51efb2c72d..26518ec7fa 100644 --- a/crates/ty_python_semantic/src/suppression.rs +++ b/crates/ty_python_semantic/src/suppression.rs @@ -88,7 +88,7 @@ declare_lint! { #[salsa::tracked(returns(ref))] pub(crate) fn suppressions(db: &dyn Db, file: File) -> Suppressions { - let parsed = parsed_module(db.upcast(), file); + let parsed = parsed_module(db.upcast(), file).load(db.upcast()); let source = source_text(db.upcast(), file); let mut builder = SuppressionsBuilder::new(&source, db.lint_registry()); diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 468fd018e9..c6407e0001 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1,5 +1,6 @@ use infer::nearest_enclosing_class; use itertools::Either; +use ruff_db::parsed::parsed_module; use std::slice::Iter; @@ -5065,8 +5066,9 @@ impl<'db> Type<'db> { SpecialFormType::Callable => Ok(CallableType::unknown(db)), SpecialFormType::TypingSelf => { + let module = parsed_module(db.upcast(), scope_id.file(db)).load(db.upcast()); let index = semantic_index(db, scope_id.file(db)); - let Some(class) = nearest_enclosing_class(db, index, scope_id) else { + let Some(class) = nearest_enclosing_class(db, index, scope_id, &module) else { return Err(InvalidTypeExpressionError { fallback_type: Type::unknown(), invalid_expressions: smallvec::smallvec![ @@ -6302,7 +6304,7 @@ impl<'db> ContextManagerError<'db> { fn report_diagnostic( &self, - context: &InferContext<'db>, + context: &InferContext<'db, '_>, context_expression_type: Type<'db>, context_expression_node: ast::AnyNodeRef, ) { @@ -6475,7 +6477,7 @@ impl<'db> IterationError<'db> { /// Reports the diagnostic for this error. fn report_diagnostic( &self, - context: &InferContext<'db>, + context: &InferContext<'db, '_>, iterable_type: Type<'db>, iterable_node: ast::AnyNodeRef, ) { @@ -6951,7 +6953,7 @@ impl<'db> ConstructorCallError<'db> { fn report_diagnostic( &self, - context: &InferContext<'db>, + context: &InferContext<'db, '_>, context_expression_type: Type<'db>, context_expression_node: ast::AnyNodeRef, ) { @@ -7578,7 +7580,8 @@ pub struct PEP695TypeAliasType<'db> { impl<'db> PEP695TypeAliasType<'db> { pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { let scope = self.rhs_scope(db); - let type_alias_stmt_node = scope.node(db).expect_type_alias(); + let module = parsed_module(db.upcast(), scope.file(db)).load(db.upcast()); + let type_alias_stmt_node = scope.node(db).expect_type_alias(&module); semantic_index(db, scope.file(db)).expect_single_definition(type_alias_stmt_node) } @@ -7586,7 +7589,8 @@ impl<'db> PEP695TypeAliasType<'db> { #[salsa::tracked] pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.rhs_scope(db); - let type_alias_stmt_node = scope.node(db).expect_type_alias(); + let module = parsed_module(db.upcast(), scope.file(db)).load(db.upcast()); + let type_alias_stmt_node = scope.node(db).expect_type_alias(&module); let definition = self.definition(db); definition_expression_type(db, definition, &type_alias_stmt_node.value) } @@ -8654,10 +8658,8 @@ pub(crate) mod tests { ); let events = db.take_salsa_events(); - let call = &*parsed_module(&db, bar).syntax().body[1] - .as_assign_stmt() - .unwrap() - .value; + let module = parsed_module(&db, bar).load(&db); + let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; let foo_call = semantic_index(&db, bar).expression(call); assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events); diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 14906fd6da..a7b6e9d94b 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -4,6 +4,7 @@ //! union of types, each of which might contain multiple overloads. use itertools::Itertools; +use ruff_db::parsed::parsed_module; use smallvec::{SmallVec, smallvec}; use super::{ @@ -198,7 +199,11 @@ impl<'db> Bindings<'db> { /// report a single diagnostic if we couldn't match any union element or overload. /// TODO: Update this to add subdiagnostics about how we failed to match each union element and /// overload. - pub(crate) fn report_diagnostics(&self, context: &InferContext<'db>, node: ast::AnyNodeRef) { + pub(crate) fn report_diagnostics( + &self, + context: &InferContext<'db, '_>, + node: ast::AnyNodeRef, + ) { // If all union elements are not callable, report that the union as a whole is not // callable. if self.into_iter().all(|b| !b.is_callable()) { @@ -1367,7 +1372,7 @@ impl<'db> CallableBinding<'db> { fn report_diagnostics( &self, - context: &InferContext<'db>, + context: &InferContext<'db, '_>, node: ast::AnyNodeRef, union_diag: Option<&UnionDiagnostic<'_, '_>>, ) { @@ -1840,7 +1845,7 @@ impl<'db> Binding<'db> { fn report_diagnostics( &self, - context: &InferContext<'db>, + context: &InferContext<'db, '_>, node: ast::AnyNodeRef, callable_ty: Type<'db>, callable_description: Option<&CallableDescription>, @@ -2128,7 +2133,7 @@ pub(crate) enum BindingError<'db> { impl<'db> BindingError<'db> { fn report_diagnostic( &self, - context: &InferContext<'db>, + context: &InferContext<'db, '_>, node: ast::AnyNodeRef, callable_ty: Type<'db>, callable_description: Option<&CallableDescription>, @@ -2285,7 +2290,10 @@ impl<'db> BindingError<'db> { )); if let Some(typevar_definition) = typevar.definition(context.db()) { - let typevar_range = typevar_definition.full_range(context.db()); + let module = + parsed_module(context.db().upcast(), typevar_definition.file(context.db())) + .load(context.db().upcast()); + let typevar_range = typevar_definition.full_range(context.db(), &module); let mut sub = SubDiagnostic::new(Severity::Info, "Type variable defined here"); sub.annotate(Annotation::primary(typevar_range.into())); diag.sub(sub); diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index ffc558e969..344c200237 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -37,6 +37,7 @@ use indexmap::IndexSet; use itertools::Itertools as _; use ruff_db::diagnostic::Span; use ruff_db::files::File; +use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_ast::name::Name; use ruff_python_ast::{self as ast, PythonVersion}; use ruff_text_size::{Ranged, TextRange}; @@ -715,7 +716,8 @@ impl<'db> ClassLiteral<'db> { #[salsa::tracked(cycle_fn=pep695_generic_context_cycle_recover, cycle_initial=pep695_generic_context_cycle_initial)] pub(crate) fn pep695_generic_context(self, db: &'db dyn Db) -> Option> { let scope = self.body_scope(db); - let class_def_node = scope.node(db).expect_class(); + let parsed = parsed_module(db.upcast(), scope.file(db)).load(db.upcast()); + let class_def_node = scope.node(db).expect_class(&parsed); class_def_node.type_params.as_ref().map(|type_params| { let index = semantic_index(db, scope.file(db)); GenericContext::from_type_params(db, index, type_params) @@ -754,14 +756,16 @@ impl<'db> ClassLiteral<'db> { /// ## Note /// Only call this function from queries in the same file or your /// query depends on the AST of another file (bad!). - fn node(self, db: &'db dyn Db) -> &'db ast::StmtClassDef { - self.body_scope(db).node(db).expect_class() + fn node<'ast>(self, db: &'db dyn Db, module: &'ast ParsedModuleRef) -> &'ast ast::StmtClassDef { + let scope = self.body_scope(db); + scope.node(db).expect_class(module) } pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { let body_scope = self.body_scope(db); + let module = parsed_module(db.upcast(), body_scope.file(db)).load(db.upcast()); let index = semantic_index(db, body_scope.file(db)); - index.expect_single_definition(body_scope.node(db).expect_class()) + index.expect_single_definition(body_scope.node(db).expect_class(&module)) } pub(crate) fn apply_optional_specialization( @@ -835,7 +839,8 @@ impl<'db> ClassLiteral<'db> { pub(super) fn explicit_bases(self, db: &'db dyn Db) -> Box<[Type<'db>]> { tracing::trace!("ClassLiteral::explicit_bases_query: {}", self.name(db)); - let class_stmt = self.node(db); + let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast()); + let class_stmt = self.node(db, &module); let class_definition = semantic_index(db, self.file(db)).expect_single_definition(class_stmt); @@ -897,7 +902,9 @@ impl<'db> ClassLiteral<'db> { fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> { tracing::trace!("ClassLiteral::decorators: {}", self.name(db)); - let class_stmt = self.node(db); + let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast()); + + let class_stmt = self.node(db, &module); if class_stmt.decorator_list.is_empty() { return Box::new([]); } @@ -983,8 +990,8 @@ impl<'db> ClassLiteral<'db> { /// ## Note /// Only call this function from queries in the same file or your /// query depends on the AST of another file (bad!). - fn explicit_metaclass(self, db: &'db dyn Db) -> Option> { - let class_stmt = self.node(db); + fn explicit_metaclass(self, db: &'db dyn Db, module: &ParsedModuleRef) -> Option> { + let class_stmt = self.node(db, module); let metaclass_node = &class_stmt .arguments .as_ref()? @@ -1039,7 +1046,9 @@ impl<'db> ClassLiteral<'db> { return Ok((SubclassOfType::subclass_of_unknown(), None)); } - let explicit_metaclass = self.explicit_metaclass(db); + let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast()); + + let explicit_metaclass = self.explicit_metaclass(db, &module); let (metaclass, class_metaclass_was_from) = if let Some(metaclass) = explicit_metaclass { (metaclass, self) } else if let Some(base_class) = base_classes.next() { @@ -1608,6 +1617,7 @@ impl<'db> ClassLiteral<'db> { let mut is_attribute_bound = Truthiness::AlwaysFalse; let file = class_body_scope.file(db); + let module = parsed_module(db.upcast(), file).load(db.upcast()); let index = semantic_index(db, file); let class_map = use_def_map(db, class_body_scope); let class_table = place_table(db, class_body_scope); @@ -1619,19 +1629,20 @@ impl<'db> ClassLiteral<'db> { let method_map = use_def_map(db, method_scope); // The attribute assignment inherits the visibility of the method which contains it - let is_method_visible = if let Some(method_def) = method_scope.node(db).as_function() { - let method = index.expect_single_definition(method_def); - let method_place = class_table.place_id_by_name(&method_def.name).unwrap(); - class_map - .public_bindings(method_place) - .find_map(|bind| { - (bind.binding.is_defined_and(|def| def == method)) - .then(|| class_map.is_binding_visible(db, &bind)) - }) - .unwrap_or(Truthiness::AlwaysFalse) - } else { - Truthiness::AlwaysFalse - }; + let is_method_visible = + if let Some(method_def) = method_scope.node(db).as_function(&module) { + let method = index.expect_single_definition(method_def); + let method_place = class_table.place_id_by_name(&method_def.name).unwrap(); + class_map + .public_bindings(method_place) + .find_map(|bind| { + (bind.binding.is_defined_and(|def| def == method)) + .then(|| class_map.is_binding_visible(db, &bind)) + }) + .unwrap_or(Truthiness::AlwaysFalse) + } else { + Truthiness::AlwaysFalse + }; if is_method_visible.is_always_false() { continue; } @@ -1688,8 +1699,10 @@ impl<'db> ClassLiteral<'db> { // self.name: // self.name: = … - let annotation_ty = - infer_expression_type(db, index.expression(ann_assign.annotation())); + let annotation_ty = infer_expression_type( + db, + index.expression(ann_assign.annotation(&module)), + ); // TODO: check if there are conflicting declarations match is_attribute_bound { @@ -1714,8 +1727,9 @@ impl<'db> ClassLiteral<'db> { // [.., self.name, ..] = let unpacked = infer_unpack_types(db, unpack); - let target_ast_id = - assign.target().scoped_expression_id(db, method_scope); + let target_ast_id = assign + .target(&module) + .scoped_expression_id(db, method_scope); let inferred_ty = unpacked.expression_type(target_ast_id); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); @@ -1725,8 +1739,10 @@ impl<'db> ClassLiteral<'db> { // // self.name = - let inferred_ty = - infer_expression_type(db, index.expression(assign.value())); + let inferred_ty = infer_expression_type( + db, + index.expression(assign.value(&module)), + ); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } @@ -1740,8 +1756,9 @@ impl<'db> ClassLiteral<'db> { // for .., self.name, .. in : let unpacked = infer_unpack_types(db, unpack); - let target_ast_id = - for_stmt.target().scoped_expression_id(db, method_scope); + let target_ast_id = for_stmt + .target(&module) + .scoped_expression_id(db, method_scope); let inferred_ty = unpacked.expression_type(target_ast_id); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); @@ -1753,7 +1770,7 @@ impl<'db> ClassLiteral<'db> { let iterable_ty = infer_expression_type( db, - index.expression(for_stmt.iterable()), + index.expression(for_stmt.iterable(&module)), ); // TODO: Potential diagnostics resulting from the iterable are currently not reported. let inferred_ty = iterable_ty.iterate(db); @@ -1770,8 +1787,9 @@ impl<'db> ClassLiteral<'db> { // with as .., self.name, ..: let unpacked = infer_unpack_types(db, unpack); - let target_ast_id = - with_item.target().scoped_expression_id(db, method_scope); + let target_ast_id = with_item + .target(&module) + .scoped_expression_id(db, method_scope); let inferred_ty = unpacked.expression_type(target_ast_id); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); @@ -1783,7 +1801,7 @@ impl<'db> ClassLiteral<'db> { let context_ty = infer_expression_type( db, - index.expression(with_item.context_expr()), + index.expression(with_item.context_expr(&module)), ); let inferred_ty = context_ty.enter(db); @@ -1800,7 +1818,7 @@ impl<'db> ClassLiteral<'db> { let unpacked = infer_unpack_types(db, unpack); let target_ast_id = comprehension - .target() + .target(&module) .scoped_expression_id(db, unpack.target_scope(db)); let inferred_ty = unpacked.expression_type(target_ast_id); @@ -1813,7 +1831,7 @@ impl<'db> ClassLiteral<'db> { let iterable_ty = infer_expression_type( db, - index.expression(comprehension.iterable()), + index.expression(comprehension.iterable(&module)), ); // TODO: Potential diagnostics resulting from the iterable are currently not reported. let inferred_ty = iterable_ty.iterate(db); @@ -2003,8 +2021,8 @@ impl<'db> ClassLiteral<'db> { /// Returns a [`Span`] with the range of the class's header. /// /// See [`Self::header_range`] for more details. - pub(super) fn header_span(self, db: &'db dyn Db) -> Span { - Span::from(self.file(db)).with_range(self.header_range(db)) + pub(super) fn header_span(self, db: &'db dyn Db, module: &ParsedModuleRef) -> Span { + Span::from(self.file(db)).with_range(self.header_range(db, module)) } /// Returns the range of the class's "header": the class name @@ -2014,9 +2032,9 @@ impl<'db> ClassLiteral<'db> { /// class Foo(Bar, metaclass=Baz): ... /// ^^^^^^^^^^^^^^^^^^^^^^^ /// ``` - pub(super) fn header_range(self, db: &'db dyn Db) -> TextRange { + pub(super) fn header_range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> TextRange { let class_scope = self.body_scope(db); - let class_node = class_scope.node(db).expect_class(); + let class_node = class_scope.node(db).expect_class(module); let class_name = &class_node.name; TextRange::new( class_name.start(), diff --git a/crates/ty_python_semantic/src/types/context.rs b/crates/ty_python_semantic/src/types/context.rs index 498a1b644a..f36b19873a 100644 --- a/crates/ty_python_semantic/src/types/context.rs +++ b/crates/ty_python_semantic/src/types/context.rs @@ -2,6 +2,7 @@ use std::fmt; use drop_bomb::DebugDropBomb; use ruff_db::diagnostic::{DiagnosticTag, SubDiagnostic}; +use ruff_db::parsed::ParsedModuleRef; use ruff_db::{ diagnostic::{Annotation, Diagnostic, DiagnosticId, IntoDiagnosticMessage, Severity, Span}, files::File, @@ -32,20 +33,22 @@ use crate::{ /// It's important that the context is explicitly consumed before dropping by calling /// [`InferContext::finish`] and the returned diagnostics must be stored /// on the current [`TypeInference`](super::infer::TypeInference) result. -pub(crate) struct InferContext<'db> { +pub(crate) struct InferContext<'db, 'ast> { db: &'db dyn Db, scope: ScopeId<'db>, file: File, + module: &'ast ParsedModuleRef, diagnostics: std::cell::RefCell, no_type_check: InNoTypeCheck, bomb: DebugDropBomb, } -impl<'db> InferContext<'db> { - pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { +impl<'db, 'ast> InferContext<'db, 'ast> { + pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>, module: &'ast ParsedModuleRef) -> Self { Self { db, scope, + module, file: scope.file(db), diagnostics: std::cell::RefCell::new(TypeCheckDiagnostics::default()), no_type_check: InNoTypeCheck::default(), @@ -60,6 +63,11 @@ impl<'db> InferContext<'db> { self.file } + /// The module for which the types are inferred. + pub(crate) fn module(&self) -> &'ast ParsedModuleRef { + self.module + } + /// Create a span with the range of the given expression /// in the file being currently type checked. /// @@ -160,7 +168,7 @@ impl<'db> InferContext<'db> { // Inspect all ancestor function scopes by walking bottom up and infer the function's type. let mut function_scope_tys = index .ancestor_scopes(scope_id) - .filter_map(|(_, scope)| scope.node().as_function()) + .filter_map(|(_, scope)| scope.node().as_function(self.module())) .map(|node| binding_type(self.db, index.expect_single_definition(node))) .filter_map(Type::into_function_literal); @@ -187,7 +195,7 @@ impl<'db> InferContext<'db> { } } -impl fmt::Debug for InferContext<'_> { +impl fmt::Debug for InferContext<'_, '_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("TyContext") .field("file", &self.file) @@ -221,7 +229,7 @@ pub(crate) enum InNoTypeCheck { /// will attach a message to the primary span on the diagnostic. pub(super) struct LintDiagnosticGuard<'db, 'ctx> { /// The typing context. - ctx: &'ctx InferContext<'db>, + ctx: &'ctx InferContext<'db, 'ctx>, /// The diagnostic that we want to report. /// /// This is always `Some` until the `Drop` impl. @@ -363,7 +371,7 @@ impl Drop for LintDiagnosticGuard<'_, '_> { /// it is known that the diagnostic should not be reported. This can happen /// when the diagnostic is disabled or suppressed (among other reasons). pub(super) struct LintDiagnosticGuardBuilder<'db, 'ctx> { - ctx: &'ctx InferContext<'db>, + ctx: &'ctx InferContext<'db, 'ctx>, id: DiagnosticId, severity: Severity, source: LintSource, @@ -372,7 +380,7 @@ pub(super) struct LintDiagnosticGuardBuilder<'db, 'ctx> { impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> { fn new( - ctx: &'ctx InferContext<'db>, + ctx: &'ctx InferContext<'db, 'ctx>, lint: &'static LintMetadata, range: TextRange, ) -> Option> { @@ -462,7 +470,7 @@ impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> { /// if either is violated, then the `Drop` impl on `DiagnosticGuard` will /// panic. pub(super) struct DiagnosticGuard<'db, 'ctx> { - ctx: &'ctx InferContext<'db>, + ctx: &'ctx InferContext<'db, 'ctx>, /// The diagnostic that we want to report. /// /// This is always `Some` until the `Drop` impl. @@ -550,14 +558,14 @@ impl Drop for DiagnosticGuard<'_, '_> { /// minimal amount of information with which to construct a diagnostic) before /// one can mutate the diagnostic. pub(super) struct DiagnosticGuardBuilder<'db, 'ctx> { - ctx: &'ctx InferContext<'db>, + ctx: &'ctx InferContext<'db, 'ctx>, id: DiagnosticId, severity: Severity, } impl<'db, 'ctx> DiagnosticGuardBuilder<'db, 'ctx> { fn new( - ctx: &'ctx InferContext<'db>, + ctx: &'ctx InferContext<'db, 'ctx>, id: DiagnosticId, severity: Severity, ) -> Option> { diff --git a/crates/ty_python_semantic/src/types/definition.rs b/crates/ty_python_semantic/src/types/definition.rs index 466673b09f..e6b553d374 100644 --- a/crates/ty_python_semantic/src/types/definition.rs +++ b/crates/ty_python_semantic/src/types/definition.rs @@ -1,6 +1,7 @@ use crate::semantic_index::definition::Definition; use crate::{Db, Module}; use ruff_db::files::FileRange; +use ruff_db::parsed::parsed_module; use ruff_db::source::source_text; use ruff_text_size::{TextLen, TextRange}; @@ -20,7 +21,10 @@ impl TypeDefinition<'_> { Self::Class(definition) | Self::Function(definition) | Self::TypeVar(definition) - | Self::TypeAlias(definition) => Some(definition.focus_range(db)), + | Self::TypeAlias(definition) => { + let module = parsed_module(db.upcast(), definition.file(db)).load(db.upcast()); + Some(definition.focus_range(db, &module)) + } } } @@ -34,7 +38,10 @@ impl TypeDefinition<'_> { Self::Class(definition) | Self::Function(definition) | Self::TypeVar(definition) - | Self::TypeAlias(definition) => Some(definition.full_range(db)), + | Self::TypeAlias(definition) => { + let module = parsed_module(db.upcast(), definition.file(db)).load(db.upcast()); + Some(definition.full_range(db, &module)) + } } } } diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index 8f62976c1e..7233c81b6e 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -1730,7 +1730,7 @@ pub(super) fn report_implicit_return_type( or `typing_extensions.Protocol` are considered protocol classes", ); sub_diagnostic.annotate( - Annotation::primary(class.header_span(db)).message(format_args!( + Annotation::primary(class.header_span(db, context.module())).message(format_args!( "`Protocol` not present in `{class}`'s immediate bases", class = class.name(db) )), @@ -1850,7 +1850,7 @@ pub(crate) fn report_bad_argument_to_get_protocol_members( class.name(db) ), ); - class_def_diagnostic.annotate(Annotation::primary(class.header_span(db))); + class_def_diagnostic.annotate(Annotation::primary(class.header_span(db, context.module()))); diagnostic.sub(class_def_diagnostic); diagnostic.info( @@ -1910,7 +1910,7 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol( ), ); class_def_diagnostic.annotate( - Annotation::primary(protocol.header_span(db)) + Annotation::primary(protocol.header_span(db, context.module())) .message(format_args!("`{class_name}` declared here")), ); diagnostic.sub(class_def_diagnostic); @@ -1941,7 +1941,7 @@ pub(crate) fn report_attempted_protocol_instantiation( format_args!("Protocol classes cannot be instantiated"), ); class_def_diagnostic.annotate( - Annotation::primary(protocol.header_span(db)) + Annotation::primary(protocol.header_span(db, context.module())) .message(format_args!("`{class_name}` declared as a protocol here")), ); diagnostic.sub(class_def_diagnostic); @@ -1955,7 +1955,9 @@ pub(crate) fn report_duplicate_bases( ) { let db = context.db(); - let Some(builder) = context.report_lint(&DUPLICATE_BASE, class.header_range(db)) else { + let Some(builder) = + context.report_lint(&DUPLICATE_BASE, class.header_range(db, context.module())) + else { return; }; @@ -2104,7 +2106,7 @@ fn report_unsupported_base( } fn report_invalid_base<'ctx, 'db>( - context: &'ctx InferContext<'db>, + context: &'ctx InferContext<'db, '_>, base_node: &ast::Expr, base_type: Type<'db>, class: ClassLiteral<'db>, diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 3b67e9c682..c6164e6d92 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -54,6 +54,7 @@ use std::str::FromStr; use bitflags::bitflags; use ruff_db::diagnostic::Span; use ruff_db::files::{File, FileRange}; +use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_ast as ast; use ruff_text_size::Ranged; @@ -187,7 +188,12 @@ impl<'db> OverloadLiteral<'db> { self.has_known_decorator(db, FunctionDecorators::OVERLOAD) } - fn node(self, db: &'db dyn Db, file: File) -> &'db ast::StmtFunctionDef { + fn node<'ast>( + self, + db: &dyn Db, + file: File, + module: &'ast ParsedModuleRef, + ) -> &'ast ast::StmtFunctionDef { debug_assert_eq!( file, self.file(db), @@ -195,14 +201,18 @@ impl<'db> OverloadLiteral<'db> { the function is defined." ); - self.body_scope(db).node(db).expect_function() + self.body_scope(db).node(db).expect_function(module) } /// Returns the [`FileRange`] of the function's name. - pub(crate) fn focus_range(self, db: &dyn Db) -> FileRange { + pub(crate) fn focus_range(self, db: &dyn Db, module: &ParsedModuleRef) -> FileRange { FileRange::new( self.file(db), - self.body_scope(db).node(db).expect_function().name.range, + self.body_scope(db) + .node(db) + .expect_function(module) + .name + .range, ) } @@ -216,8 +226,9 @@ impl<'db> OverloadLiteral<'db> { /// over-invalidation. fn definition(self, db: &'db dyn Db) -> Definition<'db> { let body_scope = self.body_scope(db); + let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast()); let index = semantic_index(db, body_scope.file(db)); - index.expect_single_definition(body_scope.node(db).expect_function()) + index.expect_single_definition(body_scope.node(db).expect_function(&module)) } /// Returns the overload immediately before this one in the AST. Returns `None` if there is no @@ -226,11 +237,12 @@ impl<'db> OverloadLiteral<'db> { // The semantic model records a use for each function on the name node. This is used // here to get the previous function definition with the same name. let scope = self.definition(db).scope(db); + let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast()); let use_def = semantic_index(db, scope.file(db)).use_def_map(scope.file_scope_id(db)); let use_id = self .body_scope(db) .node(db) - .expect_function() + .expect_function(&module) .name .scoped_use_id(db, scope); @@ -266,7 +278,8 @@ impl<'db> OverloadLiteral<'db> { inherited_generic_context: Option>, ) -> Signature<'db> { let scope = self.body_scope(db); - let function_stmt_node = scope.node(db).expect_function(); + let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast()); + let function_stmt_node = scope.node(db).expect_function(&module); let definition = self.definition(db); let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| { let index = semantic_index(db, scope.file(db)); @@ -289,7 +302,8 @@ impl<'db> OverloadLiteral<'db> { let function_scope = self.body_scope(db); let span = Span::from(function_scope.file(db)); let node = function_scope.node(db); - let func_def = node.as_function()?; + let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast()); + let func_def = node.as_function(&module)?; let range = parameter_index .and_then(|parameter_index| { func_def @@ -308,7 +322,8 @@ impl<'db> OverloadLiteral<'db> { let function_scope = self.body_scope(db); let span = Span::from(function_scope.file(db)); let node = function_scope.node(db); - let func_def = node.as_function()?; + let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast()); + let func_def = node.as_function(&module)?; let return_type_range = func_def.returns.as_ref().map(|returns| returns.range()); let mut signature = func_def.name.range.cover(func_def.parameters.range); if let Some(return_type_range) = return_type_range { @@ -553,8 +568,13 @@ impl<'db> FunctionType<'db> { } /// Returns the AST node for this function. - pub(crate) fn node(self, db: &'db dyn Db, file: File) -> &'db ast::StmtFunctionDef { - self.literal(db).last_definition(db).node(db, file) + pub(crate) fn node<'ast>( + self, + db: &dyn Db, + file: File, + module: &'ast ParsedModuleRef, + ) -> &'ast ast::StmtFunctionDef { + self.literal(db).last_definition(db).node(db, file, module) } pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index b0e6a19ddd..2c90d14cb2 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -36,7 +36,7 @@ use itertools::{Either, Itertools}; use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity}; use ruff_db::files::File; -use ruff_db::parsed::parsed_module; +use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_ast::visitor::{Visitor, walk_expr}; use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, PythonVersion}; use ruff_python_stdlib::builtins::version_builtin_was_added; @@ -136,11 +136,13 @@ pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Ty let file = scope.file(db); let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered(); + let module = parsed_module(db.upcast(), file).load(db.upcast()); + // Using the index here is fine because the code below depends on the AST anyway. // The isolation of the query is by the return inferred types. let index = semantic_index(db, file); - TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index).finish() + TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish() } fn scope_cycle_recover<'db>( @@ -164,16 +166,17 @@ pub(crate) fn infer_definition_types<'db>( definition: Definition<'db>, ) -> TypeInference<'db> { let file = definition.file(db); + let module = parsed_module(db.upcast(), file).load(db.upcast()); let _span = tracing::trace_span!( "infer_definition_types", - range = ?definition.kind(db).target_range(), + range = ?definition.kind(db).target_range(&module), ?file ) .entered(); let index = semantic_index(db, file); - TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index).finish() + TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index, &module).finish() } fn definition_cycle_recover<'db>( @@ -202,17 +205,18 @@ pub(crate) fn infer_deferred_types<'db>( definition: Definition<'db>, ) -> TypeInference<'db> { let file = definition.file(db); + let module = parsed_module(db.upcast(), file).load(db.upcast()); let _span = tracing::trace_span!( "infer_deferred_types", definition = ?definition.as_id(), - range = ?definition.kind(db).target_range(), + range = ?definition.kind(db).target_range(&module), ?file ) .entered(); let index = semantic_index(db, file); - TypeInferenceBuilder::new(db, InferenceRegion::Deferred(definition), index).finish() + TypeInferenceBuilder::new(db, InferenceRegion::Deferred(definition), index, &module).finish() } fn deferred_cycle_recover<'db>( @@ -238,17 +242,18 @@ pub(crate) fn infer_expression_types<'db>( expression: Expression<'db>, ) -> TypeInference<'db> { let file = expression.file(db); + let module = parsed_module(db.upcast(), file).load(db.upcast()); let _span = tracing::trace_span!( "infer_expression_types", expression = ?expression.as_id(), - range = ?expression.node_ref(db).range(), + range = ?expression.node_ref(db, &module).range(), ?file ) .entered(); let index = semantic_index(db, file); - TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index).finish() + TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index, &module).finish() } fn expression_cycle_recover<'db>( @@ -275,10 +280,15 @@ fn expression_cycle_initial<'db>( pub(super) fn infer_same_file_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, + parsed: &ParsedModuleRef, ) -> Type<'db> { let inference = infer_expression_types(db, expression); let scope = expression.scope(db); - inference.expression_type(expression.node_ref(db).scoped_expression_id(db, scope)) + inference.expression_type( + expression + .node_ref(db, parsed) + .scoped_expression_id(db, scope), + ) } /// Infers the type of an expression where the expression might come from another file. @@ -293,8 +303,11 @@ pub(crate) fn infer_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, ) -> Type<'db> { + let file = expression.file(db); + let module = parsed_module(db.upcast(), file).load(db.upcast()); + // It's okay to call the "same file" version here because we're inside a salsa query. - infer_same_file_expression_type(db, expression) + infer_same_file_expression_type(db, expression, &module) } fn single_expression_cycle_recover<'db>( @@ -322,11 +335,12 @@ fn single_expression_cycle_initial<'db>( #[salsa::tracked(returns(ref), cycle_fn=unpack_cycle_recover, cycle_initial=unpack_cycle_initial)] pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> { let file = unpack.file(db); - let _span = - tracing::trace_span!("infer_unpack_types", range=?unpack.range(db), ?file).entered(); + let module = parsed_module(db.upcast(), file).load(db.upcast()); + let _span = tracing::trace_span!("infer_unpack_types", range=?unpack.range(db, &module), ?file) + .entered(); - let mut unpacker = Unpacker::new(db, unpack.target_scope(db), unpack.value_scope(db)); - unpacker.unpack(unpack.target(db), unpack.value(db)); + let mut unpacker = Unpacker::new(db, unpack.target_scope(db), unpack.value_scope(db), &module); + unpacker.unpack(unpack.target(db, &module), unpack.value(db)); unpacker.finish() } @@ -356,11 +370,12 @@ pub(crate) fn nearest_enclosing_class<'db>( db: &'db dyn Db, semantic: &SemanticIndex<'db>, scope: ScopeId, + parsed: &ParsedModuleRef, ) -> Option> { semantic .ancestor_scopes(scope.file_scope_id(db)) .find_map(|(_, ancestor_scope)| { - let class = ancestor_scope.node().as_class()?; + let class = ancestor_scope.node().as_class(parsed)?; let definition = semantic.expect_single_definition(class); infer_definition_types(db, definition) .declaration_type(definition) @@ -569,8 +584,8 @@ enum DeclaredAndInferredType<'db> { /// Similarly, when we encounter a standalone-inferable expression (right-hand side of an /// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we /// don't infer its types more than once. -pub(super) struct TypeInferenceBuilder<'db> { - context: InferContext<'db>, +pub(super) struct TypeInferenceBuilder<'db, 'ast> { + context: InferContext<'db, 'ast>, index: &'db SemanticIndex<'db>, region: InferenceRegion<'db>, @@ -617,7 +632,7 @@ pub(super) struct TypeInferenceBuilder<'db> { deferred_state: DeferredExpressionState, } -impl<'db> TypeInferenceBuilder<'db> { +impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// How big a string do we build before bailing? /// /// This is a fairly arbitrary number. It should be *far* more than enough @@ -629,11 +644,12 @@ impl<'db> TypeInferenceBuilder<'db> { db: &'db dyn Db, region: InferenceRegion<'db>, index: &'db SemanticIndex<'db>, + module: &'ast ParsedModuleRef, ) -> Self { let scope = region.scope(db); Self { - context: InferContext::new(db, scope), + context: InferContext::new(db, scope, module), index, region, return_types_and_ranges: vec![], @@ -659,6 +675,10 @@ impl<'db> TypeInferenceBuilder<'db> { self.context.file() } + fn module(&self) -> &'ast ParsedModuleRef { + self.context.module() + } + fn db(&self) -> &'db dyn Db { self.context.db() } @@ -756,35 +776,36 @@ impl<'db> TypeInferenceBuilder<'db> { let node = scope.node(self.db()); match node { NodeWithScopeKind::Module => { - let parsed = parsed_module(self.db().upcast(), self.file()); - self.infer_module(parsed.syntax()); + self.infer_module(self.module().syntax()); } - NodeWithScopeKind::Function(function) => self.infer_function_body(function.node()), - NodeWithScopeKind::Lambda(lambda) => self.infer_lambda_body(lambda.node()), - NodeWithScopeKind::Class(class) => self.infer_class_body(class.node()), + NodeWithScopeKind::Function(function) => { + self.infer_function_body(function.node(self.module())); + } + NodeWithScopeKind::Lambda(lambda) => self.infer_lambda_body(lambda.node(self.module())), + NodeWithScopeKind::Class(class) => self.infer_class_body(class.node(self.module())), NodeWithScopeKind::ClassTypeParameters(class) => { - self.infer_class_type_params(class.node()); + self.infer_class_type_params(class.node(self.module())); } NodeWithScopeKind::FunctionTypeParameters(function) => { - self.infer_function_type_params(function.node()); + self.infer_function_type_params(function.node(self.module())); } NodeWithScopeKind::TypeAliasTypeParameters(type_alias) => { - self.infer_type_alias_type_params(type_alias.node()); + self.infer_type_alias_type_params(type_alias.node(self.module())); } NodeWithScopeKind::TypeAlias(type_alias) => { - self.infer_type_alias(type_alias.node()); + self.infer_type_alias(type_alias.node(self.module())); } NodeWithScopeKind::ListComprehension(comprehension) => { - self.infer_list_comprehension_expression_scope(comprehension.node()); + self.infer_list_comprehension_expression_scope(comprehension.node(self.module())); } NodeWithScopeKind::SetComprehension(comprehension) => { - self.infer_set_comprehension_expression_scope(comprehension.node()); + self.infer_set_comprehension_expression_scope(comprehension.node(self.module())); } NodeWithScopeKind::DictComprehension(comprehension) => { - self.infer_dict_comprehension_expression_scope(comprehension.node()); + self.infer_dict_comprehension_expression_scope(comprehension.node(self.module())); } NodeWithScopeKind::GeneratorExpression(generator) => { - self.infer_generator_expression_scope(generator.node()); + self.infer_generator_expression_scope(generator.node(self.module())); } } @@ -823,7 +844,7 @@ impl<'db> TypeInferenceBuilder<'db> { if let DefinitionKind::Class(class) = definition.kind(self.db()) { ty.inner_type() .into_class_literal() - .map(|class_literal| (class_literal, class.node())) + .map(|class_literal| (class_literal, class.node(self.module()))) } else { None } @@ -1143,7 +1164,7 @@ impl<'db> TypeInferenceBuilder<'db> { // Check that the overloaded function has at least two overloads if let [single_overload] = overloads.as_ref() { - let function_node = function.node(self.db(), self.file()); + let function_node = function.node(self.db(), self.file(), self.module()); if let Some(builder) = self .context .report_lint(&INVALID_OVERLOAD, &function_node.name) @@ -1154,7 +1175,7 @@ impl<'db> TypeInferenceBuilder<'db> { )); diagnostic.annotate( self.context - .secondary(single_overload.focus_range(self.db())) + .secondary(single_overload.focus_range(self.db(), self.module())) .message(format_args!("Only one overload defined here")), ); } @@ -1169,7 +1190,8 @@ impl<'db> TypeInferenceBuilder<'db> { if let NodeWithScopeKind::Class(class_node_ref) = scope { let class = binding_type( self.db(), - self.index.expect_single_definition(class_node_ref.node()), + self.index + .expect_single_definition(class_node_ref.node(self.module())), ) .expect_class_literal(); @@ -1187,7 +1209,7 @@ impl<'db> TypeInferenceBuilder<'db> { } if implementation_required { - let function_node = function.node(self.db(), self.file()); + let function_node = function.node(self.db(), self.file(), self.module()); if let Some(builder) = self .context .report_lint(&INVALID_OVERLOAD, &function_node.name) @@ -1222,7 +1244,7 @@ impl<'db> TypeInferenceBuilder<'db> { continue; } - let function_node = function.node(self.db(), self.file()); + let function_node = function.node(self.db(), self.file(), self.module()); if let Some(builder) = self .context .report_lint(&INVALID_OVERLOAD, &function_node.name) @@ -1235,7 +1257,7 @@ impl<'db> TypeInferenceBuilder<'db> { for function in decorator_missing { diagnostic.annotate( self.context - .secondary(function.focus_range(self.db())) + .secondary(function.focus_range(self.db(), self.module())) .message(format_args!("Missing here")), ); } @@ -1251,7 +1273,7 @@ impl<'db> TypeInferenceBuilder<'db> { if !overload.has_known_decorator(self.db(), decorator) { continue; } - let function_node = function.node(self.db(), self.file()); + let function_node = function.node(self.db(), self.file(), self.module()); let Some(builder) = self .context .report_lint(&INVALID_OVERLOAD, &function_node.name) @@ -1264,7 +1286,7 @@ impl<'db> TypeInferenceBuilder<'db> { )); diagnostic.annotate( self.context - .secondary(implementation.focus_range(self.db())) + .secondary(implementation.focus_range(self.db(), self.module())) .message(format_args!("Implementation defined here")), ); } @@ -1277,7 +1299,7 @@ impl<'db> TypeInferenceBuilder<'db> { if !overload.has_known_decorator(self.db(), decorator) { continue; } - let function_node = function.node(self.db(), self.file()); + let function_node = function.node(self.db(), self.file(), self.module()); let Some(builder) = self .context .report_lint(&INVALID_OVERLOAD, &function_node.name) @@ -1290,7 +1312,7 @@ impl<'db> TypeInferenceBuilder<'db> { )); diagnostic.annotate( self.context - .secondary(first_overload.focus_range(self.db())) + .secondary(first_overload.focus_range(self.db(), self.module())) .message(format_args!("First overload defined here")), ); } @@ -1302,24 +1324,34 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_region_definition(&mut self, definition: Definition<'db>) { match definition.kind(self.db()) { DefinitionKind::Function(function) => { - self.infer_function_definition(function.node(), definition); + self.infer_function_definition(function.node(self.module()), definition); + } + DefinitionKind::Class(class) => { + self.infer_class_definition(class.node(self.module()), definition); } - DefinitionKind::Class(class) => self.infer_class_definition(class.node(), definition), DefinitionKind::TypeAlias(type_alias) => { - self.infer_type_alias_definition(type_alias.node(), definition); + self.infer_type_alias_definition(type_alias.node(self.module()), definition); } DefinitionKind::Import(import) => { - self.infer_import_definition(import.import(), import.alias(), definition); + self.infer_import_definition( + import.import(self.module()), + import.alias(self.module()), + definition, + ); } DefinitionKind::ImportFrom(import_from) => { self.infer_import_from_definition( - import_from.import(), - import_from.alias(), + import_from.import(self.module()), + import_from.alias(self.module()), definition, ); } DefinitionKind::StarImport(import) => { - self.infer_import_from_definition(import.import(), import.alias(), definition); + self.infer_import_from_definition( + import.import(self.module()), + import.alias(self.module()), + definition, + ); } DefinitionKind::Assignment(assignment) => { self.infer_assignment_definition(assignment, definition); @@ -1328,32 +1360,47 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_annotated_assignment_definition(annotated_assignment, definition); } DefinitionKind::AugmentedAssignment(augmented_assignment) => { - self.infer_augment_assignment_definition(augmented_assignment.node(), definition); + self.infer_augment_assignment_definition( + augmented_assignment.node(self.module()), + definition, + ); } DefinitionKind::For(for_statement_definition) => { self.infer_for_statement_definition(for_statement_definition, definition); } DefinitionKind::NamedExpression(named_expression) => { - self.infer_named_expression_definition(named_expression.node(), definition); + self.infer_named_expression_definition( + named_expression.node(self.module()), + definition, + ); } DefinitionKind::Comprehension(comprehension) => { self.infer_comprehension_definition(comprehension, definition); } DefinitionKind::VariadicPositionalParameter(parameter) => { - self.infer_variadic_positional_parameter_definition(parameter, definition); + self.infer_variadic_positional_parameter_definition( + parameter.node(self.module()), + definition, + ); } DefinitionKind::VariadicKeywordParameter(parameter) => { - self.infer_variadic_keyword_parameter_definition(parameter, definition); + self.infer_variadic_keyword_parameter_definition( + parameter.node(self.module()), + definition, + ); } DefinitionKind::Parameter(parameter_with_default) => { - self.infer_parameter_definition(parameter_with_default, definition); + self.infer_parameter_definition( + parameter_with_default.node(self.module()), + definition, + ); } DefinitionKind::WithItem(with_item_definition) => { self.infer_with_item_definition(with_item_definition, definition); } DefinitionKind::MatchPattern(match_pattern) => { self.infer_match_pattern_definition( - match_pattern.pattern(), + match_pattern.pattern(self.module()), match_pattern.index(), definition, ); @@ -1362,13 +1409,13 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_except_handler_definition(except_handler_definition, definition); } DefinitionKind::TypeVar(node) => { - self.infer_typevar_definition(node, definition); + self.infer_typevar_definition(node.node(self.module()), definition); } DefinitionKind::ParamSpec(node) => { - self.infer_paramspec_definition(node, definition); + self.infer_paramspec_definition(node.node(self.module()), definition); } DefinitionKind::TypeVarTuple(node) => { - self.infer_typevartuple_definition(node, definition); + self.infer_typevartuple_definition(node.node(self.module()), definition); } } } @@ -1384,8 +1431,10 @@ impl<'db> TypeInferenceBuilder<'db> { // implementation to allow this "split" to happen. match definition.kind(self.db()) { - DefinitionKind::Function(function) => self.infer_function_deferred(function.node()), - DefinitionKind::Class(class) => self.infer_class_deferred(class.node()), + DefinitionKind::Function(function) => { + self.infer_function_deferred(function.node(self.module())); + } + DefinitionKind::Class(class) => self.infer_class_deferred(class.node(self.module())), _ => {} } } @@ -1393,10 +1442,10 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_region_expression(&mut self, expression: Expression<'db>) { match expression.kind(self.db()) { ExpressionKind::Normal => { - self.infer_expression_impl(expression.node_ref(self.db())); + self.infer_expression_impl(expression.node_ref(self.db(), self.module())); } ExpressionKind::TypeExpression => { - self.infer_type_expression(expression.node_ref(self.db())); + self.infer_type_expression(expression.node_ref(self.db(), self.module())); } } } @@ -1441,7 +1490,7 @@ impl<'db> TypeInferenceBuilder<'db> { debug_assert!( binding .kind(self.db()) - .category(self.context.in_stub()) + .category(self.context.in_stub(), self.module()) .is_binding() ); @@ -1555,7 +1604,7 @@ impl<'db> TypeInferenceBuilder<'db> { debug_assert!( declaration .kind(self.db()) - .category(self.context.in_stub()) + .category(self.context.in_stub(), self.module()) .is_declaration() ); let use_def = self.index.use_def_map(declaration.file_scope(self.db())); @@ -1601,13 +1650,13 @@ impl<'db> TypeInferenceBuilder<'db> { debug_assert!( definition .kind(self.db()) - .category(self.context.in_stub()) + .category(self.context.in_stub(), self.module()) .is_binding() ); debug_assert!( definition .kind(self.db()) - .category(self.context.in_stub()) + .category(self.context.in_stub(), self.module()) .is_declaration() ); @@ -1763,7 +1812,7 @@ impl<'db> TypeInferenceBuilder<'db> { _ => return None, }; - let class_stmt = class_scope.node().as_class()?; + let class_stmt = class_scope.node().as_class(self.module())?; let class_definition = self.index.expect_single_definition(class_stmt); binding_type(self.db(), class_definition).into_class_literal() } @@ -1784,17 +1833,21 @@ impl<'db> TypeInferenceBuilder<'db> { return false; }; - node_ref.decorator_list.iter().any(|decorator| { - let decorator_type = self.file_expression_type(&decorator.expression); + node_ref + .node(self.module()) + .decorator_list + .iter() + .any(|decorator| { + let decorator_type = self.file_expression_type(&decorator.expression); - match decorator_type { - Type::FunctionLiteral(function) => matches!( - function.known(self.db()), - Some(KnownFunction::Overload | KnownFunction::AbstractMethod) - ), - _ => false, - } - }) + match decorator_type { + Type::FunctionLiteral(function) => matches!( + function.known(self.db()), + Some(KnownFunction::Overload | KnownFunction::AbstractMethod) + ), + _ => false, + } + }) } fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) { @@ -2558,8 +2611,8 @@ impl<'db> TypeInferenceBuilder<'db> { with_item: &WithItemDefinitionKind<'db>, definition: Definition<'db>, ) { - let context_expr = with_item.context_expr(); - let target = with_item.target(); + let context_expr = with_item.context_expr(self.module()); + let target = with_item.target(self.module()); let context_expr_ty = self.infer_standalone_expression(context_expr); @@ -2707,12 +2760,12 @@ impl<'db> TypeInferenceBuilder<'db> { definition: Definition<'db>, ) { let symbol_ty = self.infer_exception( - except_handler_definition.handled_exceptions(), + except_handler_definition.handled_exceptions(self.module()), except_handler_definition.is_star(), ); self.add_binding( - except_handler_definition.node().into(), + except_handler_definition.node(self.module()).into(), definition, symbol_ty, ); @@ -3001,7 +3054,7 @@ impl<'db> TypeInferenceBuilder<'db> { /// `target`. fn infer_target(&mut self, target: &ast::Expr, value: &ast::Expr, infer_value_expr: F) where - F: Fn(&mut TypeInferenceBuilder<'db>, &ast::Expr) -> Type<'db>, + F: Fn(&mut TypeInferenceBuilder<'db, '_>, &ast::Expr) -> Type<'db>, { let assigned_ty = match target { ast::Expr::Name(_) => None, @@ -3542,8 +3595,8 @@ impl<'db> TypeInferenceBuilder<'db> { assignment: &AssignmentDefinitionKind<'db>, definition: Definition<'db>, ) { - let value = assignment.value(); - let target = assignment.target(); + let value = assignment.value(self.module()); + let target = assignment.target(self.module()); let value_ty = self.infer_standalone_expression(value); @@ -3625,9 +3678,9 @@ impl<'db> TypeInferenceBuilder<'db> { assignment: &'db AnnotatedAssignmentDefinitionKind, definition: Definition<'db>, ) { - let annotation = assignment.annotation(); - let target = assignment.target(); - let value = assignment.value(); + let annotation = assignment.annotation(self.module()); + let target = assignment.target(self.module()); + let value = assignment.value(self.module()); let mut declared_ty = self.infer_annotation_expression( annotation, @@ -3858,8 +3911,8 @@ impl<'db> TypeInferenceBuilder<'db> { for_stmt: &ForStmtDefinitionKind<'db>, definition: Definition<'db>, ) { - let iterable = for_stmt.iterable(); - let target = for_stmt.target(); + let iterable = for_stmt.iterable(self.module()); + let target = for_stmt.target(self.module()); let iterable_type = self.infer_standalone_expression(iterable); @@ -3967,7 +4020,7 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_import_definition( &mut self, node: &ast::StmtImport, - alias: &'db ast::Alias, + alias: &ast::Alias, definition: Definition<'db>, ) { let ast::Alias { @@ -4146,7 +4199,7 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_import_from_definition( &mut self, - import_from: &'db ast::StmtImportFrom, + import_from: &ast::StmtImportFrom, alias: &ast::Alias, definition: Definition<'db>, ) { @@ -4804,7 +4857,11 @@ impl<'db> TypeInferenceBuilder<'db> { // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `[... for a.x in not_iterable] if is_first { - infer_same_file_expression_type(builder.db(), builder.index.expression(iter_expr)) + infer_same_file_expression_type( + builder.db(), + builder.index.expression(iter_expr), + builder.module(), + ) } else { builder.infer_standalone_expression(iter_expr) } @@ -4820,8 +4877,8 @@ impl<'db> TypeInferenceBuilder<'db> { comprehension: &ComprehensionDefinitionKind<'db>, definition: Definition<'db>, ) { - let iterable = comprehension.iterable(); - let target = comprehension.target(); + let iterable = comprehension.iterable(self.module()); + let target = comprehension.target(self.module()); let expression = self.index.expression(iterable); let result = infer_expression_types(self.db(), expression); @@ -5009,8 +5066,10 @@ impl<'db> TypeInferenceBuilder<'db> { /// Returns `None` if the scope is not function-like, or has no parameters. fn first_param_type_in_scope(&self, scope: ScopeId) -> Option> { let first_param = match scope.node(self.db()) { - NodeWithScopeKind::Function(f) => f.parameters.iter().next(), - NodeWithScopeKind::Lambda(l) => l.parameters.as_ref()?.iter().next(), + NodeWithScopeKind::Function(f) => f.node(self.module()).parameters.iter().next(), + NodeWithScopeKind::Lambda(l) => { + l.node(self.module()).parameters.as_ref()?.iter().next() + } _ => None, }?; @@ -5371,6 +5430,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), self.index, scope, + self.module(), ) else { overload.set_return_type(Type::unknown()); BoundSuperError::UnavailableImplicitArguments @@ -5435,7 +5495,11 @@ impl<'db> TypeInferenceBuilder<'db> { let Some(target) = assigned_to.as_ref().and_then(|assigned_to| { - match assigned_to.node().targets.as_slice() { + match assigned_to + .node(self.module()) + .targets + .as_slice() + { [ast::Expr::Name(target)] => Some(target), _ => None, } @@ -5605,7 +5669,11 @@ impl<'db> TypeInferenceBuilder<'db> { let containing_assignment = assigned_to.as_ref().and_then(|assigned_to| { - match assigned_to.node().targets.as_slice() { + match assigned_to + .node(self.module()) + .targets + .as_slice() + { [ast::Expr::Name(target)] => Some( self.index.expect_single_definition(target), ), @@ -8125,7 +8193,7 @@ impl<'db> TypeInferenceBuilder<'db> { } /// Annotation expressions. -impl<'db> TypeInferenceBuilder<'db> { +impl<'db> TypeInferenceBuilder<'db, '_> { /// Infer the type of an annotation expression with the given [`DeferredExpressionState`]. fn infer_annotation_expression( &mut self, @@ -8314,7 +8382,7 @@ impl<'db> TypeInferenceBuilder<'db> { } /// Type expressions -impl<'db> TypeInferenceBuilder<'db> { +impl<'db> TypeInferenceBuilder<'db, '_> { /// Infer the type of a type expression. fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> { let ty = self.infer_type_expression_no_store(expression); @@ -9340,10 +9408,10 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_literal_parameter_type<'ast>( + fn infer_literal_parameter_type<'param>( &mut self, - parameters: &'ast ast::Expr, - ) -> Result, Vec<&'ast ast::Expr>> { + parameters: &'param ast::Expr, + ) -> Result, Vec<&'param ast::Expr>> { Ok(match parameters { // TODO handle type aliases ast::Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => { @@ -9723,6 +9791,7 @@ mod tests { symbol_name: &str, ) -> Place<'db> { let file = system_path_to_file(db, file_name).expect("file to exist"); + let module = parsed_module(db, file).load(db); let index = semantic_index(db, file); let mut file_scope_id = FileScopeId::global(); let mut scope = file_scope_id.to_scope_id(db, file); @@ -9733,7 +9802,7 @@ mod tests { .unwrap_or_else(|| panic!("scope of {expected_scope_name}")) .0; scope = file_scope_id.to_scope_id(db, file); - assert_eq!(scope.name(db), *expected_scope_name); + assert_eq!(scope.name(db, &module), *expected_scope_name); } symbol(db, scope, symbol_name).place @@ -10087,7 +10156,7 @@ mod tests { fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { fn x_rhs_expression(db: &TestDb) -> Expression<'_> { let file_main = system_path_to_file(db, "/src/main.py").unwrap(); - let ast = parsed_module(db, file_main); + let ast = parsed_module(db, file_main).load(db); // Get the second statement in `main.py` (x = …) and extract the expression // node on the right-hand side: let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value; @@ -10170,7 +10239,7 @@ mod tests { fn dependency_own_instance_member() -> anyhow::Result<()> { fn x_rhs_expression(db: &TestDb) -> Expression<'_> { let file_main = system_path_to_file(db, "/src/main.py").unwrap(); - let ast = parsed_module(db, file_main); + let ast = parsed_module(db, file_main).load(db); // Get the second statement in `main.py` (x = …) and extract the expression // node on the right-hand side: let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value; diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index e6cc62c383..a0e4ce3624 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -13,6 +13,7 @@ use crate::types::{ infer_expression_types, }; +use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_stdlib::identifiers::is_identifier; use itertools::Itertools; @@ -73,7 +74,8 @@ fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, ) -> Option> { - NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), true).finish() + let module = parsed_module(db.upcast(), pattern.file(db)).load(db.upcast()); + NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), true).finish() } #[salsa::tracked( @@ -85,7 +87,9 @@ fn all_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, ) -> Option> { - NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), true).finish() + let module = parsed_module(db.upcast(), expression.file(db)).load(db.upcast()); + NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), true) + .finish() } #[salsa::tracked( @@ -97,7 +101,9 @@ fn all_negative_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, ) -> Option> { - NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish() + let module = parsed_module(db.upcast(), expression.file(db)).load(db.upcast()); + NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), false) + .finish() } #[salsa::tracked(returns(as_ref))] @@ -105,7 +111,8 @@ fn all_negative_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, ) -> Option> { - NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish() + let module = parsed_module(db.upcast(), pattern.file(db)).load(db.upcast()); + NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), false).finish() } #[expect(clippy::ref_option)] @@ -251,16 +258,23 @@ fn expr_name(expr: &ast::Expr) -> Option<&ast::name::Name> { } } -struct NarrowingConstraintsBuilder<'db> { +struct NarrowingConstraintsBuilder<'db, 'ast> { db: &'db dyn Db, + module: &'ast ParsedModuleRef, predicate: PredicateNode<'db>, is_positive: bool, } -impl<'db> NarrowingConstraintsBuilder<'db> { - fn new(db: &'db dyn Db, predicate: PredicateNode<'db>, is_positive: bool) -> Self { +impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { + fn new( + db: &'db dyn Db, + module: &'ast ParsedModuleRef, + predicate: PredicateNode<'db>, + is_positive: bool, + ) -> Self { Self { db, + module, predicate, is_positive, } @@ -289,7 +303,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - let expression_node = expression.node_ref(self.db); + let expression_node = expression.node_ref(self.db, self.module); self.evaluate_expression_node_predicate(expression_node, expression, is_positive) } @@ -775,7 +789,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> { subject: Expression<'db>, singleton: ast::Singleton, ) -> Option> { - let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id); + let symbol = self + .expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id); let ty = match singleton { ast::Singleton::None => Type::none(self.db), @@ -790,8 +805,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { subject: Expression<'db>, cls: Expression<'db>, ) -> Option> { - let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id); - let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db)?; + let symbol = self + .expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id); + let ty = infer_same_file_expression_type(self.db, cls, self.module).to_instance(self.db)?; Some(NarrowingConstraints::from_iter([(symbol, ty)])) } @@ -801,8 +817,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { subject: Expression<'db>, value: Expression<'db>, ) -> Option> { - let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id); - let ty = infer_same_file_expression_type(self.db, value); + let symbol = self + .expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id); + let ty = infer_same_file_expression_type(self.db, value, self.module); Some(NarrowingConstraints::from_iter([(symbol, ty)])) } diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index f06ad7c517..af40005580 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use std::cmp::Ordering; +use ruff_db::parsed::ParsedModuleRef; use rustc_hash::FxHashMap; use ruff_python_ast::{self as ast, AnyNodeRef}; @@ -16,21 +17,22 @@ use super::diagnostic::INVALID_ASSIGNMENT; use super::{KnownClass, TupleType, UnionType}; /// Unpacks the value expression type to their respective targets. -pub(crate) struct Unpacker<'db> { - context: InferContext<'db>, +pub(crate) struct Unpacker<'db, 'ast> { + context: InferContext<'db, 'ast>, target_scope: ScopeId<'db>, value_scope: ScopeId<'db>, targets: FxHashMap>, } -impl<'db> Unpacker<'db> { +impl<'db, 'ast> Unpacker<'db, 'ast> { pub(crate) fn new( db: &'db dyn Db, target_scope: ScopeId<'db>, value_scope: ScopeId<'db>, + module: &'ast ParsedModuleRef, ) -> Self { Self { - context: InferContext::new(db, target_scope), + context: InferContext::new(db, target_scope, module), targets: FxHashMap::default(), target_scope, value_scope, @@ -41,6 +43,10 @@ impl<'db> Unpacker<'db> { self.context.db() } + fn module(&self) -> &'ast ParsedModuleRef { + self.context.module() + } + /// Unpack the value to the target expression. pub(crate) fn unpack(&mut self, target: &ast::Expr, value: UnpackValue<'db>) { debug_assert!( @@ -48,15 +54,16 @@ impl<'db> Unpacker<'db> { "Unpacking target must be a list or tuple expression" ); - let value_type = infer_expression_types(self.db(), value.expression()) - .expression_type(value.scoped_expression_id(self.db(), self.value_scope)); + let value_type = infer_expression_types(self.db(), value.expression()).expression_type( + value.scoped_expression_id(self.db(), self.value_scope, self.module()), + ); let value_type = match value.kind() { UnpackKind::Assign => { if self.context.in_stub() && value .expression() - .node_ref(self.db()) + .node_ref(self.db(), self.module()) .is_ellipsis_literal_expr() { Type::unknown() @@ -65,22 +72,34 @@ impl<'db> Unpacker<'db> { } } UnpackKind::Iterable => value_type.try_iterate(self.db()).unwrap_or_else(|err| { - err.report_diagnostic(&self.context, value_type, value.as_any_node_ref(self.db())); + err.report_diagnostic( + &self.context, + value_type, + value.as_any_node_ref(self.db(), self.module()), + ); err.fallback_element_type(self.db()) }), UnpackKind::ContextManager => value_type.try_enter(self.db()).unwrap_or_else(|err| { - err.report_diagnostic(&self.context, value_type, value.as_any_node_ref(self.db())); + err.report_diagnostic( + &self.context, + value_type, + value.as_any_node_ref(self.db(), self.module()), + ); err.fallback_enter_type(self.db()) }), }; - self.unpack_inner(target, value.as_any_node_ref(self.db()), value_type); + self.unpack_inner( + target, + value.as_any_node_ref(self.db(), self.module()), + value_type, + ); } fn unpack_inner( &mut self, target: &ast::Expr, - value_expr: AnyNodeRef<'db>, + value_expr: AnyNodeRef<'_>, value_ty: Type<'db>, ) { match target { diff --git a/crates/ty_python_semantic/src/unpack.rs b/crates/ty_python_semantic/src/unpack.rs index 0e34dbe765..3dda4dd2f2 100644 --- a/crates/ty_python_semantic/src/unpack.rs +++ b/crates/ty_python_semantic/src/unpack.rs @@ -1,4 +1,5 @@ use ruff_db::files::File; +use ruff_db::parsed::ParsedModuleRef; use ruff_python_ast::{self as ast, AnyNodeRef}; use ruff_text_size::{Ranged, TextRange}; @@ -37,9 +38,9 @@ pub(crate) struct Unpack<'db> { /// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target /// expression is `(a, b)`. #[no_eq] - #[returns(deref)] #[tracked] - pub(crate) target: AstNodeRef, + #[returns(ref)] + pub(crate) _target: AstNodeRef, /// The ingredient representing the value expression of the unpacking. For example, in /// `(a, b) = (1, 2)`, the value expression is `(1, 2)`. @@ -49,6 +50,14 @@ pub(crate) struct Unpack<'db> { } impl<'db> Unpack<'db> { + pub(crate) fn target<'ast>( + self, + db: &'db dyn Db, + parsed: &'ast ParsedModuleRef, + ) -> &'ast ast::Expr { + self._target(db).node(parsed) + } + /// Returns the scope in which the unpack value expression belongs. /// /// The scope in which the target and value expression belongs to are usually the same @@ -65,8 +74,8 @@ impl<'db> Unpack<'db> { } /// Returns the range of the unpack target expression. - pub(crate) fn range(self, db: &'db dyn Db) -> TextRange { - self.target(db).range() + pub(crate) fn range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> TextRange { + self.target(db, module).range() } } @@ -94,15 +103,20 @@ impl<'db> UnpackValue<'db> { self, db: &'db dyn Db, scope: ScopeId<'db>, + module: &ParsedModuleRef, ) -> ScopedExpressionId { self.expression() - .node_ref(db) + .node_ref(db, module) .scoped_expression_id(db, scope) } /// Returns the expression as an [`AnyNodeRef`]. - pub(crate) fn as_any_node_ref(self, db: &'db dyn Db) -> AnyNodeRef<'db> { - self.expression().node_ref(db).into() + pub(crate) fn as_any_node_ref<'ast>( + self, + db: &'db dyn Db, + module: &'ast ParsedModuleRef, + ) -> AnyNodeRef<'ast> { + self.expression().node_ref(db, module).into() } pub(crate) const fn kind(self) -> UnpackKind { diff --git a/crates/ty_test/src/assertion.rs b/crates/ty_test/src/assertion.rs index 38ff0494d0..fd567dc60a 100644 --- a/crates/ty_test/src/assertion.rs +++ b/crates/ty_test/src/assertion.rs @@ -57,7 +57,7 @@ impl InlineFileAssertions { pub(crate) fn from_file(db: &Db, file: File) -> Self { let source = source_text(db, file); let lines = line_index(db, file); - let parsed = parsed_module(db, file); + let parsed = parsed_module(db, file).load(db); let comment_ranges = CommentRanges::from(parsed.tokens()); Self { comment_ranges, diff --git a/crates/ty_test/src/lib.rs b/crates/ty_test/src/lib.rs index ba24781106..4749f477bd 100644 --- a/crates/ty_test/src/lib.rs +++ b/crates/ty_test/src/lib.rs @@ -294,7 +294,7 @@ fn run_test( let failures: Failures = test_files .into_iter() .filter_map(|test_file| { - let parsed = parsed_module(db, test_file.file); + let parsed = parsed_module(db, test_file.file).load(db); let mut diagnostics: Vec = parsed .errors() diff --git a/crates/ty_wasm/src/lib.rs b/crates/ty_wasm/src/lib.rs index 20dde86bc6..e20ab2a803 100644 --- a/crates/ty_wasm/src/lib.rs +++ b/crates/ty_wasm/src/lib.rs @@ -201,7 +201,7 @@ impl Workspace { /// Returns the parsed AST for `path` pub fn parsed(&self, file_id: &FileHandle) -> Result { - let parsed = ruff_db::parsed::parsed_module(&self.db, file_id.file); + let parsed = ruff_db::parsed::parsed_module(&self.db, file_id.file).load(&self.db); Ok(format!("{:#?}", parsed.syntax())) } @@ -212,7 +212,7 @@ impl Workspace { /// Returns the token stream for `path` serialized as a string. pub fn tokens(&self, file_id: &FileHandle) -> Result { - let parsed = ruff_db::parsed::parsed_module(&self.db, file_id.file); + let parsed = ruff_db::parsed::parsed_module(&self.db, file_id.file).load(&self.db); Ok(format!("{:#?}", parsed.tokens())) }