diff --git a/Cargo.lock b/Cargo.lock index 7460c790b3..82908741fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -503,6 +503,11 @@ name = "countme" version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7704b5fdd17b18ae31c4c1da5a2e0305a2bf17b5249300a9ee9ed7b72114c636" +dependencies = [ + "dashmap 5.5.3", + "once_cell", + "rustc-hash 1.1.0", +] [[package]] name = "crc32fast" @@ -1853,27 +1858,20 @@ dependencies = [ [[package]] name = "red_knot" -version = "0.1.0" +version = "0.0.0" dependencies = [ "anyhow", - "bitflags 2.6.0", + "countme", "crossbeam", "ctrlc", - "dashmap 6.0.1", - "hashbrown 0.14.5", - "indexmap", - "is-macro", "notify", - "parking_lot", "rayon", "red_knot_module_resolver", - "ruff_index", - "ruff_notebook", + "red_knot_python_semantic", + "ruff_db", "ruff_python_ast", - "ruff_python_parser", - "ruff_text_size", "rustc-hash 2.0.0", - "tempfile", + "salsa", "tracing", "tracing-subscriber", "tracing-tree", diff --git a/Cargo.toml b/Cargo.toml index d2c8b1d6a2..a563af269b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ ruff_text_size = { path = "crates/ruff_text_size" } ruff_workspace = { path = "crates/ruff_workspace" } red_knot_module_resolver = { path = "crates/red_knot_module_resolver" } +red_knot_python_semantic = { path = "crates/red_knot_python_semantic" } aho-corasick = { version = "1.1.3" } annotate-snippets = { version = "0.9.2", features = ["color"] } @@ -96,7 +97,6 @@ once_cell = { version = "1.19.0" } path-absolutize = { version = "3.1.1" } path-slash = { version = "0.2.1" } pathdiff = { version = "0.2.1" } -parking_lot = "0.12.1" pep440_rs = { version = "0.6.0", features = ["serde"] } pretty_assertions = "1.3.0" proc-macro2 = { version = "1.0.79" } diff --git a/crates/red_knot/Cargo.toml b/crates/red_knot/Cargo.toml index 6ac07c1777..c155e627fa 100644 --- a/crates/red_knot/Cargo.toml +++ b/crates/red_knot/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "red_knot" -version = "0.1.0" +version = "0.0.0" edition.workspace = true rust-version.workspace = true homepage.workspace = true @@ -8,36 +8,29 @@ documentation.workspace = true repository.workspace = true authors.workspace = true license.workspace = true +default-run = "red_knot" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] red_knot_module_resolver = { workspace = true } +red_knot_python_semantic = { workspace = true } -ruff_python_parser = { workspace = true } +ruff_db = { workspace = true } ruff_python_ast = { workspace = true } -ruff_text_size = { workspace = true } -ruff_index = { workspace = true } -ruff_notebook = { workspace = true } anyhow = { workspace = true } -bitflags = { workspace = true } +countme = { workspace = true, features = ["enable"] } crossbeam = { workspace = true } ctrlc = { version = "3.4.4" } -dashmap = { workspace = true } -hashbrown = { workspace = true } -indexmap = { workspace = true } -is-macro = { workspace = true } notify = { workspace = true } -parking_lot = { workspace = true } rayon = { workspace = true } rustc-hash = { workspace = true } +salsa = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } tracing-tree = { workspace = true } -[dev-dependencies] -tempfile = { workspace = true } [lints] workspace = true diff --git a/crates/red_knot/src/ast_ids.rs b/crates/red_knot/src/ast_ids.rs deleted file mode 100644 index 5d88bf2f46..0000000000 --- a/crates/red_knot/src/ast_ids.rs +++ /dev/null @@ -1,418 +0,0 @@ -use std::any::type_name; -use std::fmt::{Debug, Formatter}; -use std::hash::{Hash, Hasher}; -use std::marker::PhantomData; - -use rustc_hash::FxHashMap; - -use ruff_index::{Idx, IndexVec}; -use ruff_python_ast::visitor::source_order; -use ruff_python_ast::visitor::source_order::{SourceOrderVisitor, TraversalSignal}; -use ruff_python_ast::{ - AnyNodeRef, AstNode, ExceptHandler, ExceptHandlerExceptHandler, Expr, MatchCase, ModModule, - NodeKind, Parameter, Stmt, StmtAnnAssign, StmtAssign, StmtAugAssign, StmtClassDef, - StmtFunctionDef, StmtGlobal, StmtImport, StmtImportFrom, StmtNonlocal, StmtTypeAlias, - TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, WithItem, -}; -use ruff_text_size::{Ranged, TextRange}; - -/// A type agnostic ID that uniquely identifies an AST node in a file. -#[ruff_index::newtype_index] -pub struct AstId; - -/// A typed ID that uniquely identifies an AST node in a file. -/// -/// This is different from [`AstId`] in that it is a combination of ID and the type of the node the ID identifies. -/// Typing the ID prevents mixing IDs of different node types and allows to restrict the API to only accept -/// nodes for which an ID has been created (not all AST nodes get an ID). -pub struct TypedAstId { - erased: AstId, - _marker: PhantomData N>, -} - -impl TypedAstId { - /// Upcasts this ID from a more specific node type to a more general node type. - pub fn upcast(self) -> TypedAstId - where - N: Into, - { - TypedAstId { - erased: self.erased, - _marker: PhantomData, - } - } -} - -impl Copy for TypedAstId {} -impl Clone for TypedAstId { - fn clone(&self) -> Self { - *self - } -} - -impl PartialEq for TypedAstId { - fn eq(&self, other: &Self) -> bool { - self.erased == other.erased - } -} - -impl Eq for TypedAstId {} -impl Hash for TypedAstId { - fn hash(&self, state: &mut H) { - self.erased.hash(state); - } -} - -impl Debug for TypedAstId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("TypedAstId") - .field(&self.erased) - .field(&type_name::()) - .finish() - } -} - -pub struct AstIds { - ids: IndexVec, - reverse: FxHashMap, -} - -impl AstIds { - // TODO rust analyzer doesn't allocate an ID for every node. It only allocates ids for - // nodes with a corresponding HIR element, that is nodes that are definitions. - pub fn from_module(module: &ModModule) -> Self { - let mut visitor = AstIdsVisitor::default(); - - // TODO: visit_module? - // Make sure we visit the root - visitor.create_id(module); - visitor.visit_body(&module.body); - - while let Some(deferred) = visitor.deferred.pop() { - match deferred { - DeferredNode::FunctionDefinition(def) => { - def.visit_source_order(&mut visitor); - } - DeferredNode::ClassDefinition(def) => def.visit_source_order(&mut visitor), - } - } - - AstIds { - ids: visitor.ids, - reverse: visitor.reverse, - } - } - - /// Returns the ID to the root node. - pub fn root(&self) -> NodeKey { - self.ids[AstId::new(0)] - } - - /// Returns the [`TypedAstId`] for a node. - pub fn ast_id(&self, node: &N) -> TypedAstId { - let key = node.syntax_node_key(); - TypedAstId { - erased: self.reverse.get(&key).copied().unwrap(), - _marker: PhantomData, - } - } - - /// Returns the [`TypedAstId`] for the node identified with the given [`TypedNodeKey`]. - pub fn ast_id_for_key(&self, node: &TypedNodeKey) -> TypedAstId { - let ast_id = self.ast_id_for_node_key(node.inner); - - TypedAstId { - erased: ast_id, - _marker: PhantomData, - } - } - - /// Returns the untyped [`AstId`] for the node identified by the given `node` key. - pub fn ast_id_for_node_key(&self, node: NodeKey) -> AstId { - self.reverse - .get(&node) - .copied() - .expect("Can't find node in AstIds map.") - } - - /// Returns the [`TypedNodeKey`] for the node identified by the given [`TypedAstId`]. - pub fn key(&self, id: TypedAstId) -> TypedNodeKey { - let syntax_key = self.ids[id.erased]; - - TypedNodeKey::new(syntax_key).unwrap() - } - - pub fn node_key(&self, id: TypedAstId) -> NodeKey { - self.ids[id.erased] - } -} - -impl std::fmt::Debug for AstIds { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut map = f.debug_map(); - for (key, value) in self.ids.iter_enumerated() { - map.entry(&key, &value); - } - - map.finish() - } -} - -impl PartialEq for AstIds { - fn eq(&self, other: &Self) -> bool { - self.ids == other.ids - } -} - -impl Eq for AstIds {} - -#[derive(Default)] -struct AstIdsVisitor<'a> { - ids: IndexVec, - reverse: FxHashMap, - deferred: Vec>, -} - -impl<'a> AstIdsVisitor<'a> { - fn create_id(&mut self, node: &A) { - let node_key = node.syntax_node_key(); - - let id = self.ids.push(node_key); - self.reverse.insert(node_key, id); - } -} - -impl<'a> SourceOrderVisitor<'a> for AstIdsVisitor<'a> { - fn visit_stmt(&mut self, stmt: &'a Stmt) { - match stmt { - Stmt::FunctionDef(def) => { - self.create_id(def); - self.deferred.push(DeferredNode::FunctionDefinition(def)); - return; - } - // TODO defer visiting the assignment body, type alias parameters etc? - Stmt::ClassDef(def) => { - self.create_id(def); - self.deferred.push(DeferredNode::ClassDefinition(def)); - return; - } - Stmt::Expr(_) => { - // Skip - return; - } - Stmt::Return(_) => {} - Stmt::Delete(_) => {} - Stmt::Assign(assignment) => self.create_id(assignment), - Stmt::AugAssign(assignment) => { - self.create_id(assignment); - } - Stmt::AnnAssign(assignment) => self.create_id(assignment), - Stmt::TypeAlias(assignment) => self.create_id(assignment), - Stmt::For(_) => {} - Stmt::While(_) => {} - Stmt::If(_) => {} - Stmt::With(_) => {} - Stmt::Match(_) => {} - Stmt::Raise(_) => {} - Stmt::Try(_) => {} - Stmt::Assert(_) => {} - Stmt::Import(import) => self.create_id(import), - Stmt::ImportFrom(import_from) => self.create_id(import_from), - Stmt::Global(global) => self.create_id(global), - Stmt::Nonlocal(non_local) => self.create_id(non_local), - Stmt::Pass(_) => {} - Stmt::Break(_) => {} - Stmt::Continue(_) => {} - Stmt::IpyEscapeCommand(_) => {} - } - - source_order::walk_stmt(self, stmt); - } - - fn visit_expr(&mut self, _expr: &'a Expr) {} - - fn visit_parameter(&mut self, parameter: &'a Parameter) { - self.create_id(parameter); - source_order::walk_parameter(self, parameter); - } - - fn visit_except_handler(&mut self, except_handler: &'a ExceptHandler) { - match except_handler { - ExceptHandler::ExceptHandler(except_handler) => { - self.create_id(except_handler); - } - } - - source_order::walk_except_handler(self, except_handler); - } - - fn visit_with_item(&mut self, with_item: &'a WithItem) { - self.create_id(with_item); - source_order::walk_with_item(self, with_item); - } - - fn visit_match_case(&mut self, match_case: &'a MatchCase) { - self.create_id(match_case); - source_order::walk_match_case(self, match_case); - } - - fn visit_type_param(&mut self, type_param: &'a TypeParam) { - self.create_id(type_param); - } -} - -enum DeferredNode<'a> { - FunctionDefinition(&'a StmtFunctionDef), - ClassDefinition(&'a StmtClassDef), -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct TypedNodeKey { - /// The type erased node key. - inner: NodeKey, - _marker: PhantomData N>, -} - -impl TypedNodeKey { - pub fn from_node(node: &N) -> Self { - let inner = NodeKey::from_node(node.as_any_node_ref()); - Self { - inner, - _marker: PhantomData, - } - } - - pub fn new(node_key: NodeKey) -> Option { - N::can_cast(node_key.kind).then_some(TypedNodeKey { - inner: node_key, - _marker: PhantomData, - }) - } - - pub fn resolve<'a>(&self, root: AnyNodeRef<'a>) -> Option> { - let node_ref = self.inner.resolve(root)?; - - Some(N::cast_ref(node_ref).unwrap()) - } - - pub fn resolve_unwrap<'a>(&self, root: AnyNodeRef<'a>) -> N::Ref<'a> { - self.resolve(root).expect("node should resolve") - } - - pub fn erased(&self) -> &NodeKey { - &self.inner - } -} - -struct FindNodeKeyVisitor<'a> { - key: NodeKey, - result: Option>, -} - -impl<'a> SourceOrderVisitor<'a> for FindNodeKeyVisitor<'a> { - fn enter_node(&mut self, node: AnyNodeRef<'a>) -> TraversalSignal { - if self.result.is_some() { - return TraversalSignal::Skip; - } - - if node.range() == self.key.range && node.kind() == self.key.kind { - self.result = Some(node); - TraversalSignal::Skip - } else if node.range().contains_range(self.key.range) { - TraversalSignal::Traverse - } else { - TraversalSignal::Skip - } - } - - fn visit_body(&mut self, body: &'a [Stmt]) { - // TODO it would be more efficient to use binary search instead of linear - for stmt in body { - if stmt.range().start() > self.key.range.end() { - break; - } - - self.visit_stmt(stmt); - } - } -} - -// TODO an alternative to this is to have a `NodeId` on each node (in increasing order depending on the position). -// This would allow to reduce the size of this to a u32. -// What would be nice if we could use an `Arc::weak_ref` here but that only works if we use -// `Arc` internally -// TODO: Implement the logic to resolve a node, given a db (and the correct file). -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct NodeKey { - kind: NodeKind, - range: TextRange, -} - -impl NodeKey { - pub fn from_node(node: AnyNodeRef) -> Self { - NodeKey { - kind: node.kind(), - range: node.range(), - } - } - pub fn resolve<'a>(&self, root: AnyNodeRef<'a>) -> Option> { - // We need to do a binary search here. Only traverse into a node if the range is withint the node - let mut visitor = FindNodeKeyVisitor { - key: *self, - result: None, - }; - - if visitor.enter_node(root) == TraversalSignal::Traverse { - root.visit_preorder(&mut visitor); - } - - visitor.result - } -} - -/// Marker trait implemented by AST nodes for which we extract the `AstId`. -pub trait HasAstId: AstNode { - fn node_key(&self) -> TypedNodeKey - where - Self: Sized, - { - TypedNodeKey { - inner: self.syntax_node_key(), - _marker: PhantomData, - } - } - - fn syntax_node_key(&self) -> NodeKey { - NodeKey { - kind: self.as_any_node_ref().kind(), - range: self.range(), - } - } -} - -impl HasAstId for StmtFunctionDef {} -impl HasAstId for StmtClassDef {} -impl HasAstId for StmtAnnAssign {} -impl HasAstId for StmtAugAssign {} -impl HasAstId for StmtAssign {} -impl HasAstId for StmtTypeAlias {} - -impl HasAstId for ModModule {} - -impl HasAstId for StmtImport {} - -impl HasAstId for StmtImportFrom {} - -impl HasAstId for Parameter {} - -impl HasAstId for TypeParam {} -impl HasAstId for Stmt {} -impl HasAstId for TypeParamTypeVar {} -impl HasAstId for TypeParamTypeVarTuple {} -impl HasAstId for TypeParamParamSpec {} -impl HasAstId for StmtGlobal {} -impl HasAstId for StmtNonlocal {} - -impl HasAstId for ExceptHandlerExceptHandler {} -impl HasAstId for WithItem {} -impl HasAstId for MatchCase {} diff --git a/crates/red_knot/src/cache.rs b/crates/red_knot/src/cache.rs deleted file mode 100644 index 719a1449ed..0000000000 --- a/crates/red_knot/src/cache.rs +++ /dev/null @@ -1,165 +0,0 @@ -use std::fmt::Formatter; -use std::hash::Hash; -use std::sync::atomic::{AtomicUsize, Ordering}; - -use crate::db::QueryResult; -use dashmap::mapref::entry::Entry; - -use crate::FxDashMap; - -/// Simple key value cache that locks on a per-key level. -pub struct KeyValueCache { - map: FxDashMap, - statistics: CacheStatistics, -} - -impl KeyValueCache -where - K: Eq + Hash + Clone, - V: Clone, -{ - pub fn try_get(&self, key: &K) -> Option { - if let Some(existing) = self.map.get(key) { - self.statistics.hit(); - Some(existing.clone()) - } else { - self.statistics.miss(); - None - } - } - - pub fn get(&self, key: &K, compute: F) -> QueryResult - where - F: FnOnce(&K) -> QueryResult, - { - Ok(match self.map.entry(key.clone()) { - Entry::Occupied(cached) => { - self.statistics.hit(); - - cached.get().clone() - } - Entry::Vacant(vacant) => { - self.statistics.miss(); - - let value = compute(key)?; - vacant.insert(value.clone()); - value - } - }) - } - - pub fn set(&mut self, key: K, value: V) { - self.map.insert(key, value); - } - - pub fn remove(&mut self, key: &K) -> Option { - self.map.remove(key).map(|(_, value)| value) - } - - pub fn clear(&mut self) { - self.map.clear(); - self.map.shrink_to_fit(); - } - - pub fn statistics(&self) -> Option { - self.statistics.to_statistics() - } -} - -impl Default for KeyValueCache -where - K: Eq + Hash, - V: Clone, -{ - fn default() -> Self { - Self { - map: FxDashMap::default(), - statistics: CacheStatistics::default(), - } - } -} - -impl std::fmt::Debug for KeyValueCache -where - K: std::fmt::Debug + Eq + Hash, - V: std::fmt::Debug, -{ - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut debug = f.debug_map(); - - for entry in &self.map { - debug.entry(&entry.value(), &entry.key()); - } - - debug.finish() - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Statistics { - pub hits: usize, - pub misses: usize, -} - -impl Statistics { - #[allow(clippy::cast_precision_loss)] - pub fn hit_rate(&self) -> Option { - if self.hits + self.misses == 0 { - return None; - } - - Some((self.hits as f64) / (self.hits + self.misses) as f64) - } -} - -#[cfg(debug_assertions)] -pub type CacheStatistics = DebugStatistics; - -#[cfg(not(debug_assertions))] -pub type CacheStatistics = ReleaseStatistics; - -pub trait StatisticsRecorder { - fn hit(&self); - fn miss(&self); - fn to_statistics(&self) -> Option; -} - -#[derive(Debug, Default)] -pub struct DebugStatistics { - hits: AtomicUsize, - misses: AtomicUsize, -} - -impl StatisticsRecorder for DebugStatistics { - // TODO figure out appropriate Ordering - fn hit(&self) { - self.hits.fetch_add(1, Ordering::SeqCst); - } - - fn miss(&self) { - self.misses.fetch_add(1, Ordering::SeqCst); - } - - fn to_statistics(&self) -> Option { - let hits = self.hits.load(Ordering::SeqCst); - let misses = self.misses.load(Ordering::SeqCst); - - Some(Statistics { hits, misses }) - } -} - -#[derive(Debug, Default)] -pub struct ReleaseStatistics; - -impl StatisticsRecorder for ReleaseStatistics { - #[inline] - fn hit(&self) {} - - #[inline] - fn miss(&self) {} - - #[inline] - fn to_statistics(&self) -> Option { - None - } -} diff --git a/crates/red_knot/src/cancellation.rs b/crates/red_knot/src/cancellation.rs deleted file mode 100644 index 6f91bc8e2b..0000000000 --- a/crates/red_knot/src/cancellation.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::sync::atomic::AtomicBool; -use std::sync::Arc; - -#[derive(Debug, Clone, Default)] -pub struct CancellationTokenSource { - signal: Arc, -} - -impl CancellationTokenSource { - pub fn new() -> Self { - Self { - signal: Arc::new(AtomicBool::new(false)), - } - } - - #[tracing::instrument(level = "trace", skip_all)] - pub fn cancel(&self) { - self.signal.store(true, std::sync::atomic::Ordering::SeqCst); - } - - pub fn is_cancelled(&self) -> bool { - self.signal.load(std::sync::atomic::Ordering::SeqCst) - } - - pub fn token(&self) -> CancellationToken { - CancellationToken { - signal: self.signal.clone(), - } - } -} - -#[derive(Clone, Debug)] -pub struct CancellationToken { - signal: Arc, -} - -impl CancellationToken { - /// Returns `true` if cancellation has been requested. - pub fn is_cancelled(&self) -> bool { - self.signal.load(std::sync::atomic::Ordering::SeqCst) - } -} diff --git a/crates/red_knot/src/db.rs b/crates/red_knot/src/db.rs index 9e6a540a79..a61a6ff026 100644 --- a/crates/red_knot/src/db.rs +++ b/crates/red_knot/src/db.rs @@ -1,248 +1,10 @@ -use std::sync::Arc; +use red_knot_python_semantic::Db as SemanticDb; +use ruff_db::Upcast; +use salsa::DbWithJar; -pub use jars::{HasJar, HasJars}; -pub use query::{QueryError, QueryResult}; -pub use runtime::DbRuntime; -pub use storage::JarsStorage; +use crate::lint::{lint_semantic, lint_syntax, unwind_if_cancelled}; -use crate::files::FileId; -use crate::lint::{LintSemanticStorage, LintSyntaxStorage}; -use crate::module::ModuleResolver; -use crate::parse::ParsedStorage; -use crate::semantic::SemanticIndexStorage; -use crate::semantic::TypeStore; -use crate::source::SourceStorage; +pub trait Db: DbWithJar + SemanticDb + Upcast {} -mod jars; -mod query; -mod runtime; -mod storage; - -pub trait Database { - /// Returns a reference to the runtime of the current worker. - fn runtime(&self) -> &DbRuntime; - - /// Returns a mutable reference to the runtime. Only one worker can hold a mutable reference to the runtime. - fn runtime_mut(&mut self) -> &mut DbRuntime; - - /// Returns `Ok` if the queries have not been cancelled and `Err(QueryError::Cancelled)` otherwise. - fn cancelled(&self) -> QueryResult<()> { - self.runtime().cancelled() - } - - /// Returns `true` if the queries have been cancelled. - fn is_cancelled(&self) -> bool { - self.runtime().is_cancelled() - } -} - -/// Database that supports running queries from multiple threads. -pub trait ParallelDatabase: Database + Send { - /// Creates a snapshot of the database state that can be used to query the database in another thread. - /// - /// The snapshot is a read-only view of the database but query results are shared between threads. - /// All queries will be automatically cancelled when applying any mutations (calling [`HasJars::jars_mut`]) - /// to the database (not the snapshot, because they're readonly). - /// - /// ## Creating a snapshot - /// - /// Creating a snapshot of the database's jars is cheap but creating a snapshot of - /// other state stored on the database might require deep-cloning data. That's why you should - /// avoid creating snapshots in a hot function (e.g. don't create a snapshot for each file, instead - /// create a snapshot when scheduling the check of an entire program). - /// - /// ## Salsa compatibility - /// Salsa prohibits creating a snapshot while running a local query (it's fine if other workers run a query) [[source](https://github.com/salsa-rs/salsa/issues/80)]. - /// We should avoid creating snapshots while running a query because we might want to adopt Salsa in the future (if we can figure out persistent caching). - /// Unfortunately, the infrastructure doesn't provide an automated way of knowing when a query is run, that's - /// why we have to "enforce" this constraint manually. - #[must_use] - fn snapshot(&self) -> Snapshot; -} - -pub trait DbWithJar: Database + HasJar {} - -/// Readonly snapshot of a database. -/// -/// ## Dead locks -/// A snapshot should always be dropped as soon as it is no longer necessary to run queries. -/// Storing the snapshot without running a query or periodically checking if cancellation was requested -/// can lead to deadlocks because mutating the [`Database`] requires cancels all pending queries -/// and waiting for all [`Snapshot`]s to be dropped. -#[derive(Debug)] -pub struct Snapshot -where - DB: ParallelDatabase, -{ - db: DB, -} - -impl Snapshot -where - DB: ParallelDatabase, -{ - pub fn new(db: DB) -> Self { - Snapshot { db } - } -} - -impl std::ops::Deref for Snapshot -where - DB: ParallelDatabase, -{ - type Target = DB; - - fn deref(&self) -> &DB { - &self.db - } -} - -pub trait Upcast { - fn upcast(&self) -> &T; -} - -// Red knot specific databases code. - -pub trait SourceDb: DbWithJar { - // queries - fn file_id(&self, path: &std::path::Path) -> FileId; - - fn file_path(&self, file_id: FileId) -> Arc; -} - -pub trait SemanticDb: SourceDb + DbWithJar + Upcast {} - -pub trait LintDb: SemanticDb + DbWithJar + Upcast {} - -pub trait Db: LintDb + Upcast {} - -#[derive(Debug, Default)] -pub struct SourceJar { - pub sources: SourceStorage, - pub parsed: ParsedStorage, -} - -#[derive(Debug, Default)] -pub struct SemanticJar { - pub module_resolver: ModuleResolver, - pub semantic_indices: SemanticIndexStorage, - pub type_store: TypeStore, -} - -#[derive(Debug, Default)] -pub struct LintJar { - pub lint_syntax: LintSyntaxStorage, - pub lint_semantic: LintSemanticStorage, -} - -#[cfg(test)] -pub(crate) mod tests { - use std::path::Path; - use std::sync::Arc; - - use crate::db::{ - Database, DbRuntime, DbWithJar, HasJar, HasJars, JarsStorage, LintDb, LintJar, QueryResult, - SourceDb, SourceJar, Upcast, - }; - use crate::files::{FileId, Files}; - - use super::{SemanticDb, SemanticJar}; - - // This can be a partial database used in a single crate for testing. - // It would hold fewer data than the full database. - #[derive(Debug, Default)] - pub(crate) struct TestDb { - files: Files, - jars: JarsStorage, - } - - impl HasJar for TestDb { - fn jar(&self) -> QueryResult<&SourceJar> { - Ok(&self.jars()?.0) - } - - fn jar_mut(&mut self) -> &mut SourceJar { - &mut self.jars_mut().0 - } - } - - impl HasJar for TestDb { - fn jar(&self) -> QueryResult<&SemanticJar> { - Ok(&self.jars()?.1) - } - - fn jar_mut(&mut self) -> &mut SemanticJar { - &mut self.jars_mut().1 - } - } - - impl HasJar for TestDb { - fn jar(&self) -> QueryResult<&LintJar> { - Ok(&self.jars()?.2) - } - - fn jar_mut(&mut self) -> &mut LintJar { - &mut self.jars_mut().2 - } - } - - impl SourceDb for TestDb { - fn file_id(&self, path: &Path) -> FileId { - self.files.intern(path) - } - - fn file_path(&self, file_id: FileId) -> Arc { - self.files.path(file_id) - } - } - - impl DbWithJar for TestDb {} - - impl Upcast for TestDb { - fn upcast(&self) -> &(dyn SourceDb + 'static) { - self - } - } - - impl SemanticDb for TestDb {} - - impl DbWithJar for TestDb {} - - impl Upcast for TestDb { - fn upcast(&self) -> &(dyn SemanticDb + 'static) { - self - } - } - - impl LintDb for TestDb {} - - impl Upcast for TestDb { - fn upcast(&self) -> &(dyn LintDb + 'static) { - self - } - } - - impl DbWithJar for TestDb {} - - impl HasJars for TestDb { - type Jars = (SourceJar, SemanticJar, LintJar); - - fn jars(&self) -> QueryResult<&Self::Jars> { - self.jars.jars() - } - - fn jars_mut(&mut self) -> &mut Self::Jars { - self.jars.jars_mut() - } - } - - impl Database for TestDb { - fn runtime(&self) -> &DbRuntime { - self.jars.runtime() - } - - fn runtime_mut(&mut self) -> &mut DbRuntime { - self.jars.runtime_mut() - } - } -} +#[salsa::jar(db=Db)] +pub struct Jar(lint_syntax, lint_semantic, unwind_if_cancelled); diff --git a/crates/red_knot/src/db/jars.rs b/crates/red_knot/src/db/jars.rs deleted file mode 100644 index 7fd24e4dd3..0000000000 --- a/crates/red_knot/src/db/jars.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::db::query::QueryResult; - -/// Gives access to a specific jar in the database. -/// -/// Nope, the terminology isn't borrowed from Java but from Salsa , -/// which is an analogy to storing the salsa in different jars. -/// -/// The basic idea is that each crate can define its own jar and the jars can be combined to a single -/// database in the top level crate. Each crate also defines its own `Database` trait. The combination of -/// `Database` trait and the jar allows to write queries in isolation without having to know how they get composed at the upper levels. -/// -/// Salsa further defines a `HasIngredient` trait which slices the jar to a specific storage (e.g. a specific cache). -/// We don't need this just yet because we write our queries by hand. We may want a similar trait if we decide -/// to use a macro to generate the queries. -pub trait HasJar { - /// Gives a read-only reference to the jar. - fn jar(&self) -> QueryResult<&T>; - - /// Gives a mutable reference to the jar. - fn jar_mut(&mut self) -> &mut T; -} - -/// Gives access to the jars in a database. -pub trait HasJars { - /// A type storing the jars. - /// - /// Most commonly, this is a tuple where each jar is a tuple element. - type Jars: Default; - - /// Gives access to the underlying jars but tests if the queries have been cancelled. - /// - /// Returns `Err(QueryError::Cancelled)` if the queries have been cancelled. - fn jars(&self) -> QueryResult<&Self::Jars>; - - /// Gives mutable access to the underlying jars. - fn jars_mut(&mut self) -> &mut Self::Jars; -} diff --git a/crates/red_knot/src/db/query.rs b/crates/red_knot/src/db/query.rs deleted file mode 100644 index d020decd6e..0000000000 --- a/crates/red_knot/src/db/query.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::fmt::{Display, Formatter}; - -/// Reason why a db query operation failed. -#[derive(Debug, Clone, Copy)] -pub enum QueryError { - /// The query was cancelled because the DB was mutated or the query was cancelled by the host (e.g. on a file change or when pressing CTRL+C). - Cancelled, -} - -impl Display for QueryError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - QueryError::Cancelled => f.write_str("query was cancelled"), - } - } -} - -impl std::error::Error for QueryError {} - -pub type QueryResult = Result; diff --git a/crates/red_knot/src/db/runtime.rs b/crates/red_knot/src/db/runtime.rs deleted file mode 100644 index c8530eb168..0000000000 --- a/crates/red_knot/src/db/runtime.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::cancellation::CancellationTokenSource; -use crate::db::{QueryError, QueryResult}; - -/// Holds the jar agnostic state of the database. -#[derive(Debug, Default)] -pub struct DbRuntime { - /// The cancellation token source used to signal other works that the queries should be aborted and - /// exit at the next possible point. - cancellation_token: CancellationTokenSource, -} - -impl DbRuntime { - pub(super) fn snapshot(&self) -> Self { - Self { - cancellation_token: self.cancellation_token.clone(), - } - } - - /// Cancels the pending queries of other workers. The current worker cannot have any pending - /// queries because we're holding a mutable reference to the runtime. - pub(super) fn cancel_other_workers(&mut self) { - self.cancellation_token.cancel(); - // Set a new cancellation token so that we're in a non-cancelled state again when running the next - // query. - self.cancellation_token = CancellationTokenSource::default(); - } - - /// Returns `Ok` if the queries have not been cancelled and `Err(QueryError::Cancelled)` otherwise. - pub(super) fn cancelled(&self) -> QueryResult<()> { - if self.cancellation_token.is_cancelled() { - Err(QueryError::Cancelled) - } else { - Ok(()) - } - } - - /// Returns `true` if the queries have been cancelled. - pub(super) fn is_cancelled(&self) -> bool { - self.cancellation_token.is_cancelled() - } -} diff --git a/crates/red_knot/src/db/storage.rs b/crates/red_knot/src/db/storage.rs deleted file mode 100644 index afb57e3230..0000000000 --- a/crates/red_knot/src/db/storage.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::fmt::Formatter; -use std::sync::Arc; - -use crossbeam::sync::WaitGroup; - -use crate::db::query::QueryResult; -use crate::db::runtime::DbRuntime; -use crate::db::{HasJars, ParallelDatabase}; - -/// Stores the jars of a database and the state for each worker. -/// -/// Today, all state is shared across all workers, but it may be desired to store data per worker in the future. -pub struct JarsStorage -where - T: HasJars + Sized, -{ - // It's important that `jars_wait_group` is declared after `jars` to ensure that `jars` is dropped first. - // See https://doc.rust-lang.org/reference/destructors.html - /// Stores the jars of the database. - jars: Arc, - - /// Used to count the references to `jars`. Allows implementing `jars_mut` without requiring to clone `jars`. - jars_wait_group: WaitGroup, - - /// The data agnostic state. - runtime: DbRuntime, -} - -impl JarsStorage -where - Db: HasJars, -{ - pub(super) fn new() -> Self { - Self { - jars: Arc::new(Db::Jars::default()), - jars_wait_group: WaitGroup::default(), - runtime: DbRuntime::default(), - } - } - - /// Creates a snapshot of the jars. - /// - /// Creating the snapshot is cheap because it doesn't clone the jars, it only increments a ref counter. - #[must_use] - pub fn snapshot(&self) -> JarsStorage - where - Db: ParallelDatabase, - { - Self { - jars: self.jars.clone(), - jars_wait_group: self.jars_wait_group.clone(), - runtime: self.runtime.snapshot(), - } - } - - pub(crate) fn jars(&self) -> QueryResult<&Db::Jars> { - self.runtime.cancelled()?; - Ok(&self.jars) - } - - /// Returns a mutable reference to the jars without cloning their content. - /// - /// The method cancels any pending queries of other works and waits for them to complete so that - /// this instance is the only instance holding a reference to the jars. - pub(crate) fn jars_mut(&mut self) -> &mut Db::Jars { - // We have a mutable ref here, so no more workers can be spawned between calling this function and taking the mut ref below. - self.cancel_other_workers(); - - // Now all other references to `self.jars` should have been released. We can now safely return a mutable reference - // to the Arc's content. - let jars = - Arc::get_mut(&mut self.jars).expect("All references to jars should have been released"); - - jars - } - - pub(crate) fn runtime(&self) -> &DbRuntime { - &self.runtime - } - - pub(crate) fn runtime_mut(&mut self) -> &mut DbRuntime { - // Note: This method may need to use a similar trick to `jars_mut` if `DbRuntime` is ever to store data that is shared between workers. - &mut self.runtime - } - - #[tracing::instrument(level = "trace", skip(self))] - fn cancel_other_workers(&mut self) { - self.runtime.cancel_other_workers(); - - // Wait for all other works to complete. - let existing_wait = std::mem::take(&mut self.jars_wait_group); - existing_wait.wait(); - } -} - -impl Default for JarsStorage -where - Db: HasJars, -{ - fn default() -> Self { - Self::new() - } -} - -impl std::fmt::Debug for JarsStorage -where - T: HasJars, - ::Jars: std::fmt::Debug, -{ - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SharedStorage") - .field("jars", &self.jars) - .field("jars_wait_group", &self.jars_wait_group) - .field("runtime", &self.runtime) - .finish() - } -} diff --git a/crates/red_knot/src/files.rs b/crates/red_knot/src/files.rs deleted file mode 100644 index de8bf68a4f..0000000000 --- a/crates/red_knot/src/files.rs +++ /dev/null @@ -1,180 +0,0 @@ -use std::fmt::{Debug, Formatter}; -use std::hash::{Hash, Hasher}; -use std::path::Path; -use std::sync::Arc; - -use hashbrown::hash_map::RawEntryMut; -use parking_lot::RwLock; -use rustc_hash::FxHasher; - -use ruff_index::{newtype_index, IndexVec}; - -type Map = hashbrown::HashMap; - -#[newtype_index] -pub struct FileId; - -// TODO we'll need a higher level virtual file system abstraction that allows testing if a file exists -// or retrieving its content (ideally lazily and in a way that the memory can be retained later) -// I suspect that we'll end up with a FileSystem trait and our own Path abstraction. -#[derive(Default)] -pub struct Files { - inner: Arc>, -} - -impl Files { - #[tracing::instrument(level = "debug", skip(self))] - pub fn intern(&self, path: &Path) -> FileId { - self.inner.write().intern(path) - } - - pub fn try_get(&self, path: &Path) -> Option { - self.inner.read().try_get(path) - } - - #[tracing::instrument(level = "debug", skip(self))] - pub fn path(&self, id: FileId) -> Arc { - self.inner.read().path(id) - } - - /// Snapshots files for a new database snapshot. - /// - /// This method should not be used outside a database snapshot. - #[must_use] - pub fn snapshot(&self) -> Files { - Files { - inner: self.inner.clone(), - } - } -} - -impl Debug for Files { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let files = self.inner.read(); - let mut debug = f.debug_map(); - for item in files.iter() { - debug.entry(&item.0, &item.1); - } - - debug.finish() - } -} - -impl PartialEq for Files { - fn eq(&self, other: &Self) -> bool { - self.inner.read().eq(&other.inner.read()) - } -} - -impl Eq for Files {} - -#[derive(Default)] -struct FilesInner { - by_path: Map, - // TODO should we use a map here to reclaim the space for removed files? - // TODO I think we should use our own path abstraction here to avoid having to normalize paths - // and dealing with non-utf paths everywhere. - by_id: IndexVec>, -} - -impl FilesInner { - /// Inserts the path and returns a new id for it or returns the id if it is an existing path. - // TODO should this accept Path or PathBuf? - pub(crate) fn intern(&mut self, path: &Path) -> FileId { - let hash = FilesInner::hash_path(path); - - let entry = self - .by_path - .raw_entry_mut() - .from_hash(hash, |existing_file| &*self.by_id[*existing_file] == path); - - match entry { - RawEntryMut::Occupied(entry) => *entry.key(), - RawEntryMut::Vacant(entry) => { - let id = self.by_id.push(Arc::from(path)); - entry.insert_with_hasher(hash, id, (), |file| { - FilesInner::hash_path(&self.by_id[*file]) - }); - id - } - } - } - - fn hash_path(path: &Path) -> u64 { - let mut hasher = FxHasher::default(); - path.hash(&mut hasher); - hasher.finish() - } - - pub(crate) fn try_get(&self, path: &Path) -> Option { - let mut hasher = FxHasher::default(); - path.hash(&mut hasher); - let hash = hasher.finish(); - - Some( - *self - .by_path - .raw_entry() - .from_hash(hash, |existing_file| &*self.by_id[*existing_file] == path)? - .0, - ) - } - - /// Returns the path for the file with the given id. - pub(crate) fn path(&self, id: FileId) -> Arc { - self.by_id[id].clone() - } - - pub(crate) fn iter(&self) -> impl Iterator)> + '_ { - self.by_path.keys().map(|id| (*id, self.by_id[*id].clone())) - } -} - -impl PartialEq for FilesInner { - fn eq(&self, other: &Self) -> bool { - self.by_id == other.by_id - } -} - -impl Eq for FilesInner {} - -#[cfg(test)] -mod tests { - use super::*; - use std::path::PathBuf; - - #[test] - fn insert_path_twice_same_id() { - let files = Files::default(); - let path = PathBuf::from("foo/bar"); - let id1 = files.intern(&path); - let id2 = files.intern(&path); - assert_eq!(id1, id2); - } - - #[test] - fn insert_different_paths_different_ids() { - let files = Files::default(); - let path1 = PathBuf::from("foo/bar"); - let path2 = PathBuf::from("foo/bar/baz"); - let id1 = files.intern(&path1); - let id2 = files.intern(&path2); - assert_ne!(id1, id2); - } - - #[test] - fn four_files() { - let files = Files::default(); - let foo_path = PathBuf::from("foo"); - let foo_id = files.intern(&foo_path); - let bar_path = PathBuf::from("bar"); - files.intern(&bar_path); - let baz_path = PathBuf::from("baz"); - files.intern(&baz_path); - let qux_path = PathBuf::from("qux"); - files.intern(&qux_path); - - let foo_id_2 = files.try_get(&foo_path).expect("foo_path to be found"); - assert_eq!(foo_id_2, foo_id); - } -} diff --git a/crates/red_knot/src/hir.rs b/crates/red_knot/src/hir.rs deleted file mode 100644 index 5b7eeeafdf..0000000000 --- a/crates/red_knot/src/hir.rs +++ /dev/null @@ -1,67 +0,0 @@ -//! Key observations -//! -//! The HIR (High-Level Intermediate Representation) avoids allocations to large extends by: -//! * Using an arena per node type -//! * using ids and id ranges to reference items. -//! -//! Using separate arena per node type has the advantage that the IDs are relatively stable, because -//! they only change when a node of the same kind has been added or removed. (What's unclear is if that matters or if -//! it still triggers a re-compute because the AST-id in the node has changed). -//! -//! The HIR does not store all details. It mainly stores the *public* interface. There's a reference -//! back to the AST node to get more details. -//! -//! - -use crate::ast_ids::{HasAstId, TypedAstId}; -use crate::files::FileId; -use std::fmt::Formatter; -use std::hash::{Hash, Hasher}; - -pub struct HirAstId { - file_id: FileId, - node_id: TypedAstId, -} - -impl Copy for HirAstId {} -impl Clone for HirAstId { - fn clone(&self) -> Self { - *self - } -} - -impl PartialEq for HirAstId { - fn eq(&self, other: &Self) -> bool { - self.file_id == other.file_id && self.node_id == other.node_id - } -} - -impl Eq for HirAstId {} - -impl std::fmt::Debug for HirAstId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("HirAstId") - .field("file_id", &self.file_id) - .field("node_id", &self.node_id) - .finish() - } -} - -impl Hash for HirAstId { - fn hash(&self, state: &mut H) { - self.file_id.hash(state); - self.node_id.hash(state); - } -} - -impl HirAstId { - pub fn upcast(self) -> HirAstId - where - N: Into, - { - HirAstId { - file_id: self.file_id, - node_id: self.node_id.upcast(), - } - } -} diff --git a/crates/red_knot/src/hir/definition.rs b/crates/red_knot/src/hir/definition.rs deleted file mode 100644 index 35b239796a..0000000000 --- a/crates/red_knot/src/hir/definition.rs +++ /dev/null @@ -1,556 +0,0 @@ -use std::ops::{Index, Range}; - -use ruff_index::{newtype_index, IndexVec}; -use ruff_python_ast::visitor::preorder; -use ruff_python_ast::visitor::preorder::PreorderVisitor; -use ruff_python_ast::{ - Decorator, ExceptHandler, ExceptHandlerExceptHandler, Expr, MatchCase, ModModule, Stmt, - StmtAnnAssign, StmtAssign, StmtClassDef, StmtFunctionDef, StmtGlobal, StmtImport, - StmtImportFrom, StmtNonlocal, StmtTypeAlias, TypeParam, TypeParamParamSpec, TypeParamTypeVar, - TypeParamTypeVarTuple, WithItem, -}; - -use crate::ast_ids::{AstIds, HasAstId}; -use crate::files::FileId; -use crate::hir::HirAstId; -use crate::Name; - -#[newtype_index] -pub struct FunctionId; - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct Function { - ast_id: HirAstId, - name: Name, - parameters: Range, - type_parameters: Range, // TODO: type_parameters, return expression, decorators -} - -#[newtype_index] -pub struct ParameterId; - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct Parameter { - kind: ParameterKind, - name: Name, - default: Option<()>, // TODO use expression HIR - ast_id: HirAstId, -} - -// TODO or should `Parameter` be an enum? -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub enum ParameterKind { - PositionalOnly, - Arguments, - Vararg, - KeywordOnly, - Kwarg, -} - -#[newtype_index] -pub struct ClassId; - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct Class { - name: Name, - ast_id: HirAstId, - // TODO type parameters, inheritance, decorators, members -} - -#[newtype_index] -pub struct AssignmentId; - -// This can have more than one name... -// but that means we can't implement `name()` on `ModuleItem`. - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct Assignment { - // TODO: Handle multiple names / targets - name: Name, - ast_id: HirAstId, -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct AnnotatedAssignment { - name: Name, - ast_id: HirAstId, -} - -#[newtype_index] -pub struct AnnotatedAssignmentId; - -#[newtype_index] -pub struct TypeAliasId; - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct TypeAlias { - name: Name, - ast_id: HirAstId, - parameters: Range, -} - -#[newtype_index] -pub struct TypeParameterId; - -#[derive(Debug, Clone, Eq, PartialEq)] -pub enum TypeParameter { - TypeVar(TypeParameterTypeVar), - ParamSpec(TypeParameterParamSpec), - TypeVarTuple(TypeParameterTypeVarTuple), -} - -impl TypeParameter { - pub fn ast_id(&self) -> HirAstId { - match self { - TypeParameter::TypeVar(type_var) => type_var.ast_id.upcast(), - TypeParameter::ParamSpec(param_spec) => param_spec.ast_id.upcast(), - TypeParameter::TypeVarTuple(type_var_tuple) => type_var_tuple.ast_id.upcast(), - } - } -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct TypeParameterTypeVar { - name: Name, - ast_id: HirAstId, -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct TypeParameterParamSpec { - name: Name, - ast_id: HirAstId, -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct TypeParameterTypeVarTuple { - name: Name, - ast_id: HirAstId, -} - -#[newtype_index] -pub struct GlobalId; - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct Global { - // TODO track names - ast_id: HirAstId, -} - -#[newtype_index] -pub struct NonLocalId; - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct NonLocal { - // TODO track names - ast_id: HirAstId, -} - -pub enum DefinitionId { - Function(FunctionId), - Parameter(ParameterId), - Class(ClassId), - Assignment(AssignmentId), - AnnotatedAssignment(AnnotatedAssignmentId), - Global(GlobalId), - NonLocal(NonLocalId), - TypeParameter(TypeParameterId), - TypeAlias(TypeAlias), -} - -pub enum DefinitionItem { - Function(Function), - Parameter(Parameter), - Class(Class), - Assignment(Assignment), - AnnotatedAssignment(AnnotatedAssignment), - Global(Global), - NonLocal(NonLocal), - TypeParameter(TypeParameter), - TypeAlias(TypeAlias), -} - -// The closest is rust-analyzers item-tree. It only represents "Items" which make the public interface of a module -// (it excludes any other statement or expressions). rust-analyzer uses it as the main input to the name resolution -// algorithm -// > It is the input to the name resolution algorithm, as well as to the queries defined in `adt.rs`, -// > `data.rs`, and most things in `attr.rs`. -// -// > One important purpose of this layer is to provide an "invalidation barrier" for incremental -// > computations: when typing inside an item body, the `ItemTree` of the modified file is typically -// > unaffected, so we don't have to recompute name resolution results or item data (see `data.rs`). -// -// I haven't fully figured this out but I think that this composes the "public" interface of a module? -// But maybe that's too optimistic. -// -// -#[derive(Debug, Clone, Default, Eq, PartialEq)] -pub struct Definitions { - functions: IndexVec, - parameters: IndexVec, - classes: IndexVec, - assignments: IndexVec, - annotated_assignments: IndexVec, - type_aliases: IndexVec, - type_parameters: IndexVec, - globals: IndexVec, - non_locals: IndexVec, -} - -impl Definitions { - pub fn from_module(module: &ModModule, ast_ids: &AstIds, file_id: FileId) -> Self { - let mut visitor = DefinitionsVisitor { - definitions: Definitions::default(), - ast_ids, - file_id, - }; - - visitor.visit_body(&module.body); - - visitor.definitions - } -} - -impl Index for Definitions { - type Output = Function; - - fn index(&self, index: FunctionId) -> &Self::Output { - &self.functions[index] - } -} - -impl Index for Definitions { - type Output = Parameter; - - fn index(&self, index: ParameterId) -> &Self::Output { - &self.parameters[index] - } -} - -impl Index for Definitions { - type Output = Class; - - fn index(&self, index: ClassId) -> &Self::Output { - &self.classes[index] - } -} - -impl Index for Definitions { - type Output = Assignment; - - fn index(&self, index: AssignmentId) -> &Self::Output { - &self.assignments[index] - } -} - -impl Index for Definitions { - type Output = AnnotatedAssignment; - - fn index(&self, index: AnnotatedAssignmentId) -> &Self::Output { - &self.annotated_assignments[index] - } -} - -impl Index for Definitions { - type Output = TypeAlias; - - fn index(&self, index: TypeAliasId) -> &Self::Output { - &self.type_aliases[index] - } -} - -impl Index for Definitions { - type Output = Global; - - fn index(&self, index: GlobalId) -> &Self::Output { - &self.globals[index] - } -} - -impl Index for Definitions { - type Output = NonLocal; - - fn index(&self, index: NonLocalId) -> &Self::Output { - &self.non_locals[index] - } -} - -impl Index for Definitions { - type Output = TypeParameter; - - fn index(&self, index: TypeParameterId) -> &Self::Output { - &self.type_parameters[index] - } -} - -struct DefinitionsVisitor<'a> { - definitions: Definitions, - ast_ids: &'a AstIds, - file_id: FileId, -} - -impl DefinitionsVisitor<'_> { - fn ast_id(&self, node: &N) -> HirAstId { - HirAstId { - file_id: self.file_id, - node_id: self.ast_ids.ast_id(node), - } - } - - fn lower_function_def(&mut self, function: &StmtFunctionDef) -> FunctionId { - let name = Name::new(&function.name); - - let first_type_parameter_id = self.definitions.type_parameters.next_index(); - let mut last_type_parameter_id = first_type_parameter_id; - - if let Some(type_params) = &function.type_params { - for parameter in &type_params.type_params { - let id = self.lower_type_parameter(parameter); - last_type_parameter_id = id; - } - } - - let parameters = self.lower_parameters(&function.parameters); - - self.definitions.functions.push(Function { - name, - ast_id: self.ast_id(function), - parameters, - type_parameters: first_type_parameter_id..last_type_parameter_id, - }) - } - - fn lower_parameters(&mut self, parameters: &ruff_python_ast::Parameters) -> Range { - let first_parameter_id = self.definitions.parameters.next_index(); - let mut last_parameter_id = first_parameter_id; - - for parameter in ¶meters.posonlyargs { - last_parameter_id = self.definitions.parameters.push(Parameter { - kind: ParameterKind::PositionalOnly, - name: Name::new(¶meter.parameter.name), - default: None, - ast_id: self.ast_id(¶meter.parameter), - }); - } - - if let Some(vararg) = ¶meters.vararg { - last_parameter_id = self.definitions.parameters.push(Parameter { - kind: ParameterKind::Vararg, - name: Name::new(&vararg.name), - default: None, - ast_id: self.ast_id(vararg), - }); - } - - for parameter in ¶meters.kwonlyargs { - last_parameter_id = self.definitions.parameters.push(Parameter { - kind: ParameterKind::KeywordOnly, - name: Name::new(¶meter.parameter.name), - default: None, - ast_id: self.ast_id(¶meter.parameter), - }); - } - - if let Some(kwarg) = ¶meters.kwarg { - last_parameter_id = self.definitions.parameters.push(Parameter { - kind: ParameterKind::KeywordOnly, - name: Name::new(&kwarg.name), - default: None, - ast_id: self.ast_id(kwarg), - }); - } - - first_parameter_id..last_parameter_id - } - - fn lower_class_def(&mut self, class: &StmtClassDef) -> ClassId { - let name = Name::new(&class.name); - - self.definitions.classes.push(Class { - name, - ast_id: self.ast_id(class), - }) - } - - fn lower_assignment(&mut self, assignment: &StmtAssign) { - // FIXME handle multiple names - if let Some(Expr::Name(name)) = assignment.targets.first() { - self.definitions.assignments.push(Assignment { - name: Name::new(&name.id), - ast_id: self.ast_id(assignment), - }); - } - } - - fn lower_annotated_assignment(&mut self, annotated_assignment: &StmtAnnAssign) { - if let Expr::Name(name) = &*annotated_assignment.target { - self.definitions - .annotated_assignments - .push(AnnotatedAssignment { - name: Name::new(&name.id), - ast_id: self.ast_id(annotated_assignment), - }); - } - } - - fn lower_type_alias(&mut self, type_alias: &StmtTypeAlias) { - if let Expr::Name(name) = &*type_alias.name { - let name = Name::new(&name.id); - - let lower_parameters_id = self.definitions.type_parameters.next_index(); - let mut last_parameter_id = lower_parameters_id; - - if let Some(type_params) = &type_alias.type_params { - for type_parameter in &type_params.type_params { - let id = self.lower_type_parameter(type_parameter); - last_parameter_id = id; - } - } - - self.definitions.type_aliases.push(TypeAlias { - name, - ast_id: self.ast_id(type_alias), - parameters: lower_parameters_id..last_parameter_id, - }); - } - } - - fn lower_type_parameter(&mut self, type_parameter: &TypeParam) -> TypeParameterId { - match type_parameter { - TypeParam::TypeVar(type_var) => { - self.definitions - .type_parameters - .push(TypeParameter::TypeVar(TypeParameterTypeVar { - name: Name::new(&type_var.name), - ast_id: self.ast_id(type_var), - })) - } - TypeParam::ParamSpec(param_spec) => { - self.definitions - .type_parameters - .push(TypeParameter::ParamSpec(TypeParameterParamSpec { - name: Name::new(¶m_spec.name), - ast_id: self.ast_id(param_spec), - })) - } - TypeParam::TypeVarTuple(type_var_tuple) => { - self.definitions - .type_parameters - .push(TypeParameter::TypeVarTuple(TypeParameterTypeVarTuple { - name: Name::new(&type_var_tuple.name), - ast_id: self.ast_id(type_var_tuple), - })) - } - } - } - - fn lower_import(&mut self, _import: &StmtImport) { - // TODO - } - - fn lower_import_from(&mut self, _import_from: &StmtImportFrom) { - // TODO - } - - fn lower_global(&mut self, global: &StmtGlobal) -> GlobalId { - self.definitions.globals.push(Global { - ast_id: self.ast_id(global), - }) - } - - fn lower_non_local(&mut self, non_local: &StmtNonlocal) -> NonLocalId { - self.definitions.non_locals.push(NonLocal { - ast_id: self.ast_id(non_local), - }) - } - - fn lower_except_handler(&mut self, _except_handler: &ExceptHandlerExceptHandler) { - // TODO - } - - fn lower_with_item(&mut self, _with_item: &WithItem) { - // TODO - } - - fn lower_match_case(&mut self, _match_case: &MatchCase) { - // TODO - } -} - -impl PreorderVisitor<'_> for DefinitionsVisitor<'_> { - fn visit_stmt(&mut self, stmt: &Stmt) { - match stmt { - // Definition statements - Stmt::FunctionDef(definition) => { - self.lower_function_def(definition); - self.visit_body(&definition.body); - } - Stmt::ClassDef(definition) => { - self.lower_class_def(definition); - self.visit_body(&definition.body); - } - Stmt::Assign(assignment) => { - self.lower_assignment(assignment); - } - Stmt::AnnAssign(annotated_assignment) => { - self.lower_annotated_assignment(annotated_assignment); - } - Stmt::TypeAlias(type_alias) => { - self.lower_type_alias(type_alias); - } - - Stmt::Import(import) => self.lower_import(import), - Stmt::ImportFrom(import_from) => self.lower_import_from(import_from), - Stmt::Global(global) => { - self.lower_global(global); - } - Stmt::Nonlocal(non_local) => { - self.lower_non_local(non_local); - } - - // Visit the compound statement bodies because they can contain other definitions. - Stmt::For(_) - | Stmt::While(_) - | Stmt::If(_) - | Stmt::With(_) - | Stmt::Match(_) - | Stmt::Try(_) => { - preorder::walk_stmt(self, stmt); - } - - // Skip over simple statements because they can't contain any other definitions. - Stmt::Return(_) - | Stmt::Delete(_) - | Stmt::AugAssign(_) - | Stmt::Raise(_) - | Stmt::Assert(_) - | Stmt::Expr(_) - | Stmt::Pass(_) - | Stmt::Break(_) - | Stmt::Continue(_) - | Stmt::IpyEscapeCommand(_) => { - // No op - } - } - } - - fn visit_expr(&mut self, _: &'_ Expr) {} - - fn visit_decorator(&mut self, _decorator: &'_ Decorator) {} - - fn visit_except_handler(&mut self, except_handler: &'_ ExceptHandler) { - match except_handler { - ExceptHandler::ExceptHandler(except_handler) => { - self.lower_except_handler(except_handler); - } - } - } - - fn visit_with_item(&mut self, with_item: &'_ WithItem) { - self.lower_with_item(with_item); - } - - fn visit_match_case(&mut self, match_case: &'_ MatchCase) { - self.lower_match_case(match_case); - self.visit_body(&match_case.body); - } -} diff --git a/crates/red_knot/src/lib.rs b/crates/red_knot/src/lib.rs index b04d8ed8a5..7d1629c24b 100644 --- a/crates/red_knot/src/lib.rs +++ b/crates/red_knot/src/lib.rs @@ -1,68 +1,52 @@ -use std::hash::BuildHasherDefault; -use std::path::{Path, PathBuf}; +use rustc_hash::FxHashSet; -use rustc_hash::{FxHashSet, FxHasher}; +use ruff_db::file_system::{FileSystemPath, FileSystemPathBuf}; +use ruff_db::vfs::VfsFile; -use crate::files::FileId; +use crate::db::Jar; -pub mod ast_ids; -pub mod cache; -pub mod cancellation; pub mod db; -pub mod files; -pub mod hir; pub mod lint; -pub mod module; -mod parse; pub mod program; -mod semantic; -pub mod source; pub mod watch; -pub(crate) type FxDashMap = dashmap::DashMap>; -#[allow(unused)] -pub(crate) type FxDashSet = dashmap::DashSet>; -pub(crate) type FxIndexSet = indexmap::set::IndexSet>; - #[derive(Debug, Clone)] pub struct Workspace { - /// TODO this should be a resolved path. We should probably use a newtype wrapper that guarantees that - /// PATH is a UTF-8 path and is normalized. - root: PathBuf, + root: FileSystemPathBuf, /// The files that are open in the workspace. /// /// * Editor: The files that are actively being edited in the editor (the user has a tab open with the file). /// * CLI: The resolved files passed as arguments to the CLI. - open_files: FxHashSet, + open_files: FxHashSet, } impl Workspace { - pub fn new(root: PathBuf) -> Self { + pub fn new(root: FileSystemPathBuf) -> Self { Self { root, open_files: FxHashSet::default(), } } - pub fn root(&self) -> &Path { + pub fn root(&self) -> &FileSystemPath { self.root.as_path() } // TODO having the content in workspace feels wrong. - pub fn open_file(&mut self, file_id: FileId) { + pub fn open_file(&mut self, file_id: VfsFile) { self.open_files.insert(file_id); } - pub fn close_file(&mut self, file_id: FileId) { + pub fn close_file(&mut self, file_id: VfsFile) { self.open_files.remove(&file_id); } // TODO introduce an `OpenFile` type instead of using an anonymous tuple. - pub fn open_files(&self) -> impl Iterator + '_ { + pub fn open_files(&self) -> impl Iterator + '_ { self.open_files.iter().copied() } - pub fn is_file_open(&self, file_id: FileId) -> bool { + pub fn is_file_open(&self, file_id: VfsFile) -> bool { self.open_files.contains(&file_id) } } diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index a801bf9196..e32a70424e 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -1,61 +1,59 @@ -use red_knot_module_resolver::ModuleName; use std::cell::RefCell; -use std::ops::{Deref, DerefMut}; -use std::sync::Arc; +use std::ops::Deref; use std::time::Duration; -use ruff_python_ast::visitor::Visitor; -use ruff_python_ast::{ModModule, StringLiteral}; -use ruff_python_parser::Parsed; +use tracing::trace_span; -use crate::cache::KeyValueCache; -use crate::db::{LintDb, LintJar, QueryResult}; -use crate::files::FileId; -use crate::module::resolve_module; -use crate::parse::parse; -use crate::semantic::{infer_definition_type, infer_symbol_public_type, Type}; -use crate::semantic::{ - resolve_global_symbol, semantic_index, Definition, GlobalSymbolId, SemanticIndex, SymbolId, -}; -use crate::source::{source_text, Source}; +use red_knot_module_resolver::ModuleName; +use red_knot_python_semantic::types::Type; +use red_knot_python_semantic::{HasTy, SemanticModel}; +use ruff_db::parsed::{parsed_module, ParsedModule}; +use ruff_db::source::{source_text, SourceText}; +use ruff_db::vfs::VfsFile; +use ruff_python_ast as ast; +use ruff_python_ast::visitor::{walk_stmt, Visitor}; -#[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult { - let lint_jar: &LintJar = db.jar()?; - let storage = &lint_jar.lint_syntax; +use crate::db::Db; +/// Workaround query to test for if the computation should be cancelled. +/// Ideally, push for Salsa to expose an API for testing if cancellation was requested. +#[salsa::tracked] +#[allow(unused_variables)] +pub(crate) fn unwind_if_cancelled(db: &dyn Db) {} + +#[salsa::tracked(return_ref)] +pub(crate) fn lint_syntax(db: &dyn Db, file_id: VfsFile) -> Diagnostics { #[allow(clippy::print_stdout)] if std::env::var("RED_KNOT_SLOW_LINT").is_ok() { for i in 0..10 { - db.cancelled()?; + unwind_if_cancelled(db); + println!("RED_KNOT_SLOW_LINT is set, sleeping for {i}/10 seconds"); std::thread::sleep(Duration::from_secs(1)); } } - storage.get(&file_id, |file_id| { - let mut diagnostics = Vec::new(); + let mut diagnostics = Vec::new(); - let source = source_text(db.upcast(), *file_id)?; - lint_lines(source.text(), &mut diagnostics); + let source = source_text(db.upcast(), file_id); + lint_lines(&source, &mut diagnostics); - let parsed = parse(db.upcast(), *file_id)?; + let parsed = parsed_module(db.upcast(), file_id); - if parsed.errors().is_empty() { - let ast = parsed.syntax(); + if parsed.errors().is_empty() { + let ast = parsed.syntax(); - let mut visitor = SyntaxLintVisitor { - diagnostics, - source: source.text(), - }; - visitor.visit_body(&ast.body); - diagnostics = visitor.diagnostics; - } else { - diagnostics.extend(parsed.errors().iter().map(std::string::ToString::to_string)); - } + let mut visitor = SyntaxLintVisitor { + diagnostics, + source: &source, + }; + visitor.visit_body(&ast.body); + diagnostics = visitor.diagnostics; + } else { + diagnostics.extend(parsed.errors().iter().map(ToString::to_string)); + } - Ok(Diagnostics::from(diagnostics)) - }) + Diagnostics::from(diagnostics) } fn lint_lines(source: &str, diagnostics: &mut Vec) { @@ -75,179 +73,127 @@ fn lint_lines(source: &str, diagnostics: &mut Vec) { } } -#[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn lint_semantic(db: &dyn LintDb, file_id: FileId) -> QueryResult { - let lint_jar: &LintJar = db.jar()?; - let storage = &lint_jar.lint_semantic; +#[salsa::tracked(return_ref)] +pub(crate) fn lint_semantic(db: &dyn Db, file_id: VfsFile) -> Diagnostics { + let _span = trace_span!("lint_semantic", ?file_id).entered(); - storage.get(&file_id, |file_id| { - let source = source_text(db.upcast(), *file_id)?; - let parsed = parse(db.upcast(), *file_id)?; - let semantic_index = semantic_index(db.upcast(), *file_id)?; + let source = source_text(db.upcast(), file_id); + let parsed = parsed_module(db.upcast(), file_id); + let semantic = SemanticModel::new(db.upcast(), file_id); - let context = SemanticLintContext { - file_id: *file_id, - source, - parsed: &parsed, - semantic_index, - db, - diagnostics: RefCell::new(Vec::new()), - }; - - lint_unresolved_imports(&context)?; - lint_bad_overrides(&context)?; - - Ok(Diagnostics::from(context.diagnostics.take())) - }) -} - -fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> { - // TODO: Consider iterating over the dependencies (imports) only instead of all definitions. - for (symbol, definition) in context.semantic_index().symbol_table().all_definitions() { - match definition { - Definition::Import(import) => { - let ty = context.infer_symbol_public_type(symbol)?; - - if ty.is_unknown() { - context.push_diagnostic(format!("Unresolved module {}", import.module)); - } - } - Definition::ImportFrom(import) => { - let ty = context.infer_symbol_public_type(symbol)?; - - if ty.is_unknown() { - let module_name = import.module().map(Deref::deref).unwrap_or_default(); - let message = if import.level() > 0 { - format!( - "Unresolved relative import '{}' from {}{}", - import.name(), - ".".repeat(import.level() as usize), - module_name - ) - } else { - format!( - "Unresolved import '{}' from '{}'", - import.name(), - module_name - ) - }; - - context.push_diagnostic(message); - } - } - _ => {} - } + if !parsed.is_valid() { + return Diagnostics::Empty; } - Ok(()) -} - -fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> { - // TODO we should have a special marker on the real typing module (from typeshed) so if you - // have your own "typing" module in your project, we don't consider it THE typing module (and - // same for other stdlib modules that our lint rules care about) - let Some(typing_override) = - context.resolve_global_symbol(&ModuleName::new_static("typing").unwrap(), "override")? - else { - // TODO once we bundle typeshed, this should be unreachable!() - return Ok(()); + let context = SemanticLintContext { + source, + parsed, + semantic, + diagnostics: RefCell::new(Vec::new()), }; - // TODO we should maybe index definitions by type instead of iterating all, or else iterate all - // just once, match, and branch to all lint rules that care about a type of definition - for (symbol, definition) in context.semantic_index().symbol_table().all_definitions() { - if !matches!(definition, Definition::FunctionDef(_)) { - continue; + SemanticVisitor { context: &context }.visit_body(parsed.suite()); + + Diagnostics::from(context.diagnostics.take()) +} + +fn lint_unresolved_imports(context: &SemanticLintContext, import: AnyImportRef) { + match import { + AnyImportRef::Import(import) => { + for alias in &import.names { + let ty = alias.ty(&context.semantic); + + if ty.is_unknown() { + context.push_diagnostic(format!("Unresolved import '{}'", &alias.name)); + } + } } - let ty = infer_definition_type( - context.db.upcast(), - GlobalSymbolId { - file_id: context.file_id, - symbol_id: symbol, - }, - definition.clone(), - )?; - let Type::Function(func) = ty else { - unreachable!("type of a FunctionDef should always be a Function"); + AnyImportRef::ImportFrom(import) => { + for alias in &import.names { + let ty = alias.ty(&context.semantic); + + if ty.is_unknown() { + context.push_diagnostic(format!("Unresolved import '{}'", &alias.name)); + } + } + } + } +} + +fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) { + let semantic = &context.semantic; + let typing_context = semantic.typing_context(); + + // TODO we should have a special marker on the real typing module (from typeshed) so if you + // have your own "typing" module in your project, we don't consider it THE typing module (and + // same for other stdlib modules that our lint rules care about) + let Some(typing) = semantic.resolve_module(ModuleName::new("typing").unwrap()) else { + return; + }; + + let Some(typing_override) = semantic.public_symbol(&typing, "override") else { + return; + }; + + let override_ty = semantic.public_symbol_ty(typing_override); + + let Type::Class(class_ty) = class.ty(semantic) else { + return; + }; + + for function in class + .body + .iter() + .filter_map(|stmt| stmt.as_function_def_stmt()) + { + let Type::Function(ty) = function.ty(semantic) else { + return; }; - let Some(class) = func.get_containing_class(context.db.upcast())? else { - // not a method of a class - continue; - }; - if func.has_decorator(context.db.upcast(), typing_override)? { - let method_name = func.name(context.db.upcast())?; - if class - .get_super_class_member(context.db.upcast(), &method_name)? + + if ty.has_decorator(&typing_context, override_ty) { + let method_name = ty.name(&typing_context); + if class_ty + .inherited_class_member(&typing_context, method_name) .is_none() { // TODO should have a qualname() method to support nested classes context.push_diagnostic( format!( "Method {}.{} is decorated with `typing.override` but does not override any base class method", - class.name(context.db.upcast())?, + class_ty.name(&typing_context), method_name, )); } } } - Ok(()) } -pub struct SemanticLintContext<'a> { - file_id: FileId, - source: Source, - parsed: &'a Parsed, - semantic_index: Arc, - db: &'a dyn LintDb, +pub(crate) struct SemanticLintContext<'a> { + source: SourceText, + parsed: &'a ParsedModule, + semantic: SemanticModel<'a>, diagnostics: RefCell>, } -impl<'a> SemanticLintContext<'a> { - pub fn source_text(&self) -> &str { - self.source.text() +impl<'db> SemanticLintContext<'db> { + #[allow(unused)] + pub(crate) fn source_text(&self) -> &str { + self.source.as_str() } - pub fn file_id(&self) -> FileId { - self.file_id - } - - pub fn ast(&self) -> &'a ModModule { + #[allow(unused)] + pub(crate) fn ast(&self) -> &'db ast::ModModule { self.parsed.syntax() } - pub fn semantic_index(&self) -> &SemanticIndex { - &self.semantic_index - } - - pub fn infer_symbol_public_type(&self, symbol_id: SymbolId) -> QueryResult { - infer_symbol_public_type( - self.db.upcast(), - GlobalSymbolId { - file_id: self.file_id, - symbol_id, - }, - ) - } - - pub fn push_diagnostic(&self, diagnostic: String) { + pub(crate) fn push_diagnostic(&self, diagnostic: String) { self.diagnostics.borrow_mut().push(diagnostic); } - pub fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator) { + #[allow(unused)] + pub(crate) fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator) { self.diagnostics.get_mut().extend(diagnostics); } - - pub fn resolve_global_symbol( - &self, - module: &ModuleName, - symbol_name: &str, - ) -> QueryResult> { - let Some(module) = resolve_module(self.db.upcast(), module)? else { - return Ok(None); - }; - - resolve_global_symbol(self.db.upcast(), module, symbol_name) - } } #[derive(Debug)] @@ -257,7 +203,7 @@ struct SyntaxLintVisitor<'a> { } impl Visitor<'_> for SyntaxLintVisitor<'_> { - fn visit_string_literal(&mut self, string_literal: &'_ StringLiteral) { + fn visit_string_literal(&mut self, string_literal: &'_ ast::StringLiteral) { // A very naive implementation of use double quotes let text = &self.source[string_literal.range]; @@ -268,10 +214,33 @@ impl Visitor<'_> for SyntaxLintVisitor<'_> { } } -#[derive(Debug, Clone)] +struct SemanticVisitor<'a> { + context: &'a SemanticLintContext<'a>, +} + +impl Visitor<'_> for SemanticVisitor<'_> { + fn visit_stmt(&mut self, stmt: &ast::Stmt) { + match stmt { + ast::Stmt::ClassDef(class) => { + lint_bad_override(self.context, class); + } + ast::Stmt::Import(import) => { + lint_unresolved_imports(self.context, AnyImportRef::Import(import)); + } + ast::Stmt::ImportFrom(import) => { + lint_unresolved_imports(self.context, AnyImportRef::ImportFrom(import)); + } + _ => {} + } + + walk_stmt(self, stmt); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Diagnostics { Empty, - List(Arc>), + List(Vec), } impl Diagnostics { @@ -295,41 +264,13 @@ impl From> for Diagnostics { if value.is_empty() { Diagnostics::Empty } else { - Diagnostics::List(Arc::new(value)) + Diagnostics::List(value) } } } -#[derive(Default, Debug)] -pub struct LintSyntaxStorage(KeyValueCache); - -impl Deref for LintSyntaxStorage { - type Target = KeyValueCache; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for LintSyntaxStorage { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -#[derive(Default, Debug)] -pub struct LintSemanticStorage(KeyValueCache); - -impl Deref for LintSemanticStorage { - type Target = KeyValueCache; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for LintSemanticStorage { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } +#[derive(Copy, Clone, Debug)] +enum AnyImportRef<'a> { + Import(&'a ast::StmtImport), + ImportFrom(&'a ast::StmtImportFrom), } diff --git a/crates/red_knot/src/main.rs b/crates/red_knot/src/main.rs index da6075f4df..0a34e38dd2 100644 --- a/crates/red_knot/src/main.rs +++ b/crates/red_knot/src/main.rs @@ -1,9 +1,7 @@ -#![allow(clippy::dbg_macro)] - -use std::path::Path; use std::sync::Mutex; use crossbeam::channel as crossbeam_channel; +use salsa::ParallelDatabase; use tracing::subscriber::Interest; use tracing::{Level, Metadata}; use tracing_subscriber::filter::LevelFilter; @@ -11,15 +9,21 @@ use tracing_subscriber::layer::{Context, Filter, SubscriberExt}; use tracing_subscriber::{Layer, Registry}; use tracing_tree::time::Uptime; -use red_knot::db::{HasJar, ParallelDatabase, QueryError, SourceDb, SourceJar}; -use red_knot::module::{set_module_search_paths, ModuleResolutionInputs}; -use red_knot::program::check::ExecutionMode; use red_knot::program::{FileWatcherChange, Program}; use red_knot::watch::FileWatcher; use red_knot::Workspace; +use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings}; +use ruff_db::file_system::{FileSystem, FileSystemPath, OsFileSystem}; +use ruff_db::vfs::system_path_to_file; -#[allow(clippy::print_stdout, clippy::unnecessary_wraps, clippy::print_stderr)] -fn main() -> anyhow::Result<()> { +#[allow( + clippy::print_stdout, + clippy::unnecessary_wraps, + clippy::print_stderr, + clippy::dbg_macro +)] +pub fn main() -> anyhow::Result<()> { + countme::enable(true); setup_tracing(); let arguments: Vec<_> = std::env::args().collect(); @@ -29,34 +33,39 @@ fn main() -> anyhow::Result<()> { return Err(anyhow::anyhow!("Invalid arguments")); } - let entry_point = Path::new(&arguments[1]); + let fs = OsFileSystem; + let entry_point = FileSystemPath::new(&arguments[1]); - if !entry_point.exists() { + if !fs.exists(entry_point) { eprintln!("The entry point does not exist."); return Err(anyhow::anyhow!("Invalid arguments")); } - if !entry_point.is_file() { + if !fs.is_file(entry_point) { eprintln!("The entry point is not a file."); return Err(anyhow::anyhow!("Invalid arguments")); } + let entry_point = entry_point.to_path_buf(); + let workspace_folder = entry_point.parent().unwrap(); let workspace = Workspace::new(workspace_folder.to_path_buf()); let workspace_search_path = workspace.root().to_path_buf(); - let search_paths = ModuleResolutionInputs { - extra_paths: vec![], - workspace_root: workspace_search_path, - site_packages: None, - custom_typeshed: None, - }; + let mut program = Program::new(workspace, fs); - let mut program = Program::new(workspace); - set_module_search_paths(&mut program, search_paths); + set_module_resolution_settings( + &mut program, + ModuleResolutionSettings { + extra_paths: vec![], + workspace_root: workspace_search_path, + site_packages: None, + custom_typeshed: None, + }, + ); - let entry_id = program.file_id(entry_point); + let entry_id = system_path_to_file(&program, entry_point.clone()).unwrap(); program.workspace_mut().open_file(entry_id); let (main_loop, main_loop_cancellation_token) = MainLoop::new(); @@ -78,14 +87,11 @@ fn main() -> anyhow::Result<()> { file_changes_notifier.notify(changes); })?; - file_watcher.watch_folder(workspace_folder)?; + file_watcher.watch_folder(workspace_folder.as_std_path())?; main_loop.run(&mut program); - let source_jar: &SourceJar = program.jar().unwrap(); - - dbg!(source_jar.parsed.statistics()); - dbg!(source_jar.sources.statistics()); + println!("{}", countme::get_all()); Ok(()) } @@ -127,6 +133,7 @@ impl MainLoop { } } + #[allow(clippy::print_stderr)] fn run(self, program: &mut Program) { self.orchestrator_sender .send(OrchestratorMessage::Run) @@ -142,8 +149,8 @@ impl MainLoop { // Spawn a new task that checks the program. This needs to be done in a separate thread // to prevent blocking the main loop here. - rayon::spawn(move || match program.check(ExecutionMode::ThreadPool) { - Ok(result) => { + rayon::spawn(move || { + if let Ok(result) = program.check() { sender .send(OrchestratorMessage::CheckProgramCompleted { diagnostics: result, @@ -151,7 +158,6 @@ impl MainLoop { }) .unwrap(); } - Err(QueryError::Cancelled) => {} }); } MainLoopMessage::ApplyChanges(changes) => { @@ -159,9 +165,11 @@ impl MainLoop { program.apply_changes(changes); } MainLoopMessage::CheckCompleted(diagnostics) => { - dbg!(diagnostics); + eprintln!("{}", diagnostics.join("\n")); + eprintln!("{}", countme::get_all()); } MainLoopMessage::Exit => { + eprintln!("{}", countme::get_all()); return; } } @@ -210,6 +218,7 @@ struct Orchestrator { } impl Orchestrator { + #[allow(clippy::print_stderr)] fn run(&mut self) { while let Ok(message) = self.receiver.recv() { match message { diff --git a/crates/red_knot/src/module.rs b/crates/red_knot/src/module.rs deleted file mode 100644 index 3e7672b899..0000000000 --- a/crates/red_knot/src/module.rs +++ /dev/null @@ -1,1239 +0,0 @@ -use std::fmt::Formatter; -use std::ops::Deref; -use std::path::{Path, PathBuf}; -use std::sync::atomic::AtomicU32; -use std::sync::Arc; - -use dashmap::mapref::entry::Entry; - -use red_knot_module_resolver::{ModuleKind, ModuleName}; - -use crate::db::{QueryResult, SemanticDb, SemanticJar}; -use crate::files::FileId; -use crate::semantic::Dependency; -use crate::FxDashMap; - -/// Representation of a Python module. -/// -/// The inner type wrapped by this struct is a unique identifier for the module -/// that is used by the struct's methods to lazily query information about the module. -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct Module(u32); - -impl Module { - /// Return the absolute name of the module (e.g. `foo.bar`) - pub fn name(&self, db: &dyn SemanticDb) -> QueryResult { - let jar: &SemanticJar = db.jar()?; - let modules = &jar.module_resolver; - - Ok(modules.modules.get(self).unwrap().name.clone()) - } - - /// Return the path to the source code that defines this module - pub fn path(&self, db: &dyn SemanticDb) -> QueryResult { - let jar: &SemanticJar = db.jar()?; - let modules = &jar.module_resolver; - - Ok(modules.modules.get(self).unwrap().path.clone()) - } - - /// Determine whether this module is a single-file module or a package - pub fn kind(&self, db: &dyn SemanticDb) -> QueryResult { - let jar: &SemanticJar = db.jar()?; - let modules = &jar.module_resolver; - - Ok(modules.modules.get(self).unwrap().kind) - } - - /// Attempt to resolve a dependency of this module to an absolute [`ModuleName`]. - /// - /// A dependency could be either absolute (e.g. the `foo` dependency implied by `from foo import bar`) - /// or relative to this module (e.g. the `.foo` dependency implied by `from .foo import bar`) - /// - /// - Returns an error if the query failed. - /// - Returns `Ok(None)` if the query succeeded, - /// but the dependency refers to a module that does not exist. - /// - Returns `Ok(Some(ModuleName))` if the query succeeded, - /// and the dependency refers to a module that exists. - pub fn resolve_dependency( - &self, - db: &dyn SemanticDb, - dependency: &Dependency, - ) -> QueryResult> { - let (level, module) = match dependency { - Dependency::Module(module) => return Ok(Some(module.clone())), - Dependency::Relative { level, module } => (*level, module.as_deref()), - }; - - let name = self.name(db)?; - let kind = self.kind(db)?; - - let mut components = name.components().peekable(); - - let start = match kind { - // `.` resolves to the enclosing package - ModuleKind::Module => 0, - // `.` resolves to the current package - ModuleKind::Package => 1, - }; - - // Skip over the relative parts. - for _ in start..level.get() { - if components.next_back().is_none() { - return Ok(None); - } - } - - let mut name = String::new(); - - for part in components.chain(module) { - if !name.is_empty() { - name.push('.'); - } - - name.push_str(part); - } - - Ok(ModuleName::new(&name)) - } -} - -/// A search path in which to search modules. -/// Corresponds to a path in [`sys.path`](https://docs.python.org/3/library/sys_path_init.html) at runtime. -/// -/// Cloning a search path is cheap because it's an `Arc`. -#[derive(Clone, PartialEq, Eq)] -pub struct ModuleSearchPath { - inner: Arc, -} - -impl ModuleSearchPath { - pub fn new(path: PathBuf, kind: ModuleSearchPathKind) -> Self { - Self { - inner: Arc::new(ModuleSearchPathInner { path, kind }), - } - } - - /// Determine whether this is a first-party, third-party or standard-library search path - pub fn kind(&self) -> ModuleSearchPathKind { - self.inner.kind - } - - /// Return the location of the search path on the file system - pub fn path(&self) -> &Path { - &self.inner.path - } -} - -impl std::fmt::Debug for ModuleSearchPath { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - self.inner.fmt(f) - } -} - -#[derive(Debug, Eq, PartialEq)] -struct ModuleSearchPathInner { - path: PathBuf, - kind: ModuleSearchPathKind, -} - -/// Enumeration of the different kinds of search paths type checkers are expected to support. -/// -/// N.B. Although we don't implement `Ord` for this enum, they are ordered in terms of the -/// priority that we want to give these modules when resolving them. -/// This is roughly [the order given in the typing spec], but typeshed's stubs -/// for the standard library are moved higher up to match Python's semantics at runtime. -/// -/// [the order given in the typing spec]: https://typing.readthedocs.io/en/latest/spec/distributing.html#import-resolution-ordering -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, is_macro::Is)] -pub enum ModuleSearchPathKind { - /// "Extra" paths provided by the user in a config file, env var or CLI flag. - /// E.g. mypy's `MYPYPATH` env var, or pyright's `stubPath` configuration setting - Extra, - - /// Files in the project we're directly being invoked on - FirstParty, - - /// The `stdlib` directory of typeshed (either vendored or custom) - StandardLibrary, - - /// Stubs or runtime modules installed in site-packages - SitePackagesThirdParty, - - /// Vendored third-party stubs from typeshed - VendoredThirdParty, -} - -#[derive(Debug, Eq, PartialEq)] -pub struct ModuleData { - name: ModuleName, - path: ModulePath, - kind: ModuleKind, -} - -////////////////////////////////////////////////////// -// Queries -////////////////////////////////////////////////////// - -/// Resolves a module name to a module. -/// -/// TODO: This would not work with Salsa because `ModuleName` isn't an ingredient -/// and, therefore, cannot be used as part of a query. -/// For this to work with salsa, it would be necessary to intern all `ModuleName`s. -#[tracing::instrument(level = "debug", skip(db))] -pub fn resolve_module(db: &dyn SemanticDb, name: &ModuleName) -> QueryResult> { - let jar: &SemanticJar = db.jar()?; - let modules = &jar.module_resolver; - - let entry = modules.by_name.entry(name.clone()); - - match entry { - Entry::Occupied(entry) => Ok(Some(*entry.get())), - Entry::Vacant(entry) => { - let Some((root_path, absolute_path, kind)) = resolve_name(name, &modules.search_paths) - else { - return Ok(None); - }; - let Ok(normalized) = absolute_path.canonicalize() else { - return Ok(None); - }; - - let file_id = db.file_id(&normalized); - let path = ModulePath::new(root_path.clone(), file_id); - - let module = Module( - modules - .next_module_id - .fetch_add(1, std::sync::atomic::Ordering::Relaxed), - ); - - modules.modules.insert( - module, - Arc::from(ModuleData { - name: name.clone(), - path, - kind, - }), - ); - - // A path can map to multiple modules because of symlinks: - // ``` - // foo.py - // bar.py -> foo.py - // ``` - // Here, both `foo` and `bar` resolve to the same module but through different paths. - // That's why we need to insert the absolute path and not the normalized path here. - let absolute_file_id = if absolute_path == normalized { - file_id - } else { - db.file_id(&absolute_path) - }; - - modules.by_file.insert(absolute_file_id, module); - - entry.insert_entry(module); - - Ok(Some(module)) - } - } -} - -/// Resolves the module for the given path. -/// -/// Returns `None` if the path is not a module locatable via `sys.path`. -#[tracing::instrument(level = "debug", skip(db))] -pub fn path_to_module(db: &dyn SemanticDb, path: &Path) -> QueryResult> { - let file = db.file_id(path); - file_to_module(db, file) -} - -/// Resolves the module for the file with the given id. -/// -/// Returns `None` if the file is not a module locatable via `sys.path`. -#[tracing::instrument(level = "debug", skip(db))] -pub fn file_to_module(db: &dyn SemanticDb, file: FileId) -> QueryResult> { - let jar: &SemanticJar = db.jar()?; - let modules = &jar.module_resolver; - - if let Some(existing) = modules.by_file.get(&file) { - return Ok(Some(*existing)); - } - - let path = db.file_path(file); - - debug_assert!(path.is_absolute()); - - let Some((root_path, relative_path)) = modules.search_paths.iter().find_map(|root| { - let relative_path = path.strip_prefix(root.path()).ok()?; - Some((root.clone(), relative_path)) - }) else { - return Ok(None); - }; - - let Some(module_name) = from_relative_path(relative_path) else { - return Ok(None); - }; - - // Resolve the module name to see if Python would resolve the name to the same path. - // If it doesn't, then that means that multiple modules have the same in different - // root paths, but that the module corresponding to the past path is in a lower priority search path, - // in which case we ignore it. - let Some(module) = resolve_module(db, &module_name)? else { - return Ok(None); - }; - let module_path = module.path(db)?; - - if module_path.root() == &root_path { - let Ok(normalized) = path.canonicalize() else { - return Ok(None); - }; - let interned_normalized = db.file_id(&normalized); - - if interned_normalized != module_path.file() { - // This path is for a module with the same name but with a different precedence. For example: - // ``` - // src/foo.py - // src/foo/__init__.py - // ``` - // The module name of `src/foo.py` is `foo`, but the module loaded by Python is `src/foo/__init__.py`. - // That means we need to ignore `src/foo.py` even though it resolves to the same module name. - return Ok(None); - } - - // Path has been inserted by `resolved` - Ok(Some(module)) - } else { - // This path is for a module with the same name but in a module search path with a lower priority. - // Ignore it. - Ok(None) - } -} - -fn from_relative_path(path: &Path) -> Option { - let path = if path.ends_with("__init__.py") || path.ends_with("__init__.pyi") { - path.parent()? - } else { - path - }; - - let name = if let Some(parent) = path.parent() { - let mut name = String::with_capacity(path.to_str().unwrap().len()); - - for component in parent.components() { - name.push_str(component.as_os_str().to_str()?); - name.push('.'); - } - - // SAFETY: Unwrap is safe here or `parent` would have returned `None`. - name.push_str(path.file_stem().unwrap().to_str().unwrap()); - - name - } else { - path.file_stem()?.to_str().unwrap().to_string() - }; - - ModuleName::new(&name) -} - -////////////////////////////////////////////////////// -// Mutations -////////////////////////////////////////////////////// - -/// Changes the module search paths to `search_paths`. -pub fn set_module_search_paths(db: &mut dyn SemanticDb, search_paths: ModuleResolutionInputs) { - let jar: &mut SemanticJar = db.jar_mut(); - - jar.module_resolver = ModuleResolver::new(search_paths.into_ordered_search_paths()); -} - -/// Struct for holding the various paths that are put together -/// to create an `OrderedSearchPatsh` instance -/// -/// - `extra_paths` is a list of user-provided paths -/// that should take first priority in the module resolution. -/// Examples in other type checkers are mypy's MYPYPATH environment variable, -/// or pyright's stubPath configuration setting. -/// - `workspace_root` is the root of the workspace, -/// used for finding first-party modules -/// - `site-packages` is the path to the user's `site-packages` directory, -/// where third-party packages from ``PyPI`` are installed -/// - `custom_typeshed` is a path to standard-library typeshed stubs. -/// Currently this has to be a directory that exists on disk. -/// (TODO: fall back to vendored stubs if no custom directory is provided.) -#[derive(Debug)] -pub struct ModuleResolutionInputs { - pub extra_paths: Vec, - pub workspace_root: PathBuf, - pub site_packages: Option, - pub custom_typeshed: Option, -} - -impl ModuleResolutionInputs { - /// Implementation of PEP 561's module resolution order - /// (with some small, deliberate, differences) - fn into_ordered_search_paths(self) -> OrderedSearchPaths { - let ModuleResolutionInputs { - extra_paths, - workspace_root, - site_packages, - custom_typeshed, - } = self; - - OrderedSearchPaths( - extra_paths - .into_iter() - .map(|path| ModuleSearchPath::new(path, ModuleSearchPathKind::Extra)) - .chain(std::iter::once(ModuleSearchPath::new( - workspace_root, - ModuleSearchPathKind::FirstParty, - ))) - // TODO fallback to vendored typeshed stubs if no custom typeshed directory is provided by the user - .chain(custom_typeshed.into_iter().map(|path| { - ModuleSearchPath::new( - path.join(TYPESHED_STDLIB_DIRECTORY), - ModuleSearchPathKind::StandardLibrary, - ) - })) - .chain(site_packages.into_iter().map(|path| { - ModuleSearchPath::new(path, ModuleSearchPathKind::SitePackagesThirdParty) - })) - // TODO vendor typeshed's third-party stubs as well as the stdlib and fallback to them as a final step - .collect(), - ) - } -} - -const TYPESHED_STDLIB_DIRECTORY: &str = "stdlib"; - -/// A resolved module resolution order, implementing PEP 561 -/// (with some small, deliberate differences) -#[derive(Clone, Debug, Default, Eq, PartialEq)] -struct OrderedSearchPaths(Vec); - -impl Deref for OrderedSearchPaths { - type Target = [ModuleSearchPath]; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// Adds a module located at `path` to the resolver. -/// -/// Returns `None` if the path doesn't resolve to a module. -/// -/// Returns `Some(module, other_modules)`, where `module` is the resolved module -/// with file location `path`, and `other_modules` is a `Vec` of `ModuleData` instances. -/// Each element in `other_modules` provides information regarding a single module that needs -/// re-resolving because it was part of a namespace package and might now resolve differently. -/// -/// Note: This won't work with salsa because `Path` is not an ingredient. -pub fn add_module(db: &mut dyn SemanticDb, path: &Path) -> Option<(Module, Vec>)> { - // No locking is required because we're holding a mutable reference to `modules`. - - // TODO This needs tests - - // Note: Intentionally bypass caching here. Module should not be in the cache yet. - let module = path_to_module(db, path).ok()??; - - // The code below is to handle the addition of `__init__.py` files. - // When an `__init__.py` file is added, we need to remove all modules that are part of the same package. - // For example, an `__init__.py` is added to `foo`, we need to remove `foo.bar`, `foo.baz`, etc. - // because they were namespace packages before and could have been from different search paths. - let Some(filename) = path.file_name() else { - return Some((module, Vec::new())); - }; - - if !matches!(filename.to_str(), Some("__init__.py" | "__init__.pyi")) { - return Some((module, Vec::new())); - } - - let Some(parent_name) = module.name(db).ok()?.parent() else { - return Some((module, Vec::new())); - }; - - let mut to_remove = Vec::new(); - - let jar: &mut SemanticJar = db.jar_mut(); - let modules = &mut jar.module_resolver; - - modules.by_file.retain(|_, module| { - if modules - .modules - .get(module) - .unwrap() - .name - .starts_with(&parent_name) - { - to_remove.push(*module); - false - } else { - true - } - }); - - // TODO remove need for this vec - let mut removed = Vec::with_capacity(to_remove.len()); - for module in &to_remove { - removed.push(modules.remove_module(*module)); - } - - Some((module, removed)) -} - -#[derive(Default)] -pub struct ModuleResolver { - /// The search paths where modules are located (and searched). Corresponds to `sys.path` at runtime. - search_paths: OrderedSearchPaths, - - // Locking: Locking is done by acquiring a (write) lock on `by_name`. This is because `by_name` is the primary - // lookup method. Acquiring locks in any other ordering can result in deadlocks. - /// Looks up a module by name - by_name: FxDashMap, - - /// A map of all known modules to data about those modules - modules: FxDashMap>, - - /// Lookup from absolute path to module. - /// The same module might be reachable from different paths when symlinks are involved. - by_file: FxDashMap, - next_module_id: AtomicU32, -} - -impl ModuleResolver { - fn new(search_paths: OrderedSearchPaths) -> Self { - Self { - search_paths, - modules: FxDashMap::default(), - by_name: FxDashMap::default(), - by_file: FxDashMap::default(), - next_module_id: AtomicU32::new(0), - } - } - - /// Remove a module from the inner cache - pub(crate) fn remove_module_by_file(&mut self, file_id: FileId) { - // No locking is required because we're holding a mutable reference to `self`. - let Some((_, module)) = self.by_file.remove(&file_id) else { - return; - }; - - self.remove_module(module); - } - - fn remove_module(&mut self, module: Module) -> Arc { - let (_, module_data) = self.modules.remove(&module).unwrap(); - - self.by_name.remove(&module_data.name).unwrap(); - - // It's possible that multiple paths map to the same module. - // Search all other paths referencing the same module. - self.by_file - .retain(|_, current_module| *current_module != module); - - module_data - } -} - -#[allow(clippy::missing_fields_in_debug)] -impl std::fmt::Debug for ModuleResolver { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ModuleResolver") - .field("search_paths", &self.search_paths) - .field("modules", &self.by_name) - .finish() - } -} - -/// The resolved path of a module. -/// -/// It should be highly likely that the file still exists when accessing but it isn't 100% guaranteed -/// because the file could have been deleted between resolving the module name and accessing it. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ModulePath { - root: ModuleSearchPath, - file_id: FileId, -} - -impl ModulePath { - pub fn new(root: ModuleSearchPath, file_id: FileId) -> Self { - Self { root, file_id } - } - - /// The search path that was used to locate the module - pub fn root(&self) -> &ModuleSearchPath { - &self.root - } - - /// The file containing the source code for the module - pub fn file(&self) -> FileId { - self.file_id - } -} - -/// Given a module name and a list of search paths in which to lookup modules, -/// attempt to resolve the module name -fn resolve_name( - name: &ModuleName, - search_paths: &[ModuleSearchPath], -) -> Option<(ModuleSearchPath, PathBuf, ModuleKind)> { - for search_path in search_paths { - let mut components = name.components(); - let module_name = components.next_back()?; - - match resolve_package(search_path, components) { - Ok(resolved_package) => { - let mut package_path = resolved_package.path; - - package_path.push(module_name); - - // Must be a `__init__.pyi` or `__init__.py` or it isn't a package. - let kind = if package_path.is_dir() { - package_path.push("__init__"); - ModuleKind::Package - } else { - ModuleKind::Module - }; - - // TODO Implement full https://peps.python.org/pep-0561/#type-checker-module-resolution-order resolution - let stub = package_path.with_extension("pyi"); - - if stub.is_file() { - return Some((search_path.clone(), stub, kind)); - } - - let module = package_path.with_extension("py"); - - if module.is_file() { - return Some((search_path.clone(), module, kind)); - } - - // For regular packages, don't search the next search path. All files of that - // package must be in the same location - if resolved_package.kind.is_regular_package() { - return None; - } - } - Err(parent_kind) => { - if parent_kind.is_regular_package() { - // For regular packages, don't search the next search path. - return None; - } - } - } - } - - None -} - -fn resolve_package<'a, I>( - module_search_path: &ModuleSearchPath, - components: I, -) -> Result -where - I: Iterator, -{ - let mut package_path = module_search_path.path().to_path_buf(); - - // `true` if inside a folder that is a namespace package (has no `__init__.py`). - // Namespace packages are special because they can be spread across multiple search paths. - // https://peps.python.org/pep-0420/ - let mut in_namespace_package = false; - - // `true` if resolving a sub-package. For example, `true` when resolving `bar` of `foo.bar`. - let mut in_sub_package = false; - - // For `foo.bar.baz`, test that `foo` and `baz` both contain a `__init__.py`. - for folder in components { - package_path.push(folder); - - let has_init_py = package_path.join("__init__.py").is_file() - || package_path.join("__init__.pyi").is_file(); - - if has_init_py { - in_namespace_package = false; - } else if package_path.is_dir() { - // A directory without an `__init__.py` is a namespace package, continue with the next folder. - in_namespace_package = true; - } else if in_namespace_package { - // Package not found but it is part of a namespace package. - return Err(PackageKind::Namespace); - } else if in_sub_package { - // A regular sub package wasn't found. - return Err(PackageKind::Regular); - } else { - // We couldn't find `foo` for `foo.bar.baz`, search the next search path. - return Err(PackageKind::Root); - } - - in_sub_package = true; - } - - let kind = if in_namespace_package { - PackageKind::Namespace - } else if in_sub_package { - PackageKind::Regular - } else { - PackageKind::Root - }; - - Ok(ResolvedPackage { - kind, - path: package_path, - }) -} - -#[derive(Debug)] -struct ResolvedPackage { - path: PathBuf, - kind: PackageKind, -} - -#[derive(Copy, Clone, Eq, PartialEq, Debug)] -enum PackageKind { - /// A root package or module. E.g. `foo` in `foo.bar.baz` or just `foo`. - Root, - - /// A regular sub-package where the parent contains an `__init__.py`. - /// - /// For example, `bar` in `foo.bar` when the `foo` directory contains an `__init__.py`. - Regular, - - /// A sub-package in a namespace package. A namespace package is a package without an `__init__.py`. - /// - /// For example, `bar` in `foo.bar` if the `foo` directory contains no `__init__.py`. - Namespace, -} - -impl PackageKind { - const fn is_regular_package(self) -> bool { - matches!(self, PackageKind::Regular) - } -} - -#[cfg(test)] -mod tests { - use red_knot_module_resolver::ModuleName; - use std::num::NonZeroU32; - use std::path::PathBuf; - - use crate::db::tests::TestDb; - use crate::db::SourceDb; - use crate::module::{ - path_to_module, resolve_module, set_module_search_paths, ModuleKind, - ModuleResolutionInputs, TYPESHED_STDLIB_DIRECTORY, - }; - use crate::semantic::Dependency; - - struct TestCase { - temp_dir: tempfile::TempDir, - db: TestDb, - - src: PathBuf, - custom_typeshed: PathBuf, - site_packages: PathBuf, - } - - fn create_resolver() -> std::io::Result { - let temp_dir = tempfile::tempdir()?; - - let src = temp_dir.path().join("src"); - let site_packages = temp_dir.path().join("site_packages"); - let custom_typeshed = temp_dir.path().join("typeshed"); - - std::fs::create_dir(&src)?; - std::fs::create_dir(&site_packages)?; - std::fs::create_dir(&custom_typeshed)?; - - let src = src.canonicalize()?; - let site_packages = site_packages.canonicalize()?; - let custom_typeshed = custom_typeshed.canonicalize()?; - - let search_paths = ModuleResolutionInputs { - extra_paths: vec![], - workspace_root: src.clone(), - site_packages: Some(site_packages.clone()), - custom_typeshed: Some(custom_typeshed.clone()), - }; - - let mut db = TestDb::default(); - set_module_search_paths(&mut db, search_paths); - - Ok(TestCase { - temp_dir, - db, - src, - custom_typeshed, - site_packages, - }) - } - - #[test] - fn first_party_module() -> anyhow::Result<()> { - let TestCase { - db, - src, - temp_dir: _temp_dir, - .. - } = create_resolver()?; - - let foo_path = src.join("foo.py"); - std::fs::write(&foo_path, "print('Hello, world!')")?; - - let foo_name = ModuleName::new_static("foo").unwrap(); - let foo_module = resolve_module(&db, &foo_name)?.unwrap(); - - assert_eq!(Some(foo_module), resolve_module(&db, &foo_name)?); - - assert_eq!(foo_name, foo_module.name(&db)?); - assert_eq!(&src, foo_module.path(&db)?.root().path()); - assert_eq!(ModuleKind::Module, foo_module.kind(&db)?); - assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file())); - - assert_eq!(Some(foo_module), path_to_module(&db, &foo_path)?); - - Ok(()) - } - - #[test] - fn stdlib() -> anyhow::Result<()> { - let TestCase { - db, - custom_typeshed, - .. - } = create_resolver()?; - let stdlib_dir = custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY); - std::fs::create_dir_all(&stdlib_dir).unwrap(); - let functools_name = ModuleName::new_static("functools").unwrap(); - let functools_path = stdlib_dir.join("functools.py"); - std::fs::write(&functools_path, "def update_wrapper(): ...").unwrap(); - let functools_module = resolve_module(&db, &functools_name)?.unwrap(); - - assert_eq!( - Some(functools_module), - resolve_module(&db, &functools_name)? - ); - assert_eq!(&stdlib_dir, functools_module.path(&db)?.root().path()); - assert_eq!(ModuleKind::Module, functools_module.kind(&db)?); - assert_eq!( - &functools_path, - &*db.file_path(functools_module.path(&db)?.file()) - ); - - assert_eq!( - Some(functools_module), - path_to_module(&db, &functools_path)? - ); - - Ok(()) - } - - #[test] - fn first_party_precedence_over_stdlib() -> anyhow::Result<()> { - let TestCase { - db, - src, - custom_typeshed, - .. - } = create_resolver()?; - - let stdlib_dir = custom_typeshed.join(TYPESHED_STDLIB_DIRECTORY); - std::fs::create_dir_all(&stdlib_dir).unwrap(); - std::fs::create_dir_all(&src).unwrap(); - - let stdlib_functools_path = stdlib_dir.join("functools.py"); - let first_party_functools_path = src.join("functools.py"); - std::fs::write(stdlib_functools_path, "def update_wrapper(): ...").unwrap(); - std::fs::write(&first_party_functools_path, "def update_wrapper(): ...").unwrap(); - let functools_name = ModuleName::new_static("functools").unwrap(); - let functools_module = resolve_module(&db, &functools_name)?.unwrap(); - - assert_eq!( - Some(functools_module), - resolve_module(&db, &functools_name)? - ); - assert_eq!(&src, functools_module.path(&db).unwrap().root().path()); - assert_eq!(ModuleKind::Module, functools_module.kind(&db)?); - assert_eq!( - &first_party_functools_path, - &*db.file_path(functools_module.path(&db)?.file()) - ); - - assert_eq!( - Some(functools_module), - path_to_module(&db, &first_party_functools_path)? - ); - - Ok(()) - } - - #[test] - fn resolve_package() -> anyhow::Result<()> { - let TestCase { - src, - db, - temp_dir: _temp_dir, - .. - } = create_resolver()?; - - let foo_name = ModuleName::new("foo").unwrap(); - let foo_dir = src.join("foo"); - let foo_path = foo_dir.join("__init__.py"); - std::fs::create_dir(&foo_dir)?; - std::fs::write(&foo_path, "print('Hello, world!')")?; - - let foo_module = resolve_module(&db, &foo_name)?.unwrap(); - - assert_eq!(foo_name, foo_module.name(&db)?); - assert_eq!(&src, foo_module.path(&db)?.root().path()); - assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file())); - - assert_eq!(Some(foo_module), path_to_module(&db, &foo_path)?); - - // Resolving by directory doesn't resolve to the init file. - assert_eq!(None, path_to_module(&db, &foo_dir)?); - - Ok(()) - } - - #[test] - fn package_priority_over_module() -> anyhow::Result<()> { - let TestCase { - db, - temp_dir: _temp_dir, - src, - .. - } = create_resolver()?; - - let foo_dir = src.join("foo"); - let foo_init = foo_dir.join("__init__.py"); - std::fs::create_dir(&foo_dir)?; - std::fs::write(&foo_init, "print('Hello, world!')")?; - - let foo_py = src.join("foo.py"); - std::fs::write(&foo_py, "print('Hello, world!')")?; - - let foo_module = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); - - assert_eq!(&src, foo_module.path(&db)?.root().path()); - assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db)?.file())); - assert_eq!(ModuleKind::Package, foo_module.kind(&db)?); - - assert_eq!(Some(foo_module), path_to_module(&db, &foo_init)?); - assert_eq!(None, path_to_module(&db, &foo_py)?); - - Ok(()) - } - - #[test] - fn typing_stub_over_module() -> anyhow::Result<()> { - let TestCase { - db, - src, - temp_dir: _temp_dir, - .. - } = create_resolver()?; - - let foo_stub = src.join("foo.pyi"); - let foo_py = src.join("foo.py"); - std::fs::write(&foo_stub, "x: int")?; - std::fs::write(&foo_py, "print('Hello, world!')")?; - - let foo = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); - - assert_eq!(&src, foo.path(&db)?.root().path()); - assert_eq!(&foo_stub, &*db.file_path(foo.path(&db)?.file())); - - assert_eq!(Some(foo), path_to_module(&db, &foo_stub)?); - assert_eq!(None, path_to_module(&db, &foo_py)?); - - Ok(()) - } - - #[test] - fn sub_packages() -> anyhow::Result<()> { - let TestCase { - db, - src, - temp_dir: _temp_dir, - .. - } = create_resolver()?; - - let foo = src.join("foo"); - let bar = foo.join("bar"); - let baz = bar.join("baz.py"); - - std::fs::create_dir_all(&bar)?; - std::fs::write(foo.join("__init__.py"), "")?; - std::fs::write(bar.join("__init__.py"), "")?; - std::fs::write(&baz, "print('Hello, world!')")?; - - let baz_module = resolve_module(&db, &ModuleName::new("foo.bar.baz").unwrap())?.unwrap(); - - assert_eq!(&src, baz_module.path(&db)?.root().path()); - assert_eq!(&baz, &*db.file_path(baz_module.path(&db)?.file())); - - assert_eq!(Some(baz_module), path_to_module(&db, &baz)?); - - Ok(()) - } - - #[test] - fn namespace_package() -> anyhow::Result<()> { - let TestCase { - db, - temp_dir: _, - src, - site_packages, - .. - } = create_resolver()?; - - // From [PEP420](https://peps.python.org/pep-0420/#nested-namespace-packages). - // But uses `src` for `project1` and `site_packages2` for `project2`. - // ``` - // src - // parent - // child - // one.py - // site_packages - // parent - // child - // two.py - // ``` - - let parent1 = src.join("parent"); - let child1 = parent1.join("child"); - let one = child1.join("one.py"); - - std::fs::create_dir_all(child1)?; - std::fs::write(&one, "print('Hello, world!')")?; - - let parent2 = site_packages.join("parent"); - let child2 = parent2.join("child"); - let two = child2.join("two.py"); - - std::fs::create_dir_all(&child2)?; - std::fs::write(&two, "print('Hello, world!')")?; - - let one_module = - resolve_module(&db, &ModuleName::new("parent.child.one").unwrap())?.unwrap(); - - assert_eq!(Some(one_module), path_to_module(&db, &one)?); - - let two_module = - resolve_module(&db, &ModuleName::new("parent.child.two").unwrap())?.unwrap(); - assert_eq!(Some(two_module), path_to_module(&db, &two)?); - - Ok(()) - } - - #[test] - fn regular_package_in_namespace_package() -> anyhow::Result<()> { - let TestCase { - db, - temp_dir: _, - src, - site_packages, - .. - } = create_resolver()?; - - // Adopted test case from the [PEP420 examples](https://peps.python.org/pep-0420/#nested-namespace-packages). - // The `src/parent/child` package is a regular package. Therefore, `site_packages/parent/child/two.py` should not be resolved. - // ``` - // src - // parent - // child - // one.py - // site_packages - // parent - // child - // two.py - // ``` - - let parent1 = src.join("parent"); - let child1 = parent1.join("child"); - let one = child1.join("one.py"); - - std::fs::create_dir_all(&child1)?; - std::fs::write(child1.join("__init__.py"), "print('Hello, world!')")?; - std::fs::write(&one, "print('Hello, world!')")?; - - let parent2 = site_packages.join("parent"); - let child2 = parent2.join("child"); - let two = child2.join("two.py"); - - std::fs::create_dir_all(&child2)?; - std::fs::write(two, "print('Hello, world!')")?; - - let one_module = - resolve_module(&db, &ModuleName::new("parent.child.one").unwrap())?.unwrap(); - - assert_eq!(Some(one_module), path_to_module(&db, &one)?); - - assert_eq!( - None, - resolve_module(&db, &ModuleName::new("parent.child.two").unwrap())? - ); - Ok(()) - } - - #[test] - fn module_search_path_priority() -> anyhow::Result<()> { - let TestCase { - db, - src, - site_packages, - temp_dir: _temp_dir, - .. - } = create_resolver()?; - - let foo_src = src.join("foo.py"); - let foo_site_packages = site_packages.join("foo.py"); - - std::fs::write(&foo_src, "")?; - std::fs::write(&foo_site_packages, "")?; - - let foo_module = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); - - assert_eq!(&src, foo_module.path(&db)?.root().path()); - assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db)?.file())); - - assert_eq!(Some(foo_module), path_to_module(&db, &foo_src)?); - assert_eq!(None, path_to_module(&db, &foo_site_packages)?); - - Ok(()) - } - - #[test] - #[cfg(target_family = "unix")] - fn symlink() -> anyhow::Result<()> { - let TestCase { - db, - src, - temp_dir: _temp_dir, - .. - } = create_resolver()?; - - let foo = src.join("foo.py"); - let bar = src.join("bar.py"); - - std::fs::write(&foo, "")?; - std::os::unix::fs::symlink(&foo, &bar)?; - - let foo_module = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); - let bar_module = resolve_module(&db, &ModuleName::new("bar").unwrap())?.unwrap(); - - assert_ne!(foo_module, bar_module); - - assert_eq!(&src, foo_module.path(&db)?.root().path()); - assert_eq!(&foo, &*db.file_path(foo_module.path(&db)?.file())); - - // Bar has a different name but it should point to the same file. - - assert_eq!(&src, bar_module.path(&db)?.root().path()); - assert_eq!(foo_module.path(&db)?.file(), bar_module.path(&db)?.file()); - assert_eq!(&foo, &*db.file_path(bar_module.path(&db)?.file())); - - assert_eq!(Some(foo_module), path_to_module(&db, &foo)?); - assert_eq!(Some(bar_module), path_to_module(&db, &bar)?); - - Ok(()) - } - - #[test] - fn resolve_dependency() -> anyhow::Result<()> { - let TestCase { - src, - db, - temp_dir: _temp_dir, - .. - } = create_resolver()?; - - let foo_dir = src.join("foo"); - let foo_path = foo_dir.join("__init__.py"); - let bar_path = foo_dir.join("bar.py"); - - std::fs::create_dir(&foo_dir)?; - std::fs::write(foo_path, "from .bar import test")?; - std::fs::write(bar_path, "test = 'Hello world'")?; - - let foo_module = resolve_module(&db, &ModuleName::new("foo").unwrap())?.unwrap(); - let bar_module = resolve_module(&db, &ModuleName::new("foo.bar").unwrap())?.unwrap(); - - // `from . import bar` in `foo/__init__.py` resolves to `foo` - assert_eq!( - ModuleName::new("foo"), - foo_module.resolve_dependency( - &db, - &Dependency::Relative { - level: NonZeroU32::new(1).unwrap(), - module: None, - } - )? - ); - - // `from baz import bar` in `foo/__init__.py` should resolve to `baz.py` - assert_eq!( - ModuleName::new("baz"), - foo_module - .resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz").unwrap()))? - ); - - // from .bar import test in `foo/__init__.py` should resolve to `foo/bar.py` - assert_eq!( - ModuleName::new("foo.bar"), - foo_module.resolve_dependency( - &db, - &Dependency::Relative { - level: NonZeroU32::new(1).unwrap(), - module: ModuleName::new("bar") - } - )? - ); - - // from .. import test in `foo/__init__.py` resolves to `` which is not a module - assert_eq!( - None, - foo_module.resolve_dependency( - &db, - &Dependency::Relative { - level: NonZeroU32::new(2).unwrap(), - module: None - } - )? - ); - - // `from . import test` in `foo/bar.py` resolves to `foo` - assert_eq!( - ModuleName::new("foo"), - bar_module.resolve_dependency( - &db, - &Dependency::Relative { - level: NonZeroU32::new(1).unwrap(), - module: None - } - )? - ); - - // `from baz import test` in `foo/bar.py` resolves to `baz` - assert_eq!( - ModuleName::new("baz"), - bar_module - .resolve_dependency(&db, &Dependency::Module(ModuleName::new("baz").unwrap()))? - ); - - // `from .baz import test` in `foo/bar.py` resolves to `foo.baz`. - assert_eq!( - ModuleName::new("foo.baz"), - bar_module.resolve_dependency( - &db, - &Dependency::Relative { - level: NonZeroU32::new(1).unwrap(), - module: ModuleName::new("baz") - } - )? - ); - - Ok(()) - } -} diff --git a/crates/red_knot/src/parse.rs b/crates/red_knot/src/parse.rs deleted file mode 100644 index 393625b3ae..0000000000 --- a/crates/red_knot/src/parse.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::ops::{Deref, DerefMut}; -use std::sync::Arc; - -use ruff_python_ast::ModModule; -use ruff_python_parser::Parsed; - -use crate::cache::KeyValueCache; -use crate::db::{QueryResult, SourceDb}; -use crate::files::FileId; -use crate::source::source_text; - -#[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn parse(db: &dyn SourceDb, file_id: FileId) -> QueryResult>> { - let jar = db.jar()?; - - jar.parsed.get(&file_id, |file_id| { - let source = source_text(db, *file_id)?; - - Ok(Arc::new(ruff_python_parser::parse_unchecked_source( - source.text(), - source.kind().into(), - ))) - }) -} - -#[derive(Debug, Default)] -pub struct ParsedStorage(KeyValueCache>>); - -impl Deref for ParsedStorage { - type Target = KeyValueCache>>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for ParsedStorage { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} diff --git a/crates/red_knot/src/program/check.rs b/crates/red_knot/src/program/check.rs index 872b52e9f7..22633ad9a3 100644 --- a/crates/red_knot/src/program/check.rs +++ b/crates/red_knot/src/program/check.rs @@ -1,413 +1,28 @@ -use rayon::{current_num_threads, yield_local}; -use rustc_hash::FxHashSet; +use ruff_db::vfs::VfsFile; +use salsa::Cancelled; -use crate::db::{Database, QueryError, QueryResult}; -use crate::files::FileId; use crate::lint::{lint_semantic, lint_syntax, Diagnostics}; -use crate::module::{file_to_module, resolve_module}; use crate::program::Program; -use crate::semantic::{semantic_index, Dependency}; impl Program { /// Checks all open files in the workspace and its dependencies. #[tracing::instrument(level = "debug", skip_all)] - pub fn check(&self, mode: ExecutionMode) -> QueryResult> { - self.cancelled()?; + pub fn check(&self) -> Result, Cancelled> { + self.with_db(|db| { + let mut result = Vec::new(); + for open_file in db.workspace.open_files() { + result.extend_from_slice(&db.check_file(open_file)); + } - let mut context = CheckContext::new(self); - - match mode { - ExecutionMode::SingleThreaded => SingleThreadedExecutor.run(&mut context)?, - ExecutionMode::ThreadPool => ThreadPoolExecutor.run(&mut context)?, - }; - - Ok(context.finish()) + result + }) } - #[tracing::instrument(level = "debug", skip(self, context))] - fn check_file(&self, file: FileId, context: &CheckFileContext) -> QueryResult { - self.cancelled()?; - - let index = semantic_index(self, file)?; - let dependencies = index.symbol_table().dependencies(); - - if !dependencies.is_empty() { - let module = file_to_module(self, file)?; - - // TODO scheduling all dependencies here is wasteful if we don't infer any types on them - // but I think that's unlikely, so it is okay? - // Anyway, we need to figure out a way to retrieve the dependencies of a module - // from the persistent cache. So maybe it should be a separate query after all. - for dependency in dependencies { - let dependency_name = match dependency { - Dependency::Module(name) => Some(name.clone()), - Dependency::Relative { .. } => match &module { - Some(module) => module.resolve_dependency(self, dependency)?, - None => None, - }, - }; - - if let Some(dependency_name) = dependency_name { - // TODO We may want to have a different check functions for non-first-party - // files because we only need to index them and not check them. - // Supporting non-first-party code also requires supporting typing stubs. - if let Some(dependency) = resolve_module(self, &dependency_name)? { - if dependency.path(self)?.root().kind().is_first_party() { - context.schedule_dependency(dependency.path(self)?.file()); - } - } - } - } - } - + #[tracing::instrument(level = "debug", skip(self))] + fn check_file(&self, file: VfsFile) -> Diagnostics { let mut diagnostics = Vec::new(); - - if self.workspace().is_file_open(file) { - diagnostics.extend_from_slice(&lint_syntax(self, file)?); - diagnostics.extend_from_slice(&lint_semantic(self, file)?); - } - - Ok(Diagnostics::from(diagnostics)) + diagnostics.extend_from_slice(lint_syntax(self, file)); + diagnostics.extend_from_slice(lint_semantic(self, file)); + Diagnostics::from(diagnostics) } } - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ExecutionMode { - SingleThreaded, - ThreadPool, -} - -/// Context that stores state information about the entire check operation. -struct CheckContext<'a> { - /// IDs of the files that have been queued for checking. - /// - /// Used to avoid queuing the same file twice. - scheduled_files: FxHashSet, - - /// Reference to the program that is checked. - program: &'a Program, - - /// The aggregated diagnostics - diagnostics: Vec, -} - -impl<'a> CheckContext<'a> { - fn new(program: &'a Program) -> Self { - Self { - scheduled_files: FxHashSet::default(), - program, - diagnostics: Vec::new(), - } - } - - /// Returns the tasks to check all open files in the workspace. - fn check_open_files(&mut self) -> Vec { - self.scheduled_files - .extend(self.program.workspace().open_files()); - - self.program - .workspace() - .open_files() - .map(|file_id| CheckOpenFileTask { file_id }) - .collect() - } - - /// Returns the task to check a dependency. - fn check_dependency(&mut self, file_id: FileId) -> Option { - if self.scheduled_files.insert(file_id) { - Some(CheckDependencyTask { file_id }) - } else { - None - } - } - - /// Pushes the result for a single file check operation - fn push_diagnostics(&mut self, diagnostics: &Diagnostics) { - self.diagnostics.extend_from_slice(diagnostics); - } - - /// Returns a reference to the program that is being checked. - fn program(&self) -> &'a Program { - self.program - } - - /// Creates a task context that is used to check a single file. - fn task_context<'b, S>(&self, dependency_scheduler: &'b S) -> CheckTaskContext<'a, 'b, S> - where - S: ScheduleDependency, - { - CheckTaskContext { - program: self.program, - dependency_scheduler, - } - } - - fn finish(self) -> Vec { - self.diagnostics - } -} - -/// Trait that abstracts away how a dependency of a file gets scheduled for checking. -trait ScheduleDependency { - /// Schedules the file with the given ID for checking. - fn schedule(&self, file_id: FileId); -} - -impl ScheduleDependency for T -where - T: Fn(FileId), -{ - fn schedule(&self, file_id: FileId) { - let f = self; - f(file_id); - } -} - -/// Context that is used to run a single file check task. -/// -/// The task is generic over `S` because it is passed across thread boundaries and -/// we don't want to add the requirement that [`ScheduleDependency`] must be [`Send`]. -struct CheckTaskContext<'a, 'scheduler, S> -where - S: ScheduleDependency, -{ - dependency_scheduler: &'scheduler S, - program: &'a Program, -} - -impl<'a, 'scheduler, S> CheckTaskContext<'a, 'scheduler, S> -where - S: ScheduleDependency, -{ - fn as_file_context(&self) -> CheckFileContext<'scheduler> { - CheckFileContext { - dependency_scheduler: self.dependency_scheduler, - } - } -} - -/// Context passed when checking a single file. -/// -/// This is a trimmed down version of [`CheckTaskContext`] with the type parameter `S` erased -/// to avoid monomorphization of [`Program:check_file`]. -struct CheckFileContext<'a> { - dependency_scheduler: &'a dyn ScheduleDependency, -} - -impl<'a> CheckFileContext<'a> { - fn schedule_dependency(&self, file_id: FileId) { - self.dependency_scheduler.schedule(file_id); - } -} - -#[derive(Debug)] -enum CheckFileTask { - OpenFile(CheckOpenFileTask), - Dependency(CheckDependencyTask), -} - -impl CheckFileTask { - /// Runs the task and returns the results for checking this file. - fn run(&self, context: &CheckTaskContext) -> QueryResult - where - S: ScheduleDependency, - { - match self { - Self::OpenFile(task) => task.run(context), - Self::Dependency(task) => task.run(context), - } - } - - fn file_id(&self) -> FileId { - match self { - CheckFileTask::OpenFile(task) => task.file_id, - CheckFileTask::Dependency(task) => task.file_id, - } - } -} - -/// Task to check an open file. - -#[derive(Debug)] -struct CheckOpenFileTask { - file_id: FileId, -} - -impl CheckOpenFileTask { - fn run(&self, context: &CheckTaskContext) -> QueryResult - where - S: ScheduleDependency, - { - context - .program - .check_file(self.file_id, &context.as_file_context()) - } -} - -/// Task to check a dependency file. -#[derive(Debug)] -struct CheckDependencyTask { - file_id: FileId, -} - -impl CheckDependencyTask { - fn run(&self, context: &CheckTaskContext) -> QueryResult - where - S: ScheduleDependency, - { - context - .program - .check_file(self.file_id, &context.as_file_context()) - } -} - -/// Executor that schedules the checking of individual program files. -trait CheckExecutor { - fn run(self, context: &mut CheckContext) -> QueryResult<()>; -} - -/// Executor that runs all check operations on the current thread. -/// -/// The executor does not schedule dependencies for checking. -/// The main motivation for scheduling dependencies -/// in a multithreaded environment is to parse and index the dependencies concurrently. -/// However, that doesn't make sense in a single threaded environment, because the dependencies then compute -/// with checking the open files. Checking dependencies in a single threaded environment is more likely -/// to hurt performance because we end up analyzing files in their entirety, even if we only need to type check parts of them. -#[derive(Debug, Default)] -struct SingleThreadedExecutor; - -impl CheckExecutor for SingleThreadedExecutor { - fn run(self, context: &mut CheckContext) -> QueryResult<()> { - let mut queue = context.check_open_files(); - - let noop_schedule_dependency = |_| {}; - - while let Some(file) = queue.pop() { - context.program().cancelled()?; - - let task_context = context.task_context(&noop_schedule_dependency); - context.push_diagnostics(&file.run(&task_context)?); - } - - Ok(()) - } -} - -/// Executor that runs the check operations on a thread pool. -/// -/// The executor runs each check operation as its own task using a thread pool. -/// -/// Other than [`SingleThreadedExecutor`], this executor schedules dependencies for checking. It -/// even schedules dependencies for checking when the thread pool size is 1 for a better debugging experience. -#[derive(Debug, Default)] -struct ThreadPoolExecutor; - -impl CheckExecutor for ThreadPoolExecutor { - fn run(self, context: &mut CheckContext) -> QueryResult<()> { - let num_threads = current_num_threads(); - let single_threaded = num_threads == 1; - let span = tracing::trace_span!("ThreadPoolExecutor::run", num_threads); - let _ = span.enter(); - - let mut queue: Vec<_> = context - .check_open_files() - .into_iter() - .map(CheckFileTask::OpenFile) - .collect(); - - let (sender, receiver) = if single_threaded { - // Use an unbounded queue for single threaded execution to prevent deadlocks - // when a single file schedules multiple dependencies. - crossbeam::channel::unbounded() - } else { - // Use a bounded queue to apply backpressure when the orchestration thread isn't able to keep - // up processing messages from the worker threads. - crossbeam::channel::bounded(num_threads) - }; - - let schedule_sender = sender.clone(); - let schedule_dependency = move |file_id| { - schedule_sender - .send(ThreadPoolMessage::ScheduleDependency(file_id)) - .unwrap(); - }; - - let result = rayon::in_place_scope(|scope| { - let mut pending = 0usize; - - loop { - context.program().cancelled()?; - - // 1. Try to get a queued message to ensure that we have always remaining space in the channel to prevent blocking the worker threads. - // 2. Try to process a queued file - // 3. If there's no queued file wait for the next incoming message. - // 4. Exit if there are no more messages and no senders. - let message = if let Ok(message) = receiver.try_recv() { - message - } else if let Some(task) = queue.pop() { - pending += 1; - - let task_context = context.task_context(&schedule_dependency); - let sender = sender.clone(); - let task_span = tracing::trace_span!( - parent: &span, - "CheckFileTask::run", - file_id = task.file_id().as_u32(), - ); - - scope.spawn(move |_| { - task_span.in_scope(|| match task.run(&task_context) { - Ok(result) => { - sender.send(ThreadPoolMessage::Completed(result)).unwrap(); - } - Err(err) => sender.send(ThreadPoolMessage::Errored(err)).unwrap(), - }); - }); - - // If this is a single threaded rayon thread pool, yield the current thread - // or we never start processing the work items. - if single_threaded { - yield_local(); - } - - continue; - } else if let Ok(message) = receiver.recv() { - message - } else { - break; - }; - - match message { - ThreadPoolMessage::ScheduleDependency(dependency) => { - if let Some(task) = context.check_dependency(dependency) { - queue.push(CheckFileTask::Dependency(task)); - } - } - ThreadPoolMessage::Completed(diagnostics) => { - context.push_diagnostics(&diagnostics); - pending -= 1; - - if pending == 0 && queue.is_empty() { - break; - } - } - ThreadPoolMessage::Errored(err) => { - return Err(err); - } - } - } - - Ok(()) - }); - - result - } -} - -#[derive(Debug)] -enum ThreadPoolMessage { - ScheduleDependency(FileId), - Completed(Diagnostics), - Errored(QueryError), -} diff --git a/crates/red_knot/src/program/mod.rs b/crates/red_knot/src/program/mod.rs index 69ac836160..92ab5a5a42 100644 --- a/crates/red_knot/src/program/mod.rs +++ b/crates/red_knot/src/program/mod.rs @@ -1,30 +1,36 @@ -use std::collections::hash_map::Entry; -use std::path::{Path, PathBuf}; +use std::panic::{RefUnwindSafe, UnwindSafe}; use std::sync::Arc; -use rustc_hash::FxHashMap; +use salsa::{Cancelled, Database}; -use crate::db::{ - Database, Db, DbRuntime, DbWithJar, HasJar, HasJars, JarsStorage, LintDb, LintJar, - ParallelDatabase, QueryResult, SemanticDb, SemanticJar, Snapshot, SourceDb, SourceJar, Upcast, -}; -use crate::files::{FileId, Files}; +use red_knot_module_resolver::{Db as ResolverDb, Jar as ResolverJar}; +use red_knot_python_semantic::{Db as SemanticDb, Jar as SemanticJar}; +use ruff_db::file_system::{FileSystem, FileSystemPathBuf}; +use ruff_db::vfs::{Vfs, VfsFile, VfsPath}; +use ruff_db::{Db as SourceDb, Jar as SourceJar, Upcast}; + +use crate::db::{Db, Jar}; use crate::Workspace; -pub mod check; +mod check; -#[derive(Debug)] +#[salsa::db(SourceJar, ResolverJar, SemanticJar, Jar)] pub struct Program { - jars: JarsStorage, - files: Files, + storage: salsa::Storage, + vfs: Vfs, + fs: Arc, workspace: Workspace, } impl Program { - pub fn new(workspace: Workspace) -> Self { + pub fn new(workspace: Workspace, file_system: Fs) -> Self + where + Fs: FileSystem + 'static + Send + Sync + RefUnwindSafe, + { Self { - jars: JarsStorage::default(), - files: Files::default(), + storage: salsa::Storage::default(), + vfs: Vfs::default(), + fs: Arc::new(file_system), workspace, } } @@ -33,30 +39,11 @@ impl Program { where I: IntoIterator, { - let mut aggregated_changes = AggregatedChanges::default(); - - aggregated_changes.extend(changes.into_iter().map(|change| FileChange { - id: self.files.intern(&change.path), - kind: change.kind, - })); - - let (source, semantic, lint) = self.jars_mut(); - for change in aggregated_changes.iter() { - semantic.module_resolver.remove_module_by_file(change.id); - semantic.semantic_indices.remove(&change.id); - source.sources.remove(&change.id); - source.parsed.remove(&change.id); - // TODO: remove all dependent modules as well - semantic.type_store.remove_module(change.id); - lint.lint_syntax.remove(&change.id); - lint.lint_semantic.remove(&change.id); + for change in changes { + VfsFile::touch_path(self, &VfsPath::file_system(change.path)); } } - pub fn files(&self) -> &Files { - &self.files - } - pub fn workspace(&self) -> &Workspace { &self.workspace } @@ -64,28 +51,18 @@ impl Program { pub fn workspace_mut(&mut self) -> &mut Workspace { &mut self.workspace } -} -impl SourceDb for Program { - fn file_id(&self, path: &Path) -> FileId { - self.files.intern(path) - } - - fn file_path(&self, file_id: FileId) -> Arc { - self.files.path(file_id) + #[allow(clippy::unnecessary_wraps)] + fn with_db(&self, f: F) -> Result + where + F: FnOnce(&Program) -> T + UnwindSafe, + { + // TODO: Catch in `Caancelled::catch` + // See https://salsa.zulipchat.com/#narrow/stream/145099-general/topic/How.20to.20use.20.60Cancelled.3A.3Acatch.60 + Ok(f(self)) } } -impl DbWithJar for Program {} - -impl SemanticDb for Program {} - -impl DbWithJar for Program {} - -impl LintDb for Program {} - -impl DbWithJar for Program {} - impl Upcast for Program { fn upcast(&self) -> &(dyn SemanticDb + 'static) { self @@ -98,178 +75,57 @@ impl Upcast for Program { } } -impl Upcast for Program { - fn upcast(&self) -> &(dyn LintDb + 'static) { +impl Upcast for Program { + fn upcast(&self) -> &(dyn ResolverDb + 'static) { self } } -impl Db for Program {} +impl ResolverDb for Program {} -impl Database for Program { - fn runtime(&self) -> &DbRuntime { - self.jars.runtime() +impl SemanticDb for Program {} + +impl SourceDb for Program { + fn file_system(&self) -> &dyn FileSystem { + &*self.fs } - fn runtime_mut(&mut self) -> &mut DbRuntime { - self.jars.runtime_mut() + fn vfs(&self) -> &Vfs { + &self.vfs } } -impl ParallelDatabase for Program { - fn snapshot(&self) -> Snapshot { - Snapshot::new(Self { - jars: self.jars.snapshot(), - files: self.files.snapshot(), +impl Database for Program {} + +impl Db for Program {} + +impl salsa::ParallelDatabase for Program { + fn snapshot(&self) -> salsa::Snapshot { + salsa::Snapshot::new(Self { + storage: self.storage.snapshot(), + vfs: self.vfs.snapshot(), + fs: self.fs.clone(), workspace: self.workspace.clone(), }) } } -impl HasJars for Program { - type Jars = (SourceJar, SemanticJar, LintJar); - - fn jars(&self) -> QueryResult<&Self::Jars> { - self.jars.jars() - } - - fn jars_mut(&mut self) -> &mut Self::Jars { - self.jars.jars_mut() - } -} - -impl HasJar for Program { - fn jar(&self) -> QueryResult<&SourceJar> { - Ok(&self.jars()?.0) - } - - fn jar_mut(&mut self) -> &mut SourceJar { - &mut self.jars_mut().0 - } -} - -impl HasJar for Program { - fn jar(&self) -> QueryResult<&SemanticJar> { - Ok(&self.jars()?.1) - } - - fn jar_mut(&mut self) -> &mut SemanticJar { - &mut self.jars_mut().1 - } -} - -impl HasJar for Program { - fn jar(&self) -> QueryResult<&LintJar> { - Ok(&self.jars()?.2) - } - - fn jar_mut(&mut self) -> &mut LintJar { - &mut self.jars_mut().2 - } -} - #[derive(Clone, Debug)] pub struct FileWatcherChange { - path: PathBuf, + path: FileSystemPathBuf, + #[allow(unused)] kind: FileChangeKind, } impl FileWatcherChange { - pub fn new(path: PathBuf, kind: FileChangeKind) -> Self { + pub fn new(path: FileSystemPathBuf, kind: FileChangeKind) -> Self { Self { path, kind } } } -#[derive(Copy, Clone, Debug)] -struct FileChange { - id: FileId, - kind: FileChangeKind, -} - -impl FileChange { - fn file_id(self) -> FileId { - self.id - } - - fn kind(self) -> FileChangeKind { - self.kind - } -} - #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum FileChangeKind { Created, Modified, Deleted, } - -#[derive(Default, Debug)] -struct AggregatedChanges { - changes: FxHashMap, -} - -impl AggregatedChanges { - fn add(&mut self, change: FileChange) { - match self.changes.entry(change.file_id()) { - Entry::Occupied(mut entry) => { - let merged = entry.get_mut(); - - match (merged, change.kind()) { - (FileChangeKind::Created, FileChangeKind::Deleted) => { - // Deletion after creations means that ruff never saw the file. - entry.remove(); - } - (FileChangeKind::Created, FileChangeKind::Modified) => { - // No-op, for ruff, modifying a file that it doesn't yet know that it exists is still considered a creation. - } - - (FileChangeKind::Modified, FileChangeKind::Created) => { - // Uhh, that should probably not happen. Continue considering it a modification. - } - - (FileChangeKind::Modified, FileChangeKind::Deleted) => { - *entry.get_mut() = FileChangeKind::Deleted; - } - - (FileChangeKind::Deleted, FileChangeKind::Created) => { - *entry.get_mut() = FileChangeKind::Modified; - } - - (FileChangeKind::Deleted, FileChangeKind::Modified) => { - // That's weird, but let's consider it a modification. - *entry.get_mut() = FileChangeKind::Modified; - } - - (FileChangeKind::Created, FileChangeKind::Created) - | (FileChangeKind::Modified, FileChangeKind::Modified) - | (FileChangeKind::Deleted, FileChangeKind::Deleted) => { - // No-op transitions. Some of them should be impossible but we handle them anyway. - } - } - } - Entry::Vacant(entry) => { - entry.insert(change.kind()); - } - } - } - - fn extend(&mut self, changes: I) - where - I: IntoIterator, - { - let iter = changes.into_iter(); - let (lower, _) = iter.size_hint(); - self.changes.reserve(lower); - - for change in iter { - self.add(change); - } - } - - fn iter(&self) -> impl Iterator + '_ { - self.changes.iter().map(|(id, kind)| FileChange { - id: *id, - kind: *kind, - }) - } -} diff --git a/crates/red_knot/src/semantic.rs b/crates/red_knot/src/semantic.rs deleted file mode 100644 index be4753be96..0000000000 --- a/crates/red_knot/src/semantic.rs +++ /dev/null @@ -1,881 +0,0 @@ -use std::num::NonZeroU32; - -use ruff_python_ast as ast; -use ruff_python_ast::visitor::source_order::SourceOrderVisitor; -use ruff_python_ast::AstNode; - -use crate::ast_ids::{NodeKey, TypedNodeKey}; -use crate::cache::KeyValueCache; -use crate::db::{QueryResult, SemanticDb, SemanticJar}; -use crate::files::FileId; -use crate::module::Module; -use crate::parse::parse; -pub(crate) use definitions::Definition; -use definitions::{ImportDefinition, ImportFromDefinition}; -pub(crate) use flow_graph::ConstrainedDefinition; -use flow_graph::{FlowGraph, FlowGraphBuilder, FlowNodeId, ReachableDefinitionsIterator}; -use red_knot_module_resolver::ModuleName; -use ruff_index::{newtype_index, IndexVec}; -use rustc_hash::FxHashMap; -use std::ops::{Deref, DerefMut}; -use std::sync::Arc; -pub(crate) use symbol_table::{Dependency, SymbolId}; -use symbol_table::{ScopeId, ScopeKind, SymbolFlags, SymbolTable, SymbolTableBuilder}; -pub(crate) use types::{infer_definition_type, infer_symbol_public_type, Type, TypeStore}; - -mod definitions; -mod flow_graph; -mod symbol_table; -mod types; - -#[tracing::instrument(level = "debug", skip(db))] -pub fn semantic_index(db: &dyn SemanticDb, file_id: FileId) -> QueryResult> { - let jar: &SemanticJar = db.jar()?; - - jar.semantic_indices.get(&file_id, |_| { - let parsed = parse(db.upcast(), file_id)?; - Ok(Arc::from(SemanticIndex::from_ast(parsed.syntax()))) - }) -} - -#[tracing::instrument(level = "debug", skip(db))] -pub fn resolve_global_symbol( - db: &dyn SemanticDb, - module: Module, - name: &str, -) -> QueryResult> { - let file_id = module.path(db)?.file(); - let symbol_table = &semantic_index(db, file_id)?.symbol_table; - let Some(symbol_id) = symbol_table.root_symbol_id_by_name(name) else { - return Ok(None); - }; - Ok(Some(GlobalSymbolId { file_id, symbol_id })) -} - -#[newtype_index] -pub struct ExpressionId; - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct GlobalSymbolId { - pub(crate) file_id: FileId, - pub(crate) symbol_id: SymbolId, -} - -#[derive(Debug)] -pub struct SemanticIndex { - symbol_table: SymbolTable, - flow_graph: FlowGraph, - expressions: FxHashMap, - expressions_by_id: IndexVec, -} - -impl SemanticIndex { - pub fn from_ast(module: &ast::ModModule) -> Self { - let root_scope_id = SymbolTable::root_scope_id(); - let mut indexer = SemanticIndexer { - symbol_table_builder: SymbolTableBuilder::new(), - flow_graph_builder: FlowGraphBuilder::new(), - scopes: vec![ScopeState { - scope_id: root_scope_id, - current_flow_node_id: FlowGraph::start(), - }], - expressions: FxHashMap::default(), - expressions_by_id: IndexVec::default(), - current_definition: None, - }; - indexer.visit_body(&module.body); - indexer.finish() - } - - fn resolve_expression_id<'a>( - &self, - ast: &'a ast::ModModule, - expression_id: ExpressionId, - ) -> ast::AnyNodeRef<'a> { - let node_key = self.expressions_by_id[expression_id]; - node_key - .resolve(ast.as_any_node_ref()) - .expect("node to resolve") - } - - /// Return an iterator over all definitions of `symbol_id` reachable from `use_expr`. The value - /// of `symbol_id` in `use_expr` must originate from one of the iterated definitions (or from - /// an external reassignment of the name outside of this scope). - pub fn reachable_definitions( - &self, - symbol_id: SymbolId, - use_expr: &ast::Expr, - ) -> ReachableDefinitionsIterator { - let expression_id = self.expression_id(use_expr); - ReachableDefinitionsIterator::new( - &self.flow_graph, - symbol_id, - self.flow_graph.for_expr(expression_id), - ) - } - - pub fn expression_id(&self, expression: &ast::Expr) -> ExpressionId { - self.expressions[&NodeKey::from_node(expression.into())] - } - - pub fn symbol_table(&self) -> &SymbolTable { - &self.symbol_table - } -} - -#[derive(Debug)] -struct ScopeState { - scope_id: ScopeId, - current_flow_node_id: FlowNodeId, -} - -#[derive(Debug)] -struct SemanticIndexer { - symbol_table_builder: SymbolTableBuilder, - flow_graph_builder: FlowGraphBuilder, - scopes: Vec, - /// the definition whose target(s) we are currently walking - current_definition: Option, - expressions: FxHashMap, - expressions_by_id: IndexVec, -} - -impl SemanticIndexer { - pub(crate) fn finish(mut self) -> SemanticIndex { - let SemanticIndexer { - flow_graph_builder, - symbol_table_builder, - .. - } = self; - self.expressions.shrink_to_fit(); - self.expressions_by_id.shrink_to_fit(); - SemanticIndex { - flow_graph: flow_graph_builder.finish(), - symbol_table: symbol_table_builder.finish(), - expressions: self.expressions, - expressions_by_id: self.expressions_by_id, - } - } - - fn set_current_flow_node(&mut self, new_flow_node_id: FlowNodeId) { - let scope_state = self.scopes.last_mut().expect("scope stack is never empty"); - scope_state.current_flow_node_id = new_flow_node_id; - } - - fn current_flow_node(&self) -> FlowNodeId { - self.scopes - .last() - .expect("scope stack is never empty") - .current_flow_node_id - } - - fn add_or_update_symbol(&mut self, identifier: &str, flags: SymbolFlags) -> SymbolId { - self.symbol_table_builder - .add_or_update_symbol(self.cur_scope(), identifier, flags) - } - - fn add_or_update_symbol_with_def( - &mut self, - identifier: &str, - definition: Definition, - ) -> SymbolId { - let symbol_id = self.add_or_update_symbol(identifier, SymbolFlags::IS_DEFINED); - self.symbol_table_builder - .add_definition(symbol_id, definition.clone()); - let new_flow_node_id = - self.flow_graph_builder - .add_definition(symbol_id, definition, self.current_flow_node()); - self.set_current_flow_node(new_flow_node_id); - symbol_id - } - - fn push_scope( - &mut self, - name: &str, - kind: ScopeKind, - definition: Option, - defining_symbol: Option, - ) -> ScopeId { - let scope_id = self.symbol_table_builder.add_child_scope( - self.cur_scope(), - name, - kind, - definition, - defining_symbol, - ); - self.scopes.push(ScopeState { - scope_id, - current_flow_node_id: FlowGraph::start(), - }); - scope_id - } - - fn pop_scope(&mut self) -> ScopeId { - self.scopes - .pop() - .expect("Scope stack should never be empty") - .scope_id - } - - fn cur_scope(&self) -> ScopeId { - self.scopes - .last() - .expect("Scope stack should never be empty") - .scope_id - } - - fn record_scope_for_node(&mut self, node_key: NodeKey, scope_id: ScopeId) { - self.symbol_table_builder - .record_scope_for_node(node_key, scope_id); - } - - fn insert_constraint(&mut self, expr: &ast::Expr) { - let node_key = NodeKey::from_node(expr.into()); - let expression_id = self.expressions[&node_key]; - let constraint = self - .flow_graph_builder - .add_constraint(self.current_flow_node(), expression_id); - self.set_current_flow_node(constraint); - } - - fn with_type_params( - &mut self, - name: &str, - params: &Option>, - definition: Option, - defining_symbol: Option, - nested: impl FnOnce(&mut Self) -> ScopeId, - ) -> ScopeId { - if let Some(type_params) = params { - self.push_scope(name, ScopeKind::Annotation, definition, defining_symbol); - for type_param in &type_params.type_params { - let name = match type_param { - ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name, - ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name, - ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name, - }; - self.add_or_update_symbol(name, SymbolFlags::IS_DEFINED); - } - } - let scope_id = nested(self); - if params.is_some() { - self.pop_scope(); - } - scope_id - } -} - -impl SourceOrderVisitor<'_> for SemanticIndexer { - fn visit_expr(&mut self, expr: &ast::Expr) { - let node_key = NodeKey::from_node(expr.into()); - let expression_id = self.expressions_by_id.push(node_key); - - let flow_expression_id = self - .flow_graph_builder - .record_expr(self.current_flow_node()); - debug_assert_eq!(expression_id, flow_expression_id); - - let symbol_expression_id = self - .symbol_table_builder - .record_expression(self.cur_scope()); - - debug_assert_eq!(expression_id, symbol_expression_id); - - self.expressions.insert(node_key, expression_id); - - match expr { - ast::Expr::Name(ast::ExprName { id, ctx, .. }) => { - let flags = match ctx { - ast::ExprContext::Load => SymbolFlags::IS_USED, - ast::ExprContext::Store => SymbolFlags::IS_DEFINED, - ast::ExprContext::Del => SymbolFlags::IS_DEFINED, - ast::ExprContext::Invalid => SymbolFlags::empty(), - }; - self.add_or_update_symbol(id, flags); - if flags.contains(SymbolFlags::IS_DEFINED) { - if let Some(curdef) = self.current_definition.clone() { - self.add_or_update_symbol_with_def(id, curdef); - } - } - ast::visitor::source_order::walk_expr(self, expr); - } - ast::Expr::Named(node) => { - debug_assert!(self.current_definition.is_none()); - self.current_definition = - Some(Definition::NamedExpr(TypedNodeKey::from_node(node))); - // TODO walrus in comprehensions is implicitly nonlocal - self.visit_expr(&node.target); - self.current_definition = None; - self.visit_expr(&node.value); - } - ast::Expr::If(ast::ExprIf { - body, test, orelse, .. - }) => { - // TODO detect statically known truthy or falsy test (via type inference, not naive - // AST inspection, so we can't simplify here, need to record test expression in CFG - // for later checking) - - self.visit_expr(test); - - let if_branch = self.flow_graph_builder.add_branch(self.current_flow_node()); - - self.set_current_flow_node(if_branch); - self.insert_constraint(test); - self.visit_expr(body); - - let post_body = self.current_flow_node(); - - self.set_current_flow_node(if_branch); - self.visit_expr(orelse); - - let post_else = self - .flow_graph_builder - .add_phi(self.current_flow_node(), post_body); - - self.set_current_flow_node(post_else); - } - _ => { - ast::visitor::source_order::walk_expr(self, expr); - } - } - } - - fn visit_stmt(&mut self, stmt: &ast::Stmt) { - // TODO need to capture more definition statements here - match stmt { - ast::Stmt::ClassDef(node) => { - let node_key = TypedNodeKey::from_node(node); - let def = Definition::ClassDef(node_key.clone()); - let symbol_id = self.add_or_update_symbol_with_def(&node.name, def.clone()); - for decorator in &node.decorator_list { - self.visit_decorator(decorator); - } - let scope_id = self.with_type_params( - &node.name, - &node.type_params, - Some(def.clone()), - Some(symbol_id), - |indexer| { - if let Some(arguments) = &node.arguments { - indexer.visit_arguments(arguments); - } - let scope_id = indexer.push_scope( - &node.name, - ScopeKind::Class, - Some(def.clone()), - Some(symbol_id), - ); - indexer.visit_body(&node.body); - indexer.pop_scope(); - scope_id - }, - ); - self.record_scope_for_node(*node_key.erased(), scope_id); - } - ast::Stmt::FunctionDef(node) => { - let node_key = TypedNodeKey::from_node(node); - let def = Definition::FunctionDef(node_key.clone()); - let symbol_id = self.add_or_update_symbol_with_def(&node.name, def.clone()); - for decorator in &node.decorator_list { - self.visit_decorator(decorator); - } - let scope_id = self.with_type_params( - &node.name, - &node.type_params, - Some(def.clone()), - Some(symbol_id), - |indexer| { - indexer.visit_parameters(&node.parameters); - for expr in &node.returns { - indexer.visit_annotation(expr); - } - let scope_id = indexer.push_scope( - &node.name, - ScopeKind::Function, - Some(def.clone()), - Some(symbol_id), - ); - indexer.visit_body(&node.body); - indexer.pop_scope(); - scope_id - }, - ); - self.record_scope_for_node(*node_key.erased(), scope_id); - } - ast::Stmt::Import(ast::StmtImport { names, .. }) => { - for alias in names { - let symbol_name = if let Some(asname) = &alias.asname { - asname.id.as_str() - } else { - alias.name.id.split('.').next().unwrap() - }; - - let module = ModuleName::new(&alias.name.id).unwrap(); - - let def = Definition::Import(ImportDefinition { - module: module.clone(), - }); - self.add_or_update_symbol_with_def(symbol_name, def); - self.symbol_table_builder - .add_dependency(Dependency::Module(module)); - } - } - ast::Stmt::ImportFrom(ast::StmtImportFrom { - module, - names, - level, - .. - }) => { - let module = module.as_ref().and_then(|m| ModuleName::new(&m.id)); - - for alias in names { - let symbol_name = if let Some(asname) = &alias.asname { - asname.id.as_str() - } else { - alias.name.id.as_str() - }; - let def = Definition::ImportFrom(ImportFromDefinition { - module: module.clone(), - name: alias.name.id.clone(), - level: *level, - }); - self.add_or_update_symbol_with_def(symbol_name, def); - } - - let dependency = if let Some(module) = module { - match NonZeroU32::new(*level) { - Some(level) => Dependency::Relative { - level, - module: Some(module), - }, - None => Dependency::Module(module), - } - } else { - Dependency::Relative { - level: NonZeroU32::new(*level) - .expect("Import without a module to have a level > 0"), - module, - } - }; - - self.symbol_table_builder.add_dependency(dependency); - } - ast::Stmt::Assign(node) => { - debug_assert!(self.current_definition.is_none()); - self.visit_expr(&node.value); - self.current_definition = - Some(Definition::Assignment(TypedNodeKey::from_node(node))); - for expr in &node.targets { - self.visit_expr(expr); - } - - self.current_definition = None; - } - ast::Stmt::If(node) => { - // TODO detect statically known truthy or falsy test (via type inference, not naive - // AST inspection, so we can't simplify here, need to record test expression in CFG - // for later checking) - - // we visit the if "test" condition first regardless - self.visit_expr(&node.test); - - // create branch node: does the if test pass or not? - let if_branch = self.flow_graph_builder.add_branch(self.current_flow_node()); - - // visit the body of the `if` clause - self.set_current_flow_node(if_branch); - self.insert_constraint(&node.test); - self.visit_body(&node.body); - - // Flow node for the last if/elif condition branch; represents the "no branch - // taken yet" possibility (where "taking a branch" means that the condition in an - // if or elif evaluated to true and control flow went into that clause). - let mut prior_branch = if_branch; - - // Flow node for the state after the prior if/elif/else clause; represents "we have - // taken one of the branches up to this point." Initially set to the post-if-clause - // state, later will be set to the phi node joining that possible path with the - // possibility that we took a later if/elif/else clause instead. - let mut post_prior_clause = self.current_flow_node(); - - // Flag to mark if the final clause is an "else" -- if so, that means the "match no - // clauses" path is not possible, we have to go through one of the clauses. - let mut last_branch_is_else = false; - - for clause in &node.elif_else_clauses { - if let Some(test) = &clause.test { - self.visit_expr(test); - // This is an elif clause. Create a new branch node. Its predecessor is the - // previous branch node, because we can only take one branch in an entire - // if/elif/else chain, so if we take this branch, it can only be because we - // didn't take the previous one. - prior_branch = self.flow_graph_builder.add_branch(prior_branch); - self.set_current_flow_node(prior_branch); - self.insert_constraint(test); - } else { - // This is an else clause. No need to create a branch node; there's no - // branch here, if we haven't taken any previous branch, we definitely go - // into the "else" clause. - self.set_current_flow_node(prior_branch); - last_branch_is_else = true; - } - self.visit_elif_else_clause(clause); - // Update `post_prior_clause` to a new phi node joining the possibility that we - // took any of the previous branches with the possibility that we took the one - // just visited. - post_prior_clause = self - .flow_graph_builder - .add_phi(self.current_flow_node(), post_prior_clause); - } - - if !last_branch_is_else { - // Final branch was not an "else", which means it's possible we took zero - // branches in the entire if/elif chain, so we need one more phi node to join - // the "no branches taken" possibility. - post_prior_clause = self - .flow_graph_builder - .add_phi(post_prior_clause, prior_branch); - } - - // Onward, with current flow node set to our final Phi node. - self.set_current_flow_node(post_prior_clause); - } - _ => { - ast::visitor::source_order::walk_stmt(self, stmt); - } - } - } -} - -#[derive(Debug, Default)] -pub struct SemanticIndexStorage(KeyValueCache>); - -impl Deref for SemanticIndexStorage { - type Target = KeyValueCache>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for SemanticIndexStorage { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -#[cfg(test)] -mod tests { - use crate::semantic::symbol_table::{Symbol, SymbolIterator}; - use ruff_python_ast as ast; - use ruff_python_ast::ModModule; - use ruff_python_parser::{Mode, Parsed}; - - use super::{Definition, ScopeKind, SemanticIndex, SymbolId}; - - fn parse(code: &str) -> Parsed { - ruff_python_parser::parse_unchecked(code, Mode::Module) - .try_into_module() - .unwrap() - } - - fn names(it: SymbolIterator) -> Vec<&str> - where - I: Iterator, - { - let mut symbols: Vec<_> = it.map(Symbol::name).collect(); - symbols.sort_unstable(); - symbols - } - - #[test] - fn empty() { - let parsed = parse(""); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()).len(), 0); - } - - #[test] - fn simple() { - let parsed = parse("x"); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["x"]); - assert_eq!( - table - .definitions(table.root_symbol_id_by_name("x").unwrap()) - .len(), - 0 - ); - } - - #[test] - fn annotation_only() { - let parsed = parse("x: int"); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["int", "x"]); - // TODO record definition - } - - #[test] - fn import() { - let parsed = parse("import foo"); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["foo"]); - assert_eq!( - table - .definitions(table.root_symbol_id_by_name("foo").unwrap()) - .len(), - 1 - ); - } - - #[test] - fn import_sub() { - let parsed = parse("import foo.bar"); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["foo"]); - } - - #[test] - fn import_as() { - let parsed = parse("import foo.bar as baz"); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["baz"]); - } - - #[test] - fn import_from() { - let parsed = parse("from bar import foo"); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["foo"]); - assert_eq!( - table - .definitions(table.root_symbol_id_by_name("foo").unwrap()) - .len(), - 1 - ); - assert!( - table.root_symbol_id_by_name("foo").is_some_and(|sid| { - let s = sid.symbol(&table); - s.is_defined() || !s.is_used() - }), - "symbols that are defined get the defined flag" - ); - } - - #[test] - fn assign() { - let parsed = parse("x = foo"); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["foo", "x"]); - assert_eq!( - table - .definitions(table.root_symbol_id_by_name("x").unwrap()) - .len(), - 1 - ); - assert!( - table.root_symbol_id_by_name("foo").is_some_and(|sid| { - let s = sid.symbol(&table); - !s.is_defined() && s.is_used() - }), - "a symbol used but not defined in a scope should have only the used flag" - ); - } - - #[test] - fn class_scope() { - let parsed = parse( - " - class C: - x = 1 - y = 2 - ", - ); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["C", "y"]); - let scopes = table.root_child_scope_ids(); - assert_eq!(scopes.len(), 1); - let c_scope = scopes[0].scope(&table); - assert_eq!(c_scope.kind(), ScopeKind::Class); - assert_eq!(c_scope.name(), "C"); - assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); - assert_eq!( - table - .definitions(table.root_symbol_id_by_name("C").unwrap()) - .len(), - 1 - ); - } - - #[test] - fn func_scope() { - let parsed = parse( - " - def func(): - x = 1 - y = 2 - ", - ); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["func", "y"]); - let scopes = table.root_child_scope_ids(); - assert_eq!(scopes.len(), 1); - let func_scope = scopes[0].scope(&table); - assert_eq!(func_scope.kind(), ScopeKind::Function); - assert_eq!(func_scope.name(), "func"); - assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); - assert_eq!( - table - .definitions(table.root_symbol_id_by_name("func").unwrap()) - .len(), - 1 - ); - } - - #[test] - fn dupes() { - let parsed = parse( - " - def func(): - x = 1 - def func(): - y = 2 - ", - ); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["func"]); - let scopes = table.root_child_scope_ids(); - assert_eq!(scopes.len(), 2); - let func_scope_1 = scopes[0].scope(&table); - let func_scope_2 = scopes[1].scope(&table); - assert_eq!(func_scope_1.kind(), ScopeKind::Function); - assert_eq!(func_scope_1.name(), "func"); - assert_eq!(func_scope_2.kind(), ScopeKind::Function); - assert_eq!(func_scope_2.name(), "func"); - assert_eq!(names(table.symbols_for_scope(scopes[0])), vec!["x"]); - assert_eq!(names(table.symbols_for_scope(scopes[1])), vec!["y"]); - assert_eq!( - table - .definitions(table.root_symbol_id_by_name("func").unwrap()) - .len(), - 2 - ); - } - - #[test] - fn generic_func() { - let parsed = parse( - " - def func[T](): - x = 1 - ", - ); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["func"]); - let scopes = table.root_child_scope_ids(); - assert_eq!(scopes.len(), 1); - let ann_scope_id = scopes[0]; - let ann_scope = ann_scope_id.scope(&table); - assert_eq!(ann_scope.kind(), ScopeKind::Annotation); - assert_eq!(ann_scope.name(), "func"); - assert_eq!(names(table.symbols_for_scope(ann_scope_id)), vec!["T"]); - let scopes = table.child_scope_ids_of(ann_scope_id); - assert_eq!(scopes.len(), 1); - let func_scope_id = scopes[0]; - let func_scope = func_scope_id.scope(&table); - assert_eq!(func_scope.kind(), ScopeKind::Function); - assert_eq!(func_scope.name(), "func"); - assert_eq!(names(table.symbols_for_scope(func_scope_id)), vec!["x"]); - } - - #[test] - fn generic_class() { - let parsed = parse( - " - class C[T]: - x = 1 - ", - ); - let table = SemanticIndex::from_ast(parsed.syntax()).symbol_table; - assert_eq!(names(table.root_symbols()), vec!["C"]); - let scopes = table.root_child_scope_ids(); - assert_eq!(scopes.len(), 1); - let ann_scope_id = scopes[0]; - let ann_scope = ann_scope_id.scope(&table); - assert_eq!(ann_scope.kind(), ScopeKind::Annotation); - assert_eq!(ann_scope.name(), "C"); - assert_eq!(names(table.symbols_for_scope(ann_scope_id)), vec!["T"]); - assert!( - table - .symbol_by_name(ann_scope_id, "T") - .is_some_and(|s| s.is_defined() && !s.is_used()), - "type parameters are defined by the scope that introduces them" - ); - let scopes = table.child_scope_ids_of(ann_scope_id); - assert_eq!(scopes.len(), 1); - let func_scope_id = scopes[0]; - let func_scope = func_scope_id.scope(&table); - assert_eq!(func_scope.kind(), ScopeKind::Class); - assert_eq!(func_scope.name(), "C"); - assert_eq!(names(table.symbols_for_scope(func_scope_id)), vec!["x"]); - } - - #[test] - fn reachability_trivial() { - let parsed = parse("x = 1; x"); - let ast = parsed.syntax(); - let index = SemanticIndex::from_ast(ast); - let table = &index.symbol_table; - let x_sym = table - .root_symbol_id_by_name("x") - .expect("x symbol should exist"); - let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else { - panic!("should be an expr") - }; - let x_defs: Vec<_> = index - .reachable_definitions(x_sym, x_use) - .map(|constrained_definition| constrained_definition.definition) - .collect(); - assert_eq!(x_defs.len(), 1); - let Definition::Assignment(node_key) = &x_defs[0] else { - panic!("def should be an assignment") - }; - let Some(def_node) = node_key.resolve(ast.into()) else { - panic!("node key should resolve") - }; - let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { - value: ast::Number::Int(num), - .. - }) = &*def_node.value - else { - panic!("should be a number literal") - }; - assert_eq!(*num, 1); - } - - #[test] - fn expression_scope() { - let parsed = parse("x = 1;\ndef test():\n y = 4"); - let ast = parsed.syntax(); - let index = SemanticIndex::from_ast(ast); - let table = &index.symbol_table; - - let x_sym = table - .root_symbol_by_name("x") - .expect("x symbol should exist"); - - let x_stmt = ast.body[0].as_assign_stmt().unwrap(); - - let x_id = index.expression_id(&x_stmt.targets[0]); - - assert_eq!(table.scope_of_expression(x_id).kind(), ScopeKind::Module); - assert_eq!(table.scope_id_of_expression(x_id), x_sym.scope_id()); - - let def = ast.body[1].as_function_def_stmt().unwrap(); - let y_stmt = def.body[0].as_assign_stmt().unwrap(); - let y_id = index.expression_id(&y_stmt.targets[0]); - - assert_eq!(table.scope_of_expression(y_id).kind(), ScopeKind::Function); - } -} diff --git a/crates/red_knot/src/semantic/definitions.rs b/crates/red_knot/src/semantic/definitions.rs deleted file mode 100644 index 112e9d03b9..0000000000 --- a/crates/red_knot/src/semantic/definitions.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::ast_ids::TypedNodeKey; -use red_knot_module_resolver::ModuleName; -use ruff_python_ast as ast; -use ruff_python_ast::name::Name; - -// TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST; -// this is at best O(log n). If looking up definitions is a bottleneck we should look for -// alternatives here. -// TODO intern Definitions in SymbolTable and reference using IDs? -#[derive(Clone, Debug)] -pub enum Definition { - // For the import cases, we don't need reference to any arbitrary AST subtrees (annotations, - // RHS), and referencing just the import statement node is imprecise (a single import statement - // can assign many symbols, we'd have to re-search for the one we care about), so we just copy - // the small amount of information we need from the AST. - Import(ImportDefinition), - ImportFrom(ImportFromDefinition), - ClassDef(TypedNodeKey), - FunctionDef(TypedNodeKey), - Assignment(TypedNodeKey), - AnnotatedAssignment(TypedNodeKey), - NamedExpr(TypedNodeKey), - /// represents the implicit initial definition of every name as "unbound" - Unbound, - // TODO with statements, except handlers, function args... -} - -#[derive(Clone, Debug)] -pub struct ImportDefinition { - pub module: ModuleName, -} - -#[derive(Clone, Debug)] -pub struct ImportFromDefinition { - pub module: Option, - pub name: Name, - pub level: u32, -} - -impl ImportFromDefinition { - pub fn module(&self) -> Option<&ModuleName> { - self.module.as_ref() - } - - pub fn name(&self) -> &Name { - &self.name - } - - pub fn level(&self) -> u32 { - self.level - } -} diff --git a/crates/red_knot/src/semantic/flow_graph.rs b/crates/red_knot/src/semantic/flow_graph.rs deleted file mode 100644 index 6277dba085..0000000000 --- a/crates/red_knot/src/semantic/flow_graph.rs +++ /dev/null @@ -1,270 +0,0 @@ -use super::symbol_table::SymbolId; -use crate::semantic::{Definition, ExpressionId}; -use ruff_index::{newtype_index, IndexVec}; -use std::iter::FusedIterator; -use std::ops::Range; - -#[newtype_index] -pub struct FlowNodeId; - -#[derive(Debug)] -pub(crate) enum FlowNode { - Start, - Definition(DefinitionFlowNode), - Branch(BranchFlowNode), - Phi(PhiFlowNode), - Constraint(ConstraintFlowNode), -} - -/// A point in control flow where a symbol is defined -#[derive(Debug)] -pub(crate) struct DefinitionFlowNode { - symbol_id: SymbolId, - definition: Definition, - predecessor: FlowNodeId, -} - -/// A branch in control flow -#[derive(Debug)] -pub(crate) struct BranchFlowNode { - predecessor: FlowNodeId, -} - -/// A join point where control flow paths come together -#[derive(Debug)] -pub(crate) struct PhiFlowNode { - first_predecessor: FlowNodeId, - second_predecessor: FlowNodeId, -} - -/// A branch test which may apply constraints to a symbol's type -#[derive(Debug)] -pub(crate) struct ConstraintFlowNode { - predecessor: FlowNodeId, - test_expression: ExpressionId, -} - -#[derive(Debug)] -pub struct FlowGraph { - flow_nodes_by_id: IndexVec, - expression_map: IndexVec, -} - -impl FlowGraph { - pub fn start() -> FlowNodeId { - FlowNodeId::from_usize(0) - } - - pub fn for_expr(&self, expr: ExpressionId) -> FlowNodeId { - self.expression_map[expr] - } -} - -#[derive(Debug)] -pub(crate) struct FlowGraphBuilder { - flow_graph: FlowGraph, -} - -impl FlowGraphBuilder { - pub(crate) fn new() -> Self { - let mut graph = FlowGraph { - flow_nodes_by_id: IndexVec::default(), - expression_map: IndexVec::default(), - }; - graph.flow_nodes_by_id.push(FlowNode::Start); - Self { flow_graph: graph } - } - - pub(crate) fn add(&mut self, node: FlowNode) -> FlowNodeId { - self.flow_graph.flow_nodes_by_id.push(node) - } - - pub(crate) fn add_definition( - &mut self, - symbol_id: SymbolId, - definition: Definition, - predecessor: FlowNodeId, - ) -> FlowNodeId { - self.add(FlowNode::Definition(DefinitionFlowNode { - symbol_id, - definition, - predecessor, - })) - } - - pub(crate) fn add_branch(&mut self, predecessor: FlowNodeId) -> FlowNodeId { - self.add(FlowNode::Branch(BranchFlowNode { predecessor })) - } - - pub(crate) fn add_phi( - &mut self, - first_predecessor: FlowNodeId, - second_predecessor: FlowNodeId, - ) -> FlowNodeId { - self.add(FlowNode::Phi(PhiFlowNode { - first_predecessor, - second_predecessor, - })) - } - - pub(crate) fn add_constraint( - &mut self, - predecessor: FlowNodeId, - test_expression: ExpressionId, - ) -> FlowNodeId { - self.add(FlowNode::Constraint(ConstraintFlowNode { - predecessor, - test_expression, - })) - } - - pub(super) fn record_expr(&mut self, node_id: FlowNodeId) -> ExpressionId { - self.flow_graph.expression_map.push(node_id) - } - - pub(super) fn finish(mut self) -> FlowGraph { - self.flow_graph.flow_nodes_by_id.shrink_to_fit(); - self.flow_graph.expression_map.shrink_to_fit(); - self.flow_graph - } -} - -/// A definition, and the set of constraints between a use and the definition -#[derive(Debug, Clone)] -pub struct ConstrainedDefinition { - pub definition: Definition, - pub constraints: Vec, -} - -/// A flow node and the constraints we passed through to reach it -#[derive(Debug)] -struct FlowState { - node_id: FlowNodeId, - constraints_range: Range, -} - -#[derive(Debug)] -pub struct ReachableDefinitionsIterator<'a> { - flow_graph: &'a FlowGraph, - symbol_id: SymbolId, - pending: Vec, - constraints: Vec, -} - -impl<'a> ReachableDefinitionsIterator<'a> { - pub fn new(flow_graph: &'a FlowGraph, symbol_id: SymbolId, start_node_id: FlowNodeId) -> Self { - Self { - flow_graph, - symbol_id, - pending: vec![FlowState { - node_id: start_node_id, - constraints_range: 0..0, - }], - constraints: vec![], - } - } -} - -impl<'a> Iterator for ReachableDefinitionsIterator<'a> { - type Item = ConstrainedDefinition; - - fn next(&mut self) -> Option { - let FlowState { - mut node_id, - mut constraints_range, - } = self.pending.pop()?; - self.constraints.truncate(constraints_range.end + 1); - loop { - match &self.flow_graph.flow_nodes_by_id[node_id] { - FlowNode::Start => { - // constraints on unbound are irrelevant - return Some(ConstrainedDefinition { - definition: Definition::Unbound, - constraints: vec![], - }); - } - FlowNode::Definition(def_node) => { - if def_node.symbol_id == self.symbol_id { - return Some(ConstrainedDefinition { - definition: def_node.definition.clone(), - constraints: self.constraints[constraints_range].to_vec(), - }); - } - node_id = def_node.predecessor; - } - FlowNode::Branch(branch_node) => { - node_id = branch_node.predecessor; - } - FlowNode::Phi(phi_node) => { - self.pending.push(FlowState { - node_id: phi_node.first_predecessor, - constraints_range: constraints_range.clone(), - }); - node_id = phi_node.second_predecessor; - } - FlowNode::Constraint(constraint_node) => { - node_id = constraint_node.predecessor; - self.constraints.push(constraint_node.test_expression); - constraints_range.end += 1; - } - } - } - } -} - -impl<'a> FusedIterator for ReachableDefinitionsIterator<'a> {} - -impl std::fmt::Display for FlowGraph { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - writeln!(f, "flowchart TD")?; - for (id, node) in self.flow_nodes_by_id.iter_enumerated() { - write!(f, " id{}", id.as_u32())?; - match node { - FlowNode::Start => writeln!(f, r"[\Start/]")?, - FlowNode::Definition(def_node) => { - writeln!(f, r"(Define symbol {})", def_node.symbol_id.as_u32())?; - writeln!( - f, - r" id{}-->id{}", - def_node.predecessor.as_u32(), - id.as_u32() - )?; - } - FlowNode::Branch(branch_node) => { - writeln!(f, r"{{Branch}}")?; - writeln!( - f, - r" id{}-->id{}", - branch_node.predecessor.as_u32(), - id.as_u32() - )?; - } - FlowNode::Phi(phi_node) => { - writeln!(f, r"((Phi))")?; - writeln!( - f, - r" id{}-->id{}", - phi_node.second_predecessor.as_u32(), - id.as_u32() - )?; - writeln!( - f, - r" id{}-->id{}", - phi_node.first_predecessor.as_u32(), - id.as_u32() - )?; - } - FlowNode::Constraint(constraint_node) => { - writeln!(f, r"((Constraint))")?; - writeln!( - f, - r" id{}-->id{}", - constraint_node.predecessor.as_u32(), - id.as_u32() - )?; - } - } - } - Ok(()) - } -} diff --git a/crates/red_knot/src/semantic/symbol_table.rs b/crates/red_knot/src/semantic/symbol_table.rs deleted file mode 100644 index 9bca6ce0b8..0000000000 --- a/crates/red_knot/src/semantic/symbol_table.rs +++ /dev/null @@ -1,560 +0,0 @@ -#![allow(dead_code)] - -use std::hash::{Hash, Hasher}; -use std::iter::{Copied, DoubleEndedIterator, FusedIterator}; -use std::num::NonZeroU32; - -use bitflags::bitflags; -use hashbrown::hash_map::{Keys, RawEntryMut}; -use red_knot_module_resolver::ModuleName; -use rustc_hash::{FxHashMap, FxHasher}; - -use ruff_index::{newtype_index, IndexVec}; -use ruff_python_ast::name::Name; - -use crate::ast_ids::NodeKey; -use crate::semantic::{Definition, ExpressionId}; - -type Map = hashbrown::HashMap; - -#[newtype_index] -pub struct ScopeId; - -impl ScopeId { - pub fn scope(self, table: &SymbolTable) -> &Scope { - &table.scopes_by_id[self] - } -} - -#[newtype_index] -pub struct SymbolId; - -impl SymbolId { - pub fn symbol(self, table: &SymbolTable) -> &Symbol { - &table.symbols_by_id[self] - } -} - -#[derive(Copy, Clone, Debug, PartialEq)] -pub enum ScopeKind { - Module, - Annotation, - Class, - Function, -} - -#[derive(Debug)] -pub struct Scope { - name: Name, - kind: ScopeKind, - parent: Option, - children: Vec, - /// the definition (e.g. class or function) that created this scope - definition: Option, - /// the symbol (e.g. class or function) that owns this scope - defining_symbol: Option, - /// symbol IDs, hashed by symbol name - symbols_by_name: Map, -} - -impl Scope { - pub fn name(&self) -> &str { - self.name.as_str() - } - - pub fn kind(&self) -> ScopeKind { - self.kind - } - - pub fn definition(&self) -> Option { - self.definition.clone() - } - - pub fn defining_symbol(&self) -> Option { - self.defining_symbol - } -} - -#[derive(Debug)] -pub(crate) enum Kind { - FreeVar, - CellVar, - CellVarAssigned, - ExplicitGlobal, - ImplicitGlobal, -} - -bitflags! { - #[derive(Copy,Clone,Debug)] - pub struct SymbolFlags: u8 { - const IS_USED = 1 << 0; - const IS_DEFINED = 1 << 1; - /// TODO: This flag is not yet set by anything - const MARKED_GLOBAL = 1 << 2; - /// TODO: This flag is not yet set by anything - const MARKED_NONLOCAL = 1 << 3; - } -} - -#[derive(Debug)] -pub struct Symbol { - name: Name, - flags: SymbolFlags, - scope_id: ScopeId, - // kind: Kind, -} - -impl Symbol { - pub fn name(&self) -> &str { - self.name.as_str() - } - - pub fn scope_id(&self) -> ScopeId { - self.scope_id - } - - /// Is the symbol used in its containing scope? - pub fn is_used(&self) -> bool { - self.flags.contains(SymbolFlags::IS_USED) - } - - /// Is the symbol defined in its containing scope? - pub fn is_defined(&self) -> bool { - self.flags.contains(SymbolFlags::IS_DEFINED) - } - - // TODO: implement Symbol.kind 2-pass analysis to categorize as: free-var, cell-var, - // explicit-global, implicit-global and implement Symbol.kind by modifying the preorder - // traversal code -} - -#[derive(Debug, Clone)] -pub enum Dependency { - Module(ModuleName), - Relative { - level: NonZeroU32, - module: Option, - }, -} - -/// Table of all symbols in all scopes for a module. -#[derive(Debug)] -pub struct SymbolTable { - scopes_by_id: IndexVec, - symbols_by_id: IndexVec, - /// the definitions for each symbol - defs: FxHashMap>, - /// map of AST node (e.g. class/function def) to sub-scope it creates - scopes_by_node: FxHashMap, - /// Maps expressions to their enclosing scope. - expression_scopes: IndexVec, - /// dependencies of this module - dependencies: Vec, -} - -impl SymbolTable { - pub fn dependencies(&self) -> &[Dependency] { - &self.dependencies - } - - pub const fn root_scope_id() -> ScopeId { - ScopeId::from_usize(0) - } - - pub fn root_scope(&self) -> &Scope { - &self.scopes_by_id[SymbolTable::root_scope_id()] - } - - pub fn symbol_ids_for_scope(&self, scope_id: ScopeId) -> Copied> { - self.scopes_by_id[scope_id].symbols_by_name.keys().copied() - } - - pub fn symbols_for_scope( - &self, - scope_id: ScopeId, - ) -> SymbolIterator>> { - SymbolIterator { - table: self, - ids: self.symbol_ids_for_scope(scope_id), - } - } - - pub fn root_symbol_ids(&self) -> Copied> { - self.symbol_ids_for_scope(SymbolTable::root_scope_id()) - } - - pub fn root_symbols(&self) -> SymbolIterator>> { - self.symbols_for_scope(SymbolTable::root_scope_id()) - } - - pub fn child_scope_ids_of(&self, scope_id: ScopeId) -> &[ScopeId] { - &self.scopes_by_id[scope_id].children - } - - pub fn child_scopes_of(&self, scope_id: ScopeId) -> ScopeIterator<&[ScopeId]> { - ScopeIterator { - table: self, - ids: self.child_scope_ids_of(scope_id), - } - } - - pub fn root_child_scope_ids(&self) -> &[ScopeId] { - self.child_scope_ids_of(SymbolTable::root_scope_id()) - } - - pub fn root_child_scopes(&self) -> ScopeIterator<&[ScopeId]> { - self.child_scopes_of(SymbolTable::root_scope_id()) - } - - pub fn symbol_id_by_name(&self, scope_id: ScopeId, name: &str) -> Option { - let scope = &self.scopes_by_id[scope_id]; - let hash = SymbolTable::hash_name(name); - let name = Name::new(name); - Some( - *scope - .symbols_by_name - .raw_entry() - .from_hash(hash, |symid| self.symbols_by_id[*symid].name == name)? - .0, - ) - } - - pub fn symbol_by_name(&self, scope_id: ScopeId, name: &str) -> Option<&Symbol> { - Some(&self.symbols_by_id[self.symbol_id_by_name(scope_id, name)?]) - } - - pub fn root_symbol_id_by_name(&self, name: &str) -> Option { - self.symbol_id_by_name(SymbolTable::root_scope_id(), name) - } - - pub fn root_symbol_by_name(&self, name: &str) -> Option<&Symbol> { - self.symbol_by_name(SymbolTable::root_scope_id(), name) - } - - pub fn scope_id_of_symbol(&self, symbol_id: SymbolId) -> ScopeId { - self.symbols_by_id[symbol_id].scope_id - } - - pub fn scope_of_symbol(&self, symbol_id: SymbolId) -> &Scope { - &self.scopes_by_id[self.scope_id_of_symbol(symbol_id)] - } - - pub fn scope_id_of_expression(&self, expression: ExpressionId) -> ScopeId { - self.expression_scopes[expression] - } - - pub fn scope_of_expression(&self, expr_id: ExpressionId) -> &Scope { - &self.scopes_by_id[self.scope_id_of_expression(expr_id)] - } - - pub fn parent_scopes( - &self, - scope_id: ScopeId, - ) -> ScopeIterator + '_> { - ScopeIterator { - table: self, - ids: std::iter::successors(Some(scope_id), |scope| self.scopes_by_id[*scope].parent), - } - } - - pub fn parent_scope(&self, scope_id: ScopeId) -> Option { - self.scopes_by_id[scope_id].parent - } - - pub fn scope_id_for_node(&self, node_key: &NodeKey) -> ScopeId { - self.scopes_by_node[node_key] - } - - pub fn definitions(&self, symbol_id: SymbolId) -> &[Definition] { - self.defs - .get(&symbol_id) - .map(std::vec::Vec::as_slice) - .unwrap_or_default() - } - - pub fn all_definitions(&self) -> impl Iterator + '_ { - self.defs - .iter() - .flat_map(|(sym_id, defs)| defs.iter().map(move |def| (*sym_id, def))) - } - - fn hash_name(name: &str) -> u64 { - let mut hasher = FxHasher::default(); - name.hash(&mut hasher); - hasher.finish() - } -} - -pub struct SymbolIterator<'a, I> { - table: &'a SymbolTable, - ids: I, -} - -impl<'a, I> Iterator for SymbolIterator<'a, I> -where - I: Iterator, -{ - type Item = &'a Symbol; - - fn next(&mut self) -> Option { - let id = self.ids.next()?; - Some(&self.table.symbols_by_id[id]) - } - - fn size_hint(&self) -> (usize, Option) { - self.ids.size_hint() - } -} - -impl<'a, I> FusedIterator for SymbolIterator<'a, I> where - I: Iterator + FusedIterator -{ -} - -impl<'a, I> DoubleEndedIterator for SymbolIterator<'a, I> -where - I: Iterator + DoubleEndedIterator, -{ - fn next_back(&mut self) -> Option { - let id = self.ids.next_back()?; - Some(&self.table.symbols_by_id[id]) - } -} - -// TODO maybe get rid of this and just do all data access via methods on ScopeId? -pub struct ScopeIterator<'a, I> { - table: &'a SymbolTable, - ids: I, -} - -/// iterate (`ScopeId`, `Scope`) pairs for given `ScopeId` iterator -impl<'a, I> Iterator for ScopeIterator<'a, I> -where - I: Iterator, -{ - type Item = (ScopeId, &'a Scope); - - fn next(&mut self) -> Option { - let id = self.ids.next()?; - Some((id, &self.table.scopes_by_id[id])) - } - - fn size_hint(&self) -> (usize, Option) { - self.ids.size_hint() - } -} - -impl<'a, I> FusedIterator for ScopeIterator<'a, I> where I: Iterator + FusedIterator {} - -impl<'a, I> DoubleEndedIterator for ScopeIterator<'a, I> -where - I: Iterator + DoubleEndedIterator, -{ - fn next_back(&mut self) -> Option { - let id = self.ids.next_back()?; - Some((id, &self.table.scopes_by_id[id])) - } -} - -#[derive(Debug)] -pub(super) struct SymbolTableBuilder { - symbol_table: SymbolTable, -} - -impl SymbolTableBuilder { - pub(super) fn new() -> Self { - let mut table = SymbolTable { - scopes_by_id: IndexVec::new(), - symbols_by_id: IndexVec::new(), - defs: FxHashMap::default(), - scopes_by_node: FxHashMap::default(), - expression_scopes: IndexVec::new(), - dependencies: Vec::new(), - }; - table.scopes_by_id.push(Scope { - name: Name::new(""), - kind: ScopeKind::Module, - parent: None, - children: Vec::new(), - definition: None, - defining_symbol: None, - symbols_by_name: Map::default(), - }); - Self { - symbol_table: table, - } - } - - pub(super) fn finish(self) -> SymbolTable { - let mut symbol_table = self.symbol_table; - symbol_table.scopes_by_id.shrink_to_fit(); - symbol_table.symbols_by_id.shrink_to_fit(); - symbol_table.defs.shrink_to_fit(); - symbol_table.scopes_by_node.shrink_to_fit(); - symbol_table.expression_scopes.shrink_to_fit(); - symbol_table.dependencies.shrink_to_fit(); - symbol_table - } - - pub(super) fn add_or_update_symbol( - &mut self, - scope_id: ScopeId, - name: &str, - flags: SymbolFlags, - ) -> SymbolId { - let hash = SymbolTable::hash_name(name); - let scope = &mut self.symbol_table.scopes_by_id[scope_id]; - let name = Name::new(name); - - let entry = scope - .symbols_by_name - .raw_entry_mut() - .from_hash(hash, |existing| { - self.symbol_table.symbols_by_id[*existing].name == name - }); - - match entry { - RawEntryMut::Occupied(entry) => { - if let Some(symbol) = self.symbol_table.symbols_by_id.get_mut(*entry.key()) { - symbol.flags.insert(flags); - }; - *entry.key() - } - RawEntryMut::Vacant(entry) => { - let id = self.symbol_table.symbols_by_id.push(Symbol { - name, - flags, - scope_id, - }); - entry.insert_with_hasher(hash, id, (), |symid| { - SymbolTable::hash_name(&self.symbol_table.symbols_by_id[*symid].name) - }); - id - } - } - } - - pub(super) fn add_definition(&mut self, symbol_id: SymbolId, definition: Definition) { - self.symbol_table - .defs - .entry(symbol_id) - .or_default() - .push(definition); - } - - pub(super) fn add_child_scope( - &mut self, - parent_scope_id: ScopeId, - name: &str, - kind: ScopeKind, - definition: Option, - defining_symbol: Option, - ) -> ScopeId { - let new_scope_id = self.symbol_table.scopes_by_id.push(Scope { - name: Name::new(name), - kind, - parent: Some(parent_scope_id), - children: Vec::new(), - definition, - defining_symbol, - symbols_by_name: Map::default(), - }); - let parent_scope = &mut self.symbol_table.scopes_by_id[parent_scope_id]; - parent_scope.children.push(new_scope_id); - new_scope_id - } - - pub(super) fn record_scope_for_node(&mut self, node_key: NodeKey, scope_id: ScopeId) { - self.symbol_table.scopes_by_node.insert(node_key, scope_id); - } - - pub(super) fn add_dependency(&mut self, dependency: Dependency) { - self.symbol_table.dependencies.push(dependency); - } - - /// Records the scope for the current expression - pub(super) fn record_expression(&mut self, scope: ScopeId) -> ExpressionId { - self.symbol_table.expression_scopes.push(scope) - } -} - -#[cfg(test)] -mod tests { - use super::{ScopeKind, SymbolFlags, SymbolTable, SymbolTableBuilder}; - - #[test] - fn insert_same_name_symbol_twice() { - let mut builder = SymbolTableBuilder::new(); - let root_scope_id = SymbolTable::root_scope_id(); - let symbol_id_1 = - builder.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::IS_DEFINED); - let symbol_id_2 = builder.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::IS_USED); - let table = builder.finish(); - - assert_eq!(symbol_id_1, symbol_id_2); - assert!(symbol_id_1.symbol(&table).is_used(), "flags must merge"); - assert!(symbol_id_1.symbol(&table).is_defined(), "flags must merge"); - } - - #[test] - fn insert_different_named_symbols() { - let mut builder = SymbolTableBuilder::new(); - let root_scope_id = SymbolTable::root_scope_id(); - let symbol_id_1 = builder.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty()); - let symbol_id_2 = builder.add_or_update_symbol(root_scope_id, "bar", SymbolFlags::empty()); - - assert_ne!(symbol_id_1, symbol_id_2); - } - - #[test] - fn add_child_scope_with_symbol() { - let mut builder = SymbolTableBuilder::new(); - let root_scope_id = SymbolTable::root_scope_id(); - let foo_symbol_top = - builder.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty()); - let c_scope = builder.add_child_scope(root_scope_id, "C", ScopeKind::Class, None, None); - let foo_symbol_inner = builder.add_or_update_symbol(c_scope, "foo", SymbolFlags::empty()); - - assert_ne!(foo_symbol_top, foo_symbol_inner); - } - - #[test] - fn scope_from_id() { - let table = SymbolTableBuilder::new().finish(); - let root_scope_id = SymbolTable::root_scope_id(); - let scope = root_scope_id.scope(&table); - - assert_eq!(scope.name.as_str(), ""); - assert_eq!(scope.kind, ScopeKind::Module); - } - - #[test] - fn symbol_from_id() { - let mut builder = SymbolTableBuilder::new(); - let root_scope_id = SymbolTable::root_scope_id(); - let foo_symbol_id = - builder.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty()); - let table = builder.finish(); - let symbol = foo_symbol_id.symbol(&table); - - assert_eq!(symbol.name(), "foo"); - } - - #[test] - fn bigger_symbol_table() { - let mut builder = SymbolTableBuilder::new(); - let root_scope_id = SymbolTable::root_scope_id(); - let foo_symbol_id = - builder.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty()); - builder.add_or_update_symbol(root_scope_id, "bar", SymbolFlags::empty()); - builder.add_or_update_symbol(root_scope_id, "baz", SymbolFlags::empty()); - builder.add_or_update_symbol(root_scope_id, "qux", SymbolFlags::empty()); - let table = builder.finish(); - - let foo_symbol_id_2 = table - .root_symbol_id_by_name("foo") - .expect("foo symbol to be found"); - - assert_eq!(foo_symbol_id_2, foo_symbol_id); - } -} diff --git a/crates/red_knot/src/semantic/types.rs b/crates/red_knot/src/semantic/types.rs deleted file mode 100644 index 74960c4b50..0000000000 --- a/crates/red_knot/src/semantic/types.rs +++ /dev/null @@ -1,1111 +0,0 @@ -#![allow(dead_code)] -use crate::ast_ids::NodeKey; -use crate::db::{QueryResult, SemanticDb, SemanticJar}; -use crate::files::FileId; -use crate::module::Module; -use crate::semantic::{ - resolve_global_symbol, semantic_index, GlobalSymbolId, ScopeId, ScopeKind, SymbolId, -}; -use crate::{FxDashMap, FxIndexSet}; -use ruff_index::{newtype_index, IndexVec}; -use ruff_python_ast as ast; -use rustc_hash::FxHashMap; - -pub(crate) mod infer; - -pub(crate) use infer::{infer_definition_type, infer_symbol_public_type}; -use red_knot_module_resolver::ModuleName; -use ruff_python_ast::name::Name; - -/// unique ID for a type -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -pub enum Type { - /// the dynamic type: a statically-unknown set of values - Any, - /// the empty set of values - Never, - /// unknown type (no annotation) - /// equivalent to Any, or to object in strict mode - Unknown, - /// name is not bound to any value - Unbound, - /// the None object (TODO remove this in favor of Instance(types.NoneType) - None, - /// a specific function object - Function(FunctionTypeId), - /// a specific module object - Module(ModuleTypeId), - /// a specific class object - Class(ClassTypeId), - /// the set of Python objects with the given class in their __class__'s method resolution order - Instance(ClassTypeId), - Union(UnionTypeId), - Intersection(IntersectionTypeId), - IntLiteral(i64), - // TODO protocols, callable types, overloads, generics, type vars -} - -impl Type { - fn display<'a>(&'a self, store: &'a TypeStore) -> DisplayType<'a> { - DisplayType { ty: self, store } - } - - pub const fn is_unbound(&self) -> bool { - matches!(self, Type::Unbound) - } - - pub const fn is_unknown(&self) -> bool { - matches!(self, Type::Unknown) - } - - pub fn get_member(&self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { - match self { - Type::Any => Ok(Some(Type::Any)), - Type::Never => todo!("attribute lookup on Never type"), - Type::Unknown => Ok(Some(Type::Unknown)), - Type::Unbound => todo!("attribute lookup on Unbound type"), - Type::None => todo!("attribute lookup on None type"), - Type::Function(_) => todo!("attribute lookup on Function type"), - Type::Module(module_id) => module_id.get_member(db, name), - Type::Class(class_id) => class_id.get_class_member(db, name), - Type::Instance(_) => { - // TODO MRO? get_own_instance_member, get_instance_member - todo!("attribute lookup on Instance type") - } - Type::Union(union_id) => { - let jar: &SemanticJar = db.jar()?; - let _todo_union_ref = jar.type_store.get_union(*union_id); - // TODO perform the get_member on each type in the union - // TODO return the union of those results - // TODO if any of those results is `None` then include Unknown in the result union - todo!("attribute lookup on Union type") - } - Type::Intersection(_) => { - // TODO perform the get_member on each type in the intersection - // TODO return the intersection of those results - todo!("attribute lookup on Intersection type") - } - Type::IntLiteral(_) => { - // TODO raise error - Ok(Some(Type::Unknown)) - } - } - } - - // when this is fully fleshed out, it will use the db arg and may return QueryError - #[allow(clippy::unnecessary_wraps)] - pub fn resolve_bin_op( - &self, - _db: &dyn SemanticDb, - op: ast::Operator, - right_ty: Type, - ) -> QueryResult { - match self { - Type::Any => Ok(Type::Any), - Type::Unknown => Ok(Type::Unknown), - Type::IntLiteral(n) => { - match right_ty { - Type::IntLiteral(m) => { - match op { - ast::Operator::Add => Ok(n - .checked_add(m) - .map(Type::IntLiteral) - // TODO builtins.int - .unwrap_or(Type::Unknown)), - ast::Operator::Sub => Ok(n - .checked_sub(m) - .map(Type::IntLiteral) - // TODO builtins.int - .unwrap_or(Type::Unknown)), - ast::Operator::Mult => Ok(n - .checked_mul(m) - .map(Type::IntLiteral) - // TODO builtins.int - .unwrap_or(Type::Unknown)), - ast::Operator::Div => Ok(n - .checked_div(m) - .map(Type::IntLiteral) - // TODO builtins.int - .unwrap_or(Type::Unknown)), - ast::Operator::Mod => Ok(n - .checked_rem(m) - .map(Type::IntLiteral) - // TODO division by zero error - .unwrap_or(Type::Unknown)), - _ => todo!("complete binop op support for IntLiteral"), - } - } - _ => todo!("complete binop right_ty support for IntLiteral"), - } - } - _ => todo!("complete binop support"), - } - } -} - -impl From for Type { - fn from(id: FunctionTypeId) -> Self { - Type::Function(id) - } -} - -impl From for Type { - fn from(id: UnionTypeId) -> Self { - Type::Union(id) - } -} - -impl From for Type { - fn from(id: IntersectionTypeId) -> Self { - Type::Intersection(id) - } -} - -// TODO: currently calling `get_function` et al and holding on to the `FunctionTypeRef` will lock a -// shard of this dashmap, for as long as you hold the reference. This may be a problem. We could -// switch to having all the arenas hold Arc, or we could see if we can split up ModuleTypeStore, -// and/or give it inner mutability and finer-grained internal locking. -#[derive(Debug, Default)] -pub struct TypeStore { - modules: FxDashMap, -} - -impl TypeStore { - pub fn remove_module(&mut self, file_id: FileId) { - self.modules.remove(&file_id); - } - - pub fn cache_symbol_public_type(&self, symbol: GlobalSymbolId, ty: Type) { - self.add_or_get_module(symbol.file_id) - .symbol_types - .insert(symbol.symbol_id, ty); - } - - pub fn cache_node_type(&self, file_id: FileId, node_key: NodeKey, ty: Type) { - self.add_or_get_module(file_id) - .node_types - .insert(node_key, ty); - } - - pub fn get_cached_symbol_public_type(&self, symbol: GlobalSymbolId) -> Option { - self.try_get_module(symbol.file_id)? - .symbol_types - .get(&symbol.symbol_id) - .copied() - } - - pub fn get_cached_node_type(&self, file_id: FileId, node_key: &NodeKey) -> Option { - self.try_get_module(file_id)? - .node_types - .get(node_key) - .copied() - } - - fn add_or_get_module(&self, file_id: FileId) -> ModuleStoreRefMut { - self.modules - .entry(file_id) - .or_insert_with(|| ModuleTypeStore::new(file_id)) - } - - fn get_module(&self, file_id: FileId) -> ModuleStoreRef { - self.try_get_module(file_id).expect("module should exist") - } - - fn try_get_module(&self, file_id: FileId) -> Option { - self.modules.get(&file_id) - } - - fn add_function_type( - &self, - file_id: FileId, - name: &str, - symbol_id: SymbolId, - scope_id: ScopeId, - decorators: Vec, - ) -> FunctionTypeId { - self.add_or_get_module(file_id) - .add_function(name, symbol_id, scope_id, decorators) - } - - fn add_function( - &self, - file_id: FileId, - name: &str, - symbol_id: SymbolId, - scope_id: ScopeId, - decorators: Vec, - ) -> Type { - Type::Function(self.add_function_type(file_id, name, symbol_id, scope_id, decorators)) - } - - fn add_class_type( - &self, - file_id: FileId, - name: &str, - scope_id: ScopeId, - bases: Vec, - ) -> ClassTypeId { - self.add_or_get_module(file_id) - .add_class(name, scope_id, bases) - } - - fn add_class(&self, file_id: FileId, name: &str, scope_id: ScopeId, bases: Vec) -> Type { - Type::Class(self.add_class_type(file_id, name, scope_id, bases)) - } - - /// add "raw" union type with exactly given elements - fn add_union_type(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId { - self.add_or_get_module(file_id).add_union(elems) - } - - /// add union with normalization; may not return a `UnionType` - fn add_union(&self, file_id: FileId, elems: &[Type]) -> Type { - let mut flattened = Vec::with_capacity(elems.len()); - for ty in elems { - match ty { - Type::Union(union_id) => flattened.extend(union_id.elements(self)), - _ => flattened.push(*ty), - } - } - // TODO don't add identical unions - // TODO de-duplicate union elements - match flattened[..] { - [] => Type::Never, - [ty] => ty, - _ => Type::Union(self.add_union_type(file_id, &flattened)), - } - } - - /// add "raw" intersection type with exactly given elements - fn add_intersection_type( - &self, - file_id: FileId, - positive: &[Type], - negative: &[Type], - ) -> IntersectionTypeId { - self.add_or_get_module(file_id) - .add_intersection(positive, negative) - } - - /// add intersection with normalization; may not return an `IntersectionType` - fn add_intersection(&self, file_id: FileId, positive: &[Type], negative: &[Type]) -> Type { - let mut pos_flattened = Vec::with_capacity(positive.len()); - let mut neg_flattened = Vec::with_capacity(negative.len()); - for ty in positive { - match ty { - Type::Intersection(intersection_id) => { - pos_flattened.extend(intersection_id.positive(self)); - neg_flattened.extend(intersection_id.negative(self)); - } - _ => pos_flattened.push(*ty), - } - } - for ty in negative { - match ty { - Type::Intersection(intersection_id) => { - pos_flattened.extend(intersection_id.negative(self)); - neg_flattened.extend(intersection_id.positive(self)); - } - _ => neg_flattened.push(*ty), - } - } - // TODO don't add identical intersections - // TODO deduplicate intersection elements - // TODO maintain DNF form (union of intersections) - match (&pos_flattened[..], &neg_flattened[..]) { - ([], []) => Type::Any, // TODO should be object - ([ty], []) => *ty, - (pos, neg) => Type::Intersection(self.add_intersection_type(file_id, pos, neg)), - } - } - - fn get_function(&self, id: FunctionTypeId) -> FunctionTypeRef { - FunctionTypeRef { - module_store: self.get_module(id.file_id), - function_id: id.func_id, - } - } - - fn get_class(&self, id: ClassTypeId) -> ClassTypeRef { - ClassTypeRef { - module_store: self.get_module(id.file_id), - class_id: id.class_id, - } - } - - fn get_union(&self, id: UnionTypeId) -> UnionTypeRef { - UnionTypeRef { - module_store: self.get_module(id.file_id), - union_id: id.union_id, - } - } - - fn get_intersection(&self, id: IntersectionTypeId) -> IntersectionTypeRef { - IntersectionTypeRef { - module_store: self.get_module(id.file_id), - intersection_id: id.intersection_id, - } - } -} - -type ModuleStoreRef<'a> = dashmap::mapref::one::Ref<'a, FileId, ModuleTypeStore>; - -type ModuleStoreRefMut<'a> = dashmap::mapref::one::RefMut<'a, FileId, ModuleTypeStore>; - -#[derive(Debug)] -pub(crate) struct FunctionTypeRef<'a> { - module_store: ModuleStoreRef<'a>, - function_id: ModuleFunctionTypeId, -} - -impl<'a> std::ops::Deref for FunctionTypeRef<'a> { - type Target = FunctionType; - - fn deref(&self) -> &Self::Target { - self.module_store.get_function(self.function_id) - } -} - -#[derive(Debug)] -pub(crate) struct ClassTypeRef<'a> { - module_store: ModuleStoreRef<'a>, - class_id: ModuleClassTypeId, -} - -impl<'a> std::ops::Deref for ClassTypeRef<'a> { - type Target = ClassType; - - fn deref(&self) -> &Self::Target { - self.module_store.get_class(self.class_id) - } -} - -#[derive(Debug)] -pub(crate) struct UnionTypeRef<'a> { - module_store: ModuleStoreRef<'a>, - union_id: ModuleUnionTypeId, -} - -impl<'a> std::ops::Deref for UnionTypeRef<'a> { - type Target = UnionType; - - fn deref(&self) -> &Self::Target { - self.module_store.get_union(self.union_id) - } -} - -#[derive(Debug)] -pub(crate) struct IntersectionTypeRef<'a> { - module_store: ModuleStoreRef<'a>, - intersection_id: ModuleIntersectionTypeId, -} - -impl<'a> std::ops::Deref for IntersectionTypeRef<'a> { - type Target = IntersectionType; - - fn deref(&self) -> &Self::Target { - self.module_store.get_intersection(self.intersection_id) - } -} - -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] -pub struct FunctionTypeId { - file_id: FileId, - func_id: ModuleFunctionTypeId, -} - -impl FunctionTypeId { - fn function(self, db: &dyn SemanticDb) -> QueryResult { - let jar: &SemanticJar = db.jar()?; - Ok(jar.type_store.get_function(self)) - } - - pub(crate) fn name(self, db: &dyn SemanticDb) -> QueryResult { - Ok(self.function(db)?.name().into()) - } - - pub(crate) fn global_symbol(self, db: &dyn SemanticDb) -> QueryResult { - Ok(GlobalSymbolId { - file_id: self.file(), - symbol_id: self.symbol(db)?, - }) - } - - pub(crate) fn file(self) -> FileId { - self.file_id - } - - pub(crate) fn symbol(self, db: &dyn SemanticDb) -> QueryResult { - let FunctionType { symbol_id, .. } = *self.function(db)?; - Ok(symbol_id) - } - - pub(crate) fn get_containing_class( - self, - db: &dyn SemanticDb, - ) -> QueryResult> { - let index = semantic_index(db, self.file_id)?; - let table = index.symbol_table(); - let FunctionType { symbol_id, .. } = *self.function(db)?; - let scope_id = symbol_id.symbol(table).scope_id(); - let scope = scope_id.scope(table); - if !matches!(scope.kind(), ScopeKind::Class) { - return Ok(None); - }; - let Some(def) = scope.definition() else { - return Ok(None); - }; - let Some(symbol_id) = scope.defining_symbol() else { - return Ok(None); - }; - let Type::Class(class) = infer_definition_type( - db, - GlobalSymbolId { - file_id: self.file_id, - symbol_id, - }, - def, - )? - else { - return Ok(None); - }; - Ok(Some(class)) - } - - pub(crate) fn has_decorator( - self, - db: &dyn SemanticDb, - decorator_symbol: GlobalSymbolId, - ) -> QueryResult { - for deco_ty in self.function(db)?.decorators() { - let Type::Function(deco_func) = deco_ty else { - continue; - }; - if deco_func.global_symbol(db)? == decorator_symbol { - return Ok(true); - } - } - Ok(false) - } -} - -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] -pub struct ModuleTypeId { - module: Module, - file_id: FileId, -} - -impl ModuleTypeId { - fn module(self, db: &dyn SemanticDb) -> QueryResult { - let jar: &SemanticJar = db.jar()?; - Ok(jar.type_store.add_or_get_module(self.file_id).downgrade()) - } - - pub(crate) fn name(self, db: &dyn SemanticDb) -> QueryResult { - self.module.name(db) - } - - fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { - if let Some(symbol_id) = resolve_global_symbol(db, self.module, name)? { - Ok(Some(infer_symbol_public_type(db, symbol_id)?)) - } else { - Ok(None) - } - } -} - -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] -pub struct ClassTypeId { - file_id: FileId, - class_id: ModuleClassTypeId, -} - -impl ClassTypeId { - fn class(self, db: &dyn SemanticDb) -> QueryResult { - let jar: &SemanticJar = db.jar()?; - Ok(jar.type_store.get_class(self)) - } - - pub(crate) fn name(self, db: &dyn SemanticDb) -> QueryResult { - Ok(self.class(db)?.name().into()) - } - - pub(crate) fn get_super_class_member( - self, - db: &dyn SemanticDb, - name: &Name, - ) -> QueryResult> { - // TODO we should linearize the MRO instead of doing this recursively - let class = self.class(db)?; - for base in class.bases() { - if let Type::Class(base) = base { - if let Some(own_member) = base.get_own_class_member(db, name)? { - return Ok(Some(own_member)); - } - if let Some(base_member) = base.get_super_class_member(db, name)? { - return Ok(Some(base_member)); - } - } - } - Ok(None) - } - - fn get_own_class_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { - // TODO: this should distinguish instance-only members (e.g. `x: int`) and not return them - let ClassType { scope_id, .. } = *self.class(db)?; - let index = semantic_index(db, self.file_id)?; - if let Some(symbol_id) = index.symbol_table().symbol_id_by_name(scope_id, name) { - Ok(Some(infer_symbol_public_type( - db, - GlobalSymbolId { - file_id: self.file_id, - symbol_id, - }, - )?)) - } else { - Ok(None) - } - } - - /// Get own class member or fall back to super-class member. - fn get_class_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { - self.get_own_class_member(db, name) - .or_else(|_| self.get_super_class_member(db, name)) - } - - // TODO: get_own_instance_member, get_instance_member -} - -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] -pub struct UnionTypeId { - file_id: FileId, - union_id: ModuleUnionTypeId, -} - -impl UnionTypeId { - pub fn elements(self, type_store: &TypeStore) -> Vec { - let union = type_store.get_union(self); - union.elements.iter().copied().collect() - } -} - -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] -pub struct IntersectionTypeId { - file_id: FileId, - intersection_id: ModuleIntersectionTypeId, -} - -impl IntersectionTypeId { - pub fn positive(self, type_store: &TypeStore) -> Vec { - let intersection = type_store.get_intersection(self); - intersection.positive.iter().copied().collect() - } - - pub fn negative(self, type_store: &TypeStore) -> Vec { - let intersection = type_store.get_intersection(self); - intersection.negative.iter().copied().collect() - } -} - -#[newtype_index] -struct ModuleFunctionTypeId; - -#[newtype_index] -struct ModuleClassTypeId; - -#[newtype_index] -struct ModuleUnionTypeId; - -#[newtype_index] -struct ModuleIntersectionTypeId; - -#[derive(Debug)] -struct ModuleTypeStore { - file_id: FileId, - /// arena of all function types defined in this module - functions: IndexVec, - /// arena of all class types defined in this module - classes: IndexVec, - /// arenda of all union types created in this module - unions: IndexVec, - /// arena of all intersection types created in this module - intersections: IndexVec, - /// cached public types of symbols in this module - symbol_types: FxHashMap, - /// cached types of AST nodes in this module - node_types: FxHashMap, -} - -impl ModuleTypeStore { - fn new(file_id: FileId) -> Self { - Self { - file_id, - functions: IndexVec::default(), - classes: IndexVec::default(), - unions: IndexVec::default(), - intersections: IndexVec::default(), - symbol_types: FxHashMap::default(), - node_types: FxHashMap::default(), - } - } - - fn add_function( - &mut self, - name: &str, - symbol_id: SymbolId, - scope_id: ScopeId, - decorators: Vec, - ) -> FunctionTypeId { - let func_id = self.functions.push(FunctionType { - name: Name::new(name), - symbol_id, - scope_id, - decorators, - }); - FunctionTypeId { - file_id: self.file_id, - func_id, - } - } - - fn add_class(&mut self, name: &str, scope_id: ScopeId, bases: Vec) -> ClassTypeId { - let class_id = self.classes.push(ClassType { - name: Name::new(name), - scope_id, - // TODO: if no bases are given, that should imply [object] - bases, - }); - ClassTypeId { - file_id: self.file_id, - class_id, - } - } - - fn add_union(&mut self, elems: &[Type]) -> UnionTypeId { - let union_id = self.unions.push(UnionType { - elements: elems.iter().copied().collect(), - }); - UnionTypeId { - file_id: self.file_id, - union_id, - } - } - - fn add_intersection(&mut self, positive: &[Type], negative: &[Type]) -> IntersectionTypeId { - let intersection_id = self.intersections.push(IntersectionType { - positive: positive.iter().copied().collect(), - negative: negative.iter().copied().collect(), - }); - IntersectionTypeId { - file_id: self.file_id, - intersection_id, - } - } - - fn get_function(&self, func_id: ModuleFunctionTypeId) -> &FunctionType { - &self.functions[func_id] - } - - fn get_class(&self, class_id: ModuleClassTypeId) -> &ClassType { - &self.classes[class_id] - } - - fn get_union(&self, union_id: ModuleUnionTypeId) -> &UnionType { - &self.unions[union_id] - } - - fn get_intersection(&self, intersection_id: ModuleIntersectionTypeId) -> &IntersectionType { - &self.intersections[intersection_id] - } -} - -#[derive(Copy, Clone, Debug)] -struct DisplayType<'a> { - ty: &'a Type, - store: &'a TypeStore, -} - -impl std::fmt::Display for DisplayType<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.ty { - Type::Any => f.write_str("Any"), - Type::Never => f.write_str("Never"), - Type::Unknown => f.write_str("Unknown"), - Type::Unbound => f.write_str("Unbound"), - Type::None => f.write_str("None"), - Type::Module(module_id) => { - // NOTE: something like this?: "" - todo!("{module_id:?}") - } - // TODO functions and classes should display using a fully qualified name - Type::Class(class_id) => { - f.write_str("Literal[")?; - f.write_str(self.store.get_class(*class_id).name())?; - f.write_str("]") - } - Type::Instance(class_id) => f.write_str(self.store.get_class(*class_id).name()), - Type::Function(func_id) => f.write_str(self.store.get_function(*func_id).name()), - Type::Union(union_id) => self - .store - .get_module(union_id.file_id) - .get_union(union_id.union_id) - .display(f, self.store), - Type::Intersection(int_id) => self - .store - .get_module(int_id.file_id) - .get_intersection(int_id.intersection_id) - .display(f, self.store), - Type::IntLiteral(n) => write!(f, "Literal[{n}]"), - } - } -} - -#[derive(Debug)] -pub(crate) struct ClassType { - /// Name of the class at definition - name: Name, - /// `ScopeId` of the class body - scope_id: ScopeId, - /// Types of all class bases - bases: Vec, -} - -impl ClassType { - fn name(&self) -> &str { - self.name.as_str() - } - - fn bases(&self) -> &[Type] { - self.bases.as_slice() - } -} - -#[derive(Debug)] -pub(crate) struct FunctionType { - /// name of the function at definition - name: Name, - /// symbol which this function is a definition of - symbol_id: SymbolId, - /// scope of this function's body - scope_id: ScopeId, - /// types of all decorators on this function - decorators: Vec, -} - -impl FunctionType { - fn name(&self) -> &str { - self.name.as_str() - } - - fn scope_id(&self) -> ScopeId { - self.scope_id - } - - pub(crate) fn decorators(&self) -> &[Type] { - self.decorators.as_slice() - } -} - -#[derive(Debug)] -pub(crate) struct UnionType { - // the union type includes values in any of these types - elements: FxIndexSet, -} - -impl UnionType { - fn display(&self, f: &mut std::fmt::Formatter<'_>, store: &TypeStore) -> std::fmt::Result { - let (int_literals, other_types): (Vec, Vec) = self - .elements - .iter() - .copied() - .partition(|ty| matches!(ty, Type::IntLiteral(_))); - let mut first = true; - if !int_literals.is_empty() { - f.write_str("Literal[")?; - let mut nums: Vec = int_literals - .into_iter() - .filter_map(|ty| { - if let Type::IntLiteral(n) = ty { - Some(n) - } else { - None - } - }) - .collect(); - nums.sort_unstable(); - for num in nums { - if !first { - f.write_str(", ")?; - } - write!(f, "{num}")?; - first = false; - } - f.write_str("]")?; - } - for ty in other_types { - if !first { - f.write_str(" | ")?; - }; - first = false; - write!(f, "{}", ty.display(store))?; - } - Ok(()) - } -} - -// Negation types aren't expressible in annotations, and are most likely to arise from type -// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them -// directly in intersections rather than as a separate type. This sacrifices some efficiency in the -// case where a Not appears outside an intersection (unclear when that could even happen, but we'd -// have to represent it as a single-element intersection if it did) in exchange for better -// efficiency in the within-intersection case. -#[derive(Debug)] -pub(crate) struct IntersectionType { - // the intersection type includes only values in all of these types - positive: FxIndexSet, - // the intersection type does not include any value in any of these types - negative: FxIndexSet, -} - -impl IntersectionType { - fn display(&self, f: &mut std::fmt::Formatter<'_>, store: &TypeStore) -> std::fmt::Result { - let mut first = true; - for (neg, ty) in self - .positive - .iter() - .map(|ty| (false, ty)) - .chain(self.negative.iter().map(|ty| (true, ty))) - { - if !first { - f.write_str(" & ")?; - }; - first = false; - if neg { - f.write_str("~")?; - }; - write!(f, "{}", ty.display(store))?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::Type; - use std::path::Path; - - use crate::files::Files; - use crate::semantic::symbol_table::SymbolTableBuilder; - use crate::semantic::{FileId, ScopeId, SymbolFlags, SymbolTable, TypeStore}; - use crate::FxIndexSet; - - struct TestCase { - store: TypeStore, - files: Files, - file_id: FileId, - root_scope: ScopeId, - } - - fn create_test() -> TestCase { - let files = Files::default(); - let file_id = files.intern(Path::new("/foo")); - TestCase { - store: TypeStore::default(), - files, - file_id, - root_scope: SymbolTable::root_scope_id(), - } - } - - fn assert_union_elements(store: &TypeStore, union: Type, elements: &[Type]) { - let Type::Union(union_id) = union else { - panic!("should be a union") - }; - - assert_eq!( - store.get_union(union_id).elements, - elements.iter().copied().collect::>() - ); - } - - fn assert_intersection_elements( - store: &TypeStore, - intersection: Type, - positive: &[Type], - negative: &[Type], - ) { - let Type::Intersection(intersection_id) = intersection else { - panic!("should be a intersection") - }; - - assert_eq!( - store.get_intersection(intersection_id).positive, - positive.iter().copied().collect::>() - ); - assert_eq!( - store.get_intersection(intersection_id).negative, - negative.iter().copied().collect::>() - ); - } - - #[test] - fn add_class() { - let TestCase { - store, - file_id, - root_scope, - .. - } = create_test(); - - let id = store.add_class_type(file_id, "C", root_scope, Vec::new()); - assert_eq!(store.get_class(id).name(), "C"); - let inst = Type::Instance(id); - assert_eq!(format!("{}", inst.display(&store)), "C"); - } - - #[test] - fn add_function() { - let TestCase { - store, - file_id, - root_scope, - .. - } = create_test(); - - let mut builder = SymbolTableBuilder::new(); - let func_symbol = builder.add_or_update_symbol( - SymbolTable::root_scope_id(), - "func", - SymbolFlags::IS_DEFINED, - ); - builder.finish(); - - let id = store.add_function_type( - file_id, - "func", - func_symbol, - root_scope, - vec![Type::Unknown], - ); - assert_eq!(store.get_function(id).name(), "func"); - assert_eq!(store.get_function(id).decorators(), vec![Type::Unknown]); - let func = Type::Function(id); - assert_eq!(format!("{}", func.display(&store)), "func"); - } - - #[test] - fn add_union() { - let TestCase { - store, - file_id, - root_scope, - .. - } = create_test(); - - let c1 = store.add_class_type(file_id, "C1", root_scope, Vec::new()); - let c2 = store.add_class_type(file_id, "C2", root_scope, Vec::new()); - let elems = vec![Type::Instance(c1), Type::Instance(c2)]; - let id = store.add_union_type(file_id, &elems); - let union = Type::Union(id); - - assert_union_elements(&store, union, &elems); - assert_eq!(format!("{}", union.display(&store)), "C1 | C2"); - } - - #[test] - fn add_intersection() { - let TestCase { - store, - file_id, - root_scope, - .. - } = create_test(); - - let c1 = store.add_class_type(file_id, "C1", root_scope, Vec::new()); - let c2 = store.add_class_type(file_id, "C2", root_scope, Vec::new()); - let c3 = store.add_class_type(file_id, "C3", root_scope, Vec::new()); - let pos = vec![Type::Instance(c1), Type::Instance(c2)]; - let neg = vec![Type::Instance(c3)]; - let id = store.add_intersection_type(file_id, &pos, &neg); - let intersection = Type::Intersection(id); - - assert_intersection_elements(&store, intersection, &pos, &neg); - assert_eq!(format!("{}", intersection.display(&store)), "C1 & C2 & ~C3"); - } - - #[test] - fn flatten_union_zero_elements() { - let TestCase { store, file_id, .. } = create_test(); - - let ty = store.add_union(file_id, &[]); - - assert!(matches!(ty, Type::Never), "{ty:?} should be Never"); - } - - #[test] - fn flatten_union_one_element() { - let TestCase { store, file_id, .. } = create_test(); - - let ty = store.add_union(file_id, &[Type::None]); - - assert!(matches!(ty, Type::None), "{ty:?} should be None"); - } - - #[test] - fn flatten_nested_union() { - let TestCase { store, file_id, .. } = create_test(); - - let l1 = Type::IntLiteral(1); - let l2 = Type::IntLiteral(2); - let u1 = store.add_union(file_id, &[l1, l2]); - let u2 = store.add_union(file_id, &[u1, Type::None]); - - assert_union_elements(&store, u2, &[l1, l2, Type::None]); - } - - #[test] - fn flatten_intersection_zero_elements() { - let TestCase { store, file_id, .. } = create_test(); - - let ty = store.add_intersection(file_id, &[], &[]); - - // TODO should be object, not Any - assert!(matches!(ty, Type::Any), "{ty:?} should be object"); - } - - #[test] - fn flatten_intersection_one_positive_element() { - let TestCase { store, file_id, .. } = create_test(); - - let ty = store.add_intersection(file_id, &[Type::None], &[]); - - assert!(matches!(ty, Type::None), "{ty:?} should be None"); - } - - #[test] - fn flatten_intersection_one_negative_element() { - let TestCase { store, file_id, .. } = create_test(); - - let ty = store.add_intersection(file_id, &[], &[Type::None]); - - assert_intersection_elements(&store, ty, &[], &[Type::None]); - } - - #[test] - fn flatten_nested_intersection() { - let TestCase { - store, - file_id, - root_scope, - .. - } = create_test(); - - let c1 = Type::Instance(store.add_class_type(file_id, "C1", root_scope, vec![])); - let c2 = Type::Instance(store.add_class_type(file_id, "C2", root_scope, vec![])); - let c1sub = Type::Instance(store.add_class_type(file_id, "C1sub", root_scope, vec![c1])); - let i1 = store.add_intersection(file_id, &[c1, c2], &[c1sub]); - let i2 = store.add_intersection(file_id, &[i1, Type::None], &[]); - - assert_intersection_elements(&store, i2, &[c1, c2, Type::None], &[c1sub]); - } -} diff --git a/crates/red_knot/src/semantic/types/infer.rs b/crates/red_knot/src/semantic/types/infer.rs deleted file mode 100644 index af68e00a6e..0000000000 --- a/crates/red_knot/src/semantic/types/infer.rs +++ /dev/null @@ -1,764 +0,0 @@ -#![allow(dead_code)] - -use red_knot_module_resolver::ModuleName; -use ruff_python_ast as ast; -use ruff_python_ast::AstNode; -use std::fmt::Debug; - -use crate::db::{QueryResult, SemanticDb, SemanticJar}; - -use crate::module::resolve_module; -use crate::parse::parse; -use crate::semantic::types::{ModuleTypeId, Type}; -use crate::semantic::{ - resolve_global_symbol, semantic_index, ConstrainedDefinition, Definition, GlobalSymbolId, - ImportDefinition, ImportFromDefinition, -}; -use crate::FileId; - -// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. -/// Resolve the public-facing type for a symbol (the type seen by other scopes: other modules, or -/// nested functions). Because calls to nested functions and imports can occur anywhere in control -/// flow, this type must be conservative and consider all definitions of the symbol that could -/// possibly be seen by another scope. Currently we take the most conservative approach, which is -/// the union of all definitions. We may be able to narrow this in future to eliminate definitions -/// which can't possibly (or at least likely) be seen by any other scope, so that e.g. we could -/// infer `Literal["1"]` instead of `Literal[1] | Literal["1"]` for `x` in `x = x; x = str(x);`. -#[tracing::instrument(level = "trace", skip(db))] -pub fn infer_symbol_public_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult { - let index = semantic_index(db, symbol.file_id)?; - let defs = index.symbol_table().definitions(symbol.symbol_id).to_vec(); - let jar: &SemanticJar = db.jar()?; - - if let Some(ty) = jar.type_store.get_cached_symbol_public_type(symbol) { - return Ok(ty); - } - - let ty = infer_type_from_definitions(db, symbol, defs.iter().cloned())?; - - jar.type_store.cache_symbol_public_type(symbol, ty); - - // TODO record dependencies - Ok(ty) -} - -/// Infer type of a symbol as union of the given `Definitions`. -fn infer_type_from_definitions( - db: &dyn SemanticDb, - symbol: GlobalSymbolId, - definitions: T, -) -> QueryResult -where - T: Debug + IntoIterator, -{ - infer_type_from_constrained_definitions( - db, - symbol, - definitions - .into_iter() - .map(|definition| ConstrainedDefinition { - definition, - constraints: vec![], - }), - ) -} - -/// Infer type of a symbol as union of the given `ConstrainedDefinitions`. -fn infer_type_from_constrained_definitions( - db: &dyn SemanticDb, - symbol: GlobalSymbolId, - constrained_definitions: T, -) -> QueryResult -where - T: IntoIterator, -{ - let jar: &SemanticJar = db.jar()?; - let mut tys = constrained_definitions - .into_iter() - .map(|def| infer_constrained_definition_type(db, symbol, def.clone())) - .peekable(); - if let Some(first) = tys.next() { - if tys.peek().is_some() { - Ok(jar.type_store.add_union( - symbol.file_id, - &Iterator::chain(std::iter::once(first), tys).collect::>>()?, - )) - } else { - first - } - } else { - Ok(Type::Unknown) - } -} - -/// Infer type for a ConstrainedDefinition (intersection of the definition type and the -/// constraints) -#[tracing::instrument(level = "trace", skip(db))] -pub fn infer_constrained_definition_type( - db: &dyn SemanticDb, - symbol: GlobalSymbolId, - constrained_definition: ConstrainedDefinition, -) -> QueryResult { - let ConstrainedDefinition { - definition, - constraints, - } = constrained_definition; - let index = semantic_index(db, symbol.file_id)?; - let parsed = parse(db.upcast(), symbol.file_id)?; - let mut intersected_types = vec![infer_definition_type(db, symbol, definition)?]; - for constraint in constraints { - if let Some(constraint_type) = infer_constraint_type( - db, - symbol, - index.resolve_expression_id(parsed.syntax(), constraint), - )? { - intersected_types.push(constraint_type); - } - } - let jar: &SemanticJar = db.jar()?; - Ok(jar - .type_store - .add_intersection(symbol.file_id, &intersected_types, &[])) -} - -/// Infer a type for a Definition -#[tracing::instrument(level = "trace", skip(db))] -pub fn infer_definition_type( - db: &dyn SemanticDb, - symbol: GlobalSymbolId, - definition: Definition, -) -> QueryResult { - let jar: &SemanticJar = db.jar()?; - let type_store = &jar.type_store; - let file_id = symbol.file_id; - - match definition { - Definition::Unbound => Ok(Type::Unbound), - Definition::Import(ImportDefinition { - module: module_name, - }) => { - if let Some(module) = resolve_module(db, &module_name)? { - Ok(Type::Module(ModuleTypeId { module, file_id })) - } else { - Ok(Type::Unknown) - } - } - Definition::ImportFrom(ImportFromDefinition { - module, - name, - level, - }) => { - // TODO relative imports - assert!(matches!(level, 0)); - let module_name = - ModuleName::new(module.as_ref().expect("TODO relative imports")).unwrap(); - let Some(module) = resolve_module(db, &module_name)? else { - return Ok(Type::Unknown); - }; - - if let Some(remote_symbol) = resolve_global_symbol(db, module, &name)? { - infer_symbol_public_type(db, remote_symbol) - } else { - Ok(Type::Unknown) - } - } - Definition::ClassDef(node_key) => { - if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { - Ok(ty) - } else { - let parsed = parse(db.upcast(), file_id)?; - let ast = parsed.syntax(); - let index = semantic_index(db, file_id)?; - let node = node_key.resolve_unwrap(ast.as_any_node_ref()); - - let mut bases = Vec::with_capacity(node.bases().len()); - - for base in node.bases() { - bases.push(infer_expr_type(db, file_id, base)?); - } - let scope_id = index.symbol_table().scope_id_for_node(node_key.erased()); - let ty = type_store.add_class(file_id, &node.name.id, scope_id, bases); - type_store.cache_node_type(file_id, *node_key.erased(), ty); - Ok(ty) - } - } - Definition::FunctionDef(node_key) => { - if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) { - Ok(ty) - } else { - let parsed = parse(db.upcast(), file_id)?; - let ast = parsed.syntax(); - let index = semantic_index(db, file_id)?; - let node = node_key - .resolve(ast.as_any_node_ref()) - .expect("node key should resolve"); - - let decorator_tys = node - .decorator_list - .iter() - .map(|decorator| infer_expr_type(db, file_id, &decorator.expression)) - .collect::>()?; - let scope_id = index.symbol_table().scope_id_for_node(node_key.erased()); - let ty = type_store.add_function( - file_id, - &node.name.id, - symbol.symbol_id, - scope_id, - decorator_tys, - ); - type_store.cache_node_type(file_id, *node_key.erased(), ty); - Ok(ty) - } - } - Definition::Assignment(node_key) => { - let parsed = parse(db.upcast(), file_id)?; - let ast = parsed.syntax(); - let node = node_key.resolve_unwrap(ast.as_any_node_ref()); - // TODO handle unpacking assignment - infer_expr_type(db, file_id, &node.value) - } - Definition::AnnotatedAssignment(node_key) => { - let parsed = parse(db.upcast(), file_id)?; - let ast = parsed.syntax(); - let node = node_key.resolve_unwrap(ast.as_any_node_ref()); - // TODO actually look at the annotation - let Some(value) = &node.value else { - return Ok(Type::Unknown); - }; - // TODO handle unpacking assignment - infer_expr_type(db, file_id, value) - } - Definition::NamedExpr(node_key) => { - let parsed = parse(db.upcast(), file_id)?; - let ast = parsed.syntax(); - let node = node_key.resolve_unwrap(ast.as_any_node_ref()); - infer_expr_type(db, file_id, &node.value) - } - } -} - -/// Return the type that the given constraint (an expression from a control-flow test) requires the -/// given symbol to have. For example, returns the Type "~None" as the constraint type if given the -/// symbol ID for x and the expression ID for `x is not None`. Returns (Rust) None if the given -/// expression applies no constraints on the given symbol. -#[tracing::instrument(level = "trace", skip(db))] -fn infer_constraint_type( - db: &dyn SemanticDb, - symbol_id: GlobalSymbolId, - // TODO this should preferably take an &ast::Expr instead of AnyNodeRef - expression: ast::AnyNodeRef, -) -> QueryResult> { - let file_id = symbol_id.file_id; - let index = semantic_index(db, file_id)?; - let jar: &SemanticJar = db.jar()?; - let symbol_name = symbol_id.symbol_id.symbol(&index.symbol_table).name(); - // TODO narrowing attributes - // TODO narrowing dict keys - // TODO isinstance, ==/!=, type(...), literals, bools... - match expression { - ast::AnyNodeRef::ExprCompare(ast::ExprCompare { - left, - ops, - comparators, - .. - }) => { - // TODO chained comparisons - match left.as_ref() { - ast::Expr::Name(ast::ExprName { id, .. }) if id == symbol_name => match ops[0] { - ast::CmpOp::Is | ast::CmpOp::IsNot => { - Ok(match infer_expr_type(db, file_id, &comparators[0])? { - Type::None => Some(Type::None), - _ => None, - } - .map(|ty| { - if matches!(ops[0], ast::CmpOp::IsNot) { - jar.type_store.add_intersection(file_id, &[], &[ty]) - } else { - ty - } - })) - } - _ => Ok(None), - }, - _ => Ok(None), - } - } - _ => Ok(None), - } -} - -/// Infer type of the given expression. -fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> QueryResult { - // TODO cache the resolution of the type on the node - let index = semantic_index(db, file_id)?; - match expr { - ast::Expr::NoneLiteral(_) => Ok(Type::None), - ast::Expr::NumberLiteral(ast::ExprNumberLiteral { value, .. }) => { - match value { - ast::Number::Int(n) => { - // TODO support big int literals - Ok(n.as_i64().map(Type::IntLiteral).unwrap_or(Type::Unknown)) - } - // TODO builtins.float or builtins.complex - _ => Ok(Type::Unknown), - } - } - ast::Expr::Name(name) => { - // TODO look up in the correct scope, don't assume global - if let Some(symbol_id) = index.symbol_table().root_symbol_id_by_name(&name.id) { - infer_type_from_constrained_definitions( - db, - GlobalSymbolId { file_id, symbol_id }, - index.reachable_definitions(symbol_id, expr), - ) - } else { - Ok(Type::Unknown) - } - } - ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { - let value_type = infer_expr_type(db, file_id, value)?; - let attr_name = &attr.id; - value_type - .get_member(db, attr_name) - .map(|ty| ty.unwrap_or(Type::Unknown)) - } - ast::Expr::BinOp(ast::ExprBinOp { - left, op, right, .. - }) => { - let left_ty = infer_expr_type(db, file_id, left)?; - let right_ty = infer_expr_type(db, file_id, right)?; - // TODO add reverse bin op support if right <: left - left_ty.resolve_bin_op(db, *op, right_ty) - } - ast::Expr::Named(ast::ExprNamed { value, .. }) => infer_expr_type(db, file_id, value), - ast::Expr::If(ast::ExprIf { body, orelse, .. }) => { - // TODO detect statically known truthy or falsy test - let body_ty = infer_expr_type(db, file_id, body)?; - let else_ty = infer_expr_type(db, file_id, orelse)?; - let jar: &SemanticJar = db.jar()?; - Ok(jar.type_store.add_union(file_id, &[body_ty, else_ty])) - } - _ => todo!("expression type resolution for {:?}", expr), - } -} - -#[cfg(test)] -mod tests { - - use red_knot_module_resolver::ModuleName; - use ruff_python_ast::name::Name; - use std::path::PathBuf; - - use crate::db::tests::TestDb; - use crate::db::{HasJar, SemanticJar}; - use crate::module::{resolve_module, set_module_search_paths, ModuleResolutionInputs}; - use crate::semantic::{infer_symbol_public_type, resolve_global_symbol, Type}; - - // TODO with virtual filesystem we shouldn't have to write files to disk for these - // tests - - struct TestCase { - temp_dir: tempfile::TempDir, - db: TestDb, - - src: PathBuf, - } - - fn create_test() -> std::io::Result { - let temp_dir = tempfile::tempdir()?; - - let src = temp_dir.path().join("src"); - std::fs::create_dir(&src)?; - let src = src.canonicalize()?; - - let search_paths = ModuleResolutionInputs { - extra_paths: vec![], - workspace_root: src.clone(), - site_packages: None, - custom_typeshed: None, - }; - - let mut db = TestDb::default(); - set_module_search_paths(&mut db, search_paths); - - Ok(TestCase { temp_dir, db, src }) - } - - fn write_to_path(case: &TestCase, relative_path: &str, contents: &str) -> anyhow::Result<()> { - let path = case.src.join(relative_path); - std::fs::write(path, contents)?; - Ok(()) - } - - fn get_public_type( - case: &TestCase, - module_name: &str, - variable_name: &str, - ) -> anyhow::Result { - let db = &case.db; - let module = - resolve_module(db, &ModuleName::new(module_name).unwrap())?.expect("Module to exist"); - let symbol = resolve_global_symbol(db, module, variable_name)?.expect("symbol to exist"); - - Ok(infer_symbol_public_type(db, symbol)?) - } - - fn assert_public_type( - case: &TestCase, - module_name: &str, - variable_name: &str, - type_name: &str, - ) -> anyhow::Result<()> { - let ty = get_public_type(case, module_name, variable_name)?; - - let jar = HasJar::::jar(&case.db)?; - assert_eq!(format!("{}", ty.display(&jar.type_store)), type_name); - Ok(()) - } - - #[test] - fn follow_import_to_class() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path(&case, "a.py", "from b import C as D; E = D")?; - write_to_path(&case, "b.py", "class C: pass")?; - - assert_public_type(&case, "a", "E", "Literal[C]") - } - - #[test] - fn resolve_base_class_by_name() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "mod.py", - " - class Base: pass - class Sub(Base): pass - ", - )?; - - let ty = get_public_type(&case, "mod", "Sub")?; - - let Type::Class(class_id) = ty else { - panic!("Sub is not a Class") - }; - let jar = HasJar::::jar(&case.db)?; - let base_names: Vec<_> = jar - .type_store - .get_class(class_id) - .bases() - .iter() - .map(|base_ty| format!("{}", base_ty.display(&jar.type_store))) - .collect(); - - assert_eq!(base_names, vec!["Literal[Base]"]); - - Ok(()) - } - - #[test] - fn resolve_method() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "mod.py", - " - class C: - def f(self): pass - ", - )?; - - let ty = get_public_type(&case, "mod", "C")?; - - let Type::Class(class_id) = ty else { - panic!("C is not a Class"); - }; - - let member_ty = class_id - .get_own_class_member(&case.db, &Name::new_static("f")) - .expect("C.f to resolve"); - - let Some(Type::Function(func_id)) = member_ty else { - panic!("C.f is not a Function"); - }; - - let jar = HasJar::::jar(&case.db)?; - let function = jar.type_store.get_function(func_id); - assert_eq!(function.name(), "f"); - - Ok(()) - } - - #[test] - fn resolve_module_member() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path(&case, "a.py", "import b; D = b.C")?; - write_to_path(&case, "b.py", "class C: pass")?; - - assert_public_type(&case, "a", "D", "Literal[C]") - } - - #[test] - fn resolve_literal() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path(&case, "a.py", "x = 1")?; - - assert_public_type(&case, "a", "x", "Literal[1]") - } - - #[test] - fn resolve_union() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - if flag: - x = 1 - else: - x = 2 - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[1, 2]") - } - - #[test] - fn resolve_visible_def() -> anyhow::Result<()> { - let case = create_test()?; - write_to_path(&case, "a.py", "y = 1; y = 2; x = y")?; - assert_public_type(&case, "a", "x", "Literal[2]") - } - - #[test] - fn join_paths() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - y = 1 - y = 2 - if flag: - y = 3 - x = y - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[2, 3]") - } - - #[test] - fn maybe_unbound() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - if flag: - y = 1 - x = y - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[1] | Unbound") - } - - #[test] - fn if_elif_else() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - y = 1 - y = 2 - if flag: - y = 3 - elif flag2: - y = 4 - else: - r = y - y = 5 - s = y - x = y - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[3, 4, 5]")?; - assert_public_type(&case, "a", "r", "Literal[2]")?; - assert_public_type(&case, "a", "s", "Literal[5]") - } - - #[test] - fn if_elif() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - y = 1 - y = 2 - if flag: - y = 3 - elif flag2: - y = 4 - x = y - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[2, 3, 4]") - } - - #[test] - fn literal_int_arithmetic() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - a = 2 + 1 - b = a - 4 - c = a * b - d = c / 3 - e = 5 % 3 - ", - )?; - - assert_public_type(&case, "a", "a", "Literal[3]")?; - assert_public_type(&case, "a", "b", "Literal[-1]")?; - assert_public_type(&case, "a", "c", "Literal[-3]")?; - assert_public_type(&case, "a", "d", "Literal[-1]")?; - assert_public_type(&case, "a", "e", "Literal[2]") - } - - #[test] - fn walrus() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - x = (y := 1) + 1 - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[2]")?; - assert_public_type(&case, "a", "y", "Literal[1]") - } - - #[test] - fn ifexpr() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - x = 1 if flag else 2 - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[1, 2]") - } - - #[test] - fn ifexpr_walrus() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - y = z = 0 - x = (y := 1) if flag else (z := 2) - a = y - b = z - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[1, 2]")?; - assert_public_type(&case, "a", "a", "Literal[0, 1]")?; - assert_public_type(&case, "a", "b", "Literal[0, 2]") - } - - #[test] - fn ifexpr_walrus_2() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - y = 0 - (y := 1) if flag else (y := 2) - a = y - ", - )?; - - assert_public_type(&case, "a", "a", "Literal[1, 2]") - } - - #[test] - fn ifexpr_nested() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - x = 1 if flag else 2 if flag2 else 3 - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[1, 2, 3]") - } - - #[test] - fn none() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - x = 1 if flag else None - ", - )?; - - assert_public_type(&case, "a", "x", "Literal[1] | None") - } - - #[test] - fn narrow_none() -> anyhow::Result<()> { - let case = create_test()?; - - write_to_path( - &case, - "a.py", - " - x = 1 if flag else None - y = 0 - if x is not None: - y = x - z = y - ", - )?; - - // TODO normalization of unions and intersections: this type is technically correct but - // begging for normalization - assert_public_type(&case, "a", "z", "Literal[0] | Literal[1] | None & ~None") - } -} diff --git a/crates/red_knot/src/source.rs b/crates/red_knot/src/source.rs deleted file mode 100644 index f82e5de6e0..0000000000 --- a/crates/red_knot/src/source.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::ops::{Deref, DerefMut}; -use std::sync::Arc; - -use ruff_notebook::Notebook; -use ruff_python_ast::PySourceType; - -use crate::cache::KeyValueCache; -use crate::db::{QueryResult, SourceDb}; -use crate::files::FileId; - -#[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn source_text(db: &dyn SourceDb, file_id: FileId) -> QueryResult { - let jar = db.jar()?; - let sources = &jar.sources; - - sources.get(&file_id, |file_id| { - let path = db.file_path(*file_id); - - let source_text = std::fs::read_to_string(&path).unwrap_or_else(|err| { - tracing::error!("Failed to read file '{path:?}: {err}'. Falling back to empty text"); - String::new() - }); - - let python_ty = PySourceType::from(&path); - - let kind = match python_ty { - PySourceType::Python => { - SourceKind::Python(Arc::from(source_text)) - } - PySourceType::Stub => SourceKind::Stub(Arc::from(source_text)), - PySourceType::Ipynb => { - let notebook = Notebook::from_source_code(&source_text).unwrap_or_else(|err| { - // TODO should this be changed to never fail? - // or should we instead add a diagnostic somewhere? But what would we return in this case? - tracing::error!( - "Failed to parse notebook '{path:?}: {err}'. Falling back to an empty notebook" - ); - Notebook::from_source_code("").unwrap() - }); - - SourceKind::IpyNotebook(Arc::new(notebook)) - } - }; - - Ok(Source { kind }) - }) -} - -#[derive(Debug, Clone, PartialEq)] -pub enum SourceKind { - Python(Arc), - Stub(Arc), - IpyNotebook(Arc), -} - -impl<'a> From<&'a SourceKind> for PySourceType { - fn from(value: &'a SourceKind) -> Self { - match value { - SourceKind::Python(_) => PySourceType::Python, - SourceKind::Stub(_) => PySourceType::Stub, - SourceKind::IpyNotebook(_) => PySourceType::Ipynb, - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Source { - kind: SourceKind, -} - -impl Source { - pub fn python>>(source: T) -> Self { - Self { - kind: SourceKind::Python(source.into()), - } - } - pub fn kind(&self) -> &SourceKind { - &self.kind - } - - pub fn text(&self) -> &str { - match &self.kind { - SourceKind::Python(text) => text, - SourceKind::Stub(text) => text, - SourceKind::IpyNotebook(notebook) => notebook.source_code(), - } - } -} - -#[derive(Debug, Default)] -pub struct SourceStorage(pub(crate) KeyValueCache); - -impl Deref for SourceStorage { - type Target = KeyValueCache; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for SourceStorage { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} diff --git a/crates/red_knot/src/watch.rs b/crates/red_knot/src/watch.rs index 5c8ee3fb27..bfc32f7f7f 100644 --- a/crates/red_knot/src/watch.rs +++ b/crates/red_knot/src/watch.rs @@ -1,10 +1,10 @@ use std::path::Path; +use crate::program::{FileChangeKind, FileWatcherChange}; use anyhow::Context; use notify::event::{CreateKind, RemoveKind}; use notify::{recommended_watcher, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; - -use crate::program::{FileChangeKind, FileWatcherChange}; +use ruff_db::file_system::FileSystemPath; pub struct FileWatcher { watcher: RecommendedWatcher, @@ -50,7 +50,12 @@ impl FileWatcher { for path in event.paths { if path.is_file() { - changes.push(FileWatcherChange::new(path, change_kind)); + if let Some(fs_path) = FileSystemPath::from_std_path(&path) { + changes.push(FileWatcherChange::new( + fs_path.to_path_buf(), + change_kind, + )); + } } } diff --git a/crates/red_knot_module_resolver/src/resolver.rs b/crates/red_knot_module_resolver/src/resolver.rs index 33f7281cf1..d01f4148c7 100644 --- a/crates/red_knot_module_resolver/src/resolver.rs +++ b/crates/red_knot_module_resolver/src/resolver.rs @@ -55,8 +55,8 @@ pub(crate) fn resolve_module_query<'db>( /// Resolves the module for the given path. /// /// Returns `None` if the path is not a module locatable via `sys.path`. -#[tracing::instrument(level = "debug", skip(db))] -pub fn path_to_module(db: &dyn Db, path: &VfsPath) -> Option { +#[allow(unused)] +pub(crate) fn path_to_module(db: &dyn Db, path: &VfsPath) -> Option { // It's not entirely clear on first sight why this method calls `file_to_module` instead of // it being the other way round, considering that the first thing that `file_to_module` does // is to retrieve the file's path. @@ -73,7 +73,6 @@ pub fn path_to_module(db: &dyn Db, path: &VfsPath) -> Option { /// /// Returns `None` if the file is not a module locatable via `sys.path`. #[salsa::tracked] -#[allow(unused)] pub(crate) fn file_to_module(db: &dyn Db, file: VfsFile) -> Option { let _span = tracing::trace_span!("file_to_module", ?file).entered(); @@ -367,7 +366,6 @@ impl PackageKind { #[cfg(test)] mod tests { - use ruff_db::file_system::{FileSystemPath, FileSystemPathBuf}; use ruff_db::vfs::{system_path_to_file, VfsFile, VfsPath}; diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 5e055bd9f7..cb55587646 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -59,7 +59,7 @@ pub(crate) fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId<'_> { /// Returns the symbol with the given name in `file`'s public scope or `None` if /// no symbol with the given name exists. -pub fn public_symbol<'db>( +pub(crate) fn public_symbol<'db>( db: &'db dyn Db, file: VfsFile, name: &str, @@ -72,7 +72,7 @@ pub fn public_symbol<'db>( /// The symbol tables for an entire file. #[derive(Debug)] -pub struct SemanticIndex<'db> { +pub(crate) struct SemanticIndex<'db> { /// List of all symbol tables in this file, indexed by scope. symbol_tables: IndexVec>>, diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index e0116a6a7b..c6016f5933 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -47,8 +47,9 @@ pub(crate) fn public_symbol_ty<'db>(db: &'db dyn Db, symbol: PublicSymbolId<'db> inference.symbol_ty(symbol.scoped_symbol_id(db)) } -/// Shorthand for `public_symbol_ty` that takes a symbol name instead of a [`PublicSymbolId`]. -pub fn public_symbol_ty_by_name<'db>( +/// Shorthand for [`public_symbol_ty()`] that takes a symbol name instead of a [`PublicSymbolId`]. +#[allow(unused)] +pub(crate) fn public_symbol_ty_by_name<'db>( db: &'db dyn Db, file: VfsFile, name: &str, diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index f66c1b7114..fb7a39c4bd 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1,4 +1,5 @@ use rustc_hash::FxHashMap; +use std::borrow::Cow; use std::sync::Arc; use red_knot_module_resolver::resolve_module; @@ -487,35 +488,38 @@ impl<'db> TypeInferenceBuilder<'db> { match ctx { ExprContext::Load => { - if let Some(symbol_id) = self - .index - .symbol_table(self.file_scope_id) - .symbol_id_by_name(id) - { - self.local_definition_ty(symbol_id) - } else { - let ancestors = self.index.ancestor_scopes(self.file_scope_id).skip(1); + let ancestors = self.index.ancestor_scopes(self.file_scope_id); - for (ancestor_id, _) in ancestors { - // TODO: Skip over class scopes unless the they are a immediately-nested type param scope. - // TODO: Support built-ins + for (ancestor_id, _) in ancestors { + // TODO: Skip over class scopes unless the they are a immediately-nested type param scope. + // TODO: Support built-ins + let (symbol_table, ancestor_scope) = if ancestor_id == self.file_scope_id { + (Cow::Borrowed(&self.symbol_table), None) + } else { let ancestor_scope = ancestor_id.to_scope_id(self.db, self.file_id); - let symbol_table = symbol_table(self.db, ancestor_scope); + ( + Cow::Owned(symbol_table(self.db, ancestor_scope)), + Some(ancestor_scope), + ) + }; - if let Some(symbol_id) = symbol_table.symbol_id_by_name(id) { - let symbol = symbol_table.symbol(symbol_id); + if let Some(symbol_id) = symbol_table.symbol_id_by_name(id) { + let symbol = symbol_table.symbol(symbol_id); - if !symbol.is_defined() { - continue; - } - - let types = infer_types(self.db, ancestor_scope); - return types.symbol_ty(symbol_id); + if !symbol.is_defined() { + continue; } + + return if let Some(ancestor_scope) = ancestor_scope { + let types = infer_types(self.db, ancestor_scope); + types.symbol_ty(symbol_id) + } else { + self.local_definition_ty(symbol_id) + }; } - Type::Unknown } + Type::Unknown } ExprContext::Del => Type::None, ExprContext::Invalid => Type::Unknown, diff --git a/crates/ruff_db/src/file_system/os.rs b/crates/ruff_db/src/file_system/os.rs index cdf7ceb25a..057334c5b7 100644 --- a/crates/ruff_db/src/file_system/os.rs +++ b/crates/ruff_db/src/file_system/os.rs @@ -2,6 +2,7 @@ use filetime::FileTime; use crate::file_system::{FileSystem, FileSystemPath, FileType, Metadata, Result}; +#[derive(Default)] pub struct OsFileSystem; impl OsFileSystem { diff --git a/crates/ruff_db/src/source.rs b/crates/ruff_db/src/source.rs index ab044721cc..321311a1d0 100644 --- a/crates/ruff_db/src/source.rs +++ b/crates/ruff_db/src/source.rs @@ -1,9 +1,9 @@ +use countme::Count; +use ruff_source_file::LineIndex; use salsa::DebugWithDb; use std::ops::Deref; use std::sync::Arc; -use ruff_source_file::LineIndex; - use crate::vfs::VfsFile; use crate::Db; @@ -16,6 +16,7 @@ pub fn source_text(db: &dyn Db, file: VfsFile) -> SourceText { SourceText { inner: Arc::from(content), + count: Count::new(), } } @@ -35,6 +36,7 @@ pub fn line_index(db: &dyn Db, file: VfsFile) -> LineIndex { #[derive(Clone, Eq, PartialEq)] pub struct SourceText { inner: Arc, + count: Count, } impl SourceText { diff --git a/crates/ruff_db/src/vfs.rs b/crates/ruff_db/src/vfs.rs index f9ca06eb6f..4725f3aa50 100644 --- a/crates/ruff_db/src/vfs.rs +++ b/crates/ruff_db/src/vfs.rs @@ -104,6 +104,7 @@ impl Vfs { /// /// The operation always succeeds even if the path doesn't exist on disk, isn't accessible or if the path points to a directory. /// In these cases, a file with status [`FileStatus::Deleted`] is returned. + #[tracing::instrument(level = "debug", skip(self, db))] fn file_system(&self, db: &dyn Db, path: &FileSystemPath) -> VfsFile { *self .inner @@ -135,6 +136,7 @@ impl Vfs { /// Looks up a vendored file by its path. Returns `Some` if a vendored file for the given path /// exists and `None` otherwise. + #[tracing::instrument(level = "debug", skip(self, db))] fn vendored(&self, db: &dyn Db, path: &VendoredPath) -> Option { let file = match self .inner