Compare commits
2 Commits
jack/loop-
...
dhruv/para
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec2861bc4f | ||
|
|
990d0a8999 |
@@ -886,8 +886,6 @@ class GenericClass[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _(x: list[str]):
|
||||
# TODO: This fails because we are not propagating GenericClass's generic context into the
|
||||
# Callable that we create for it.
|
||||
# revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass]
|
||||
reveal_type(into_callable(GenericClass))
|
||||
# revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
@@ -895,15 +893,10 @@ def _(x: list[str]):
|
||||
|
||||
# revealed: (x: list[T@GenericClass], y: list[T@GenericClass]) -> GenericClass[T@GenericClass]
|
||||
reveal_type(accepts_callable(GenericClass))
|
||||
# TODO: revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
# revealed: None
|
||||
# revealed: ty_extensions.GenericContext[T@GenericClass]
|
||||
reveal_type(generic_context(accepts_callable(GenericClass)))
|
||||
|
||||
# TODO: revealed: GenericClass[str]
|
||||
# TODO: no errors
|
||||
# revealed: GenericClass[T@GenericClass]
|
||||
# error: [invalid-argument-type]
|
||||
# error: [invalid-argument-type]
|
||||
# revealed: GenericClass[str]
|
||||
reveal_type(accepts_callable(GenericClass)(x, x))
|
||||
```
|
||||
|
||||
|
||||
@@ -800,3 +800,78 @@ def f(x: int, y: str):
|
||||
|
||||
reveal_type(infer_paramspec(f)) # revealed: (x: int, y: str) -> None
|
||||
```
|
||||
|
||||
## Generic context preservation through `ParamSpec` decorators
|
||||
|
||||
When a generic function is decorated with a `ParamSpec`-based decorator, the generic context of the
|
||||
decorated function should be preserved. This allows type inference to work correctly when calling the
|
||||
decorated function.
|
||||
|
||||
Regression test for <https://github.com/astral-sh/ty/issues/2336>
|
||||
|
||||
### Basic
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
from ty_extensions import generic_context
|
||||
|
||||
def decorator[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||
return func
|
||||
|
||||
@decorator
|
||||
def identity[T](value: T) -> T:
|
||||
return value
|
||||
|
||||
@decorator
|
||||
def pair[T, U](first: T, second: U) -> tuple[T, U]:
|
||||
return (first, second)
|
||||
|
||||
# revealed: ty_extensions.GenericContext[T@identity]
|
||||
reveal_type(generic_context(identity))
|
||||
# revealed: ty_extensions.GenericContext[T@pair, U@pair]
|
||||
reveal_type(generic_context(pair))
|
||||
|
||||
reveal_type(identity(1)) # revealed: Literal[1]
|
||||
reveal_type(identity("hello")) # revealed: Literal["hello"]
|
||||
|
||||
reveal_type(pair(1, "a")) # revealed: tuple[Literal[1], Literal["a"]]
|
||||
reveal_type(pair("x", 2.5)) # revealed: tuple[Literal["x"], float]
|
||||
```
|
||||
|
||||
### Chained decorators with generic functions
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
|
||||
def decorator1[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
def decorator2[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
@decorator1
|
||||
@decorator2
|
||||
def chained_generic[T](value: T) -> T:
|
||||
return value
|
||||
|
||||
reveal_type(chained_generic(42)) # revealed: Literal[42]
|
||||
reveal_type(chained_generic("test")) # revealed: Literal["test"]
|
||||
```
|
||||
|
||||
### Generic method decoration
|
||||
|
||||
```py
|
||||
from typing import Callable
|
||||
|
||||
def method_decorator[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
||||
return func
|
||||
|
||||
class Container:
|
||||
@method_decorator
|
||||
def generic_method[T](self, value: T) -> T:
|
||||
return value
|
||||
|
||||
c = Container()
|
||||
reveal_type(c.generic_method(100)) # revealed: Literal[100]
|
||||
reveal_type(c.generic_method([1, 2, 3])) # revealed: list[Unknown | int]
|
||||
```
|
||||
|
||||
@@ -188,9 +188,13 @@ impl<'db> CallableSignature<'db> {
|
||||
{
|
||||
Some(CallableSignature::from_overloads(
|
||||
callable.signatures(db).iter().map(|signature| Signature {
|
||||
generic_context: self_signature.generic_context.map(|context| {
|
||||
type_mapping.update_signature_generic_context(db, context)
|
||||
}),
|
||||
generic_context: GenericContext::merge_optional(
|
||||
db,
|
||||
signature.generic_context,
|
||||
self_signature.generic_context.map(|context| {
|
||||
type_mapping.update_signature_generic_context(db, context)
|
||||
}),
|
||||
),
|
||||
definition: signature.definition,
|
||||
parameters: if signature.parameters().is_top() {
|
||||
signature.parameters().clone()
|
||||
@@ -414,7 +418,11 @@ impl<'db> CallableSignature<'db> {
|
||||
db,
|
||||
CallableSignature::from_overloads(other_signatures.iter().map(
|
||||
|signature| {
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
},
|
||||
)),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
@@ -446,7 +454,11 @@ impl<'db> CallableSignature<'db> {
|
||||
db,
|
||||
CallableSignature::from_overloads(self_signatures.iter().map(
|
||||
|signature| {
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
},
|
||||
)),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
@@ -1083,7 +1095,11 @@ impl<'db> Signature<'db> {
|
||||
let upper = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::from_overloads(other.overloads.iter().map(|signature| {
|
||||
Signature::new(signature.parameters().clone(), Type::unknown())
|
||||
Signature::new_generic(
|
||||
signature.generic_context,
|
||||
signature.parameters().clone(),
|
||||
Type::unknown(),
|
||||
)
|
||||
})),
|
||||
CallableTypeKind::ParamSpecValue,
|
||||
));
|
||||
@@ -1339,7 +1355,8 @@ impl<'db> Signature<'db> {
|
||||
(Some(self_bound_typevar), None) => {
|
||||
let upper = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::single(Signature::new(
|
||||
CallableSignature::single(Signature::new_generic(
|
||||
other.generic_context,
|
||||
other.parameters.clone(),
|
||||
Type::unknown(),
|
||||
)),
|
||||
@@ -1358,7 +1375,8 @@ impl<'db> Signature<'db> {
|
||||
(None, Some(other_bound_typevar)) => {
|
||||
let lower = Type::Callable(CallableType::new(
|
||||
db,
|
||||
CallableSignature::single(Signature::new(
|
||||
CallableSignature::single(Signature::new_generic(
|
||||
self.generic_context,
|
||||
self.parameters.clone(),
|
||||
Type::unknown(),
|
||||
)),
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use self::schedule::spawn_main_loop;
|
||||
use crate::PositionEncoding;
|
||||
use crate::capabilities::{ResolvedClientCapabilities, server_capabilities};
|
||||
use crate::session::{InitializationOptions, Session, warn_about_unknown_options};
|
||||
use crate::session::{ClientName, InitializationOptions, Session, warn_about_unknown_options};
|
||||
use anyhow::Context;
|
||||
use lsp_server::Connection;
|
||||
use lsp_types::{ClientCapabilities, InitializeParams, MessageType, Url};
|
||||
@@ -47,6 +47,7 @@ impl Server {
|
||||
initialization_options,
|
||||
capabilities: client_capabilities,
|
||||
workspace_folders,
|
||||
client_info,
|
||||
..
|
||||
} = serde_json::from_value(init_value)
|
||||
.context("Failed to deserialize initialization parameters")?;
|
||||
@@ -65,6 +66,7 @@ impl Server {
|
||||
tracing::error!("Failed to deserialize initialization options: {error}");
|
||||
}
|
||||
|
||||
tracing::debug!("Client info: {client_info:#?}");
|
||||
tracing::debug!("Initialization options: {initialization_options:#?}");
|
||||
|
||||
let resolved_client_capabilities = ResolvedClientCapabilities::new(&client_capabilities);
|
||||
@@ -155,6 +157,7 @@ impl Server {
|
||||
workspace_urls,
|
||||
initialization_options,
|
||||
native_system,
|
||||
ClientName::from(client_info),
|
||||
in_test,
|
||||
)?,
|
||||
})
|
||||
|
||||
@@ -119,9 +119,12 @@ pub(super) fn request(req: server::Request) -> Task {
|
||||
.unwrap_or_else(|err| {
|
||||
tracing::error!("Encountered error when routing request with ID {id}: {err}");
|
||||
|
||||
Task::sync(move |_session, client| {
|
||||
Task::sync(move |session, client| {
|
||||
if matches!(err.code, ErrorCode::InternalError) {
|
||||
client.show_error_message("ty failed to handle a request from the editor. Check the logs for more details.");
|
||||
client.show_error_message(format!(
|
||||
"ty failed to handle a request from the editor. {}",
|
||||
session.client_name().log_guidance()
|
||||
));
|
||||
}
|
||||
|
||||
respond_silent_error(
|
||||
@@ -175,11 +178,12 @@ pub(super) fn notification(notif: server::Notification) -> Task {
|
||||
}
|
||||
.unwrap_or_else(|err| {
|
||||
tracing::error!("Encountered error when routing notification: {err}");
|
||||
Task::sync(move |_session, client| {
|
||||
Task::sync(move |session, client| {
|
||||
if matches!(err.code, ErrorCode::InternalError) {
|
||||
client.show_error_message(
|
||||
"ty failed to handle a notification from the editor. Check the logs for more details."
|
||||
);
|
||||
client.show_error_message(format!(
|
||||
"ty failed to handle a notification from the editor. {}",
|
||||
session.client_name().log_guidance()
|
||||
));
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -193,7 +197,7 @@ where
|
||||
Ok(Task::sync(move |session, client: &Client| {
|
||||
let _span = tracing::debug_span!("request", %id, method = R::METHOD).entered();
|
||||
let result = R::run(session, client, params);
|
||||
respond::<R>(&id, result, client);
|
||||
respond::<R>(&id, result, client, session.client_name().log_guidance());
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -217,6 +221,7 @@ where
|
||||
// SAFETY: The `snapshot` is safe to move across the unwind boundary because it is not used
|
||||
// after unwinding.
|
||||
let snapshot = AssertUnwindSafe(session.snapshot_session());
|
||||
let log_guidance = snapshot.0.client_name().log_guidance();
|
||||
|
||||
Box::new(move |client| {
|
||||
let _span = tracing::debug_span!("request", %id, method = R::METHOD).entered();
|
||||
@@ -238,7 +243,7 @@ where
|
||||
let snapshot = snapshot;
|
||||
R::handle_request(&id, snapshot.0, client, params);
|
||||
}) {
|
||||
panic_response::<R>(&id, client, &error, retry);
|
||||
panic_response::<R>(&id, client, &error, retry, log_guidance);
|
||||
}
|
||||
})
|
||||
}))
|
||||
@@ -284,6 +289,7 @@ where
|
||||
|
||||
let path = document.notebook_or_file_path();
|
||||
let db = session.project_db(path).clone();
|
||||
let log_guidance = document.client_name().log_guidance();
|
||||
|
||||
Box::new(move |client| {
|
||||
let _span = tracing::debug_span!("request", %id, method = R::METHOD).entered();
|
||||
@@ -306,7 +312,7 @@ where
|
||||
R::handle_request(&id, &db, document, client, params);
|
||||
});
|
||||
}) {
|
||||
panic_response::<R>(&id, client, &error, retry);
|
||||
panic_response::<R>(&id, client, &error, retry, log_guidance);
|
||||
}
|
||||
})
|
||||
}))
|
||||
@@ -317,6 +323,7 @@ fn panic_response<R>(
|
||||
client: &Client,
|
||||
error: &PanicError,
|
||||
request: Option<lsp_server::Request>,
|
||||
log_guidance: &str,
|
||||
) where
|
||||
R: traits::RetriableRequestHandler,
|
||||
{
|
||||
@@ -346,6 +353,7 @@ fn panic_response<R>(
|
||||
error: anyhow!("request handler {error}"),
|
||||
}),
|
||||
client,
|
||||
log_guidance,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -358,7 +366,10 @@ fn sync_notification_task<N: traits::SyncNotificationHandler>(
|
||||
let _span = tracing::debug_span!("notification", method = N::METHOD).entered();
|
||||
if let Err(err) = N::run(session, client, params) {
|
||||
tracing::error!("An error occurred while running {id}: {err}");
|
||||
client.show_error_message("ty encountered a problem. Check the logs for more details.");
|
||||
client.show_error_message(format!(
|
||||
"ty encountered a problem. {}",
|
||||
session.client_name().log_guidance()
|
||||
));
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -390,6 +401,8 @@ where
|
||||
return Box::new(|_| {});
|
||||
};
|
||||
|
||||
let log_guidance = snapshot.client_name().log_guidance();
|
||||
|
||||
Box::new(move |client| {
|
||||
let _span = tracing::debug_span!("notification", method = N::METHOD).entered();
|
||||
|
||||
@@ -399,18 +412,14 @@ where
|
||||
Ok(result) => result,
|
||||
Err(panic) => {
|
||||
tracing::error!("An error occurred while running {id}: {panic}");
|
||||
client.show_error_message(
|
||||
"ty encountered a panic. Check the logs for more details.",
|
||||
);
|
||||
client.show_error_message(format!("ty encountered a panic. {log_guidance}"));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = result {
|
||||
tracing::error!("An error occurred while running {id}: {err}");
|
||||
client.show_error_message(
|
||||
"ty encountered a problem. Check the logs for more details.",
|
||||
);
|
||||
client.show_error_message(format!("ty encountered a problem. {log_guidance}"));
|
||||
}
|
||||
})
|
||||
}))
|
||||
@@ -449,12 +458,13 @@ fn respond<Req>(
|
||||
id: &RequestId,
|
||||
result: Result<<<Req as RequestHandler>::RequestType as Request>::Result>,
|
||||
client: &Client,
|
||||
log_guidance: &str,
|
||||
) where
|
||||
Req: RequestHandler,
|
||||
{
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("An error occurred with request ID {id}: {err}");
|
||||
client.show_error_message("ty encountered a problem. Check the logs for more details.");
|
||||
client.show_error_message(format!("ty encountered a problem. {log_guidance}"));
|
||||
}
|
||||
client.respond(id, result);
|
||||
}
|
||||
|
||||
@@ -116,7 +116,10 @@ pub(super) trait BackgroundDocumentRequestHandler: RetriableRequestHandler {
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("An error occurred with request ID {id}: {err}");
|
||||
client.show_error_message("ty encountered a problem. Check the logs for more details.");
|
||||
client.show_error_message(format!(
|
||||
"ty encountered a problem. {}",
|
||||
snapshot.client_name().log_guidance()
|
||||
));
|
||||
}
|
||||
|
||||
client.respond(id, result);
|
||||
@@ -153,7 +156,10 @@ pub(super) trait BackgroundRequestHandler: RetriableRequestHandler {
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("An error occurred with request ID {id}: {err}");
|
||||
client.show_error_message("ty encountered a problem. Check the logs for more details.");
|
||||
client.show_error_message(format!(
|
||||
"ty encountered a problem. {}",
|
||||
snapshot.client_name().log_guidance()
|
||||
));
|
||||
}
|
||||
|
||||
client.respond(id, result);
|
||||
|
||||
@@ -13,7 +13,7 @@ use lsp_types::request::{
|
||||
WorkspaceDiagnosticRequest,
|
||||
};
|
||||
use lsp_types::{
|
||||
DiagnosticRegistrationOptions, DiagnosticServerCapabilities,
|
||||
ClientInfo, DiagnosticRegistrationOptions, DiagnosticServerCapabilities,
|
||||
DidChangeWatchedFilesRegistrationOptions, FileSystemWatcher, Registration, RegistrationParams,
|
||||
TextDocumentContentChangeEvent, Unregistration, UnregistrationParams, Url,
|
||||
};
|
||||
@@ -106,6 +106,9 @@ pub(crate) struct Session {
|
||||
/// Registrations is a set of LSP methods that have been dynamically registered with the
|
||||
/// client.
|
||||
registrations: HashSet<String>,
|
||||
|
||||
/// The name of the client (editor) that connected to this server.
|
||||
client_name: ClientName,
|
||||
}
|
||||
|
||||
/// LSP State for a Project
|
||||
@@ -141,6 +144,7 @@ impl Session {
|
||||
workspace_urls: Vec<Url>,
|
||||
initialization_options: InitializationOptions,
|
||||
native_system: Arc<dyn System + 'static + Send + Sync + RefUnwindSafe>,
|
||||
client_name: ClientName,
|
||||
in_test: bool,
|
||||
) -> crate::Result<Self> {
|
||||
let index = Arc::new(Index::new());
|
||||
@@ -168,6 +172,7 @@ impl Session {
|
||||
suspended_workspace_diagnostics_request: None,
|
||||
revision: 0,
|
||||
registrations: HashSet::new(),
|
||||
client_name,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -532,8 +537,8 @@ impl Session {
|
||||
);
|
||||
|
||||
client.show_error_message(format!(
|
||||
"Failed to load project for workspace {url}. \
|
||||
Please refer to the logs for more details.",
|
||||
"Failed to load project for workspace {url}. {}",
|
||||
self.client_name.log_guidance(),
|
||||
));
|
||||
|
||||
let db_with_default_settings = ProjectMetadata::from_options(
|
||||
@@ -819,6 +824,7 @@ impl Session {
|
||||
.unwrap_or_else(|| Arc::new(WorkspaceSettings::default())),
|
||||
position_encoding: self.position_encoding,
|
||||
document: document_handle,
|
||||
client_name: self.client_name,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -837,6 +843,7 @@ impl Session {
|
||||
in_test: self.in_test,
|
||||
resolved_client_capabilities: self.resolved_client_capabilities,
|
||||
revision: self.revision,
|
||||
client_name: self.client_name,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -976,6 +983,10 @@ impl Session {
|
||||
pub(crate) fn position_encoding(&self) -> PositionEncoding {
|
||||
self.position_encoding
|
||||
}
|
||||
|
||||
pub(crate) fn client_name(&self) -> ClientName {
|
||||
self.client_name
|
||||
}
|
||||
}
|
||||
|
||||
/// A guard that holds the only reference to the index and allows modifying it.
|
||||
@@ -1025,6 +1036,7 @@ pub(crate) struct DocumentSnapshot {
|
||||
workspace_settings: Arc<WorkspaceSettings>,
|
||||
position_encoding: PositionEncoding,
|
||||
document: DocumentHandle,
|
||||
client_name: ClientName,
|
||||
}
|
||||
|
||||
impl DocumentSnapshot {
|
||||
@@ -1071,6 +1083,10 @@ impl DocumentSnapshot {
|
||||
pub(crate) fn notebook_or_file_path(&self) -> &AnySystemPath {
|
||||
self.document.notebook_or_file_path()
|
||||
}
|
||||
|
||||
pub(crate) fn client_name(&self) -> ClientName {
|
||||
self.client_name
|
||||
}
|
||||
}
|
||||
|
||||
/// An immutable snapshot of the current state of [`Session`].
|
||||
@@ -1081,6 +1097,7 @@ pub(crate) struct SessionSnapshot {
|
||||
resolved_client_capabilities: ResolvedClientCapabilities,
|
||||
in_test: bool,
|
||||
revision: u64,
|
||||
client_name: ClientName,
|
||||
|
||||
/// IMPORTANT: It's important that the databases come last, or at least,
|
||||
/// after any `Arc` that we try to extract or mutate in-place using `Arc::into_inner`
|
||||
@@ -1122,6 +1139,42 @@ impl SessionSnapshot {
|
||||
pub(crate) fn revision(&self) -> u64 {
|
||||
self.revision
|
||||
}
|
||||
|
||||
pub(crate) fn client_name(&self) -> ClientName {
|
||||
self.client_name
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the client (editor) that's connected to the language server.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) enum ClientName {
|
||||
Zed,
|
||||
Other,
|
||||
}
|
||||
|
||||
impl From<Option<ClientInfo>> for ClientName {
|
||||
fn from(info: Option<ClientInfo>) -> Self {
|
||||
match info {
|
||||
Some(info) if matches!(info.name.as_str(), "Zed") => ClientName::Zed,
|
||||
_ => ClientName::Other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientName {
|
||||
/// Returns editor-specific guidance for finding logs.
|
||||
///
|
||||
/// Different editors have different ways to access language server logs, so we provide tailored
|
||||
/// instructions based on the connected client.
|
||||
pub(crate) fn log_guidance(self) -> &'static str {
|
||||
match self {
|
||||
ClientName::Zed => {
|
||||
"Please refer to the logs for more details \
|
||||
(command palette: `dev: open language server logs`)."
|
||||
}
|
||||
ClientName::Other => "Please refer to the logs for more details.",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
|
||||
Reference in New Issue
Block a user