diff --git a/Cargo.lock b/Cargo.lock index 0a261bf8a3..fc03d1e3cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3172,6 +3172,7 @@ name = "ruff_python_importer" version = "0.0.0" dependencies = [ "anyhow", + "insta", "ruff_diagnostics", "ruff_python_ast", "ruff_python_codegen", diff --git a/crates/ruff_python_importer/Cargo.toml b/crates/ruff_python_importer/Cargo.toml index 7d1d39ebc2..96070a2400 100644 --- a/crates/ruff_python_importer/Cargo.toml +++ b/crates/ruff_python_importer/Cargo.toml @@ -22,6 +22,7 @@ ruff_text_size = { workspace = true } anyhow = { workspace = true } [dev-dependencies] +insta = { workspace = true } [features] diff --git a/crates/ruff_python_importer/src/insertion.rs b/crates/ruff_python_importer/src/insertion.rs index 2a9fbb0813..d9a0db538e 100644 --- a/crates/ruff_python_importer/src/insertion.rs +++ b/crates/ruff_python_importer/src/insertion.rs @@ -128,6 +128,57 @@ impl<'a> Insertion<'a> { } } + /// Create an [`Insertion`] to insert an additional member to import + /// into a `from import member1, member2, ...` statement. + /// + /// For example, given the following code: + /// + /// ```python + /// """Hello, world!""" + /// + /// from collections import Counter + /// + /// + /// def foo(): + /// pass + /// ``` + /// + /// The insertion returned will begin after `Counter` but before the + /// newline terminator. Callers can then call [`Insertion::into_edit`] + /// with the additional member to add. A comma delimiter is handled + /// automatically. + /// + /// The statement itself is assumed to be at the top-level of the module. + /// + /// This returns `None` when `stmt` isn't a `from ... import ...` + /// statement. + pub fn existing_import(stmt: &Stmt, tokens: &Tokens) -> Option> { + let Stmt::ImportFrom(ref import_from) = *stmt else { + return None; + }; + if let Some(at) = import_from.names.last().map(Ranged::end) { + return Some(Insertion::inline(", ", at, "")); + } + // Our AST can deal with partial `from ... import` + // statements, so we might not have any members + // yet. In this case, we don't need the comma. + // + // ... however, unless we can be certain that + // inserting this name leads to a valid AST, we + // give up. + let at = import_from.end(); + if !matches!( + tokens + .before(at) + .last() + .map(ruff_python_parser::Token::kind), + Some(TokenKind::Import) + ) { + return None; + } + Some(Insertion::inline(" ", at, "")) + } + /// Create an [`Insertion`] to insert (e.g.) an import statement at the start of a given /// block, along with a prefix and suffix to use for the insertion. /// @@ -314,7 +365,7 @@ mod tests { use ruff_python_codegen::Stylist; use ruff_python_parser::parse_module; use ruff_source_file::LineEnding; - use ruff_text_size::TextSize; + use ruff_text_size::{Ranged, TextSize}; use super::Insertion; @@ -473,4 +524,286 @@ if True: Insertion::indented("", TextSize::from(9), "\n", " ") ); } + + #[test] + fn existing_import_works() { + fn snapshot(content: &str, member: &str) -> String { + let parsed = parse_module(content).unwrap(); + let edit = Insertion::existing_import(parsed.suite().first().unwrap(), parsed.tokens()) + .unwrap() + .into_edit(member); + let insert_text = edit.content().expect("edit should be non-empty"); + + let mut content = content.to_string(); + content.replace_range(edit.range().to_std_range(), insert_text); + content + } + + let source = r#" +from collections import Counter +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import Counter, defaultdict + ", + ); + + let source = r#" +from collections import Counter, OrderedDict +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import Counter, OrderedDict, defaultdict + ", + ); + + let source = r#" +from collections import (Counter) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @"from collections import (Counter, defaultdict)", + ); + + let source = r#" +from collections import (Counter, OrderedDict) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @"from collections import (Counter, OrderedDict, defaultdict)", + ); + + let source = r#" +from collections import (Counter,) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @"from collections import (Counter, defaultdict,)", + ); + + let source = r#" +from collections import (Counter, OrderedDict,) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @"from collections import (Counter, OrderedDict, defaultdict,)", + ); + + let source = r#" +from collections import ( + Counter +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + Counter, defaultdict + ) + ", + ); + + let source = r#" +from collections import ( + Counter, +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + Counter, defaultdict, + ) + ", + ); + + let source = r#" +from collections import ( + Counter, + OrderedDict +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + Counter, + OrderedDict, defaultdict + ) + ", + ); + + let source = r#" +from collections import ( + Counter, + OrderedDict, +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + Counter, + OrderedDict, defaultdict, + ) + ", + ); + + let source = r#" +from collections import \ + Counter +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import \ + Counter, defaultdict + ", + ); + + let source = r#" +from collections import \ + Counter, OrderedDict +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import \ + Counter, OrderedDict, defaultdict + ", + ); + + let source = r#" +from collections import \ + Counter, \ + OrderedDict +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import \ + Counter, \ + OrderedDict, defaultdict + ", + ); + + /* + from collections import ( + Collector # comment + ) + + from collections import ( + Collector, # comment + ) + + from collections import ( + Collector # comment + , + ) + + from collections import ( + Collector + # comment + , + ) + */ + + let source = r#" +from collections import ( + Counter # comment +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + Counter, defaultdict # comment + ) + ", + ); + + let source = r#" +from collections import ( + Counter, # comment +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + Counter, defaultdict, # comment + ) + ", + ); + + let source = r#" +from collections import ( + Counter # comment + , +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + Counter, defaultdict # comment + , + ) + ", + ); + + let source = r#" +from collections import ( + Counter + # comment + , +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + Counter, defaultdict + # comment + , + ) + ", + ); + + let source = r#" +from collections import ( + # comment 1 + Counter # comment 2 + # comment 3 +) +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @r" + from collections import ( + # comment 1 + Counter, defaultdict # comment 2 + # comment 3 + ) + ", + ); + + let source = r#" +from collections import Counter # comment +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @"from collections import Counter, defaultdict # comment", + ); + + let source = r#" +from collections import Counter, OrderedDict # comment +"#; + insta::assert_snapshot!( + snapshot(source, "defaultdict"), + @"from collections import Counter, OrderedDict, defaultdict # comment", + ); + } }