diff --git a/crates/red_knot_python_semantic/src/ast_node_ref.rs b/crates/red_knot_python_semantic/src/ast_node_ref.rs index 667b8f6363..ad4b03e9af 100644 --- a/crates/red_knot_python_semantic/src/ast_node_ref.rs +++ b/crates/red_knot_python_semantic/src/ast_node_ref.rs @@ -31,17 +31,13 @@ use ruff_db::parsed::ParsedModule; /// This means that changes to expressions in other scopes don't invalidate the expression's id, giving /// us some form of scope-stable identity for expressions. Only queries accessing the node field /// run on every AST change. All other queries only run when the expression's identity changes. -/// -/// The one exception to this is if it is known that all queries tacking the tracked struct -/// as argument or returning it as part of their result are known to access the node field. -/// Marking the field tracked is then unnecessary. #[derive(Clone)] pub struct AstNodeRef { /// Owned reference to the node's [`ParsedModule`]. /// /// The node's reference is guaranteed to remain valid as long as it's enclosing /// [`ParsedModule`] is alive. - _parsed: ParsedModule, + parsed: ParsedModule, /// Pointer to the referenced node. node: std::ptr::NonNull, @@ -59,7 +55,7 @@ impl AstNodeRef { /// the invariant `node belongs to parsed` is upheld. pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self { Self { - _parsed: parsed, + parsed, node: std::ptr::NonNull::from(node), } } @@ -89,17 +85,44 @@ where } } -impl PartialEq for AstNodeRef { +impl PartialEq for AstNodeRef +where + T: PartialEq, +{ fn eq(&self, other: &Self) -> bool { - self.node.eq(&other.node) + if self.parsed == other.parsed { + // Comparing the pointer addresses is sufficient to determine equality + // if the parsed are the same. + self.node.eq(&other.node) + } else { + // Otherwise perform a deep comparison. + self.node().eq(other.node()) + } } } -impl Eq for AstNodeRef {} +impl Eq for AstNodeRef where T: Eq {} -impl Hash for AstNodeRef { +impl Hash for AstNodeRef +where + T: Hash, +{ fn hash(&self, state: &mut H) { - self.node.hash(state); + self.node().hash(state); + } +} + +#[allow(unsafe_code)] +unsafe impl salsa::Update for AstNodeRef { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + let old_ref = &mut (*old_pointer); + + if old_ref.parsed == new_value.parsed && old_ref.node.eq(&new_value.node) { + false + } else { + *old_ref = new_value; + true + } } } @@ -133,7 +156,7 @@ mod tests { let stmt_cloned = &cloned.syntax().body[0]; let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) }; - assert_ne!(node1, cloned_node); + assert_eq!(node1, cloned_node); let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python); let other = ParsedModule::new(other_raw); diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 830ad0ffd4..adc1367560 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -439,7 +439,7 @@ impl DefinitionCategory { /// [`DefinitionKind`] fields in salsa tracked structs should be tracked (attributed with `#[tracked]`) /// because the kind is a thin wrapper around [`AstNodeRef`]. See the [`AstNodeRef`] documentation /// for an in-depth explanation of why this is necessary. -#[derive(Clone, Debug, Hash)] +#[derive(Clone, Debug)] pub enum DefinitionKind<'db> { Import(AstNodeRef), ImportFrom(ImportFromDefinitionKind), @@ -559,7 +559,7 @@ impl<'db> From>> for TargetKind<'db> { } } -#[derive(Clone, Debug, Hash)] +#[derive(Clone, Debug)] #[allow(dead_code)] pub struct MatchPatternDefinitionKind { pattern: AstNodeRef, @@ -577,7 +577,7 @@ impl MatchPatternDefinitionKind { } } -#[derive(Clone, Debug, Hash)] +#[derive(Clone, Debug)] pub struct ComprehensionDefinitionKind { iterable: AstNodeRef, target: AstNodeRef, @@ -603,7 +603,7 @@ impl ComprehensionDefinitionKind { } } -#[derive(Clone, Debug, Hash)] +#[derive(Clone, Debug)] pub struct ImportFromDefinitionKind { node: AstNodeRef, alias_index: usize, @@ -619,7 +619,7 @@ impl ImportFromDefinitionKind { } } -#[derive(Clone, Debug, Hash)] +#[derive(Clone, Debug)] pub struct AssignmentDefinitionKind<'db> { target: TargetKind<'db>, value: AstNodeRef, @@ -645,7 +645,7 @@ impl<'db> AssignmentDefinitionKind<'db> { } } -#[derive(Clone, Debug, Hash)] +#[derive(Clone, Debug)] pub struct WithItemDefinitionKind { node: AstNodeRef, target: AstNodeRef, @@ -666,7 +666,7 @@ impl WithItemDefinitionKind { } } -#[derive(Clone, Debug, Hash)] +#[derive(Clone, Debug)] pub struct ForStmtDefinitionKind<'db> { target: TargetKind<'db>, iterable: AstNodeRef, @@ -697,7 +697,7 @@ impl<'db> ForStmtDefinitionKind<'db> { } } -#[derive(Clone, Debug, Hash)] +#[derive(Clone, Debug)] pub struct ExceptHandlerDefinitionKind { handler: AstNodeRef, is_star: bool, diff --git a/crates/red_knot_python_semantic/src/unpack.rs b/crates/red_knot_python_semantic/src/unpack.rs index 05a396d4dc..ad6eac413d 100644 --- a/crates/red_knot_python_semantic/src/unpack.rs +++ b/crates/red_knot_python_semantic/src/unpack.rs @@ -36,6 +36,7 @@ pub(crate) struct Unpack<'db> { /// expression is `(a, b)`. #[no_eq] #[return_ref] + #[tracked] pub(crate) target: AstNodeRef, /// The ingredient representing the value expression of the unpacking. For example, in