diff --git a/crates/ty_ide/src/goto.rs b/crates/ty_ide/src/goto.rs index 59b4c50d57..5d9acb98d1 100644 --- a/crates/ty_ide/src/goto.rs +++ b/crates/ty_ide/src/goto.rs @@ -10,6 +10,7 @@ use ruff_db::parsed::ParsedModuleRef; use ruff_python_ast::{self as ast, AnyNodeRef}; use ruff_python_parser::TokenKind; use ruff_text_size::{Ranged, TextRange, TextSize}; +use ty_python_semantic::ImportAliasResolution; use ty_python_semantic::types::Type; use ty_python_semantic::types::definitions_for_keyword_argument; use ty_python_semantic::{ @@ -172,12 +173,16 @@ impl GotoTarget<'_> { /// Gets the navigation ranges for this goto target. /// If a stub mapper is provided, definitions from stub files will be mapped to - /// their corresponding source file implementations. + /// their corresponding source file implementations. The `alias_resolution` + /// parameter controls whether import aliases (i.e. "x" in "from a import b as x") are + /// resolved or returned as is. We want to resolve them in some cases (like + /// "goto declaration") but not in others (like find references or rename). pub(crate) fn get_definition_targets( &self, file: ruff_db::files::File, db: &dyn crate::Db, stub_mapper: Option<&StubMapper>, + alias_resolution: ImportAliasResolution, ) -> Option { use crate::NavigationTarget; use ruff_python_ast as ast; @@ -229,10 +234,14 @@ impl GotoTarget<'_> { GotoTarget::ImportSymbolAlias { alias, import_from, .. } => { - // Handle both original names and alias names in `from x import y as z` statements let symbol_name = alias.name.as_str(); - let definitions = - definitions_for_imported_symbol(db, file, import_from, symbol_name); + let definitions = definitions_for_imported_symbol( + db, + file, + import_from, + symbol_name, + alias_resolution, + ); definitions_to_navigation_targets(db, stub_mapper, definitions) } @@ -254,12 +263,18 @@ impl GotoTarget<'_> { // Handle import aliases (offset within 'z' in "import x.y as z") GotoTarget::ImportModuleAlias { alias } => { - // For import aliases, navigate to the module being aliased - // This only applies to regular import statements like "import x.y as z" - let full_module_name = alias.name.as_str(); - - // Try to resolve the module - resolve_module_to_navigation_target(db, full_module_name) + if alias_resolution == ImportAliasResolution::ResolveAliases { + let full_module_name = alias.name.as_str(); + // Try to resolve the module + resolve_module_to_navigation_target(db, full_module_name) + } else { + let alias_range = alias.asname.as_ref().unwrap().range; + Some(crate::NavigationTargets::single(NavigationTarget { + file, + focus_range: alias_range, + full_range: alias.range(), + })) + } } // Handle keyword arguments in call expressions diff --git a/crates/ty_ide/src/goto_declaration.rs b/crates/ty_ide/src/goto_declaration.rs index 9f916ccbc4..be7aea6022 100644 --- a/crates/ty_ide/src/goto_declaration.rs +++ b/crates/ty_ide/src/goto_declaration.rs @@ -3,6 +3,7 @@ use crate::{Db, NavigationTargets, RangedValue}; use ruff_db::files::{File, FileRange}; use ruff_db::parsed::parsed_module; use ruff_text_size::{Ranged, TextSize}; +use ty_python_semantic::ImportAliasResolution; /// Navigate to the declaration of a symbol. /// @@ -17,7 +18,12 @@ pub fn goto_declaration( let module = parsed_module(db, file).load(db); let goto_target = find_goto_target(&module, offset)?; - let declaration_targets = goto_target.get_definition_targets(file, db, None)?; + let declaration_targets = goto_target.get_definition_targets( + file, + db, + None, + ImportAliasResolution::ResolveAliases, + )?; Some(RangedValue { range: FileRange::new(file, goto_target.range()), diff --git a/crates/ty_ide/src/goto_definition.rs b/crates/ty_ide/src/goto_definition.rs index 49b462e13a..9491fc1c3a 100644 --- a/crates/ty_ide/src/goto_definition.rs +++ b/crates/ty_ide/src/goto_definition.rs @@ -4,6 +4,7 @@ use crate::{Db, NavigationTargets, RangedValue}; use ruff_db::files::{File, FileRange}; use ruff_db::parsed::parsed_module; use ruff_text_size::{Ranged, TextSize}; +use ty_python_semantic::ImportAliasResolution; /// Navigate to the definition of a symbol. /// @@ -22,7 +23,12 @@ pub fn goto_definition( // Create a StubMapper to map from stub files to source files let stub_mapper = StubMapper::new(db); - let definition_targets = goto_target.get_definition_targets(file, db, Some(&stub_mapper))?; + let definition_targets = goto_target.get_definition_targets( + file, + db, + Some(&stub_mapper), + ImportAliasResolution::ResolveAliases, + )?; Some(RangedValue { range: FileRange::new(file, goto_target.range()), diff --git a/crates/ty_ide/src/goto_references.rs b/crates/ty_ide/src/goto_references.rs index 40d5ea8de9..cacf8cf412 100644 --- a/crates/ty_ide/src/goto_references.rs +++ b/crates/ty_ide/src/goto_references.rs @@ -688,34 +688,30 @@ cls = MyClass .source( "utils.py", " -def helper_function(x): +def func(x): return x * 2 ", ) .source( "module.py", " -from utils import helper_function +from utils import func def process_data(data): - return helper_function(data) - -def double_process(data): - result = helper_function(data) - return helper_function(result) + return func(data) ", ) .source( "app.py", " -from utils import helper_function +from utils import func class DataProcessor: def __init__(self): - self.multiplier = helper_function + self.multiplier = func def process(self, value): - return helper_function(value) + return func(value) ", ) .build(); @@ -724,37 +720,35 @@ class DataProcessor: info[references]: Reference 1 --> utils.py:2:5 | - 2 | def helper_function(x): - | ^^^^^^^^^^^^^^^ + 2 | def func(x): + | ^^^^ 3 | return x * 2 | info[references]: Reference 2 - --> module.py:5:12 + --> module.py:2:19 | + 2 | from utils import func + | ^^^^ + 3 | 4 | def process_data(data): - 5 | return helper_function(data) - | ^^^^^^^^^^^^^^^ - 6 | - 7 | def double_process(data): | info[references]: Reference 3 - --> module.py:8:14 + --> module.py:5:12 | - 7 | def double_process(data): - 8 | result = helper_function(data) - | ^^^^^^^^^^^^^^^ - 9 | return helper_function(result) + 4 | def process_data(data): + 5 | return func(data) + | ^^^^ | info[references]: Reference 4 - --> module.py:9:12 + --> app.py:2:19 | - 7 | def double_process(data): - 8 | result = helper_function(data) - 9 | return helper_function(result) - | ^^^^^^^^^^^^^^^ + 2 | from utils import func + | ^^^^ + 3 | + 4 | class DataProcessor: | info[references]: Reference 5 @@ -762,8 +756,8 @@ class DataProcessor: | 4 | class DataProcessor: 5 | def __init__(self): - 6 | self.multiplier = helper_function - | ^^^^^^^^^^^^^^^ + 6 | self.multiplier = func + | ^^^^ 7 | 8 | def process(self, value): | @@ -772,8 +766,8 @@ class DataProcessor: --> app.py:9:16 | 8 | def process(self, value): - 9 | return helper_function(value) - | ^^^^^^^^^^^^^^^ + 9 | return func(value) + | ^^^^ | "); } @@ -855,4 +849,49 @@ def process_model(): | "); } + + #[test] + fn test_import_alias_references_should_not_resolve_to_original() { + let test = CursorTest::builder() + .source( + "original.py", + " +def func(): + pass + +func() +", + ) + .source( + "importer.py", + " +from original import func as func_alias + +func_alias() +", + ) + .build(); + + // When finding references to the alias, we should NOT find references + // to the original function in the original module + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> importer.py:2:30 + | + 2 | from original import func as func_alias + | ^^^^^^^^^^ + 3 | + 4 | func_alias() + | + + info[references]: Reference 2 + --> importer.py:4:1 + | + 2 | from original import func as func_alias + 3 | + 4 | func_alias() + | ^^^^^^^^^^ + | + "); + } } diff --git a/crates/ty_ide/src/lib.rs b/crates/ty_ide/src/lib.rs index 12d69167a8..71ffe5bc8d 100644 --- a/crates/ty_ide/src/lib.rs +++ b/crates/ty_ide/src/lib.rs @@ -12,6 +12,7 @@ mod hover; mod inlay_hints; mod markup; mod references; +mod rename; mod selection_range; mod semantic_tokens; mod signature_help; @@ -29,6 +30,7 @@ pub use hover::hover; pub use inlay_hints::inlay_hints; pub use markup::MarkupKind; pub use references::ReferencesMode; +pub use rename::{can_rename, rename}; pub use selection_range::selection_range; pub use semantic_tokens::{ SemanticToken, SemanticTokenModifier, SemanticTokenType, SemanticTokens, semantic_tokens, diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index c84bf94a72..13b0150183 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -19,6 +19,7 @@ use ruff_python_ast::{ visitor::source_order::{SourceOrderVisitor, TraversalSignal}, }; use ruff_text_size::{Ranged, TextRange}; +use ty_python_semantic::ImportAliasResolution; /// Mode for references search behavior #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -27,8 +28,10 @@ pub enum ReferencesMode { References, /// Find all references but skip the declaration ReferencesSkipDeclaration, - /// Find references for rename operations (behavior differs for imported symbols) + /// Find references for rename operations, limited to current file only Rename, + /// Find references for multi-file rename operations (searches across all files) + RenameMultiFile, /// Find references for document highlights (limits search to current file) DocumentHighlights, } @@ -42,7 +45,14 @@ pub(crate) fn references( mode: ReferencesMode, ) -> Option> { // Get the definitions for the symbol at the cursor position - let target_definitions_nav = goto_target.get_definition_targets(file, db, None)?; + + // When finding references, do not resolve any local aliases. + let target_definitions_nav = goto_target.get_definition_targets( + file, + db, + None, + ImportAliasResolution::PreserveAliases, + )?; let target_definitions: Vec = target_definitions_nav.into_iter().collect(); // Extract the target text from the goto target for fast comparison @@ -60,7 +70,12 @@ pub(crate) fn references( ); // Check if we should search across files based on the mode - let search_across_files = !matches!(mode, ReferencesMode::DocumentHighlights); + let search_across_files = matches!( + mode, + ReferencesMode::References + | ReferencesMode::ReferencesSkipDeclaration + | ReferencesMode::RenameMultiFile + ); // Check if the symbol is potentially visible outside of this module if search_across_files && is_symbol_externally_visible(goto_target) { @@ -211,6 +226,17 @@ impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> { self.check_identifier_reference(rest_name); } } + AnyNodeRef::Alias(alias) if self.should_include_declaration() => { + // Handle import alias declarations + if let Some(asname) = &alias.asname { + self.check_identifier_reference(asname); + } + // Only check the original name if it matches our target text + // This is for cases where we're renaming the imported symbol name itself + if alias.name.id == self.target_text { + self.check_identifier_reference(&alias.name); + } + } _ => {} } @@ -231,6 +257,7 @@ impl LocalReferencesFinder<'_> { ReferencesMode::References | ReferencesMode::DocumentHighlights | ReferencesMode::Rename + | ReferencesMode::RenameMultiFile ) } @@ -259,21 +286,21 @@ impl LocalReferencesFinder<'_> { let offset = covering_node.node().start(); if let Some(goto_target) = GotoTarget::from_covering_node(covering_node, offset) { - // Use the range of the covering node (the identifier) rather than the goto target - // This ensures we highlight just the identifier, not the entire expression - let range = covering_node.node().range(); - // Get the definitions for this goto target - if let Some(current_definitions_nav) = - goto_target.get_definition_targets(self.file, self.db, None) - { + if let Some(current_definitions_nav) = goto_target.get_definition_targets( + self.file, + self.db, + None, + ImportAliasResolution::PreserveAliases, + ) { let current_definitions: Vec = current_definitions_nav.into_iter().collect(); // Check if any of the current definitions match our target definitions if self.navigation_targets_match(¤t_definitions) { // Determine if this is a read or write reference let kind = self.determine_reference_kind(covering_node); - let target = ReferenceTarget::new(self.file, range, kind); + let target = + ReferenceTarget::new(self.file, covering_node.node().range(), kind); self.references.push(target); } } diff --git a/crates/ty_ide/src/rename.rs b/crates/ty_ide/src/rename.rs new file mode 100644 index 0000000000..58b44578f8 --- /dev/null +++ b/crates/ty_ide/src/rename.rs @@ -0,0 +1,641 @@ +use crate::goto::find_goto_target; +use crate::references::{ReferencesMode, references}; +use crate::{Db, ReferenceTarget}; +use ruff_db::files::File; +use ruff_text_size::{Ranged, TextSize}; +use ty_python_semantic::ImportAliasResolution; + +/// Returns the range of the symbol if it can be renamed, None if not. +pub fn can_rename(db: &dyn Db, file: File, offset: TextSize) -> Option { + let parsed = ruff_db::parsed::parsed_module(db, file); + let module = parsed.load(db); + + // Get the definitions for the symbol at the offset + let goto_target = find_goto_target(&module, offset)?; + + // Don't allow renaming of import module components + if matches!( + goto_target, + crate::goto::GotoTarget::ImportModuleComponent { .. } + ) { + return None; + } + + let current_file_in_project = is_file_in_project(db, file); + + if let Some(definition_targets) = + goto_target.get_definition_targets(file, db, None, ImportAliasResolution::PreserveAliases) + { + for target in &definition_targets { + let target_file = target.file(); + + // If definition is outside the project, refuse rename + if !is_file_in_project(db, target_file) { + return None; + } + + // If current file is not in project and any definition is outside current file, refuse rename + if !current_file_in_project && target_file != file { + return None; + } + } + } else { + // No definition targets found. This happens for keywords, so refuse rename + return None; + } + + Some(goto_target.range()) +} + +/// Perform a rename operation on the symbol at the given position. +/// Returns all locations that need to be updated with the new name. +pub fn rename( + db: &dyn Db, + file: File, + offset: TextSize, + new_name: &str, +) -> Option> { + let parsed = ruff_db::parsed::parsed_module(db, file); + let module = parsed.load(db); + + // Get the definitions for the symbol at the offset + let goto_target = find_goto_target(&module, offset)?; + + // Clients shouldn't call us with an empty new name, but just in case... + if new_name.is_empty() { + return None; + } + + // Determine if we should do a multi-file rename or single-file rename + // based on whether the current file is part of the project + let current_file_in_project = is_file_in_project(db, file); + + // Choose the appropriate rename mode: + // - If current file is in project, do multi-file rename + // - If current file is not in project, limit to single-file rename + let rename_mode = if current_file_in_project { + ReferencesMode::RenameMultiFile + } else { + ReferencesMode::Rename + }; + + // Find all references that need to be renamed + references(db, file, &goto_target, rename_mode) +} + +/// Helper function to check if a file is included in the project. +fn is_file_in_project(db: &dyn Db, file: File) -> bool { + db.project().files(db).contains(&file) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{CursorTest, IntoDiagnostic, cursor_test}; + use insta::assert_snapshot; + use ruff_db::diagnostic::{Annotation, Diagnostic, DiagnosticId, LintName, Severity, Span}; + use ruff_db::files::FileRange; + use ruff_text_size::Ranged; + + impl CursorTest { + fn prepare_rename(&self) -> String { + let Some(range) = can_rename(&self.db, self.cursor.file, self.cursor.offset) else { + return "Cannot rename".to_string(); + }; + + format!("Can rename symbol at range {range:?}") + } + + fn rename(&self, new_name: &str) -> String { + let Some(rename_results) = + rename(&self.db, self.cursor.file, self.cursor.offset, new_name) + else { + return "Cannot rename".to_string(); + }; + + if rename_results.is_empty() { + return "No locations to rename".to_string(); + } + + // Create a single diagnostic with multiple annotations + let rename_diagnostic = RenameResultSet { + locations: rename_results + .into_iter() + .map(|ref_item| FileRange::new(ref_item.file(), ref_item.range())) + .collect(), + }; + + self.render_diagnostics([rename_diagnostic]) + } + } + + struct RenameResultSet { + locations: Vec, + } + + impl IntoDiagnostic for RenameResultSet { + fn into_diagnostic(self) -> Diagnostic { + let mut main = Diagnostic::new( + DiagnosticId::Lint(LintName::of("rename")), + Severity::Info, + format!("Rename symbol (found {} locations)", self.locations.len()), + ); + + // Add the first location as primary annotation (the symbol being renamed) + if let Some(first_location) = self.locations.first() { + main.annotate(Annotation::primary( + Span::from(first_location.file()).with_range(first_location.range()), + )); + + // Add remaining locations as secondary annotations + for location in &self.locations[1..] { + main.annotate(Annotation::secondary( + Span::from(location.file()).with_range(location.range()), + )); + } + } + + main + } + } + + #[test] + fn test_prepare_rename_parameter() { + let test = cursor_test( + " +def func(value: int) -> int: + value *= 2 + return value + +value = 0 +", + ); + + assert_snapshot!(test.prepare_rename(), @"Can rename symbol at range 10..15"); + } + + #[test] + fn test_rename_parameter() { + let test = cursor_test( + " +def func(value: int) -> int: + value *= 2 + return value + +func(value=42) +", + ); + + assert_snapshot!(test.rename("number"), @r" + info[rename]: Rename symbol (found 4 locations) + --> main.py:2:10 + | + 2 | def func(value: int) -> int: + | ^^^^^ + 3 | value *= 2 + | ----- + 4 | return value + | ----- + 5 | + 6 | func(value=42) + | ----- + | + "); + } + + #[test] + fn test_rename_function() { + let test = cursor_test( + " +def func(): + pass + +result1 = func() +x = func +", + ); + + assert_snapshot!(test.rename("calculate"), @r" + info[rename]: Rename symbol (found 3 locations) + --> main.py:2:5 + | + 2 | def func(): + | ^^^^ + 3 | pass + 4 | + 5 | result1 = func() + | ---- + 6 | x = func + | ---- + | + "); + } + + #[test] + fn test_rename_class() { + let test = cursor_test( + " +class MyClass: + def __init__(self): + pass + +obj1 = MyClass() +cls = MyClass +", + ); + + assert_snapshot!(test.rename("MyNewClass"), @r" + info[rename]: Rename symbol (found 3 locations) + --> main.py:2:7 + | + 2 | class MyClass: + | ^^^^^^^ + 3 | def __init__(self): + 4 | pass + 5 | + 6 | obj1 = MyClass() + | ------- + 7 | cls = MyClass + | ------- + | + "); + } + + #[test] + fn test_rename_invalid_name() { + let test = cursor_test( + " +def func(): + pass +", + ); + + assert_snapshot!(test.rename(""), @"Cannot rename"); + assert_snapshot!(test.rename("valid_name"), @r" + info[rename]: Rename symbol (found 1 locations) + --> main.py:2:5 + | + 2 | def func(): + | ^^^^ + 3 | pass + | + "); + } + + #[test] + fn test_multi_file_function_rename() { + let test = CursorTest::builder() + .source( + "utils.py", + " +def func(x): + return x * 2 +", + ) + .source( + "module.py", + " +from utils import func + +def test(data): + return func(data) +", + ) + .source( + "app.py", + " +from utils import helper_function + +class DataProcessor: + def __init__(self): + self.multiplier = helper_function + + def process(self, value): + return helper_function(value) +", + ) + .build(); + + assert_snapshot!(test.rename("utility_function"), @r" + info[rename]: Rename symbol (found 3 locations) + --> utils.py:2:5 + | + 2 | def func(x): + | ^^^^ + 3 | return x * 2 + | + ::: module.py:2:19 + | + 2 | from utils import func + | ---- + 3 | + 4 | def test(data): + 5 | return func(data) + | ---- + | + "); + } + + #[test] + fn test_cannot_rename_import_module_component() { + // Test that we cannot rename parts of module names in import statements + let test = cursor_test( + " +import os.path +x = os.path.join('a', 'b') +", + ); + + assert_snapshot!(test.prepare_rename(), @"Cannot rename"); + } + + #[test] + fn test_cannot_rename_from_import_module_component() { + // Test that we cannot rename parts of module names in from import statements + let test = cursor_test( + " +from os.path import join +result = join('a', 'b') +", + ); + + assert_snapshot!(test.prepare_rename(), @"Cannot rename"); + } + + #[test] + fn test_cannot_rename_external_file() { + // This test verifies that we cannot rename a symbol when it's defined in a file + // that's outside the project (like a standard library function) + let test = cursor_test( + " +import os +x = os.path.join('a', 'b') +", + ); + + assert_snapshot!(test.prepare_rename(), @"Cannot rename"); + } + + #[test] + fn test_rename_alias_at_import_statement() { + let test = CursorTest::builder() + .source( + "utils.py", + " +def test(): pass +", + ) + .source( + "main.py", + " +from utils import test as test_alias +result = test_alias() +", + ) + .build(); + + assert_snapshot!(test.rename("new_alias"), @r" + info[rename]: Rename symbol (found 2 locations) + --> main.py:2:27 + | + 2 | from utils import test as test_alias + | ^^^^^^^^^^ + 3 | result = test_alias() + | ---------- + | + "); + } + + #[test] + fn test_rename_alias_at_usage_site() { + // Test renaming an alias when the cursor is on the alias in the usage statement + let test = CursorTest::builder() + .source( + "utils.py", + " +def test(): pass +", + ) + .source( + "main.py", + " +from utils import test as test_alias +result = test_alias() +", + ) + .build(); + + assert_snapshot!(test.rename("new_alias"), @r" + info[rename]: Rename symbol (found 2 locations) + --> main.py:2:27 + | + 2 | from utils import test as test_alias + | ^^^^^^^^^^ + 3 | result = test_alias() + | ---------- + | + "); + } + + #[test] + fn test_rename_across_import_chain_with_mixed_aliases() { + // Test renaming a symbol that's imported across multiple files with mixed alias patterns + // File 1 (source.py): defines the original function + // File 2 (middle.py): imports without alias from source.py + // File 3 (consumer.py): imports with alias from middle.py + let test = CursorTest::builder() + .source( + "source.py", + " +def original_function(): + return 'Hello from source' +", + ) + .source( + "middle.py", + " +from source import original_function + +def wrapper(): + return original_function() + +result = original_function() +", + ) + .source( + "consumer.py", + " +from middle import original_function as func_alias + +def process(): + return func_alias() + +value1 = func_alias() +", + ) + .build(); + + assert_snapshot!(test.rename("renamed_function"), @r" + info[rename]: Rename symbol (found 5 locations) + --> source.py:2:5 + | + 2 | def original_function(): + | ^^^^^^^^^^^^^^^^^ + 3 | return 'Hello from source' + | + ::: consumer.py:2:20 + | + 2 | from middle import original_function as func_alias + | ----------------- + 3 | + 4 | def process(): + | + ::: middle.py:2:20 + | + 2 | from source import original_function + | ----------------- + 3 | + 4 | def wrapper(): + 5 | return original_function() + | ----------------- + 6 | + 7 | result = original_function() + | ----------------- + | + "); + } + + #[test] + fn test_rename_alias_in_import_chain() { + let test = CursorTest::builder() + .source( + "file1.py", + " +def func1(): pass +", + ) + .source( + "file2.py", + " +from file1 import func1 as func2 + +func2() +", + ) + .source( + "file3.py", + " +from file2 import func2 + +class App: + def run(self): + return func2() +", + ) + .build(); + + assert_snapshot!(test.rename("new_util_name"), @r" + info[rename]: Rename symbol (found 4 locations) + --> file3.py:2:19 + | + 2 | from file2 import func2 + | ^^^^^ + 3 | + 4 | class App: + 5 | def run(self): + 6 | return func2() + | ----- + | + ::: file2.py:2:28 + | + 2 | from file1 import func1 as func2 + | ----- + 3 | + 4 | func2() + | ----- + | + "); + } + + #[test] + fn test_cannot_rename_keyword() { + // Test that we cannot rename Python keywords like "None" + let test = cursor_test( + " +def process_value(value): + if value is None: + return 'empty' + return str(value) +", + ); + + assert_snapshot!(test.prepare_rename(), @"Cannot rename"); + } + + #[test] + fn test_cannot_rename_builtin_type() { + // Test that we cannot rename Python builtin types like "int" + let test = cursor_test( + " +def convert_to_number(value): + return int(value) +", + ); + + assert_snapshot!(test.prepare_rename(), @"Cannot rename"); + } + + #[test] + fn test_rename_keyword_argument() { + // Test renaming a keyword argument and its corresponding parameter + let test = cursor_test( + " +def func(x, y=5): + return x + y + +result = func(10, y=20) +", + ); + + assert_snapshot!(test.rename("z"), @r" + info[rename]: Rename symbol (found 3 locations) + --> main.py:2:13 + | + 2 | def func(x, y=5): + | ^ + 3 | return x + y + | - + 4 | + 5 | result = func(10, y=20) + | - + | + "); + } + + #[test] + fn test_rename_parameter_with_keyword_argument() { + // Test renaming a parameter and its corresponding keyword argument + let test = cursor_test( + " +def func(x, y=5): + return x + y + +result = func(10, y=20) +", + ); + + assert_snapshot!(test.rename("z"), @r" + info[rename]: Rename symbol (found 3 locations) + --> main.py:2:13 + | + 2 | def func(x, y=5): + | ^ + 3 | return x + y + | - + 4 | + 5 | result = func(10, y=20) + | - + | + "); + } +} diff --git a/crates/ty_python_semantic/src/lib.rs b/crates/ty_python_semantic/src/lib.rs index 5cd424c82a..321d7b99a7 100644 --- a/crates/ty_python_semantic/src/lib.rs +++ b/crates/ty_python_semantic/src/lib.rs @@ -18,8 +18,8 @@ pub use python_platform::PythonPlatform; pub use semantic_model::{Completion, CompletionKind, HasType, NameKind, SemanticModel}; pub use site_packages::{PythonEnvironment, SitePackagesPaths, SysPrefixPathOrigin}; pub use types::ide_support::{ - ResolvedDefinition, definitions_for_attribute, definitions_for_imported_symbol, - definitions_for_name, map_stub_definition, + ImportAliasResolution, ResolvedDefinition, definitions_for_attribute, + definitions_for_imported_symbol, definitions_for_name, map_stub_definition, }; pub use util::diagnostics::add_inferred_python_version_hint_to_diagnostic; diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index 27e1a0fcdb..b88270a43a 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs @@ -20,7 +20,7 @@ use ruff_python_ast::name::Name; use ruff_text_size::{Ranged, TextRange}; use rustc_hash::FxHashSet; -pub use resolve_definition::{ResolvedDefinition, map_stub_definition}; +pub use resolve_definition::{ImportAliasResolution, ResolvedDefinition, map_stub_definition}; use resolve_definition::{find_symbol_in_scope, resolve_definition}; pub(crate) fn all_declarations_and_bindings<'db>( @@ -517,7 +517,12 @@ pub fn definitions_for_name<'db>( let mut resolved_definitions = Vec::new(); for definition in &all_definitions { - let resolved = resolve_definition(db, *definition, Some(name_str)); + let resolved = resolve_definition( + db, + *definition, + Some(name_str), + ImportAliasResolution::ResolveAliases, + ); resolved_definitions.extend(resolved); } @@ -528,7 +533,14 @@ pub fn definitions_for_name<'db>( }; find_symbol_in_scope(db, builtins_scope, name_str) .into_iter() - .flat_map(|def| resolve_definition(db, def, Some(name_str))) + .flat_map(|def| { + resolve_definition( + db, + def, + Some(name_str), + ImportAliasResolution::ResolveAliases, + ) + }) .collect() } else { resolved_definitions @@ -577,7 +589,12 @@ pub fn definitions_for_attribute<'db>( if let Some(module_file) = module_literal.module(db).file(db) { let module_scope = global_scope(db, module_file); for def in find_symbol_in_scope(db, module_scope, name_str) { - resolved.extend(resolve_definition(db, def, Some(name_str))); + resolved.extend(resolve_definition( + db, + def, + Some(name_str), + ImportAliasResolution::ResolveAliases, + )); } } continue; @@ -613,7 +630,12 @@ pub fn definitions_for_attribute<'db>( // Check declarations first for decl in use_def.all_reachable_symbol_declarations(place_id) { if let Some(def) = decl.declaration.definition() { - resolved.extend(resolve_definition(db, def, Some(name_str))); + resolved.extend(resolve_definition( + db, + def, + Some(name_str), + ImportAliasResolution::ResolveAliases, + )); break 'scopes; } } @@ -621,7 +643,12 @@ pub fn definitions_for_attribute<'db>( // If no declarations found, check bindings for binding in use_def.all_reachable_symbol_bindings(place_id) { if let Some(def) = binding.binding.definition() { - resolved.extend(resolve_definition(db, def, Some(name_str))); + resolved.extend(resolve_definition( + db, + def, + Some(name_str), + ImportAliasResolution::ResolveAliases, + )); break 'scopes; } } @@ -640,7 +667,12 @@ pub fn definitions_for_attribute<'db>( // Check declarations first for decl in use_def.all_reachable_member_declarations(place_id) { if let Some(def) = decl.declaration.definition() { - resolved.extend(resolve_definition(db, def, Some(name_str))); + resolved.extend(resolve_definition( + db, + def, + Some(name_str), + ImportAliasResolution::ResolveAliases, + )); break 'scopes; } } @@ -648,7 +680,12 @@ pub fn definitions_for_attribute<'db>( // If no declarations found, check bindings for binding in use_def.all_reachable_member_bindings(place_id) { if let Some(def) = binding.binding.definition() { - resolved.extend(resolve_definition(db, def, Some(name_str))); + resolved.extend(resolve_definition( + db, + def, + Some(name_str), + ImportAliasResolution::ResolveAliases, + )); break 'scopes; } } @@ -715,11 +752,15 @@ pub fn definitions_for_keyword_argument<'db>( /// Find the definitions for a symbol imported via `from x import y as z` statement. /// This function handles the case where the cursor is on the original symbol name `y`. /// Returns the same definitions as would be found for the alias `z`. +/// The `alias_resolution` parameter controls whether symbols imported with local import +/// aliases (like "x" in "from a import b as x") are resolved to their targets or kept +/// as aliases. pub fn definitions_for_imported_symbol<'db>( db: &'db dyn Db, file: File, import_node: &ast::StmtImportFrom, symbol_name: &str, + alias_resolution: ImportAliasResolution, ) -> Vec> { let mut visited = FxHashSet::default(); resolve_definition::resolve_from_import_definitions( @@ -728,6 +769,7 @@ pub fn definitions_for_imported_symbol<'db>( import_node, symbol_name, &mut visited, + alias_resolution, ) } @@ -821,6 +863,15 @@ mod resolve_definition { //! "resolved definitions". This is done recursively to find the original //! definition targeted by the import. + /// Controls whether local import aliases should be resolved to their targets or returned as-is. + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum ImportAliasResolution { + /// Resolve import aliases to their original definitions + ResolveAliases, + /// Keep import aliases as-is, don't resolve to original definitions + PreserveAliases, + } + use indexmap::IndexSet; use ruff_db::files::{File, FileRange}; use ruff_db::parsed::{ParsedModuleRef, parsed_module}; @@ -865,9 +916,16 @@ mod resolve_definition { db: &'db dyn Db, definition: Definition<'db>, symbol_name: Option<&str>, + alias_resolution: ImportAliasResolution, ) -> Vec> { let mut visited = FxHashSet::default(); - let resolved = resolve_definition_recursive(db, definition, &mut visited, symbol_name); + let resolved = resolve_definition_recursive( + db, + definition, + &mut visited, + symbol_name, + alias_resolution, + ); // If resolution failed, return the original definition as fallback if resolved.is_empty() { @@ -883,6 +941,7 @@ mod resolve_definition { definition: Definition<'db>, visited: &mut FxHashSet>, symbol_name: Option<&str>, + alias_resolution: ImportAliasResolution, ) -> Vec> { // Prevent infinite recursion if there are circular imports if visited.contains(&definition) { @@ -928,7 +987,14 @@ mod resolve_definition { // For `ImportFrom`, we need to resolve the original imported symbol name // (alias.name), not the local alias (symbol_name) - resolve_from_import_definitions(db, file, import_node, &alias.name, visited) + resolve_from_import_definitions( + db, + file, + import_node, + &alias.name, + visited, + alias_resolution, + ) } // For star imports, try to resolve to the specific symbol being accessed @@ -939,7 +1005,14 @@ mod resolve_definition { // If we have a symbol name, use the helper to resolve it in the target module if let Some(symbol_name) = symbol_name { - resolve_from_import_definitions(db, file, import_node, symbol_name, visited) + resolve_from_import_definitions( + db, + file, + import_node, + symbol_name, + visited, + alias_resolution, + ) } else { // No symbol context provided, can't resolve star import Vec::new() @@ -958,7 +1031,21 @@ mod resolve_definition { import_node: &ast::StmtImportFrom, symbol_name: &str, visited: &mut FxHashSet>, + alias_resolution: ImportAliasResolution, ) -> Vec> { + if alias_resolution == ImportAliasResolution::PreserveAliases { + for alias in &import_node.names { + if let Some(asname) = &alias.asname { + if asname.as_str() == symbol_name { + return vec![ResolvedDefinition::FileWithRange(FileRange::new( + file, + asname.range, + ))]; + } + } + } + } + // Resolve the target module file let module_file = { // Resolve the module being imported from (handles both relative and absolute imports) @@ -987,7 +1074,13 @@ mod resolve_definition { } else { let mut resolved_definitions = Vec::new(); for def in definitions_in_module { - let resolved = resolve_definition_recursive(db, def, visited, Some(symbol_name)); + let resolved = resolve_definition_recursive( + db, + def, + visited, + Some(symbol_name), + alias_resolution, + ); resolved_definitions.extend(resolved); } resolved_definitions diff --git a/crates/ty_server/src/capabilities.rs b/crates/ty_server/src/capabilities.rs index 62d87c109c..8e73cdde83 100644 --- a/crates/ty_server/src/capabilities.rs +++ b/crates/ty_server/src/capabilities.rs @@ -1,11 +1,11 @@ use lsp_types::{ ClientCapabilities, CompletionOptions, DeclarationCapability, DiagnosticOptions, DiagnosticServerCapabilities, HoverProviderCapability, InlayHintOptions, - InlayHintServerCapabilities, MarkupKind, OneOf, SelectionRangeProviderCapability, - SemanticTokensFullOptions, SemanticTokensLegend, SemanticTokensOptions, - SemanticTokensServerCapabilities, ServerCapabilities, SignatureHelpOptions, - TextDocumentSyncCapability, TextDocumentSyncKind, TextDocumentSyncOptions, - TypeDefinitionProviderCapability, WorkDoneProgressOptions, + InlayHintServerCapabilities, MarkupKind, OneOf, RenameOptions, + SelectionRangeProviderCapability, SemanticTokensFullOptions, SemanticTokensLegend, + SemanticTokensOptions, SemanticTokensServerCapabilities, ServerCapabilities, + SignatureHelpOptions, TextDocumentSyncCapability, TextDocumentSyncKind, + TextDocumentSyncOptions, TypeDefinitionProviderCapability, WorkDoneProgressOptions, }; use crate::PositionEncoding; @@ -289,6 +289,10 @@ pub(crate) fn server_capabilities( definition_provider: Some(OneOf::Left(true)), declaration_provider: Some(DeclarationCapability::Simple(true)), references_provider: Some(OneOf::Left(true)), + rename_provider: Some(OneOf::Right(RenameOptions { + prepare_provider: Some(true), + work_done_progress_options: WorkDoneProgressOptions::default(), + })), document_highlight_provider: Some(OneOf::Left(true)), hover_provider: Some(HoverProviderCapability::Simple(true)), signature_help_provider: Some(SignatureHelpOptions { diff --git a/crates/ty_server/src/server/api.rs b/crates/ty_server/src/server/api.rs index 51a7d30505..af963853df 100644 --- a/crates/ty_server/src/server/api.rs +++ b/crates/ty_server/src/server/api.rs @@ -80,6 +80,12 @@ pub(super) fn request(req: server::Request) -> Task { requests::SignatureHelpRequestHandler::METHOD => background_document_request_task::< requests::SignatureHelpRequestHandler, >(req, BackgroundSchedule::Worker), + requests::PrepareRenameRequestHandler::METHOD => background_document_request_task::< + requests::PrepareRenameRequestHandler, + >(req, BackgroundSchedule::Worker), + requests::RenameRequestHandler::METHOD => background_document_request_task::< + requests::RenameRequestHandler, + >(req, BackgroundSchedule::Worker), requests::CompletionRequestHandler::METHOD => background_document_request_task::< requests::CompletionRequestHandler, >( diff --git a/crates/ty_server/src/server/api/requests.rs b/crates/ty_server/src/server/api/requests.rs index f584018564..0be690a38e 100644 --- a/crates/ty_server/src/server/api/requests.rs +++ b/crates/ty_server/src/server/api/requests.rs @@ -8,6 +8,8 @@ mod goto_references; mod goto_type_definition; mod hover; mod inlay_hints; +mod prepare_rename; +mod rename; mod selection_range; mod semantic_tokens; mod semantic_tokens_range; @@ -26,6 +28,8 @@ pub(super) use goto_references::ReferencesRequestHandler; pub(super) use goto_type_definition::GotoTypeDefinitionRequestHandler; pub(super) use hover::HoverRequestHandler; pub(super) use inlay_hints::InlayHintRequestHandler; +pub(super) use prepare_rename::PrepareRenameRequestHandler; +pub(super) use rename::RenameRequestHandler; pub(super) use selection_range::SelectionRangeRequestHandler; pub(super) use semantic_tokens::SemanticTokensRequestHandler; pub(super) use semantic_tokens_range::SemanticTokensRangeRequestHandler; diff --git a/crates/ty_server/src/server/api/requests/prepare_rename.rs b/crates/ty_server/src/server/api/requests/prepare_rename.rs new file mode 100644 index 0000000000..c63f06bfa7 --- /dev/null +++ b/crates/ty_server/src/server/api/requests/prepare_rename.rs @@ -0,0 +1,60 @@ +use std::borrow::Cow; + +use lsp_types::request::PrepareRenameRequest; +use lsp_types::{PrepareRenameResponse, TextDocumentPositionParams, Url}; +use ruff_db::source::{line_index, source_text}; +use ty_ide::can_rename; +use ty_project::ProjectDatabase; + +use crate::document::{PositionExt, ToRangeExt}; +use crate::server::api::traits::{ + BackgroundDocumentRequestHandler, RequestHandler, RetriableRequestHandler, +}; +use crate::session::DocumentSnapshot; +use crate::session::client::Client; + +pub(crate) struct PrepareRenameRequestHandler; + +impl RequestHandler for PrepareRenameRequestHandler { + type RequestType = PrepareRenameRequest; +} + +impl BackgroundDocumentRequestHandler for PrepareRenameRequestHandler { + fn document_url(params: &TextDocumentPositionParams) -> Cow { + Cow::Borrowed(¶ms.text_document.uri) + } + + fn run_with_snapshot( + db: &ProjectDatabase, + snapshot: &DocumentSnapshot, + _client: &Client, + params: TextDocumentPositionParams, + ) -> crate::server::Result> { + if snapshot + .workspace_settings() + .is_language_services_disabled() + { + return Ok(None); + } + + let Some(file) = snapshot.file(db) else { + return Ok(None); + }; + + let source = source_text(db, file); + let line_index = line_index(db, file); + let offset = params + .position + .to_text_size(&source, &line_index, snapshot.encoding()); + + let Some(range) = can_rename(db, file, offset) else { + return Ok(None); + }; + + let lsp_range = range.to_lsp_range(&source, &line_index, snapshot.encoding()); + + Ok(Some(PrepareRenameResponse::Range(lsp_range))) + } +} + +impl RetriableRequestHandler for PrepareRenameRequestHandler {} diff --git a/crates/ty_server/src/server/api/requests/rename.rs b/crates/ty_server/src/server/api/requests/rename.rs new file mode 100644 index 0000000000..9a8af72df4 --- /dev/null +++ b/crates/ty_server/src/server/api/requests/rename.rs @@ -0,0 +1,83 @@ +use std::borrow::Cow; +use std::collections::HashMap; + +use lsp_types::request::Rename; +use lsp_types::{RenameParams, TextEdit, Url, WorkspaceEdit}; +use ruff_db::source::{line_index, source_text}; +use ty_ide::rename; +use ty_project::ProjectDatabase; + +use crate::document::{PositionExt, ToLink}; +use crate::server::api::traits::{ + BackgroundDocumentRequestHandler, RequestHandler, RetriableRequestHandler, +}; +use crate::session::DocumentSnapshot; +use crate::session::client::Client; + +pub(crate) struct RenameRequestHandler; + +impl RequestHandler for RenameRequestHandler { + type RequestType = Rename; +} + +impl BackgroundDocumentRequestHandler for RenameRequestHandler { + fn document_url(params: &RenameParams) -> Cow { + Cow::Borrowed(¶ms.text_document_position.text_document.uri) + } + + fn run_with_snapshot( + db: &ProjectDatabase, + snapshot: &DocumentSnapshot, + _client: &Client, + params: RenameParams, + ) -> crate::server::Result> { + if snapshot + .workspace_settings() + .is_language_services_disabled() + { + return Ok(None); + } + + let Some(file) = snapshot.file(db) else { + return Ok(None); + }; + + let source = source_text(db, file); + let line_index = line_index(db, file); + let offset = params.text_document_position.position.to_text_size( + &source, + &line_index, + snapshot.encoding(), + ); + + let Some(rename_results) = rename(db, file, offset, ¶ms.new_name) else { + return Ok(None); + }; + + // Group text edits by file + let mut changes: HashMap> = HashMap::new(); + + for reference in rename_results { + if let Some(location) = reference.to_location(db, snapshot.encoding()) { + let edit = TextEdit { + range: location.range, + new_text: params.new_name.clone(), + }; + + changes.entry(location.uri).or_default().push(edit); + } + } + + if changes.is_empty() { + return Ok(None); + } + + Ok(Some(WorkspaceEdit { + changes: Some(changes), + document_changes: None, + change_annotations: None, + })) + } +} + +impl RetriableRequestHandler for RenameRequestHandler {} diff --git a/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization.snap b/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization.snap index 0c00d640c8..d8ebdd77a4 100644 --- a/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization.snap +++ b/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization.snap @@ -31,6 +31,9 @@ expression: initialization_result "documentHighlightProvider": true, "documentSymbolProvider": true, "workspaceSymbolProvider": true, + "renameProvider": { + "prepareProvider": true + }, "declarationProvider": true, "semanticTokensProvider": { "legend": { diff --git a/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization_with_workspace.snap b/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization_with_workspace.snap index 0c00d640c8..d8ebdd77a4 100644 --- a/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization_with_workspace.snap +++ b/crates/ty_server/tests/e2e/snapshots/e2e__initialize__initialization_with_workspace.snap @@ -31,6 +31,9 @@ expression: initialization_result "documentHighlightProvider": true, "documentSymbolProvider": true, "workspaceSymbolProvider": true, + "renameProvider": { + "prepareProvider": true + }, "declarationProvider": true, "semanticTokensProvider": { "legend": {