Compare commits

..

5 Commits

Author SHA1 Message Date
Zanie Blue
e1c6e2e26a Add test case for missing VIRTUAL_ENV 2025-12-17 15:27:47 -06:00
Bhuminjay Soni
52849a5e68 [syntax-errors] Annotated name cannot be global (#20868)
## Summary

<!-- What's the purpose of the change? What does it do, and why? -->
This PR implements a new semantic syntax error where annotated name
can't be global
example
```
x: int = 1

def f():
    global x
    x: str = "foo"  # SyntaxError: annotated name 'x' can't be global
 ```

## Test Plan

<!-- How was it tested? -->
I have written tests as directed in #17412

---------

Signed-off-by: 11happy <soni5happy@gmail.com>
Signed-off-by: 11happy <bhuminjaysoni@gmail.com>
Co-authored-by: Brent Westbrook <brentrwestbrook@gmail.com>
2025-12-17 08:39:47 -05:00
David Peter
2a61fe2353 [ty] Handle field specifier functions that accept **kwargs and recognize metaclass-based transformers as instances of DataclassInstance (#22018)
## Summary

This contains two bug fixes:

- [Handle field specifier functions that accept
`**kwargs`](ad6918d505)
- [Recognize metaclass-based transformers as instances of
`DataclassInstance`](1a8e29b23c)

closes https://github.com/astral-sh/ty/issues/1987

## Test Plan

* New Markdown tests
* Made sure that the example in 1987 checks without errors
2025-12-17 14:22:16 +01:00
Alex Waygood
764ad8b29b [ty] Improve disambiguation of types in many cases (#22019) 2025-12-17 11:41:07 +00:00
mahiro
85af715880 Fix playground Share button showing "Copied!" before clipboard copy completes (#21942)
Co-authored-by: Micha Reiser <micha@reiser.io>
2025-12-17 12:16:01 +01:00
33 changed files with 815 additions and 275 deletions

View File

@@ -1,5 +1,6 @@
use glob::PatternError;
use ruff_notebook::{Notebook, NotebookError};
use rustc_hash::FxHashMap;
use std::panic::RefUnwindSafe;
use std::sync::{Arc, Mutex};
@@ -20,18 +21,47 @@ use super::walk_directory::WalkDirectoryBuilder;
///
/// ## Warning
/// Don't use this system for production code. It's intended for testing only.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct TestSystem {
inner: Arc<dyn WritableSystem + RefUnwindSafe + Send + Sync>,
/// Environment variable overrides. If a key is present here, it takes precedence
/// over the inner system's environment variables.
env_overrides: Arc<Mutex<FxHashMap<String, Option<String>>>>,
}
impl Clone for TestSystem {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
env_overrides: self.env_overrides.clone(),
}
}
}
impl TestSystem {
pub fn new(inner: impl WritableSystem + RefUnwindSafe + Send + Sync + 'static) -> Self {
Self {
inner: Arc::new(inner),
env_overrides: Arc::new(Mutex::new(FxHashMap::default())),
}
}
/// Sets an environment variable override. This takes precedence over the inner system.
pub fn set_env_var(&self, name: impl Into<String>, value: impl Into<String>) {
self.env_overrides
.lock()
.unwrap()
.insert(name.into(), Some(value.into()));
}
/// Removes an environment variable override, making it appear as not set.
pub fn remove_env_var(&self, name: impl Into<String>) {
self.env_overrides
.lock()
.unwrap()
.insert(name.into(), None);
}
/// Returns the [`InMemorySystem`].
///
/// ## Panics
@@ -147,6 +177,18 @@ impl System for TestSystem {
self.system().case_sensitivity()
}
fn env_var(&self, name: &str) -> std::result::Result<String, std::env::VarError> {
// Check overrides first
if let Some(override_value) = self.env_overrides.lock().unwrap().get(name) {
return match override_value {
Some(value) => Ok(value.clone()),
None => Err(std::env::VarError::NotPresent),
};
}
// Fall back to inner system
self.system().env_var(name)
}
fn dyn_clone(&self) -> Box<dyn System> {
Box::new(self.clone())
}
@@ -156,6 +198,7 @@ impl Default for TestSystem {
fn default() -> Self {
Self {
inner: Arc::new(InMemorySystem::default()),
env_overrides: Arc::new(Mutex::new(FxHashMap::default())),
}
}
}

View File

@@ -0,0 +1,38 @@
a: int = 1
def f1():
global a
a: str = "foo" # error
b: int = 1
def outer():
def inner():
global b
b: str = "nested" # error
c: int = 1
def f2():
global c
c: list[str] = [] # error
d: int = 1
def f3():
global d
d: str # error
e: int = 1
def f4():
e: str = "happy" # okay
global f
f: int = 1 # okay
g: int = 1
global g # error
class C:
x: str
global x # error
class D:
global x # error
x: str

View File

@@ -1001,6 +1001,7 @@ mod tests {
#[test_case(Path::new("write_to_debug.py"), PythonVersion::PY310)]
#[test_case(Path::new("invalid_expression.py"), PythonVersion::PY312)]
#[test_case(Path::new("global_parameter.py"), PythonVersion::PY310)]
#[test_case(Path::new("annotated_global.py"), PythonVersion::PY314)]
fn test_semantic_errors(path: &Path, python_version: PythonVersion) -> Result<()> {
let snapshot = format!(
"semantic_syntax_error_{}_{}",

View File

@@ -0,0 +1,74 @@
---
source: crates/ruff_linter/src/linter.rs
---
invalid-syntax: annotated name `a` can't be global
--> resources/test/fixtures/semantic_errors/annotated_global.py:4:5
|
2 | def f1():
3 | global a
4 | a: str = "foo" # error
| ^
5 |
6 | b: int = 1
|
invalid-syntax: annotated name `b` can't be global
--> resources/test/fixtures/semantic_errors/annotated_global.py:10:9
|
8 | def inner():
9 | global b
10 | b: str = "nested" # error
| ^
11 |
12 | c: int = 1
|
invalid-syntax: annotated name `c` can't be global
--> resources/test/fixtures/semantic_errors/annotated_global.py:15:5
|
13 | def f2():
14 | global c
15 | c: list[str] = [] # error
| ^
16 |
17 | d: int = 1
|
invalid-syntax: annotated name `d` can't be global
--> resources/test/fixtures/semantic_errors/annotated_global.py:20:5
|
18 | def f3():
19 | global d
20 | d: str # error
| ^
21 |
22 | e: int = 1
|
invalid-syntax: annotated name `g` can't be global
--> resources/test/fixtures/semantic_errors/annotated_global.py:29:1
|
27 | f: int = 1 # okay
28 |
29 | g: int = 1
| ^
30 | global g # error
|
invalid-syntax: annotated name `x` can't be global
--> resources/test/fixtures/semantic_errors/annotated_global.py:33:5
|
32 | class C:
33 | x: str
| ^
34 | global x # error
|
invalid-syntax: annotated name `x` can't be global
--> resources/test/fixtures/semantic_errors/annotated_global.py:38:5
|
36 | class D:
37 | global x # error
38 | x: str
| ^
|

View File

@@ -272,7 +272,9 @@ impl SemanticSyntaxChecker {
fn check_annotation<Ctx: SemanticSyntaxContext>(stmt: &ast::Stmt, ctx: &Ctx) {
match stmt {
Stmt::AnnAssign(ast::StmtAnnAssign { annotation, .. }) => {
Stmt::AnnAssign(ast::StmtAnnAssign {
target, annotation, ..
}) => {
if ctx.python_version() > PythonVersion::PY313 {
// test_ok valid_annotation_py313
// # parse_options: {"target-version": "3.13"}
@@ -297,6 +299,18 @@ impl SemanticSyntaxChecker {
};
visitor.visit_expr(annotation);
}
if let Expr::Name(ast::ExprName { id, .. }) = target.as_ref() {
if let Some(global_stmt) = ctx.global(id.as_str()) {
let global_start = global_stmt.start();
if !ctx.in_module_scope() || target.start() < global_start {
Self::add_error(
ctx,
SemanticSyntaxErrorKind::AnnotatedGlobal(id.to_string()),
target.range(),
);
}
}
}
}
Stmt::FunctionDef(ast::StmtFunctionDef {
type_params,

View File

@@ -2703,3 +2703,51 @@ fn pythonpath_multiple_dirs_is_respected() -> anyhow::Result<()> {
Ok(())
}
/// Test behavior when `VIRTUAL_ENV` is set but points to a non-existent path.
#[test]
fn missing_virtual_env() -> anyhow::Result<()> {
let working_venv_package1_path = if cfg!(windows) {
"project/.venv/Lib/site-packages/package1/__init__.py"
} else {
"project/.venv/lib/python3.13/site-packages/package1/__init__.py"
};
let case = CliTest::with_files([
(
"project/test.py",
r#"
from package1 import WorkingVenv
"#,
),
(
"project/.venv/pyvenv.cfg",
r#"
home = ./
"#,
),
(
working_venv_package1_path,
r#"
class WorkingVenv: ...
"#,
),
])?;
assert_cmd_snapshot!(case.command()
.current_dir(case.root().join("project"))
.env("VIRTUAL_ENV", case.root().join("nonexistent-venv")), @r"
success: false
exit_code: 2
----- stdout -----
----- stderr -----
ty failed
Cause: Failed to discover local Python environment
Cause: Invalid `VIRTUAL_ENV` environment variable `<temp_dir>/nonexistent-venv`: does not point to a directory on disk
Cause: No such file or directory (os error 2)
");
Ok(())
}

View File

@@ -1208,7 +1208,7 @@ def _(flag: bool):
reveal_type(C1.y) # revealed: int | str
C1.y = 100
# error: [invalid-assignment] "Object of type `Literal["problematic"]` is not assignable to attribute `y` on type `<class 'C1'> | <class 'C1'>`"
# error: [invalid-assignment] "Object of type `Literal["problematic"]` is not assignable to attribute `y` on type `<class 'mdtest_snippet.<locals of function '_'>.C1 @ src/mdtest_snippet.py:3'> | <class 'mdtest_snippet.<locals of function '_'>.C1 @ src/mdtest_snippet.py:8'>`"
C1.y = "problematic"
class C2:

View File

@@ -643,6 +643,91 @@ reveal_type(Person.__init__) # revealed: (self: Person, name: str) -> None
Person(name="Alice")
```
### Field specifiers using `**kwargs`
Some field specifiers may use `**kwargs` to pass through standard parameters like `default`,
`default_factory`, `init`, `kw_only`, and `alias`. This section tests that all these parameters work
correctly when passed via `**kwargs` for all three kinds of transformers.
#### Function-based transformer
```py
from typing import Any
from typing_extensions import dataclass_transform
def field(**kwargs: Any) -> Any: ...
@dataclass_transform(field_specifiers=(field,))
def create_model[T](cls: type[T]) -> type[T]:
return cls
@create_model
class Person:
id: int = field(init=False)
name: str
age: int = field(default=0)
tags: list[str] = field(default_factory=list)
email: str = field(kw_only=True)
internal_notes: str = field(alias="notes")
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
reveal_type(Person.__init__)
Person("Alice", 30, [], "some notes", email="alice@example.com")
Person("Bob", email="bob@example.com", notes="other notes")
```
#### Metaclass-based transformer
```py
from typing import Any
from typing_extensions import dataclass_transform
def field(**kwargs: Any) -> Any: ...
@dataclass_transform(field_specifiers=(field,))
class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class Person(ModelBase):
id: int = field(init=False)
name: str
age: int = field(default=0)
tags: list[str] = field(default_factory=list)
email: str = field(kw_only=True)
internal_notes: str = field(alias="notes")
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
reveal_type(Person.__init__)
Person("Alice", 30, [], "some notes", email="alice@example.com")
Person("Bob", email="bob@example.com", notes="other notes")
```
#### Base-class-based transformer
```py
from typing import Any
from typing_extensions import dataclass_transform
def field(**kwargs: Any) -> Any: ...
@dataclass_transform(field_specifiers=(field,))
class ModelBase: ...
class Person(ModelBase):
id: int = field(init=False)
name: str
age: int = field(default=0)
tags: list[str] = field(default_factory=list)
email: str = field(kw_only=True)
internal_notes: str = field(alias="notes")
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
reveal_type(Person.__init__)
Person("Alice", 30, [], "some notes", email="alice@example.com")
Person("Bob", email="bob@example.com", notes="other notes")
```
### Support for `alias`
The `alias` parameter in field specifiers allows providing an alternative name for the parameter in
@@ -868,4 +953,83 @@ reveal_type(t.key) # revealed: int
reveal_type(t.name) # revealed: str
```
## `__dataclass_fields__` and `DataclassInstance` protocol
Classes created via `dataclass_transform` should have `__dataclass_fields__` and
`__dataclass_params__` attributes, allowing them to satisfy the `DataclassInstance` protocol. This
enables use of `dataclasses.fields`, `dataclasses.asdict`, `dataclasses.replace`, etc.
### Function-based transformer
```py
from dataclasses import fields, asdict, replace, Field
from typing import dataclass_transform, Any
@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
return cls
@create_model
class Person:
name: str
age: int
p = Person("Alice", 30)
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
reveal_type(asdict(p)) # revealed: dict[str, Any]
reveal_type(replace(p, name="Bob")) # revealed: Person
```
### Metaclass-based transformer
```py
from dataclasses import fields, asdict, replace, Field
from typing import dataclass_transform, Any
@dataclass_transform()
class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class Person(ModelBase):
name: str
age: int
p = Person("Alice", 30)
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
reveal_type(asdict(p)) # revealed: dict[str, Any]
reveal_type(replace(p, name="Bob")) # revealed: Person
```
### Base-class-based transformer
```py
from dataclasses import fields, asdict, replace, Field
from typing import dataclass_transform, Any
@dataclass_transform()
class ModelBase: ...
class Person(ModelBase):
name: str
age: int
p = Person("Alice", 30)
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
reveal_type(asdict(p)) # revealed: dict[str, Any]
reveal_type(replace(p, name="Bob")) # revealed: Person
```
[`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform

View File

@@ -195,3 +195,52 @@ class C:
c = C()
c.square("hello") # error: [invalid-argument-type]
```
## Types with the same name but from different files
`module.py`:
```py
class Foo: ...
def needs_a_foo(x: Foo): ...
```
`main.py`:
```py
from module import needs_a_foo
class Foo: ...
needs_a_foo(Foo()) # error: [invalid-argument-type]
```
## TypeVars with bounds that have the same name but are from different files
In this case, using fully qualified names is *not* necessary.
```toml
[environment]
python-version = "3.12"
```
`module.py`:
```py
class Foo: ...
def needs_a_foo(x: Foo): ...
```
`main.py`:
```py
from module import needs_a_foo
class Foo: ...
def f[T: Foo](x: T) -> T:
needs_a_foo(x) # error: [invalid-argument-type]
return x
```

View File

@@ -393,7 +393,7 @@ else:
# revealed: (<class 'B'>, <class 'X'>, <class 'Y'>, <class 'O'>, <class 'object'>) | (<class 'B'>, <class 'Y'>, <class 'X'>, <class 'O'>, <class 'object'>)
reveal_mro(B)
# error: 12 [unsupported-base] "Unsupported class base with type `<class 'B'> | <class 'B'>`"
# error: 12 [unsupported-base] "Unsupported class base with type `<class 'mdtest_snippet.B @ src/mdtest_snippet.py:25'> | <class 'mdtest_snippet.B @ src/mdtest_snippet.py:28'>`"
class Z(A, B): ...
reveal_mro(Z) # revealed: (<class 'Z'>, Unknown, <class 'object'>)

View File

@@ -37,7 +37,7 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/attribute_as
# Diagnostics
```
error[invalid-assignment]: Object of type `Literal[1]` is not assignable to attribute `attr` on type `<class 'C1'> | <class 'C1'>`
error[invalid-assignment]: Object of type `Literal[1]` is not assignable to attribute `attr` on type `<class 'mdtest_snippet.<locals of function '_'>.C1 @ src/mdtest_snippet.py:3'> | <class 'mdtest_snippet.<locals of function '_'>.C1 @ src/mdtest_snippet.py:7'>`
--> src/mdtest_snippet.py:11:5
|
10 | # TODO: The error message here could be improved to explain why the assignment fails.

View File

@@ -0,0 +1,53 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_argument_type.md - Invalid argument type diagnostics - TypeVars with bounds that have the same name but are from different files
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_argument_type.md
---
# Python source files
## module.py
```
1 | class Foo: ...
2 |
3 | def needs_a_foo(x: Foo): ...
```
## main.py
```
1 | from module import needs_a_foo
2 |
3 | class Foo: ...
4 |
5 | def f[T: Foo](x: T) -> T:
6 | needs_a_foo(x) # error: [invalid-argument-type]
7 | return x
```
# Diagnostics
```
error[invalid-argument-type]: Argument to function `needs_a_foo` is incorrect
--> src/main.py:6:17
|
5 | def f[T: Foo](x: T) -> T:
6 | needs_a_foo(x) # error: [invalid-argument-type]
| ^ Expected `Foo`, found `T@f`
7 | return x
|
info: Function defined here
--> src/module.py:3:5
|
1 | class Foo: ...
2 |
3 | def needs_a_foo(x: Foo): ...
| ^^^^^^^^^^^ ------ Parameter declared here
|
info: rule `invalid-argument-type` is enabled by default
```

View File

@@ -0,0 +1,51 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_argument_type.md - Invalid argument type diagnostics - Types with the same name but from different files
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_argument_type.md
---
# Python source files
## module.py
```
1 | class Foo: ...
2 |
3 | def needs_a_foo(x: Foo): ...
```
## main.py
```
1 | from module import needs_a_foo
2 |
3 | class Foo: ...
4 |
5 | needs_a_foo(Foo()) # error: [invalid-argument-type]
```
# Diagnostics
```
error[invalid-argument-type]: Argument to function `needs_a_foo` is incorrect
--> src/main.py:5:13
|
3 | class Foo: ...
4 |
5 | needs_a_foo(Foo()) # error: [invalid-argument-type]
| ^^^^^ Expected `module.Foo`, found `main.Foo`
|
info: Function defined here
--> src/module.py:3:5
|
1 | class Foo: ...
2 |
3 | def needs_a_foo(x: Foo): ...
| ^^^^^^^^^^^ ------ Parameter declared here
|
info: rule `invalid-argument-type` is enabled by default
```

View File

@@ -615,7 +615,7 @@ impl<'db> PropertyInstanceType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.is_equivalent_to_impl(db, other, inferable, &IsEquivalentVisitor::default())
}
@@ -624,7 +624,7 @@ impl<'db> PropertyInstanceType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
let getter_equivalence = if let Some(getter) = self.getter(db) {
@@ -1486,7 +1486,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> Type<'db> {
self.filter_union(db, |elem| {
!elem
@@ -1952,7 +1952,7 @@ impl<'db> Type<'db> {
///
/// See [`TypeRelation::Subtyping`] for more details.
pub(crate) fn is_subtype_of(self, db: &'db dyn Db, target: Type<'db>) -> bool {
self.when_subtype_of(db, target, &InferableTypeVars::None)
self.when_subtype_of(db, target, InferableTypeVars::None)
.is_always_satisfied(db)
}
@@ -1960,7 +1960,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.has_relation_to(db, target, inferable, TypeRelation::Subtyping)
}
@@ -1974,7 +1974,7 @@ impl<'db> Type<'db> {
db: &'db dyn Db,
target: Type<'db>,
assuming: ConstraintSet<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.has_relation_to(
db,
@@ -1988,7 +1988,7 @@ impl<'db> Type<'db> {
///
/// See `TypeRelation::Assignability` for more details.
pub fn is_assignable_to(self, db: &'db dyn Db, target: Type<'db>) -> bool {
self.when_assignable_to(db, target, &InferableTypeVars::None)
self.when_assignable_to(db, target, InferableTypeVars::None)
.is_always_satisfied(db)
}
@@ -1998,7 +1998,7 @@ impl<'db> Type<'db> {
/// a constraint set and lets `satisfied_by_all_typevars` perform existential vs universal
/// reasoning depending on inferable typevars.
pub fn is_constraint_set_assignable_to(self, db: &'db dyn Db, target: Type<'db>) -> bool {
self.when_constraint_set_assignable_to(db, target, &InferableTypeVars::None)
self.when_constraint_set_assignable_to(db, target, InferableTypeVars::None)
.is_always_satisfied(db)
}
@@ -2006,7 +2006,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.has_relation_to(db, target, inferable, TypeRelation::Assignability)
}
@@ -2015,7 +2015,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.has_relation_to(
db,
@@ -2036,7 +2036,7 @@ impl<'db> Type<'db> {
other: Type<'db>,
) -> bool {
self_ty
.has_relation_to(db, other, &InferableTypeVars::None, TypeRelation::Redundancy)
.has_relation_to(db, other, InferableTypeVars::None, TypeRelation::Redundancy)
.is_always_satisfied(db)
}
@@ -2051,35 +2051,24 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
) -> ConstraintSet<'db> {
#[salsa::tracked(cycle_initial=has_relation_to_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn has_relation_to_tracked<'db>(
db: &'db dyn Db,
source: Type<'db>,
target: Type<'db>,
inferable: InferableTypeVars<'db>,
relation: TypeRelation<'db>,
) -> ConstraintSet<'db> {
source.has_relation_to_impl(
db,
target,
&inferable,
relation,
&HasRelationToVisitor::default(),
&IsDisjointVisitor::default(),
)
}
has_relation_to_tracked(db, self, target, inferable.clone(), relation)
self.has_relation_to_impl(
db,
target,
inferable,
relation,
&HasRelationToVisitor::default(),
&IsDisjointVisitor::default(),
)
}
fn has_relation_to_impl(
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -3104,7 +3093,7 @@ impl<'db> Type<'db> {
///
/// [equivalent to]: https://typing.python.org/en/latest/spec/glossary.html#term-equivalent
pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
self.when_equivalent_to(db, other, &InferableTypeVars::None)
self.when_equivalent_to(db, other, InferableTypeVars::None)
.is_always_satisfied(db)
}
@@ -3112,7 +3101,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
other: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.is_equivalent_to_impl(db, other, inferable, &IsEquivalentVisitor::default())
}
@@ -3121,7 +3110,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
other: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
if self == other {
@@ -3235,7 +3224,7 @@ impl<'db> Type<'db> {
/// This function aims to have no false positives, but might return wrong
/// `false` answers in some cases.
pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool {
self.when_disjoint_from(db, other, &InferableTypeVars::None)
self.when_disjoint_from(db, other, InferableTypeVars::None)
.is_always_satisfied(db)
}
@@ -3243,7 +3232,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
other: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.is_disjoint_from_impl(
db,
@@ -3258,7 +3247,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
other: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> {
@@ -3266,7 +3255,7 @@ impl<'db> Type<'db> {
db: &'db dyn Db,
protocol: ProtocolInstanceType<'db>,
other: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> {
@@ -8669,19 +8658,6 @@ fn is_redundant_with_cycle_initial<'db>(
true
}
fn has_relation_to_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_source: Type<'db>,
_target: Type<'db>,
_inferable: InferableTypeVars<'db>,
_relation: TypeRelation<'db>,
) -> ConstraintSet<'db> {
// For recursive types, optimistically assume the relation holds.
// This matches the fallback behavior of HasRelationToVisitor::default().
ConstraintSet::from(true)
}
fn apply_specialization_cycle_recover<'db>(
db: &'db dyn Db,
cycle: &salsa::Cycle,
@@ -12312,7 +12288,7 @@ impl<'db> BoundMethodType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -12346,7 +12322,7 @@ impl<'db> BoundMethodType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
self.function(db)
@@ -12562,7 +12538,7 @@ impl<'db> CallableType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -12587,7 +12563,7 @@ impl<'db> CallableType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
if self == other {
@@ -12653,7 +12629,7 @@ impl<'db> CallableTypes<'db> {
self,
db: &'db dyn Db,
other: CallableType<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -12742,7 +12718,7 @@ impl<'db> KnownBoundMethodType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -12847,7 +12823,7 @@ impl<'db> KnownBoundMethodType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
match (self, other) {
@@ -14096,7 +14072,7 @@ impl<'db> UnionType<'db> {
self,
db: &'db dyn Db,
other: Self,
_inferable: &InferableTypeVars<'db>,
_inferable: InferableTypeVars<'_, 'db>,
_visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
if self == other {
@@ -14300,7 +14276,7 @@ impl<'db> IntersectionType<'db> {
self,
db: &'db dyn Db,
other: Self,
_inferable: &InferableTypeVars<'db>,
_inferable: InferableTypeVars<'_, 'db>,
_visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
if self == other {

View File

@@ -21,7 +21,6 @@ use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::{SmallVec, smallvec, smallvec_inline};
use super::{Argument, CallArguments, CallError, CallErrorKind, InferContext, Signature, Type};
use crate::Program;
use crate::db::Db;
use crate::dunder_all::dunder_all_names;
use crate::module_resolver::KnownModule;
@@ -52,6 +51,7 @@ use crate::types::{
enums, list_members, todo_type,
};
use crate::unpack::EvaluationMode;
use crate::{DisplaySettings, Program};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
@@ -214,7 +214,7 @@ impl<'db> Bindings<'db> {
}
}
self.evaluate_known_cases(db, dataclass_field_specifiers);
self.evaluate_known_cases(db, argument_types, dataclass_field_specifiers);
// In order of precedence:
//
@@ -337,7 +337,12 @@ impl<'db> Bindings<'db> {
/// Evaluates the return type of certain known callables, where we have special-case logic to
/// determine the return type in a way that isn't directly expressible in the type system.
fn evaluate_known_cases(&mut self, db: &'db dyn Db, dataclass_field_specifiers: &[Type<'db>]) {
fn evaluate_known_cases(
&mut self,
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
dataclass_field_specifiers: &[Type<'db>],
) {
let to_bool = |ty: &Option<Type<'_>>, default: bool| -> bool {
if let Some(Type::BooleanLiteral(value)) = ty {
*value
@@ -666,25 +671,32 @@ impl<'db> Bindings<'db> {
if dataclass_field_specifiers.contains(&function)
|| function_type.is_known(db, KnownFunction::Field) =>
{
let has_default_value = overload
.parameter_type_by_name("default", false)
.is_ok_and(|ty| ty.is_some())
|| overload
.parameter_type_by_name("default_factory", false)
.is_ok_and(|ty| ty.is_some())
|| overload
.parameter_type_by_name("factory", false)
.is_ok_and(|ty| ty.is_some());
// Helper to get the type of a keyword argument by name. We first try to get it from
// the parameter binding (for explicit parameters), and then fall back to checking the
// call site arguments (for field-specifier functions that use a `**kwargs` parameter,
// instead of specifying `init`, `default` etc. explicitly).
let get_argument_type = |name, fallback_to_default| -> Option<Type<'db>> {
if let Ok(ty) =
overload.parameter_type_by_name(name, fallback_to_default)
{
return ty;
}
argument_types.iter().find_map(|(arg, ty)| {
if matches!(arg, Argument::Keyword(arg_name) if arg_name == name) {
ty
} else {
None
}
})
};
let init = overload
.parameter_type_by_name("init", true)
.unwrap_or(None);
let kw_only = overload
.parameter_type_by_name("kw_only", true)
.unwrap_or(None);
let alias = overload
.parameter_type_by_name("alias", true)
.unwrap_or(None);
let has_default_value = get_argument_type("default", false).is_some()
|| get_argument_type("default_factory", false).is_some()
|| get_argument_type("factory", false).is_some();
let init = get_argument_type("init", true);
let kw_only = get_argument_type("kw_only", true);
let alias = get_argument_type("alias", true);
// `dataclasses.field` and field-specifier functions of commonly used
// libraries like `pydantic`, `attrs`, and `SQLAlchemy` all return
@@ -737,7 +749,7 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsEquivalentTo) => {
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
let constraints =
ty_a.when_equivalent_to(db, *ty_b, &InferableTypeVars::None);
ty_a.when_equivalent_to(db, *ty_b, InferableTypeVars::None);
let tracked = TrackedConstraintSet::new(db, constraints);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::ConstraintSet(tracked),
@@ -748,7 +760,7 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsSubtypeOf) => {
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
let constraints =
ty_a.when_subtype_of(db, *ty_b, &InferableTypeVars::None);
ty_a.when_subtype_of(db, *ty_b, InferableTypeVars::None);
let tracked = TrackedConstraintSet::new(db, constraints);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::ConstraintSet(tracked),
@@ -759,7 +771,7 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsAssignableTo) => {
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
let constraints =
ty_a.when_assignable_to(db, *ty_b, &InferableTypeVars::None);
ty_a.when_assignable_to(db, *ty_b, InferableTypeVars::None);
let tracked = TrackedConstraintSet::new(db, constraints);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::ConstraintSet(tracked),
@@ -770,7 +782,7 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsDisjointFrom) => {
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
let constraints =
ty_a.when_disjoint_from(db, *ty_b, &InferableTypeVars::None);
ty_a.when_disjoint_from(db, *ty_b, InferableTypeVars::None);
let tracked = TrackedConstraintSet::new(db, constraints);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::ConstraintSet(tracked),
@@ -1254,7 +1266,7 @@ impl<'db> Bindings<'db> {
db,
*ty_b,
tracked.constraints(db),
&InferableTypeVars::None,
InferableTypeVars::None,
);
let tracked = TrackedConstraintSet::new(db, result);
overload.set_return_type(Type::KnownInstance(
@@ -1314,7 +1326,7 @@ impl<'db> Bindings<'db> {
let result = tracked
.constraints(db)
.satisfied_by_all_typevars(db, &InferableTypeVars::One(std::sync::Arc::new(inferable.clone())));
.satisfied_by_all_typevars(db, InferableTypeVars::One(&inferable));
overload.set_return_type(Type::BooleanLiteral(result));
}
@@ -1732,7 +1744,7 @@ impl<'db> CallableBinding<'db> {
.annotated_type()
.unwrap_or(Type::unknown());
if argument_type
.when_assignable_to(db, parameter_type, &overload.inferable_typevars)
.when_assignable_to(db, parameter_type, overload.inferable_typevars)
.is_always_satisfied(db)
{
is_argument_assignable_to_any_overload = true;
@@ -1986,7 +1998,7 @@ impl<'db> CallableBinding<'db> {
.when_equivalent_to(
db,
current_parameter_type,
&overload.inferable_typevars,
overload.inferable_typevars,
)
.is_always_satisfied(db)
{
@@ -2134,7 +2146,7 @@ impl<'db> CallableBinding<'db> {
.when_equivalent_to(
db,
first_overload_return_type,
&overload.inferable_typevars,
overload.inferable_typevars,
)
.is_always_satisfied(db)
})
@@ -2925,7 +2937,7 @@ struct ArgumentTypeChecker<'a, 'db> {
return_ty: Type<'db>,
errors: &'a mut Vec<BindingError<'db>>,
inferable_typevars: InferableTypeVars<'db>,
inferable_typevars: InferableTypeVars<'db, 'db>,
specialization: Option<Specialization<'db>>,
}
@@ -2997,7 +3009,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
.zip(self.call_expression_tcx.annotation);
self.inferable_typevars = generic_context.inferable_typevars(self.db);
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars.clone());
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
// Prefer the declared type of generic classes.
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
@@ -3169,7 +3181,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// constraint set that we get from this assignability check, instead of inferring and
// building them in an earlier separate step.
if argument_type
.when_assignable_to(self.db, expected_ty, &self.inferable_typevars)
.when_assignable_to(self.db, expected_ty, self.inferable_typevars)
.is_never_satisfied(self.db)
{
let positional = matches!(argument, Argument::Positional | Argument::Synthetic)
@@ -3442,7 +3454,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
.when_assignable_to(
self.db,
KnownClass::Str.to_instance(self.db),
&self.inferable_typevars,
self.inferable_typevars,
)
.is_always_satisfied(self.db)
{
@@ -3507,7 +3519,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
fn finish(
self,
) -> (
InferableTypeVars<'db>,
InferableTypeVars<'db, 'db>,
Option<Specialization<'db>>,
Type<'db>,
) {
@@ -3580,7 +3592,7 @@ pub(crate) struct Binding<'db> {
return_ty: Type<'db>,
/// The inferable typevars in this signature.
inferable_typevars: InferableTypeVars<'db>,
inferable_typevars: InferableTypeVars<'db, 'db>,
/// The specialization that was inferred from the argument types, if the callable is generic.
specialization: Option<Specialization<'db>>,
@@ -3784,7 +3796,7 @@ impl<'db> Binding<'db> {
fn snapshot(&self) -> BindingSnapshot<'db> {
BindingSnapshot {
return_ty: self.return_ty,
inferable_typevars: self.inferable_typevars.clone(),
inferable_typevars: self.inferable_typevars,
specialization: self.specialization,
argument_matches: self.argument_matches.clone(),
parameter_tys: self.parameter_tys.clone(),
@@ -3834,7 +3846,7 @@ impl<'db> Binding<'db> {
#[derive(Clone, Debug)]
struct BindingSnapshot<'db> {
return_ty: Type<'db>,
inferable_typevars: InferableTypeVars<'db>,
inferable_typevars: InferableTypeVars<'db, 'db>,
specialization: Option<Specialization<'db>>,
argument_matches: Box<[MatchedArgument<'db>]>,
parameter_tys: Box<[Option<Type<'db>>]>,
@@ -3874,7 +3886,7 @@ impl<'db> CallableBindingSnapshot<'db> {
// ... and update the snapshot with the current state of the binding.
snapshot.return_ty = binding.return_ty;
snapshot.inferable_typevars = binding.inferable_typevars.clone();
snapshot.inferable_typevars = binding.inferable_typevars;
snapshot.specialization = binding.specialization;
snapshot
.argument_matches
@@ -4156,8 +4168,13 @@ impl<'db> BindingError<'db> {
return;
};
let provided_ty_display = provided_ty.display(context.db());
let expected_ty_display = expected_ty.display(context.db());
let display_settings = DisplaySettings::from_possibly_ambiguous_types(
context.db(),
[provided_ty, expected_ty],
);
let provided_ty_display =
provided_ty.display_with(context.db(), display_settings.clone());
let expected_ty_display = expected_ty.display_with(context.db(), display_settings);
let mut diag = builder.into_diagnostic(format_args!(
"Argument{} is incorrect",

View File

@@ -600,7 +600,7 @@ impl<'db> ClassType<'db> {
/// Return `true` if `other` is present in this class's MRO.
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: ClassType<'db>) -> bool {
self.when_subclass_of(db, other, &InferableTypeVars::None)
self.when_subclass_of(db, other, InferableTypeVars::None)
.is_always_satisfied(db)
}
@@ -608,7 +608,7 @@ impl<'db> ClassType<'db> {
self,
db: &'db dyn Db,
other: ClassType<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.has_relation_to_impl(
db,
@@ -624,7 +624,7 @@ impl<'db> ClassType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -676,7 +676,7 @@ impl<'db> ClassType<'db> {
self,
db: &'db dyn Db,
other: ClassType<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
if self == other {
@@ -736,7 +736,7 @@ impl<'db> ClassType<'db> {
.is_disjoint_from(
db,
other_alias.specialization(db),
&InferableTypeVars::None,
InferableTypeVars::None,
)
.is_always_satisfied(db)
}
@@ -2277,7 +2277,11 @@ impl<'db> ClassLiteral<'db> {
specialization: Option<Specialization<'db>>,
name: &str,
) -> Member<'db> {
if self.dataclass_params(db).is_some() {
// Check if this class is dataclass-like (either via @dataclass or via dataclass_transform)
if matches!(
CodeGeneratorKind::from_class(db, self, specialization),
Some(CodeGeneratorKind::DataclassLike(_))
) {
if name == "__dataclass_fields__" {
// Make this class look like a subclass of the `DataClassInstance` protocol
return Member {

View File

@@ -346,7 +346,7 @@ impl<'db> ConstraintSet<'db> {
pub(crate) fn satisfied_by_all_typevars(
self,
db: &'db dyn Db,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> bool {
self.node.satisfied_by_all_typevars(db, inferable)
}
@@ -1311,7 +1311,7 @@ impl<'db> Node<'db> {
fn satisfied_by_all_typevars(
self,
db: &'db dyn Db,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> bool {
match self {
Node::AlwaysTrue => return true,

View File

@@ -76,14 +76,15 @@ impl<'db> DisplaySettings<'db> {
}
#[must_use]
pub fn from_possibly_ambiguous_types(
db: &'db dyn Db,
types: impl IntoIterator<Item = Type<'db>>,
) -> Self {
pub fn from_possibly_ambiguous_types<I, T>(db: &'db dyn Db, types: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<Type<'db>>,
{
let collector = AmbiguousClassCollector::default();
for ty in types {
collector.visit_type(db, ty);
collector.visit_type(db, ty.into());
}
Self {
@@ -422,6 +423,8 @@ impl<'db> super::visitor::TypeVisitor<'db> for AmbiguousClassCollector<'db> {
inner: Protocol::FromClass(class),
..
}) => return self.visit_type(db, Type::from(class)),
// no need to recurse into TypeVar bounds/constraints
Type::TypeVar(_) => return,
_ => {}
}
@@ -439,7 +442,7 @@ impl<'db> Type<'db> {
pub fn display(self, db: &'db dyn Db) -> DisplayType<'db> {
DisplayType {
ty: self,
settings: DisplaySettings::default(),
settings: DisplaySettings::from_possibly_ambiguous_types(db, [self]),
db,
}
}

View File

@@ -1108,7 +1108,7 @@ impl<'db> FunctionType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -1133,7 +1133,7 @@ impl<'db> FunctionType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
if self.normalized(db) == other.normalized(db) {

View File

@@ -2,7 +2,7 @@ use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::fmt::Display;
use itertools::Itertools;
use itertools::{Either, Itertools};
use ruff_python_ast as ast;
use rustc_hash::{FxHashMap, FxHashSet};
@@ -119,90 +119,50 @@ pub(crate) fn typing_self<'db>(
)
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum InferableTypeVars<'db> {
#[derive(Clone, Copy, Debug)]
pub(crate) enum InferableTypeVars<'a, 'db> {
None,
One(std::sync::Arc<FxHashSet<BoundTypeVarIdentity<'db>>>),
One(&'a FxHashSet<BoundTypeVarIdentity<'db>>),
Two(
std::sync::Arc<InferableTypeVars<'db>>,
std::sync::Arc<InferableTypeVars<'db>>,
&'a InferableTypeVars<'a, 'db>,
&'a InferableTypeVars<'a, 'db>,
),
}
impl<'db> std::hash::Hash for InferableTypeVars<'db> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
InferableTypeVars::None => {
0u8.hash(state);
}
InferableTypeVars::One(set) => {
1u8.hash(state);
// Sort the identities for deterministic hashing.
let mut items: Vec<_> = set.iter().copied().collect();
items.sort_unstable();
for item in items {
item.hash(state);
}
}
InferableTypeVars::Two(left, right) => {
2u8.hash(state);
left.hash(state);
right.hash(state);
}
}
}
}
impl<'db> BoundTypeVarInstance<'db> {
pub(crate) fn is_inferable(
self,
db: &'db dyn Db,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> bool {
match inferable {
InferableTypeVars::None => false,
InferableTypeVars::One(typevars) => typevars.contains(&self.identity(db)),
InferableTypeVars::Two(left, right) => {
self.is_inferable(db, left) || self.is_inferable(db, right)
self.is_inferable(db, *left) || self.is_inferable(db, *right)
}
}
}
}
impl<'db> InferableTypeVars<'db> {
pub(crate) fn merge(self, other: InferableTypeVars<'db>) -> Self {
impl<'a, 'db> InferableTypeVars<'a, 'db> {
pub(crate) fn merge(&'a self, other: &'a InferableTypeVars<'a, 'db>) -> Self {
match (self, other) {
(InferableTypeVars::None, other) | (other, InferableTypeVars::None) => other,
(left, right) => InferableTypeVars::Two(std::sync::Arc::new(left), std::sync::Arc::new(right)),
(InferableTypeVars::None, other) | (other, InferableTypeVars::None) => *other,
_ => InferableTypeVars::Two(self, other),
}
}
// This is not an IntoIterator implementation because I have no desire to try to name the
// iterator type.
pub(crate) fn iter(&self) -> impl Iterator<Item = BoundTypeVarIdentity<'db>> + '_ {
enum Iter<'a, 'db> {
Empty,
One(std::collections::hash_set::Iter<'a, BoundTypeVarIdentity<'db>>),
Two(Box<dyn Iterator<Item = BoundTypeVarIdentity<'db>> + 'a>),
}
impl<'db> Iterator for Iter<'_, 'db> {
type Item = BoundTypeVarIdentity<'db>;
fn next(&mut self) -> Option<Self::Item> {
match self {
Iter::Empty => None,
Iter::One(iter) => iter.next().copied(),
Iter::Two(iter) => iter.next(),
}
}
}
pub(crate) fn iter(self) -> impl Iterator<Item = BoundTypeVarIdentity<'db>> {
match self {
InferableTypeVars::None => Iter::Empty,
InferableTypeVars::One(typevars) => Iter::One(typevars.iter()),
InferableTypeVars::None => Either::Left(Either::Left(std::iter::empty())),
InferableTypeVars::One(typevars) => Either::Right(typevars.iter().copied()),
InferableTypeVars::Two(left, right) => {
Iter::Two(Box::new(left.iter().chain(right.iter())))
let chained: Box<dyn Iterator<Item = BoundTypeVarIdentity<'db>>> =
Box::new(left.iter().chain(right.iter()));
Either::Left(Either::Right(chained))
}
}
}
@@ -212,7 +172,7 @@ impl<'db> InferableTypeVars<'db> {
pub(crate) fn display(&self, db: &'db dyn Db) -> impl Display {
fn find_typevars<'db>(
result: &mut FxHashSet<BoundTypeVarIdentity<'db>>,
inferable: &InferableTypeVars<'db>,
inferable: &InferableTypeVars<'_, 'db>,
) {
match inferable {
InferableTypeVars::None => {}
@@ -342,7 +302,7 @@ impl<'db> GenericContext<'db> {
/// In this example, `method`'s generic context binds `Self` and `T`, but its inferable set
/// also includes `A@C`. This is needed because at each call site, we need to infer the
/// specialized class instance type whose method is being invoked.
pub(crate) fn inferable_typevars(self, db: &'db dyn Db) -> InferableTypeVars<'db> {
pub(crate) fn inferable_typevars(self, db: &'db dyn Db) -> InferableTypeVars<'db, 'db> {
#[derive(Default)]
struct CollectTypeVars<'db> {
typevars: RefCell<FxHashSet<BoundTypeVarIdentity<'db>>>,
@@ -389,8 +349,10 @@ impl<'db> GenericContext<'db> {
visitor.typevars.into_inner()
}
// Wrap the Salsa-cached FxHashSet in an Arc.
InferableTypeVars::One(std::sync::Arc::new(inferable_typevars_inner(db, self).clone()))
// This ensures that salsa caches the FxHashSet, not the InferableTypeVars that wraps it.
// (That way InferableTypeVars can contain references, and doesn't need to impl
// salsa::Update.)
InferableTypeVars::One(inferable_typevars_inner(db, self))
}
pub(crate) fn variables(
@@ -807,7 +769,7 @@ fn is_subtype_in_invariant_position<'db>(
derived_materialization: MaterializationKind,
base_type: &Type<'db>,
base_materialization: MaterializationKind,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> {
@@ -885,7 +847,7 @@ fn has_relation_in_invariant_position<'db>(
derived_materialization: Option<MaterializationKind>,
base_type: &Type<'db>,
base_materialization: Option<MaterializationKind>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -1271,7 +1233,7 @@ impl<'db> Specialization<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -1346,7 +1308,7 @@ impl<'db> Specialization<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.is_disjoint_from_impl(
db,
@@ -1361,7 +1323,7 @@ impl<'db> Specialization<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> {
@@ -1420,7 +1382,7 @@ impl<'db> Specialization<'db> {
self,
db: &'db dyn Db,
other: Specialization<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
if self.materialization_kind(db) != other.materialization_kind(db) {
@@ -1521,7 +1483,7 @@ impl<'db> PartialSpecialization<'_, 'db> {
/// specialization of a generic function.
pub(crate) struct SpecializationBuilder<'db> {
db: &'db dyn Db,
inferable: InferableTypeVars<'db>,
inferable: InferableTypeVars<'db, 'db>,
types: FxHashMap<BoundTypeVarIdentity<'db>, Type<'db>>,
}
@@ -1530,7 +1492,7 @@ pub(crate) struct SpecializationBuilder<'db> {
pub(crate) type TypeVarAssignment<'db> = (BoundTypeVarIdentity<'db>, TypeVarVariance, Type<'db>);
impl<'db> SpecializationBuilder<'db> {
pub(crate) fn new(db: &'db dyn Db, inferable: InferableTypeVars<'db>) -> Self {
pub(crate) fn new(db: &'db dyn Db, inferable: InferableTypeVars<'db, 'db>) -> Self {
Self {
db,
inferable,
@@ -1558,7 +1520,7 @@ impl<'db> SpecializationBuilder<'db> {
Self {
db: self.db,
inferable: self.inferable.clone(),
inferable: self.inferable,
types,
}
}
@@ -1725,8 +1687,8 @@ impl<'db> SpecializationBuilder<'db> {
//
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to
// specialize `T` to `int`, and so ignore the `None`.
let actual = actual.filter_disjoint_elements(self.db, formal, &self.inferable);
let formal = formal.filter_disjoint_elements(self.db, actual, &self.inferable);
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
let formal = formal.filter_disjoint_elements(self.db, actual, self.inferable);
match (formal, actual) {
// TODO: We haven't implemented a full unification solver yet. If typevars appear in
@@ -1789,7 +1751,7 @@ impl<'db> SpecializationBuilder<'db> {
let assignable_elements =
(union_formal.elements(self.db).iter()).filter(|ty| {
actual
.when_subtype_of(self.db, **ty, &self.inferable)
.when_subtype_of(self.db, **ty, self.inferable)
.is_always_satisfied(self.db)
});
if assignable_elements.exactly_one().is_ok() {
@@ -1824,7 +1786,7 @@ impl<'db> SpecializationBuilder<'db> {
// The recursive call to `infer_map_impl` may succeed even if the actual type is
// not assignable to the formal element.
if !actual
.when_assignable_to(self.db, *formal_element, &self.inferable)
.when_assignable_to(self.db, *formal_element, self.inferable)
.is_never_satisfied(self.db)
{
found_matching_element = true;
@@ -1848,12 +1810,12 @@ impl<'db> SpecializationBuilder<'db> {
}
(Type::TypeVar(bound_typevar), ty) | (ty, Type::TypeVar(bound_typevar))
if bound_typevar.is_inferable(self.db, &self.inferable) =>
if bound_typevar.is_inferable(self.db, self.inferable) =>
{
match bound_typevar.typevar(self.db).bound_or_constraints(self.db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
if !ty
.when_assignable_to(self.db, bound, &self.inferable)
.when_assignable_to(self.db, bound, self.inferable)
.is_always_satisfied(self.db)
{
return Err(SpecializationError::MismatchedBound {
@@ -1874,7 +1836,7 @@ impl<'db> SpecializationBuilder<'db> {
for constraint in constraints.elements(self.db) {
if ty
.when_assignable_to(self.db, *constraint, &self.inferable)
.when_assignable_to(self.db, *constraint, self.inferable)
.is_always_satisfied(self.db)
{
self.add_type_mapping(bound_typevar, *constraint, polarity, f);
@@ -1985,7 +1947,7 @@ impl<'db> SpecializationBuilder<'db> {
.when_constraint_set_assignable_to(
self.db,
formal_callable,
&self.inferable,
self.inferable,
);
self.add_type_mappings_from_constraint_set(formal, when, &mut f);
} else {
@@ -1994,7 +1956,7 @@ impl<'db> SpecializationBuilder<'db> {
.when_constraint_set_assignable_to_signatures(
self.db,
formal_callable,
&self.inferable,
self.inferable,
);
self.add_type_mappings_from_constraint_set(formal, when, &mut f);
}

View File

@@ -7650,7 +7650,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
annotation.filter_disjoint_elements(
self.db(),
Type::homogeneous_tuple(self.db(), Type::unknown()),
&inferable,
inferable,
)
});
@@ -7835,7 +7835,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// `collection_ty` is `list`.
let tcx = tcx.map(|annotation| {
let collection_ty = collection_class.to_instance(self.db());
annotation.filter_disjoint_elements(self.db(), collection_ty, &inferable)
annotation.filter_disjoint_elements(self.db(), collection_ty, inferable)
});
// Extract the annotated type of `T`, if provided.
@@ -7986,7 +7986,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
annotation.filter_disjoint_elements(
self.db(),
collection_class.to_instance(self.db()),
&InferableTypeVars::None,
InferableTypeVars::None,
)
});
@@ -11851,7 +11851,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match typevar.typevar(db).bound_or_constraints(db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
if provided_type
.when_assignable_to(db, bound, &InferableTypeVars::None)
.when_assignable_to(db, bound, InferableTypeVars::None)
.is_never_satisfied(db)
{
let node = get_node(index);
@@ -11882,7 +11882,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.when_assignable_to(
db,
constraints.as_type(db),
&InferableTypeVars::None,
InferableTypeVars::None,
)
.is_never_satisfied(db)
{

View File

@@ -127,7 +127,7 @@ impl<'db> Type<'db> {
self,
db: &'db dyn Db,
protocol: ProtocolInstanceType<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -405,7 +405,7 @@ impl<'db> NominalInstanceType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -438,7 +438,7 @@ impl<'db> NominalInstanceType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
match (self.0, other.0) {
@@ -460,7 +460,7 @@ impl<'db> NominalInstanceType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> {
@@ -706,7 +706,7 @@ impl<'db> ProtocolInstanceType<'db> {
.satisfies_protocol(
db,
protocol,
&InferableTypeVars::None,
InferableTypeVars::None,
TypeRelation::Subtyping,
&HasRelationToVisitor::default(),
&IsDisjointVisitor::default(),
@@ -771,7 +771,7 @@ impl<'db> ProtocolInstanceType<'db> {
self,
db: &'db dyn Db,
other: Self,
_inferable: &InferableTypeVars<'db>,
_inferable: InferableTypeVars<'_, 'db>,
_visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
if self == other {
@@ -793,7 +793,7 @@ impl<'db> ProtocolInstanceType<'db> {
self,
_db: &'db dyn Db,
_other: Self,
_inferable: &InferableTypeVars<'db>,
_inferable: InferableTypeVars<'_, 'db>,
_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> {
ConstraintSet::from(false)

View File

@@ -273,7 +273,7 @@ impl<'db> ProtocolInterface<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -694,7 +694,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
&self,
db: &'db dyn Db,
other: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> {
@@ -719,7 +719,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
&self,
db: &'db dyn Db,
other: Type<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,

View File

@@ -304,7 +304,7 @@ impl<'db> CallableSignature<'db> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.has_relation_to_impl(
db,
@@ -320,7 +320,7 @@ impl<'db> CallableSignature<'db> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -362,7 +362,7 @@ impl<'db> CallableSignature<'db> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.has_relation_to_impl(
db,
@@ -380,7 +380,7 @@ impl<'db> CallableSignature<'db> {
db: &'db dyn Db,
self_signatures: &[Signature<'db>],
other_signatures: &[Signature<'db>],
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -551,7 +551,7 @@ impl<'db> CallableSignature<'db> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
match (self.overloads.as_slice(), other.overloads.as_slice()) {
@@ -966,7 +966,7 @@ impl<'db> Signature<'db> {
}
}
fn inferable_typevars(&self, db: &'db dyn Db) -> InferableTypeVars<'db> {
fn inferable_typevars(&self, db: &'db dyn Db) -> InferableTypeVars<'db, 'db> {
match self.generic_context {
Some(generic_context) => generic_context.inferable_typevars(db),
None => InferableTypeVars::None,
@@ -980,7 +980,7 @@ impl<'db> Signature<'db> {
&self,
db: &'db dyn Db,
other: &Signature<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
// If either signature is generic, their typevars should also be considered inferable when
@@ -997,11 +997,11 @@ impl<'db> Signature<'db> {
// return t
let self_inferable = self.inferable_typevars(db);
let other_inferable = other.inferable_typevars(db);
let inferable = inferable.clone().merge(self_inferable.clone());
let inferable = inferable.merge(other_inferable.clone());
let inferable = inferable.merge(&self_inferable);
let inferable = inferable.merge(&other_inferable);
// `inner` will create a constraint set that references these newly inferable typevars.
let when = self.is_equivalent_to_inner(db, other, &inferable, visitor);
let when = self.is_equivalent_to_inner(db, other, inferable, visitor);
// But the caller does not need to consider those extra typevars. Whatever constraint set
// we produce, we reduce it back down to the inferable set that the caller asked about.
@@ -1014,7 +1014,7 @@ impl<'db> Signature<'db> {
&self,
db: &'db dyn Db,
other: &Signature<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
let mut result = ConstraintSet::from(true);
@@ -1100,7 +1100,7 @@ impl<'db> Signature<'db> {
&self,
db: &'db dyn Db,
other: &CallableSignature<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
// If this signature is a paramspec, bind it to the entire overloaded other callable.
if let Some(self_bound_typevar) = self.parameters.as_paramspec()
@@ -1143,7 +1143,7 @@ impl<'db> Signature<'db> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> ConstraintSet<'db> {
self.has_relation_to_impl(
db,
@@ -1160,7 +1160,7 @@ impl<'db> Signature<'db> {
&self,
db: &'db dyn Db,
other: &Signature<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -1179,14 +1179,14 @@ impl<'db> Signature<'db> {
// return t
let self_inferable = self.inferable_typevars(db);
let other_inferable = other.inferable_typevars(db);
let inferable = inferable.clone().merge(self_inferable.clone());
let inferable = inferable.merge(other_inferable.clone());
let inferable = inferable.merge(&self_inferable);
let inferable = inferable.merge(&other_inferable);
// `inner` will create a constraint set that references these newly inferable typevars.
let when = self.has_relation_to_inner(
db,
other,
&inferable,
inferable,
relation,
relation_visitor,
disjointness_visitor,
@@ -1203,7 +1203,7 @@ impl<'db> Signature<'db> {
&self,
db: &'db dyn Db,
other: &Signature<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,

View File

@@ -205,7 +205,7 @@ impl<'db> SubclassOfType<'db> {
self,
db: &'db dyn Db,
other: SubclassOfType<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -247,7 +247,7 @@ impl<'db> SubclassOfType<'db> {
self,
db: &'db dyn Db,
other: Self,
_inferable: &InferableTypeVars<'db>,
_inferable: InferableTypeVars<'_, 'db>,
_visitor: &IsDisjointVisitor<'db>,
) -> ConstraintSet<'db> {
match (self.subclass_of, other.subclass_of) {

View File

@@ -273,7 +273,7 @@ impl<'db> TupleType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -292,7 +292,7 @@ impl<'db> TupleType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> {
@@ -309,7 +309,7 @@ impl<'db> TupleType<'db> {
self,
db: &'db dyn Db,
other: Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
self.tuple(db)
@@ -502,7 +502,7 @@ impl<'db> FixedLengthTuple<Type<'db>> {
&self,
db: &'db dyn Db,
other: &Tuple<Type<'db>>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -589,7 +589,7 @@ impl<'db> FixedLengthTuple<Type<'db>> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
ConstraintSet::from(self.0.len() == other.0.len()).and(db, || {
@@ -908,7 +908,7 @@ impl<'db> VariableLengthTuple<Type<'db>> {
&self,
db: &'db dyn Db,
other: &Tuple<Type<'db>>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -1111,7 +1111,7 @@ impl<'db> VariableLengthTuple<Type<'db>> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
self.variable
@@ -1340,7 +1340,7 @@ impl<'db> Tuple<Type<'db>> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -1369,7 +1369,7 @@ impl<'db> Tuple<Type<'db>> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
match (self, other) {
@@ -1389,7 +1389,7 @@ impl<'db> Tuple<Type<'db>> {
&self,
db: &'db dyn Db,
other: &Self,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db> {
@@ -1409,7 +1409,7 @@ impl<'db> Tuple<Type<'db>> {
db: &'db dyn Db,
a: impl IntoIterator<Item = &'s Type<'db>>,
b: impl IntoIterator<Item = &'s Type<'db>>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
) -> ConstraintSet<'db>

View File

@@ -129,7 +129,7 @@ impl<'db> TypedDictType<'db> {
self,
db: &'db dyn Db,
target: TypedDictType<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
relation: TypeRelation<'db>,
relation_visitor: &HasRelationToVisitor<'db>,
disjointness_visitor: &IsDisjointVisitor<'db>,
@@ -314,7 +314,7 @@ impl<'db> TypedDictType<'db> {
self,
db: &'db dyn Db,
other: TypedDictType<'db>,
inferable: &InferableTypeVars<'db>,
inferable: InferableTypeVars<'_, 'db>,
visitor: &IsEquivalentVisitor<'db>,
) -> ConstraintSet<'db> {
// TODO: `closed` and `extra_items` support will go here. Until then we don't look at the

View File

@@ -1,5 +1,6 @@
use anyhow::Result;
use lsp_types::{Position, notification::ShowMessage, request::RegisterCapability};
use lsp_types::notification::ShowMessage;
use lsp_types::{Position, request::RegisterCapability};
use ruff_db::system::SystemPath;
use serde_json::Value;
use ty_server::{ClientOptions, DiagnosticMode};
@@ -474,3 +475,20 @@ fn register_multiple_capabilities() -> Result<()> {
Ok(())
}
/// Tests that the server doesn't panic when `VIRTUAL_ENV` points to a non-existent directory.
///
/// See: <https://github.com/astral-sh/ty/issues/2031>
#[test]
fn missing_virtual_env_does_not_panic() -> Result<()> {
let workspace_root = SystemPath::new("project");
// This should not panic even though VIRTUAL_ENV points to a non-existent path
let _server = TestServerBuilder::new()?
.with_workspace(workspace_root, None)?
.with_env_var("VIRTUAL_ENV", "/nonexistent/virtual/env/path")
.build()
.wait_until_workspaces_are_initialized();
Ok(())
}

View File

@@ -209,6 +209,7 @@ impl TestServer {
test_context: TestContext,
capabilities: ClientCapabilities,
initialization_options: Option<ClientOptions>,
env_vars: Vec<(String, String)>,
) -> Self {
setup_tracing();
@@ -219,11 +220,16 @@ impl TestServer {
// Create OS system with the test directory as cwd
let os_system = OsSystem::new(test_context.root());
// Create test system and set environment variable overrides
let test_system = Arc::new(TestSystem::new(os_system));
for (name, value) in env_vars {
test_system.set_env_var(name, value);
}
// Start the server in a separate thread
let server_thread = std::thread::spawn(move || {
// TODO: This should probably be configurable to test concurrency issues
let worker_threads = NonZeroUsize::new(1).unwrap();
let test_system = Arc::new(TestSystem::new(os_system));
match Server::new(worker_threads, server_connection, test_system, true) {
Ok(server) => {
@@ -1052,6 +1058,7 @@ pub(crate) struct TestServerBuilder {
workspaces: Vec<(WorkspaceFolder, Option<ClientOptions>)>,
initialization_options: Option<ClientOptions>,
client_capabilities: ClientCapabilities,
env_vars: Vec<(String, String)>,
}
impl TestServerBuilder {
@@ -1082,6 +1089,7 @@ impl TestServerBuilder {
test_context: TestContext::new()?,
initialization_options: None,
client_capabilities,
env_vars: Vec::new(),
})
}
@@ -1091,6 +1099,16 @@ impl TestServerBuilder {
self
}
/// Set an environment variable for the test server's system.
pub(crate) fn with_env_var(
mut self,
name: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.env_vars.push((name.into(), value.into()));
self
}
/// Add a workspace to the test server with the given root path and options.
///
/// This option will be used to respond to the `workspace/configuration` request that the
@@ -1237,6 +1255,7 @@ impl TestServerBuilder {
self.test_context,
self.client_capabilities,
self.initialization_options,
self.env_vars,
)
}
}

View File

@@ -16,15 +16,11 @@ export default function Chrome() {
const [theme, setTheme] = useTheme();
const handleShare = useCallback(() => {
const handleShare = useCallback(async () => {
if (settings == null || pythonSource == null) {
return;
}
persist(settings, pythonSource).catch((error) =>
// eslint-disable-next-line no-console
console.error(`Failed to share playground: ${error}`),
);
await persist(settings, pythonSource);
}, [pythonSource, settings]);
if (initPromise.current == null) {

View File

@@ -21,7 +21,7 @@ export default function Header({
version: string | null;
onChangeTheme: (theme: Theme) => void;
onReset?(): void;
onShare: () => void;
onShare: () => Promise<void>;
}) {
return (
<div

View File

@@ -1,17 +1,23 @@
import { useEffect, useState } from "react";
import AstralButton from "./AstralButton";
export default function ShareButton({ onShare }: { onShare: () => void }) {
const [copied, setCopied] = useState(false);
type ShareStatus = "initial" | "copying" | "copied";
export default function ShareButton({
onShare,
}: {
onShare: () => Promise<void>;
}) {
const [status, setStatus] = useState<ShareStatus>("initial");
useEffect(() => {
if (copied) {
const timeout = setTimeout(() => setCopied(false), 2000);
if (status === "copied") {
const timeout = setTimeout(() => setStatus("initial"), 2000);
return () => clearTimeout(timeout);
}
}, [copied]);
}, [status]);
return copied ? (
return status === "copied" ? (
<AstralButton
type="button"
className="relative flex-none leading-6 py-1.5 px-3 cursor-auto dark:shadow-copied"
@@ -28,10 +34,17 @@ export default function ShareButton({ onShare }: { onShare: () => void }) {
<AstralButton
type="button"
className="relative flex-none leading-6 py-1.5 px-3 shadow-xs disabled:opacity-50"
disabled={copied}
onClick={() => {
setCopied(true);
onShare();
disabled={status === "copying"}
onClick={async () => {
setStatus("copying");
try {
await onShare();
setStatus("copied");
} catch (error) {
// eslint-disable-next-line no-console
console.error("Failed to share playground", error);
setStatus("initial");
}
}}
>
<span

View File

@@ -48,14 +48,11 @@ export default function Playground() {
usePersistLocally(files);
const handleShare = useCallback(() => {
const handleShare = useCallback(async () => {
const serialized = serializeFiles(files);
if (serialized != null) {
persist(serialized).catch((error) => {
// eslint-disable-next-line no-console
console.error("Failed to share playground", error);
});
await persist(serialized);
}
}, [files]);