Compare commits
1 Commits
dhruv/para
...
jack/loop-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a38a18e2d3 |
@@ -886,6 +886,8 @@ class GenericClass[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _(x: list[str]):
|
||||
# TODO: This fails because we are not propagating GenericClass's generic context into the
|
||||
# Callable that we create for it.
|
||||
# revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass]
|
||||
reveal_type(into_callable(GenericClass))
|
||||
# revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
@@ -893,10 +895,15 @@ def _(x: list[str]):
|
||||
|
||||
# revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass]
|
||||
reveal_type(accepts_callable(GenericClass))
|
||||
# revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
# TODO: revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
# revealed: None
|
||||
reveal_type(generic_context(accepts_callable(GenericClass)))
|
||||
|
||||
# revealed: GenericClass[str]
|
||||
# TODO: revealed: GenericClass[str]
|
||||
# TODO: no errors
|
||||
# revealed: GenericClass[T@GenericClass]
|
||||
# error: [invalid-argument-type]
|
||||
# error: [invalid-argument-type]
|
||||
reveal_type(accepts_callable(GenericClass)(x, x))
|
||||
```
|
||||
|
||||
|
||||
@@ -800,78 +800,3 @@ def f(x: int, y: str):
|
||||
|
||||
reveal_type(infer_paramspec(f)) # revealed: (x: int, y: str) -> None
|
||||
```
|
||||
|
||||
## Generic context preservation through `ParamSpec` decorators
|
||||
|
||||
When a generic function is decorated with a `ParamSpec`-based decorator, the generic context of the
|
||||
decorated function should be preserved. This allows type inference to work correctly when calling the
|
||||
decorated function.
|
||||
|
||||
Regression test for <https://github.com/astral-sh/ty/issues/2336>
|
||||
|
||||
### Basic
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
from ty_extensions import generic_context
|
||||
|
||||
def decorator[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||
return func
|
||||
|
||||
@decorator
|
||||
def identity[T](value: T) -> T:
|
||||
return value
|
||||
|
||||
@decorator
|
||||
def pair[T, U](first: T, second: U) -> tuple[T, U]:
|
||||
return (first, second)
|
||||
|
||||
# revealed: ty_extensions.GenericContext[T@identity]
|
||||
reveal_type(generic_context(identity))
|
||||
# revealed: ty_extensions.GenericContext[T@pair, U@pair]
|
||||
reveal_type(generic_context(pair))
|
||||
|
||||
reveal_type(identity(1)) # revealed: Literal[1]
|
||||
reveal_type(identity("hello")) # revealed: Literal["hello"]
|
||||
|
||||
reveal_type(pair(1, "a")) # revealed: tuple[Literal[1], Literal["a"]]
|
||||
reveal_type(pair("x", 2.5)) # revealed: tuple[Literal["x"], float]
|
||||
```
|
||||
|
||||
### Chained decorators with generic functions
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
|
||||
def decorator1[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
def decorator2[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
@decorator1
|
||||
@decorator2
|
||||
def chained_generic[T](value: T) -> T:
|
||||
return value
|
||||
|
||||
reveal_type(chained_generic(42)) # revealed: Literal[42]
|
||||
reveal_type(chained_generic("test")) # revealed: Literal["test"]
|
||||
```
|
||||
|
||||
### Generic method decoration
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
|
||||
def method_decorator[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
class Container:
|
||||
@method_decorator
|
||||
def generic_method[T](self, value: T) -> T:
|
||||
return value
|
||||
|
||||
c = Container()
|
||||
reveal_type(c.generic_method(100)) # revealed: Literal[100]
|
||||
reveal_type(c.generic_method([1, 2, 3])) # revealed: list[Unknown | int]
|
||||
```
|
||||
|
||||
@@ -127,3 +127,64 @@ class NotBoolable:
|
||||
while NotBoolable():
|
||||
...
|
||||
```
|
||||
|
||||
## Backwards control flow
|
||||
|
||||
```py
|
||||
i = 0
|
||||
reveal_type(i) # revealed: Literal[0]
|
||||
while i < 1_000_000:
|
||||
reveal_type(i) # revealed: int
|
||||
i += 1
|
||||
reveal_type(i) # revealed: int
|
||||
reveal_type(i) # revealed: int
|
||||
|
||||
# TODO: None of these should need to be raised to `int`. Loop control flow analysis should take the
|
||||
# loop condition into account.
|
||||
i = 0
|
||||
reveal_type(i) # revealed: Literal[0]
|
||||
while i < 2:
|
||||
# TODO: Should be Literal[0, 1].
|
||||
reveal_type(i) # revealed: int
|
||||
i += 1
|
||||
# TODO: Should be Literal[1, 2].
|
||||
reveal_type(i) # revealed: int
|
||||
# TODO: Should be Literal[2].
|
||||
reveal_type(i) # revealed: int
|
||||
```
|
||||
|
||||
```py
|
||||
def random() -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
reveal_type(i) # revealed: Literal[0, 1, 2]
|
||||
if random():
|
||||
i = 1
|
||||
else:
|
||||
i = "break"
|
||||
break
|
||||
# To get here we must take the `i = 1` branch above.
|
||||
reveal_type(i) # revealed: Literal[1]
|
||||
if random():
|
||||
i = 2
|
||||
reveal_type(i) # revealed: Literal[1, 2]
|
||||
reveal_type(i) # revealed: Literal["break"]
|
||||
|
||||
i = 0
|
||||
while random():
|
||||
if random():
|
||||
reveal_type(i) # revealed: Literal[0, 1, 2, 3]
|
||||
i = 1
|
||||
reveal_type(i) # revealed: Literal[1]
|
||||
while random():
|
||||
if random():
|
||||
reveal_type(i) # revealed: Literal[1, 0, 2, 3]
|
||||
i = 2
|
||||
reveal_type(i) # revealed: Literal[2]
|
||||
if random():
|
||||
reveal_type(i) # revealed: Literal[1, 2, 0, 3]
|
||||
i = 3
|
||||
reveal_type(i) # revealed: Literal[3]
|
||||
```
|
||||
|
||||
@@ -229,6 +229,9 @@ pub(crate) struct SemanticIndex<'db> {
|
||||
/// Map from a standalone expression to its [`Expression`] ingredient.
|
||||
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
|
||||
|
||||
/// Map from loop-header definitions to their constituent definitions.
|
||||
loop_header_definitions: FxHashMap<Definition<'db>, Vec<Definition<'db>>>,
|
||||
|
||||
/// Map from nodes that create a scope to the scope they create.
|
||||
scopes_by_node: FxHashMap<NodeWithScopeKey, FileScopeId>,
|
||||
|
||||
@@ -319,6 +322,15 @@ impl<'db> SemanticIndex<'db> {
|
||||
self.scope_ids_by_scope.iter().copied()
|
||||
}
|
||||
|
||||
pub(crate) fn loop_header_definitions(
|
||||
&self,
|
||||
definition: Definition<'db>,
|
||||
) -> Option<&[Definition<'db>]> {
|
||||
self.loop_header_definitions
|
||||
.get(&definition)
|
||||
.map(|defs| defs.as_slice())
|
||||
}
|
||||
|
||||
pub(crate) fn symbol_is_global_in_scope(
|
||||
&self,
|
||||
symbol: ScopedSymbolId,
|
||||
|
||||
@@ -141,4 +141,10 @@ pub(crate) mod node_key {
|
||||
Self(NodeKey::from_node(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ExpressionNodeKey> for NodeKey {
|
||||
fn from(value: ExpressionNodeKey) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,14 +19,15 @@ use ty_module_resolver::{ModuleName, resolve_module};
|
||||
|
||||
use crate::ast_node_ref::AstNodeRef;
|
||||
use crate::node_key::NodeKey;
|
||||
use crate::semantic_index::ast_ids::AstIdsBuilder;
|
||||
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
|
||||
use crate::semantic_index::ast_ids::{AstIdsBuilder, ScopedUseId};
|
||||
use crate::semantic_index::definition::{
|
||||
AnnotatedAssignmentDefinitionNodeRef, AssignmentDefinitionNodeRef,
|
||||
ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionNodeKey,
|
||||
DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionNodeRef,
|
||||
ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, ImportFromSubmoduleDefinitionNodeRef,
|
||||
MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef, WithItemDefinitionNodeRef,
|
||||
ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionKind,
|
||||
DefinitionNodeKey, DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef,
|
||||
ForStmtDefinitionNodeRef, ImportDefinitionNodeRef, ImportFromDefinitionNodeRef,
|
||||
ImportFromSubmoduleDefinitionNodeRef, LoopHeaderDefinitionKind, MatchPatternDefinitionNodeRef,
|
||||
StarImportDefinitionNodeRef, WithItemDefinitionNodeRef,
|
||||
};
|
||||
use crate::semantic_index::expression::{Expression, ExpressionKind};
|
||||
use crate::semantic_index::place::{PlaceExpr, PlaceTableBuilder, ScopedPlaceId};
|
||||
@@ -43,6 +44,7 @@ use crate::semantic_index::scope::{
|
||||
};
|
||||
use crate::semantic_index::scope::{Scope, ScopeId, ScopeKind, ScopeLaziness};
|
||||
use crate::semantic_index::symbol::{ScopedSymbolId, Symbol};
|
||||
use crate::semantic_index::use_def::Bindings;
|
||||
use crate::semantic_index::use_def::{
|
||||
EnclosingSnapshotKey, FlowSnapshot, ScopedEnclosingSnapshotId, UseDefMapBuilder,
|
||||
};
|
||||
@@ -53,22 +55,35 @@ use crate::{Db, Program};
|
||||
|
||||
mod except_handlers;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct LoopUse {
|
||||
place: ScopedPlaceId,
|
||||
use_id: ScopedUseId,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct Loop {
|
||||
/// Flow states at each `break` in the current loop.
|
||||
break_states: Vec<FlowSnapshot>,
|
||||
uses: Vec<LoopUse>,
|
||||
defined_places: FxHashSet<ScopedPlaceId>,
|
||||
}
|
||||
|
||||
impl Loop {
|
||||
fn push_break(&mut self, state: FlowSnapshot) {
|
||||
self.break_states.push(state);
|
||||
}
|
||||
|
||||
fn record_definition(&mut self, place: ScopedPlaceId) {
|
||||
self.defined_places.insert(place);
|
||||
}
|
||||
}
|
||||
|
||||
struct ScopeInfo {
|
||||
file_scope_id: FileScopeId,
|
||||
/// Current loop state; None if we are not currently visiting a loop
|
||||
current_loop: Option<Loop>,
|
||||
condition_place_uses: Option<FxHashSet<ScopedPlaceId>>,
|
||||
}
|
||||
|
||||
pub(super) struct SemanticIndexBuilder<'db, 'ast> {
|
||||
@@ -109,6 +124,7 @@ pub(super) struct SemanticIndexBuilder<'db, 'ast> {
|
||||
scopes_by_expression: ExpressionsScopeMapBuilder,
|
||||
definitions_by_node: FxHashMap<DefinitionNodeKey, Definitions<'db>>,
|
||||
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
|
||||
loop_header_definitions: FxHashMap<Definition<'db>, Vec<Definition<'db>>>,
|
||||
imported_modules: FxHashSet<ModuleName>,
|
||||
seen_submodule_imports: FxHashSet<String>,
|
||||
/// Hashset of all [`FileScopeId`]s that correspond to [generator functions].
|
||||
@@ -147,6 +163,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
|
||||
scopes_by_node: FxHashMap::default(),
|
||||
definitions_by_node: FxHashMap::default(),
|
||||
expressions_by_node: FxHashMap::default(),
|
||||
loop_header_definitions: FxHashMap::default(),
|
||||
|
||||
seen_submodule_imports: FxHashSet::default(),
|
||||
imported_modules: FxHashSet::default(),
|
||||
@@ -256,8 +273,17 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
|
||||
|
||||
/// Pop a loop, replacing with the previous saved outer loop, if any.
|
||||
fn pop_loop(&mut self, outer_loop: Option<Loop>) -> Loop {
|
||||
std::mem::replace(&mut self.current_scope_info_mut().current_loop, outer_loop)
|
||||
.expect("pop_loop() should not be called without a prior push_loop()")
|
||||
let inner_loop = std::mem::take(&mut self.current_scope_info_mut().current_loop)
|
||||
.expect("pop_loop() should not be called without a prior push_loop()");
|
||||
let merged_outer = outer_loop.map(|mut outer| {
|
||||
outer.uses.extend(inner_loop.uses.iter().cloned());
|
||||
outer
|
||||
.defined_places
|
||||
.extend(inner_loop.defined_places.iter().copied());
|
||||
outer
|
||||
});
|
||||
self.current_scope_info_mut().current_loop = merged_outer;
|
||||
inner_loop
|
||||
}
|
||||
|
||||
fn current_loop_mut(&mut self) -> Option<&mut Loop> {
|
||||
@@ -308,6 +334,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
|
||||
self.scope_stack.push(ScopeInfo {
|
||||
file_scope_id,
|
||||
current_loop: None,
|
||||
condition_place_uses: None,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -656,6 +683,9 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
|
||||
place: ScopedPlaceId,
|
||||
definition_node: impl Into<DefinitionNodeRef<'ast, 'db>> + std::fmt::Debug + Copy,
|
||||
) -> Definition<'db> {
|
||||
if let Some(current_loop) = self.current_loop_mut() {
|
||||
current_loop.record_definition(place);
|
||||
}
|
||||
let (definition, num_definitions) = self.push_additional_definition(place, definition_node);
|
||||
debug_assert_eq!(
|
||||
num_definitions, 1,
|
||||
@@ -755,6 +785,40 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
|
||||
(definition, num_definitions)
|
||||
}
|
||||
|
||||
fn add_loop_header_definition(
|
||||
&mut self,
|
||||
place: ScopedPlaceId,
|
||||
loop_node: &'ast ast::StmtWhile,
|
||||
definitions: Vec<Definition<'db>>,
|
||||
seed_definitions: Vec<Definition<'db>>,
|
||||
bindings: Bindings,
|
||||
seed_bindings: Bindings,
|
||||
) -> Definition<'db> {
|
||||
let kind = DefinitionKind::LoopHeader(LoopHeaderDefinitionKind::new(
|
||||
AstNodeRef::new(self.module, loop_node),
|
||||
definitions,
|
||||
seed_definitions,
|
||||
bindings,
|
||||
seed_bindings,
|
||||
));
|
||||
let is_reexported = kind.is_reexported();
|
||||
let definition = Definition::new(
|
||||
self.db,
|
||||
self.file,
|
||||
self.current_scope(),
|
||||
place,
|
||||
kind,
|
||||
is_reexported,
|
||||
);
|
||||
|
||||
self.add_entry_for_definition_key(DefinitionNodeKey::from_node_key(NodeKey::from_node(
|
||||
loop_node,
|
||||
)))
|
||||
.push(definition);
|
||||
|
||||
definition
|
||||
}
|
||||
|
||||
fn record_expression_narrowing_constraint(
|
||||
&mut self,
|
||||
predicate_node: &ast::Expr,
|
||||
@@ -1318,6 +1382,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
|
||||
scopes: self.scopes,
|
||||
definitions_by_node: self.definitions_by_node,
|
||||
expressions_by_node: self.expressions_by_node,
|
||||
loop_header_definitions: self.loop_header_definitions,
|
||||
scope_ids_by_scope: self.scope_ids_by_scope,
|
||||
ast_ids,
|
||||
scopes_by_expression: self.scopes_by_expression.build(),
|
||||
@@ -1341,6 +1406,96 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
|
||||
self.source_text
|
||||
.get_or_init(|| source_text(self.db, self.file))
|
||||
}
|
||||
|
||||
fn record_use(&mut self, place_id: ScopedPlaceId, expr_node_key: ExpressionNodeKey) {
|
||||
if let ScopedPlaceId::Symbol(symbol_id) = place_id {
|
||||
self.mark_symbol_used(symbol_id);
|
||||
}
|
||||
let use_id = self.current_ast_ids().record_use(expr_node_key);
|
||||
self.current_use_def_map_mut()
|
||||
.record_use(place_id, use_id, expr_node_key.into());
|
||||
if let Some(condition_place_uses) = &mut self.current_scope_info_mut().condition_place_uses
|
||||
{
|
||||
condition_place_uses.insert(place_id);
|
||||
}
|
||||
if let Some(current_loop) = self.current_loop_mut() {
|
||||
current_loop.uses.push(LoopUse {
|
||||
place: place_id,
|
||||
use_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn create_loop_header_definitions(
|
||||
&mut self,
|
||||
loop_node: &'ast ast::StmtWhile,
|
||||
loop_state: &Loop,
|
||||
pre_loop: &FlowSnapshot,
|
||||
post_body: &FlowSnapshot,
|
||||
) {
|
||||
let mut used_places = FxHashSet::default();
|
||||
for loop_use in &loop_state.uses {
|
||||
used_places.insert(loop_use.place);
|
||||
}
|
||||
|
||||
let scope_id = self.current_scope();
|
||||
for place in loop_state.defined_places.iter() {
|
||||
if !used_places.contains(place) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pre_loop_binding_ids = pre_loop.binding_ids_for_place_excluding_unbound(*place);
|
||||
let seed_bindings = pre_loop.bindings_for_place(*place);
|
||||
let loop_bindings = self
|
||||
.current_use_def_map_mut()
|
||||
.merge_bindings(seed_bindings.clone(), post_body.bindings_for_place(*place));
|
||||
let mut seed_definitions = self
|
||||
.current_use_def_map()
|
||||
.definitions_for_place_in_snapshot(pre_loop, *place);
|
||||
let mut definitions = seed_definitions.clone();
|
||||
definitions.extend(
|
||||
self.current_use_def_map()
|
||||
.definitions_for_place_in_snapshot(post_body, *place),
|
||||
);
|
||||
definitions.sort();
|
||||
definitions.dedup();
|
||||
seed_definitions.sort();
|
||||
seed_definitions.dedup();
|
||||
|
||||
if definitions.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let header_definition = self.add_loop_header_definition(
|
||||
*place,
|
||||
loop_node,
|
||||
definitions.clone(),
|
||||
seed_definitions,
|
||||
loop_bindings,
|
||||
seed_bindings,
|
||||
);
|
||||
let header_definition_id = self.use_def_maps[scope_id]
|
||||
.register_definition_with_bindings(
|
||||
header_definition,
|
||||
pre_loop.bindings_for_place(*place),
|
||||
pre_loop.declarations_for_place(*place),
|
||||
);
|
||||
self.loop_header_definitions
|
||||
.insert(header_definition, definitions);
|
||||
|
||||
if pre_loop_binding_ids.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
for loop_use in loop_state.uses.iter().filter(|use_| use_.place == *place) {
|
||||
self.current_use_def_map_mut().replace_use_bindings(
|
||||
loop_use.use_id,
|
||||
&pre_loop_binding_ids,
|
||||
header_definition_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
|
||||
@@ -1924,41 +2079,66 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
|
||||
|
||||
self.in_type_checking_block = is_outer_block_in_type_checking;
|
||||
}
|
||||
ast::Stmt::While(ast::StmtWhile {
|
||||
test,
|
||||
body,
|
||||
orelse,
|
||||
range: _,
|
||||
node_index: _,
|
||||
}) => {
|
||||
ast::Stmt::While(stmt_while) => {
|
||||
let ast::StmtWhile {
|
||||
test,
|
||||
body,
|
||||
orelse,
|
||||
range: _,
|
||||
node_index: _,
|
||||
} = stmt_while;
|
||||
self.current_scope_info_mut()
|
||||
.condition_place_uses
|
||||
.replace(FxHashSet::default());
|
||||
let outer_loop = self.push_loop();
|
||||
let has_outer_loop = outer_loop.is_some();
|
||||
self.visit_expr(test);
|
||||
let condition_place_uses = self
|
||||
.current_scope_info_mut()
|
||||
.condition_place_uses
|
||||
.take()
|
||||
.unwrap();
|
||||
|
||||
let pre_loop = self.flow_snapshot();
|
||||
let predicate = self.record_expression_narrowing_constraint(test);
|
||||
self.record_reachability_constraint(predicate);
|
||||
let predicate = self.build_predicate(test);
|
||||
let predicate_id = self.add_predicate(predicate);
|
||||
self.record_narrowing_constraint_id(predicate_id);
|
||||
self.record_reachability_constraint_id(predicate_id);
|
||||
|
||||
let outer_loop = self.push_loop();
|
||||
self.visit_body(body);
|
||||
let this_loop = self.pop_loop(outer_loop);
|
||||
let post_body = self.flow_snapshot();
|
||||
if !has_outer_loop {
|
||||
self.create_loop_header_definitions(
|
||||
stmt_while, &this_loop, &pre_loop, &post_body,
|
||||
);
|
||||
}
|
||||
|
||||
// We execute the `else` branch once the condition evaluates to false. This could
|
||||
// happen without ever executing the body, if the condition is false the first time
|
||||
// it's tested. Or it could happen if a _later_ evaluation of the condition yields
|
||||
// false. So we merge in the pre-loop state here into the post-body state:
|
||||
|
||||
self.flow_merge(pre_loop);
|
||||
self.flow_merge(pre_loop.clone());
|
||||
|
||||
// The `else` branch can only be reached if the loop condition *can* be false. To
|
||||
// model this correctly, we need a second copy of the while condition constraint,
|
||||
// since the first and later evaluations might produce different results. We would
|
||||
// otherwise simplify `predicate AND ~predicate` to `False`.
|
||||
let later_predicate_id = self.current_use_def_map_mut().add_predicate(predicate);
|
||||
let later_reachability_constraint = self
|
||||
.current_reachability_constraints_mut()
|
||||
.add_atom(later_predicate_id);
|
||||
self.record_negated_reachability_constraint(later_reachability_constraint);
|
||||
|
||||
self.record_negated_narrowing_constraint(predicate);
|
||||
let condition_depends_on_loop = condition_place_uses
|
||||
.iter()
|
||||
.any(|place| pre_loop.has_new_bindings_for_place(&post_body, *place));
|
||||
if condition_depends_on_loop {
|
||||
self.record_ambiguous_reachability();
|
||||
} else {
|
||||
let later_predicate_id =
|
||||
self.current_use_def_map_mut().add_predicate(predicate);
|
||||
let later_reachability_constraint = self
|
||||
.current_reachability_constraints_mut()
|
||||
.add_atom(later_predicate_id);
|
||||
self.record_negated_reachability_constraint(later_reachability_constraint);
|
||||
self.record_negated_narrowing_constraint(predicate);
|
||||
}
|
||||
|
||||
self.visit_body(orelse);
|
||||
|
||||
@@ -2469,12 +2649,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
|
||||
let place_id = self.add_place(place_expr);
|
||||
|
||||
if is_use {
|
||||
if let ScopedPlaceId::Symbol(symbol_id) = place_id {
|
||||
self.mark_symbol_used(symbol_id);
|
||||
}
|
||||
let use_id = self.current_ast_ids().record_use(expr);
|
||||
self.current_use_def_map_mut()
|
||||
.record_use(place_id, use_id, node_key);
|
||||
self.record_use(place_id, expr.into());
|
||||
}
|
||||
|
||||
if is_definition {
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::node_key::NodeKey;
|
||||
use crate::semantic_index::place::ScopedPlaceId;
|
||||
use crate::semantic_index::scope::{FileScopeId, ScopeId};
|
||||
use crate::semantic_index::symbol::ScopedSymbolId;
|
||||
use crate::semantic_index::use_def::Bindings;
|
||||
use crate::unpack::{Unpack, UnpackPosition};
|
||||
|
||||
/// A definition of a place.
|
||||
@@ -753,6 +754,7 @@ pub enum DefinitionKind<'db> {
|
||||
TypeVar(AstNodeRef<ast::TypeParamTypeVar>),
|
||||
ParamSpec(AstNodeRef<ast::TypeParamParamSpec>),
|
||||
TypeVarTuple(AstNodeRef<ast::TypeParamTypeVarTuple>),
|
||||
LoopHeader(LoopHeaderDefinitionKind<'db>),
|
||||
}
|
||||
|
||||
impl DefinitionKind<'_> {
|
||||
@@ -835,6 +837,7 @@ impl DefinitionKind<'_> {
|
||||
DefinitionKind::TypeVarTuple(type_var_tuple) => {
|
||||
type_var_tuple.node(module).name.range()
|
||||
}
|
||||
DefinitionKind::LoopHeader(loop_header) => loop_header.node(module).range(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -880,6 +883,7 @@ impl DefinitionKind<'_> {
|
||||
DefinitionKind::TypeVar(type_var) => type_var.node(module).range(),
|
||||
DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).range(),
|
||||
DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.node(module).range(),
|
||||
DefinitionKind::LoopHeader(loop_header) => loop_header.node(module).range(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -935,7 +939,8 @@ impl DefinitionKind<'_> {
|
||||
| DefinitionKind::WithItem(_)
|
||||
| DefinitionKind::MatchPattern(_)
|
||||
| DefinitionKind::ImportFromSubmodule(_)
|
||||
| DefinitionKind::ExceptHandler(_) => DefinitionCategory::Binding,
|
||||
| DefinitionKind::ExceptHandler(_)
|
||||
| DefinitionKind::LoopHeader(_) => DefinitionCategory::Binding,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1211,9 +1216,62 @@ impl ExceptHandlerDefinitionKind {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, get_size2::GetSize)]
|
||||
pub struct LoopHeaderDefinitionKind<'db> {
|
||||
node: AstNodeRef<ast::StmtWhile>,
|
||||
definitions: Vec<Definition<'db>>,
|
||||
seed_definitions: Vec<Definition<'db>>,
|
||||
bindings: Bindings,
|
||||
seed_bindings: Bindings,
|
||||
}
|
||||
|
||||
impl<'db> LoopHeaderDefinitionKind<'db> {
|
||||
pub(crate) fn new(
|
||||
node: AstNodeRef<ast::StmtWhile>,
|
||||
definitions: Vec<Definition<'db>>,
|
||||
seed_definitions: Vec<Definition<'db>>,
|
||||
bindings: Bindings,
|
||||
seed_bindings: Bindings,
|
||||
) -> Self {
|
||||
Self {
|
||||
node,
|
||||
definitions,
|
||||
seed_definitions,
|
||||
bindings,
|
||||
seed_bindings,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn node<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtWhile {
|
||||
self.node.node(module)
|
||||
}
|
||||
|
||||
pub(crate) fn definitions(&self) -> &[Definition<'db>] {
|
||||
&self.definitions
|
||||
}
|
||||
|
||||
pub(crate) fn seed_definitions(&self) -> &[Definition<'db>] {
|
||||
&self.seed_definitions
|
||||
}
|
||||
|
||||
pub(crate) fn bindings(&self) -> &Bindings {
|
||||
&self.bindings
|
||||
}
|
||||
|
||||
pub(crate) fn seed_bindings(&self) -> &Bindings {
|
||||
&self.seed_bindings
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, salsa::Update, get_size2::GetSize)]
|
||||
pub(crate) struct DefinitionNodeKey(NodeKey);
|
||||
|
||||
impl DefinitionNodeKey {
|
||||
pub(crate) fn from_node_key(node_key: NodeKey) -> Self {
|
||||
Self(node_key)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ast::Alias> for DefinitionNodeKey {
|
||||
fn from(node: &ast::Alias) -> Self {
|
||||
Self(NodeKey::from_node(node))
|
||||
@@ -1280,6 +1338,12 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ast::StmtWhile> for DefinitionNodeKey {
|
||||
fn from(node: &ast::StmtWhile) -> Self {
|
||||
Self(NodeKey::from_node(node))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ast::Parameter> for DefinitionNodeKey {
|
||||
fn from(node: &ast::Parameter) -> Self {
|
||||
Self(NodeKey::from_node(node))
|
||||
|
||||
@@ -111,6 +111,13 @@ impl NarrowingConstraintsBuilder {
|
||||
) -> ScopedNarrowingConstraint {
|
||||
self.lists.intersect(a, b)
|
||||
}
|
||||
|
||||
pub(crate) fn iter_predicates(
|
||||
&self,
|
||||
set: ScopedNarrowingConstraint,
|
||||
) -> NarrowingConstraintsIterator<'_> {
|
||||
self.lists.iter_set_reverse(set).copied()
|
||||
}
|
||||
}
|
||||
|
||||
// Iteration
|
||||
@@ -142,12 +149,5 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
impl NarrowingConstraintsBuilder {
|
||||
pub(crate) fn iter_predicates(
|
||||
&self,
|
||||
set: ScopedNarrowingConstraint,
|
||||
) -> NarrowingConstraintsIterator<'_> {
|
||||
self.lists.iter_set_reverse(set).copied()
|
||||
}
|
||||
}
|
||||
// Test-only impl removed; use the main impl above.
|
||||
}
|
||||
|
||||
@@ -261,8 +261,9 @@ use crate::semantic_index::reachability_constraints::{
|
||||
};
|
||||
use crate::semantic_index::scope::{FileScopeId, ScopeKind, ScopeLaziness};
|
||||
use crate::semantic_index::symbol::ScopedSymbolId;
|
||||
pub(crate) use crate::semantic_index::use_def::place_state::Bindings;
|
||||
use crate::semantic_index::use_def::place_state::{
|
||||
Bindings, Declarations, EnclosingSnapshot, LiveBindingsIterator, LiveDeclaration,
|
||||
Declarations, EnclosingSnapshot, LiveBindingsIterator, LiveDeclaration,
|
||||
LiveDeclarationsIterator, PlaceState, PreviousDefinitions, ScopedDefinitionId,
|
||||
};
|
||||
use crate::semantic_index::{EnclosingSnapshotResult, SemanticIndex};
|
||||
@@ -364,6 +365,13 @@ impl<'db> UseDefMap<'db> {
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn bindings_from_snapshot<'a>(
|
||||
&'a self,
|
||||
bindings: &'a Bindings,
|
||||
) -> BindingWithConstraintsIterator<'a, 'db> {
|
||||
self.bindings_iterator(bindings, BoundnessAnalysis::BasedOnUnboundVisibility)
|
||||
}
|
||||
|
||||
pub(crate) fn applicable_constraints(
|
||||
&self,
|
||||
constraint_key: ConstraintKey,
|
||||
@@ -825,6 +833,52 @@ pub(super) struct FlowSnapshot {
|
||||
reachability: ScopedReachabilityConstraintId,
|
||||
}
|
||||
|
||||
impl FlowSnapshot {
|
||||
pub(super) fn has_new_bindings_for_place(
|
||||
&self,
|
||||
other: &FlowSnapshot,
|
||||
place: ScopedPlaceId,
|
||||
) -> bool {
|
||||
let (self_ids, other_ids) = match place {
|
||||
ScopedPlaceId::Symbol(symbol) => (
|
||||
self.symbol_states[symbol].bindings().binding_ids(),
|
||||
other.symbol_states[symbol].bindings().binding_ids(),
|
||||
),
|
||||
ScopedPlaceId::Member(member) => (
|
||||
self.member_states[member].bindings().binding_ids(),
|
||||
other.member_states[member].bindings().binding_ids(),
|
||||
),
|
||||
};
|
||||
other_ids.iter().any(|id| !self_ids.contains(id))
|
||||
}
|
||||
|
||||
pub(super) fn bindings_for_place(&self, place: ScopedPlaceId) -> Bindings {
|
||||
match place {
|
||||
ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].bindings().clone(),
|
||||
ScopedPlaceId::Member(member) => self.member_states[member].bindings().clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn binding_ids_for_place_excluding_unbound(
|
||||
&self,
|
||||
place: ScopedPlaceId,
|
||||
) -> Vec<ScopedDefinitionId> {
|
||||
let mut ids = match place {
|
||||
ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].bindings().binding_ids(),
|
||||
ScopedPlaceId::Member(member) => self.member_states[member].bindings().binding_ids(),
|
||||
};
|
||||
ids.retain(|id| *id != ScopedDefinitionId::UNBOUND);
|
||||
ids
|
||||
}
|
||||
|
||||
pub(super) fn declarations_for_place(&self, place: ScopedPlaceId) -> Declarations {
|
||||
match place {
|
||||
ScopedPlaceId::Symbol(symbol) => self.symbol_states[symbol].declarations().clone(),
|
||||
ScopedPlaceId::Member(member) => self.member_states[member].declarations().clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A snapshot of the state of a single symbol (e.g. `obj`) and all of its associated members
|
||||
/// (e.g. `obj.attr`, `obj["key"]`).
|
||||
pub(super) struct SingleSymbolSnapshot {
|
||||
@@ -1248,6 +1302,107 @@ impl<'db> UseDefMapBuilder<'db> {
|
||||
self.record_node_reachability(node_key);
|
||||
}
|
||||
|
||||
pub(super) fn merge_use_with_snapshot(
|
||||
&mut self,
|
||||
use_id: ScopedUseId,
|
||||
snapshot: &FlowSnapshot,
|
||||
place: ScopedPlaceId,
|
||||
predicate: Option<ScopedPredicateId>,
|
||||
) {
|
||||
let mut bindings = self.bindings_by_use[use_id].clone();
|
||||
let mut backedge_bindings = snapshot.bindings_for_place(place);
|
||||
if let Some(predicate) = predicate {
|
||||
if predicate != ScopedPredicateId::ALWAYS_TRUE
|
||||
&& predicate != ScopedPredicateId::ALWAYS_FALSE
|
||||
{
|
||||
backedge_bindings
|
||||
.record_narrowing_constraint(&mut self.narrowing_constraints, predicate.into());
|
||||
}
|
||||
}
|
||||
bindings.merge(
|
||||
backedge_bindings,
|
||||
&mut self.narrowing_constraints,
|
||||
&mut self.reachability_constraints,
|
||||
);
|
||||
self.bindings_by_use[use_id] = bindings;
|
||||
}
|
||||
|
||||
pub(super) fn register_definition(
|
||||
&mut self,
|
||||
definition: Definition<'db>,
|
||||
) -> ScopedDefinitionId {
|
||||
self.all_definitions
|
||||
.push(DefinitionState::Defined(definition))
|
||||
}
|
||||
|
||||
pub(super) fn register_definition_with_bindings(
|
||||
&mut self,
|
||||
definition: Definition<'db>,
|
||||
bindings: Bindings,
|
||||
declarations: Declarations,
|
||||
) -> ScopedDefinitionId {
|
||||
let def_id = self
|
||||
.all_definitions
|
||||
.push(DefinitionState::Defined(definition));
|
||||
self.bindings_by_definition.insert(definition, bindings);
|
||||
self.declarations_by_binding
|
||||
.insert(definition, declarations);
|
||||
def_id
|
||||
}
|
||||
|
||||
pub(super) fn replace_use_with_definition(
|
||||
&mut self,
|
||||
use_id: ScopedUseId,
|
||||
definition_id: ScopedDefinitionId,
|
||||
) {
|
||||
self.bindings_by_use[use_id].replace_definition(definition_id);
|
||||
}
|
||||
|
||||
pub(super) fn replace_use_bindings(
|
||||
&mut self,
|
||||
use_id: ScopedUseId,
|
||||
from: &[ScopedDefinitionId],
|
||||
definition_id: ScopedDefinitionId,
|
||||
) {
|
||||
self.bindings_by_use[use_id].replace_definitions(
|
||||
from,
|
||||
definition_id,
|
||||
&mut self.narrowing_constraints,
|
||||
&mut self.reachability_constraints,
|
||||
);
|
||||
}
|
||||
|
||||
pub(super) fn merge_bindings(&mut self, mut bindings: Bindings, other: Bindings) -> Bindings {
|
||||
bindings.merge(
|
||||
other,
|
||||
&mut self.narrowing_constraints,
|
||||
&mut self.reachability_constraints,
|
||||
);
|
||||
bindings
|
||||
}
|
||||
|
||||
pub(super) fn definitions_for_place_in_snapshot(
|
||||
&self,
|
||||
snapshot: &FlowSnapshot,
|
||||
place: ScopedPlaceId,
|
||||
) -> Vec<Definition<'db>> {
|
||||
let binding_ids = match place {
|
||||
ScopedPlaceId::Symbol(symbol) => {
|
||||
snapshot.symbol_states[symbol].bindings().binding_ids()
|
||||
}
|
||||
ScopedPlaceId::Member(member) => {
|
||||
snapshot.member_states[member].bindings().binding_ids()
|
||||
}
|
||||
};
|
||||
binding_ids
|
||||
.into_iter()
|
||||
.filter_map(|binding_id| match self.all_definitions[binding_id] {
|
||||
DefinitionState::Defined(definition) => Some(definition),
|
||||
DefinitionState::Undefined | DefinitionState::Deleted => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(super) fn record_node_reachability(&mut self, node_key: NodeKey) {
|
||||
self.node_reachability.insert(node_key, self.reachability);
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ use crate::semantic_index::reachability_constraints::{
|
||||
/// A newtype-index for a definition in a particular scope.
|
||||
#[newtype_index]
|
||||
#[derive(Ord, PartialOrd, get_size2::GetSize)]
|
||||
pub(super) struct ScopedDefinitionId;
|
||||
pub(crate) struct ScopedDefinitionId;
|
||||
|
||||
impl ScopedDefinitionId {
|
||||
/// A special ID that is used to describe an implicit start-of-scope state. When
|
||||
@@ -74,7 +74,7 @@ impl ScopedDefinitionId {
|
||||
/// Live declarations for a single place at some point in control flow, with their
|
||||
/// corresponding reachability constraints.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
|
||||
pub(super) struct Declarations {
|
||||
pub(crate) struct Declarations {
|
||||
/// A list of live declarations for this place, sorted by their `ScopedDefinitionId`
|
||||
live_declarations: SmallVec<[LiveDeclaration; 2]>,
|
||||
}
|
||||
@@ -206,7 +206,7 @@ impl EnclosingSnapshot {
|
||||
/// Live bindings for a single place at some point in control flow. Each live binding comes
|
||||
/// with a set of narrowing constraints and a reachability constraint.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
|
||||
pub(super) struct Bindings {
|
||||
pub(crate) struct Bindings {
|
||||
/// The narrowing constraint applicable to the "unbound" binding, if we need access to it even
|
||||
/// when it's not visible. This happens in class scopes, where local name bindings are not visible
|
||||
/// to nested scopes, but we still need to know what narrowing constraints were applied to the
|
||||
@@ -222,12 +222,109 @@ impl Bindings {
|
||||
.unwrap_or(self.live_bindings[0].narrowing_constraint)
|
||||
}
|
||||
|
||||
pub(super) fn has_defined_bindings(&self) -> bool {
|
||||
self.live_bindings
|
||||
.iter()
|
||||
.any(|binding| !binding.binding.is_unbound())
|
||||
}
|
||||
|
||||
pub(super) fn from_single(
|
||||
binding: ScopedDefinitionId,
|
||||
narrowing_constraint: ScopedNarrowingConstraint,
|
||||
reachability_constraint: ScopedReachabilityConstraintId,
|
||||
) -> Self {
|
||||
Self {
|
||||
unbound_narrowing_constraint: None,
|
||||
live_bindings: smallvec![LiveBinding {
|
||||
binding,
|
||||
narrowing_constraint,
|
||||
reachability_constraint,
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn representative_narrowing_constraint(&self) -> ScopedNarrowingConstraint {
|
||||
self.live_bindings
|
||||
.first()
|
||||
.map(|binding| binding.narrowing_constraint)
|
||||
.unwrap_or_else(ScopedNarrowingConstraint::empty)
|
||||
}
|
||||
|
||||
pub(super) fn retain_bindings(
|
||||
&mut self,
|
||||
mut predicate: impl FnMut(ScopedDefinitionId) -> bool,
|
||||
) {
|
||||
self.live_bindings
|
||||
.retain(|binding| predicate(binding.binding));
|
||||
}
|
||||
|
||||
pub(super) fn binding_ids(&self) -> Vec<ScopedDefinitionId> {
|
||||
self.live_bindings
|
||||
.iter()
|
||||
.map(|binding| binding.binding)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(super) fn finish(&mut self, reachability_constraints: &mut ReachabilityConstraintsBuilder) {
|
||||
self.live_bindings.shrink_to_fit();
|
||||
for binding in &self.live_bindings {
|
||||
reachability_constraints.mark_used(binding.reachability_constraint);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn replace_definition(&mut self, binding: ScopedDefinitionId) {
|
||||
for live_binding in &mut self.live_bindings {
|
||||
live_binding.binding = binding;
|
||||
}
|
||||
self.live_bindings
|
||||
.sort_by(|left, right| left.binding.cmp(&right.binding));
|
||||
}
|
||||
|
||||
pub(super) fn replace_definitions(
|
||||
&mut self,
|
||||
from: &[ScopedDefinitionId],
|
||||
replacement: ScopedDefinitionId,
|
||||
narrowing_constraints: &mut NarrowingConstraintsBuilder,
|
||||
reachability_constraints: &mut ReachabilityConstraintsBuilder,
|
||||
) {
|
||||
if from.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut changed = false;
|
||||
for live_binding in &mut self.live_bindings {
|
||||
if from.contains(&live_binding.binding) {
|
||||
live_binding.binding = replacement;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return;
|
||||
}
|
||||
|
||||
self.live_bindings
|
||||
.sort_by(|left, right| left.binding.cmp(&right.binding));
|
||||
|
||||
let mut merged: SmallVec<[LiveBinding; 2]> = SmallVec::new();
|
||||
for binding in std::mem::take(&mut self.live_bindings) {
|
||||
match merged.last_mut() {
|
||||
Some(last) if last.binding == binding.binding => {
|
||||
last.narrowing_constraint = narrowing_constraints.intersect_constraints(
|
||||
last.narrowing_constraint,
|
||||
binding.narrowing_constraint,
|
||||
);
|
||||
last.reachability_constraint = reachability_constraints.add_or_constraint(
|
||||
last.reachability_constraint,
|
||||
binding.reachability_constraint,
|
||||
);
|
||||
}
|
||||
_ => merged.push(binding),
|
||||
}
|
||||
}
|
||||
|
||||
self.live_bindings = merged;
|
||||
}
|
||||
}
|
||||
|
||||
/// One of the live bindings for a single place at some point in control flow.
|
||||
@@ -291,6 +388,26 @@ impl Bindings {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn add_narrowing_constraint_from(
|
||||
&mut self,
|
||||
narrowing_constraints: &mut NarrowingConstraintsBuilder,
|
||||
constraint: ScopedNarrowingConstraint,
|
||||
) {
|
||||
let predicates: Vec<_> = narrowing_constraints.iter_predicates(constraint).collect();
|
||||
let mut unbound_constraint = self.unbound_narrowing_constraint;
|
||||
for predicate in predicates {
|
||||
if let Some(existing) = unbound_constraint {
|
||||
unbound_constraint =
|
||||
Some(narrowing_constraints.add_predicate_to_constraint(existing, predicate));
|
||||
}
|
||||
for binding in &mut self.live_bindings {
|
||||
binding.narrowing_constraint = narrowing_constraints
|
||||
.add_predicate_to_constraint(binding.narrowing_constraint, predicate);
|
||||
}
|
||||
}
|
||||
self.unbound_narrowing_constraint = unbound_constraint;
|
||||
}
|
||||
|
||||
/// Add given reachability constraint to all live bindings.
|
||||
pub(super) fn record_reachability_constraint(
|
||||
&mut self,
|
||||
|
||||
@@ -18,6 +18,8 @@ use ruff_db::parsed::parsed_module;
|
||||
use ruff_python_ast as ast;
|
||||
use ruff_python_ast::name::Name;
|
||||
use ruff_text_size::{Ranged, TextRange};
|
||||
use rustc_hash::FxHashMap;
|
||||
use salsa::plumbing::{AsId, Id};
|
||||
use smallvec::{SmallVec, smallvec};
|
||||
use ty_module_resolver::{KnownModule, Module, ModuleName, resolve_module};
|
||||
|
||||
@@ -158,8 +160,59 @@ pub fn check_types(db: &dyn Db, file: File) -> Vec<Diagnostic> {
|
||||
diagnostics
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static LOOP_HEADER_OVERRIDE: RefCell<FxHashMap<Id, Type<'static>>> =
|
||||
RefCell::new(FxHashMap::default());
|
||||
}
|
||||
|
||||
fn loop_header_override<'db>(definition: Definition<'db>) -> Option<Type<'db>> {
|
||||
let id = definition.as_id();
|
||||
LOOP_HEADER_OVERRIDE.with(|cell| cell.borrow().get(&id).copied().map(restore_type_lifetime))
|
||||
}
|
||||
|
||||
pub(crate) fn with_loop_header_override<'db, R>(
|
||||
definition: Definition<'db>,
|
||||
ty: Type<'db>,
|
||||
f: impl FnOnce() -> R,
|
||||
) -> R {
|
||||
let id = definition.as_id();
|
||||
LOOP_HEADER_OVERRIDE.with(|cell| {
|
||||
let mut overrides = cell.borrow_mut();
|
||||
let previous = overrides.insert(id, erase_type_lifetime(ty));
|
||||
drop(overrides);
|
||||
|
||||
let result = f();
|
||||
|
||||
let mut overrides = cell.borrow_mut();
|
||||
match previous {
|
||||
Some(previous) => {
|
||||
overrides.insert(id, previous);
|
||||
}
|
||||
None => {
|
||||
overrides.remove(&id);
|
||||
}
|
||||
}
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
fn erase_type_lifetime<'db>(ty: Type<'db>) -> Type<'static> {
|
||||
// SAFETY: `Type` is a copyable, db-backed handle; we only use this within
|
||||
// a single thread to provide a temporary loop-header override.
|
||||
unsafe { std::mem::transmute::<Type<'db>, Type<'static>>(ty) }
|
||||
}
|
||||
|
||||
fn restore_type_lifetime<'db>(ty: Type<'static>) -> Type<'db> {
|
||||
// SAFETY: This is the inverse of `erase_type_lifetime` and is only used
|
||||
// within the dynamic scope of a loop-header override.
|
||||
unsafe { std::mem::transmute::<Type<'static>, Type<'db>>(ty) }
|
||||
}
|
||||
|
||||
/// Infer the type of a binding.
|
||||
pub(crate) fn binding_type<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> {
|
||||
if let Some(override_type) = loop_header_override(definition) {
|
||||
return override_type;
|
||||
}
|
||||
let inference = infer_definition_types(db, definition);
|
||||
inference.binding_type(definition)
|
||||
}
|
||||
@@ -11792,7 +11845,9 @@ impl<'db> UnionType<'db> {
|
||||
elements
|
||||
.into_iter()
|
||||
.fold(
|
||||
UnionBuilder::new(db).cycle_recovery(true),
|
||||
UnionBuilder::new(db)
|
||||
.cycle_recovery(true)
|
||||
.recursively_defined(RecursivelyDefined::Yes),
|
||||
|builder, element| builder.add(element.into()),
|
||||
)
|
||||
.build()
|
||||
|
||||
@@ -1377,7 +1377,8 @@ mod resolve_definition {
|
||||
| DefinitionKind::ExceptHandler(_)
|
||||
| DefinitionKind::TypeVar(_)
|
||||
| DefinitionKind::ParamSpec(_)
|
||||
| DefinitionKind::TypeVarTuple(_) => {
|
||||
| DefinitionKind::TypeVarTuple(_)
|
||||
| DefinitionKind::LoopHeader(_) => {
|
||||
// Not yet implemented
|
||||
return Err(());
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ use crate::semantic_index::ast_ids::{HasScopedUseId, ScopedUseId};
|
||||
use crate::semantic_index::definition::{
|
||||
AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind,
|
||||
Definition, DefinitionKind, DefinitionNodeKey, DefinitionState, ExceptHandlerDefinitionKind,
|
||||
ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind,
|
||||
ForStmtDefinitionKind, LoopHeaderDefinitionKind, TargetKind, WithItemDefinitionKind,
|
||||
};
|
||||
use crate::semantic_index::expression::{Expression, ExpressionKind};
|
||||
use crate::semantic_index::narrowing_constraints::ConstraintKey;
|
||||
@@ -121,7 +121,7 @@ use crate::types::{
|
||||
TypeQualifiers, TypeVarBoundOrConstraints, TypeVarBoundOrConstraintsEvaluation,
|
||||
TypeVarDefaultEvaluation, TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance,
|
||||
TypedDictType, UnionBuilder, UnionType, UnionTypeInstance, binding_type, infer_scope_types,
|
||||
todo_type,
|
||||
todo_type, with_loop_header_override,
|
||||
};
|
||||
use crate::types::{CallableTypes, overrides};
|
||||
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
|
||||
@@ -1501,6 +1501,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
DefinitionKind::TypeVarTuple(node) => {
|
||||
self.infer_typevartuple_definition(node.node(self.module()), definition);
|
||||
}
|
||||
DefinitionKind::LoopHeader(loop_header) => {
|
||||
self.infer_loop_header_definition(&loop_header, definition);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1530,6 +1533,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
DefinitionKind::Assignment(assignment) => {
|
||||
self.infer_assignment_deferred(assignment.value(self.module()));
|
||||
}
|
||||
DefinitionKind::LoopHeader(_) => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@@ -5384,6 +5388,66 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||
add.insert(self, target_ty);
|
||||
}
|
||||
|
||||
fn infer_loop_header_definition(
|
||||
&mut self,
|
||||
loop_header: &LoopHeaderDefinitionKind<'db>,
|
||||
definition: Definition<'db>,
|
||||
) {
|
||||
const MAX_LOOP_HEADER_ITERATIONS: usize = 8;
|
||||
let max_iterations = MAX_LOOP_HEADER_ITERATIONS;
|
||||
let should_widen = true;
|
||||
|
||||
let add = self.add_binding(loop_header.node(self.module()).into(), definition);
|
||||
let mut seed_union = UnionBuilder::new(self.db());
|
||||
for seed_definition in loop_header.seed_definitions() {
|
||||
seed_union = seed_union.add(binding_type(self.db(), *seed_definition));
|
||||
}
|
||||
let mut current = if seed_union.is_empty() {
|
||||
Type::unknown()
|
||||
} else {
|
||||
seed_union.build()
|
||||
};
|
||||
|
||||
let mut changed = false;
|
||||
for _ in 0..max_iterations {
|
||||
let previous_multi = self.multi_inference_state;
|
||||
self.multi_inference_state = MultiInferenceState::Overwrite;
|
||||
let next = with_loop_header_override(definition, current, || {
|
||||
let mut union = UnionBuilder::new(self.db());
|
||||
for definition in loop_header.definitions() {
|
||||
let inferred = match definition.kind(self.db()) {
|
||||
DefinitionKind::Assignment(assignment) => self
|
||||
.infer_assignment_definition_impl(
|
||||
&assignment,
|
||||
*definition,
|
||||
TypeContext::default(),
|
||||
),
|
||||
DefinitionKind::AugmentedAssignment(augmented_assignment) => {
|
||||
self.infer_augment_assignment(augmented_assignment.node(self.module()))
|
||||
}
|
||||
_ => binding_type(self.db(), *definition),
|
||||
};
|
||||
union = union.add(inferred);
|
||||
}
|
||||
union.build()
|
||||
});
|
||||
self.multi_inference_state = previous_multi;
|
||||
|
||||
if next == current {
|
||||
changed = false;
|
||||
break;
|
||||
}
|
||||
changed = true;
|
||||
current = next;
|
||||
}
|
||||
|
||||
if should_widen && changed {
|
||||
current = KnownClass::Int.to_instance(self.db());
|
||||
}
|
||||
|
||||
add.insert(self, current);
|
||||
}
|
||||
|
||||
fn infer_assignment_definition_impl(
|
||||
&mut self,
|
||||
assignment: &AssignmentDefinitionKind<'db>,
|
||||
|
||||
@@ -188,13 +188,9 @@ impl<'db> CallableSignature<'db> {
|
||||
{
|
||||
Some(CallableSignature::from_overloads(
|
||||
callable.signatures(db).iter().map(|signature| Signature {
|
||||
generic_context: GenericContext::merge_optional(
|
||||
db,
|
||||
signature.generic_context,
|
||||
self_signature.generic_context.map(|context| {
|
||||
type_mapping.update_signature_generic_context(db, context)
|
||||
}),
|
||||
),
|
||||
generic_context: self_signature.generic_context.map(|context| {
|
||||
type_mapping.update_signature_generic_context(db, context)
|
||||
}),
|
||||
definition: signature.definition,
|
||||
parameters: if signature.parameters().is_top() {
|
||||
signature.parameters().clone()
|
||||
@@ -418,11 +414,7 @@ impl<'db> CallableSignature<'db> {
|
||||
db,
|
||||
CallableSignature::from_overloads(other_signatures.iter().map(
|
||||
|signature| {
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
},
|
||||
)),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
@@ -454,11 +446,7 @@ impl<'db> CallableSignature<'db> {
|
||||
db,
|
||||
CallableSignature::from_overloads(self_signatures.iter().map(
|
||||
|signature| {
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
},
|
||||
)),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
@@ -1095,11 +1083,7 @@ impl<'db> Signature<'db> {
|
||||
let upper = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::from_overloads(other.overloads.iter().map(|signature| {
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
})),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
));
|
||||
@@ -1355,8 +1339,7 @@ impl<'db> Signature<'db> {
|
||||
(Some(self_bound_typevar), None) => {
|
||||
let upper = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::single(Signature::new_generic(
|
||||
other.generic_context,
|
||||
CallableSignature::single(Signature::new(
|
||||
other.parameters.clone(),
|
||||
Type::unknown(),
|
||||
)),
|
||||
@@ -1375,8 +1358,7 @@ impl<'db> Signature<'db> {
|
||||
(None, Some(other_bound_typevar)) => {
|
||||
let lower = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::single(Signature::new_generic(
|
||||
self.generic_context,
|
||||
CallableSignature::single(Signature::new(
|
||||
self.parameters.clone(),
|
||||
Type::unknown(),
|
||||
)),
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use self::schedule::spawn_main_loop;
|
||||
use crate::PositionEncoding;
|
||||
use crate::capabilities::{ResolvedClientCapabilities, server_capabilities};
|
||||
use crate::session::{ClientName, InitializationOptions, Session, warn_about_unknown_options};
|
||||
use crate::session::{InitializationOptions, Session, warn_about_unknown_options};
|
||||
use anyhow::Context;
|
||||
use lsp_server::Connection;
|
||||
use lsp_types::{ClientCapabilities, InitializeParams, MessageType, Url};
|
||||
@@ -47,7 +47,6 @@ impl Server {
|
||||
initialization_options,
|
||||
capabilities: client_capabilities,
|
||||
workspace_folders,
|
||||
client_info,
|
||||
..
|
||||
} = serde_json::from_value(init_value)
|
||||
.context("Failed to deserialize initialization parameters")?;
|
||||
@@ -66,7 +65,6 @@ impl Server {
|
||||
tracing::error!("Failed to deserialize initialization options: {error}");
|
||||
}
|
||||
|
||||
tracing::debug!("Client info: {client_info:#?}");
|
||||
tracing::debug!("Initialization options: {initialization_options:#?}");
|
||||
|
||||
let resolved_client_capabilities = ResolvedClientCapabilities::new(&client_capabilities);
|
||||
@@ -157,7 +155,6 @@ impl Server {
|
||||
workspace_urls,
|
||||
initialization_options,
|
||||
native_system,
|
||||
ClientName::from(client_info),
|
||||
in_test,
|
||||
)?,
|
||||
})
|
||||
|
||||
@@ -119,12 +119,9 @@ pub(super) fn request(req: server::Request) -> Task {
|
||||
.unwrap_or_else(|err| {
|
||||
tracing::error!("Encountered error when routing request with ID {id}: {err}");
|
||||
|
||||
Task::sync(move |session, client| {
|
||||
Task::sync(move |_session, client| {
|
||||
if matches!(err.code, ErrorCode::InternalError) {
|
||||
client.show_error_message(format!(
|
||||
"ty failed to handle a request from the editor. {}",
|
||||
session.client_name().log_guidance()
|
||||
));
|
||||
client.show_error_message("ty failed to handle a request from the editor. Check the logs for more details.");
|
||||
}
|
||||
|
||||
respond_silent_error(
|
||||
@@ -178,12 +175,11 @@ pub(super) fn notification(notif: server::Notification) -> Task {
|
||||
}
|
||||
.unwrap_or_else(|err| {
|
||||
tracing::error!("Encountered error when routing notification: {err}");
|
||||
Task::sync(move |session, client| {
|
||||
Task::sync(move |_session, client| {
|
||||
if matches!(err.code, ErrorCode::InternalError) {
|
||||
client.show_error_message(format!(
|
||||
"ty failed to handle a notification from the editor. {}",
|
||||
session.client_name().log_guidance()
|
||||
));
|
||||
client.show_error_message(
|
||||
"ty failed to handle a notification from the editor. Check the logs for more details."
|
||||
);
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -197,7 +193,7 @@ where
|
||||
Ok(Task::sync(move |session, client: &Client| {
|
||||
let _span = tracing::debug_span!("request", %id, method = R::METHOD).entered();
|
||||
let result = R::run(session, client, params);
|
||||
respond::<R>(&id, result, client, session.client_name().log_guidance());
|
||||
respond::<R>(&id, result, client);
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -221,7 +217,6 @@ where
|
||||
// SAFETY: The `snapshot` is safe to move across the unwind boundary because it is not used
|
||||
// after unwinding.
|
||||
let snapshot = AssertUnwindSafe(session.snapshot_session());
|
||||
let log_guidance = snapshot.0.client_name().log_guidance();
|
||||
|
||||
Box::new(move |client| {
|
||||
let _span = tracing::debug_span!("request", %id, method = R::METHOD).entered();
|
||||
@@ -243,7 +238,7 @@ where
|
||||
let snapshot = snapshot;
|
||||
R::handle_request(&id, snapshot.0, client, params);
|
||||
}) {
|
||||
panic_response::<R>(&id, client, &error, retry, log_guidance);
|
||||
panic_response::<R>(&id, client, &error, retry);
|
||||
}
|
||||
})
|
||||
}))
|
||||
@@ -289,7 +284,6 @@ where
|
||||
|
||||
let path = document.notebook_or_file_path();
|
||||
let db = session.project_db(path).clone();
|
||||
let log_guidance = document.client_name().log_guidance();
|
||||
|
||||
Box::new(move |client| {
|
||||
let _span = tracing::debug_span!("request", %id, method = R::METHOD).entered();
|
||||
@@ -312,7 +306,7 @@ where
|
||||
R::handle_request(&id, &db, document, client, params);
|
||||
});
|
||||
}) {
|
||||
panic_response::<R>(&id, client, &error, retry, log_guidance);
|
||||
panic_response::<R>(&id, client, &error, retry);
|
||||
}
|
||||
})
|
||||
}))
|
||||
@@ -323,7 +317,6 @@ fn panic_response<R>(
|
||||
client: &Client,
|
||||
error: &PanicError,
|
||||
request: Option<lsp_server::Request>,
|
||||
log_guidance: &str,
|
||||
) where
|
||||
R: traits::RetriableRequestHandler,
|
||||
{
|
||||
@@ -353,7 +346,6 @@ fn panic_response<R>(
|
||||
error: anyhow!("request handler {error}"),
|
||||
}),
|
||||
client,
|
||||
log_guidance,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -366,10 +358,7 @@ fn sync_notification_task<N: traits::SyncNotificationHandler>(
|
||||
let _span = tracing::debug_span!("notification", method = N::METHOD).entered();
|
||||
if let Err(err) = N::run(session, client, params) {
|
||||
tracing::error!("An error occurred while running {id}: {err}");
|
||||
client.show_error_message(format!(
|
||||
"ty encountered a problem. {}",
|
||||
session.client_name().log_guidance()
|
||||
));
|
||||
client.show_error_message("ty encountered a problem. Check the logs for more details.");
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -401,8 +390,6 @@ where
|
||||
return Box::new(|_| {});
|
||||
};
|
||||
|
||||
let log_guidance = snapshot.client_name().log_guidance();
|
||||
|
||||
Box::new(move |client| {
|
||||
let _span = tracing::debug_span!("notification", method = N::METHOD).entered();
|
||||
|
||||
@@ -412,14 +399,18 @@ where
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
tracing::error!("An error occurred while running {id}: {panic}");
|
||||
client.show_error_message(format!("ty encountered a panic. {log_guidance}"));
|
||||
client.show_error_message(
|
||||
"ty encountered a panic. Check the logs for more details.",
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = result {
|
||||
tracing::error!("An error occurred while running {id}: {err}");
|
||||
client.show_error_message(format!("ty encountered a problem. {log_guidance}"));
|
||||
client.show_error_message(
|
||||
"ty encountered a problem. Check the logs for more details.",
|
||||
);
|
||||
}
|
||||
})
|
||||
}))
|
||||
@@ -458,13 +449,12 @@ fn respond<Req>(
|
||||
id: &RequestId,
|
||||
result: Result<<<Req as RequestHandler>::RequestType as Request>::Result>,
|
||||
client: &Client,
|
||||
log_guidance: &str,
|
||||
) where
|
||||
Req: RequestHandler,
|
||||
{
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("An error occurred with request ID {id}: {err}");
|
||||
client.show_error_message(format!("ty encountered a problem. {log_guidance}"));
|
||||
client.show_error_message("ty encountered a problem. Check the logs for more details.");
|
||||
}
|
||||
client.respond(id, result);
|
||||
}
|
||||
|
||||
@@ -116,10 +116,7 @@ pub(super) trait BackgroundDocumentRequestHandler: RetriableRequestHandler {
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("An error occurred with request ID {id}: {err}");
|
||||
client.show_error_message(format!(
|
||||
"ty encountered a problem. {}",
|
||||
snapshot.client_name().log_guidance()
|
||||
));
|
||||
client.show_error_message("ty encountered a problem. Check the logs for more details.");
|
||||
}
|
||||
|
||||
client.respond(id, result);
|
||||
@@ -156,10 +153,7 @@ pub(super) trait BackgroundRequestHandler: RetriableRequestHandler {
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("An error occurred with request ID {id}: {err}");
|
||||
client.show_error_message(format!(
|
||||
"ty encountered a problem. {}",
|
||||
snapshot.client_name().log_guidance()
|
||||
));
|
||||
client.show_error_message("ty encountered a problem. Check the logs for more details.");
|
||||
}
|
||||
|
||||
client.respond(id, result);
|
||||
|
||||
@@ -13,7 +13,7 @@ use lsp_types::request::{
|
||||
WorkspaceDiagnosticRequest,
|
||||
};
|
||||
use lsp_types::{
|
||||
ClientInfo, DiagnosticRegistrationOptions, DiagnosticServerCapabilities,
|
||||
DiagnosticRegistrationOptions, DiagnosticServerCapabilities,
|
||||
DidChangeWatchedFilesRegistrationOptions, FileSystemWatcher, Registration, RegistrationParams,
|
||||
TextDocumentContentChangeEvent, Unregistration, UnregistrationParams, Url,
|
||||
};
|
||||
@@ -106,9 +106,6 @@ pub(crate) struct Session {
|
||||
/// Registrations is a set of LSP methods that have been dynamically registered with the
|
||||
/// client.
|
||||
registrations: HashSet<String>,
|
||||
|
||||
/// The name of the client (editor) that connected to this server.
|
||||
client_name: ClientName,
|
||||
}
|
||||
|
||||
/// LSP State for a Project
|
||||
@@ -144,7 +141,6 @@ impl Session {
|
||||
workspace_urls: Vec<Url>,
|
||||
initialization_options: InitializationOptions,
|
||||
native_system: Arc<dyn System + 'static + Send + Sync + RefUnwindSafe>,
|
||||
client_name: ClientName,
|
||||
in_test: bool,
|
||||
) -> crate::Result<Self> {
|
||||
let index = Arc::new(Index::new());
|
||||
@@ -172,7 +168,6 @@ impl Session {
|
||||
suspended_workspace_diagnostics_request: None,
|
||||
revision: 0,
|
||||
registrations: HashSet::new(),
|
||||
client_name,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -537,8 +532,8 @@ impl Session {
|
||||
);
|
||||
|
||||
client.show_error_message(format!(
|
||||
"Failed to load project for workspace {url}. {}",
|
||||
self.client_name.log_guidance(),
|
||||
"Failed to load project for workspace {url}. \
|
||||
Please refer to the logs for more details.",
|
||||
));
|
||||
|
||||
let db_with_default_settings = ProjectMetadata::from_options(
|
||||
@@ -824,7 +819,6 @@ impl Session {
|
||||
.unwrap_or_else(|| Arc::new(WorkspaceSettings::default())),
|
||||
position_encoding: self.position_encoding,
|
||||
document: document_handle,
|
||||
client_name: self.client_name,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -843,7 +837,6 @@ impl Session {
|
||||
in_test: self.in_test,
|
||||
resolved_client_capabilities: self.resolved_client_capabilities,
|
||||
revision: self.revision,
|
||||
client_name: self.client_name,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -983,10 +976,6 @@ impl Session {
|
||||
pub(crate) fn position_encoding(&self) -> PositionEncoding {
|
||||
self.position_encoding
|
||||
}
|
||||
|
||||
pub(crate) fn client_name(&self) -> ClientName {
|
||||
self.client_name
|
||||
}
|
||||
}
|
||||
|
||||
/// A guard that holds the only reference to the index and allows modifying it.
|
||||
@@ -1036,7 +1025,6 @@ pub(crate) struct DocumentSnapshot {
|
||||
workspace_settings: Arc<WorkspaceSettings>,
|
||||
position_encoding: PositionEncoding,
|
||||
document: DocumentHandle,
|
||||
client_name: ClientName,
|
||||
}
|
||||
|
||||
impl DocumentSnapshot {
|
||||
@@ -1083,10 +1071,6 @@ impl DocumentSnapshot {
|
||||
pub(crate) fn notebook_or_file_path(&self) -> &AnySystemPath {
|
||||
self.document.notebook_or_file_path()
|
||||
}
|
||||
|
||||
pub(crate) fn client_name(&self) -> ClientName {
|
||||
self.client_name
|
||||
}
|
||||
}
|
||||
|
||||
/// An immutable snapshot of the current state of [`Session`].
|
||||
@@ -1097,7 +1081,6 @@ pub(crate) struct SessionSnapshot {
|
||||
resolved_client_capabilities: ResolvedClientCapabilities,
|
||||
in_test: bool,
|
||||
revision: u64,
|
||||
client_name: ClientName,
|
||||
|
||||
/// IMPORTANT: It's important that the databases come last, or at least,
|
||||
/// after any `Arc` that we try to extract or mutate in-place using `Arc::into_inner`
|
||||
@@ -1139,42 +1122,6 @@ impl SessionSnapshot {
|
||||
pub(crate) fn revision(&self) -> u64 {
|
||||
self.revision
|
||||
}
|
||||
|
||||
pub(crate) fn client_name(&self) -> ClientName {
|
||||
self.client_name
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the client (editor) that's connected to the language server.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) enum ClientName {
|
||||
Zed,
|
||||
Other,
|
||||
}
|
||||
|
||||
impl From<Option<ClientInfo>> for ClientName {
|
||||
fn from(info: Option<ClientInfo>) -> Self {
|
||||
match info {
|
||||
Some(info) if matches!(info.name.as_str(), "Zed") => ClientName::Zed,
|
||||
_ => ClientName::Other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientName {
|
||||
/// Returns editor-specific guidance for finding logs.
|
||||
///
|
||||
/// Different editors have different ways to access language server logs, so we provide tailored
|
||||
/// instructions based on the connected client.
|
||||
pub(crate) fn log_guidance(self) -> &'static str {
|
||||
match self {
|
||||
ClientName::Zed => {
|
||||
"Please refer to the logs for more details \
|
||||
(command palette: `dev: open language server logs`)."
|
||||
}
|
||||
ClientName::Other => "Please refer to the logs for more details.",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
|
||||
Reference in New Issue
Block a user