From 2a8471191ebce43b454dbbae25aedec194808ced Mon Sep 17 00:00:00 2001 From: Chayim Refael Friedman Date: Thu, 2 Jan 2025 19:05:07 +0200 Subject: [PATCH] Support the new `CoercePointee` derive --- .../builtin_derive_macro.rs | 128 +++- .../hir-def/src/macro_expansion_tests/mod.rs | 38 +- .../hir-expand/src/builtin/derive_macro.rs | 626 ++++++++++++++++-- .../ide-assists/src/handlers/move_bounds.rs | 2 +- .../crates/ide/src/expand_macro.rs | 8 +- .../crates/intern/src/symbol/symbols.rs | 1 + .../crates/syntax/src/ast/make.rs | 20 +- .../crates/test-utils/src/minicore.rs | 9 + 8 files changed, 768 insertions(+), 64 deletions(-) diff --git a/src/tools/rust-analyzer/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs b/src/tools/rust-analyzer/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs index 163211fea5262..c31d322132897 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs @@ -2,7 +2,7 @@ use expect_test::expect; -use crate::macro_expansion_tests::check; +use crate::macro_expansion_tests::{check, check_errors}; #[test] fn test_copy_expand_simple() { @@ -16,7 +16,7 @@ struct Foo; #[derive(Copy)] struct Foo; -impl < > $crate::marker::Copy for Foo< > where {}"#]], +impl <> $crate::marker::Copy for Foo< > where {}"#]], ); } @@ -40,7 +40,7 @@ macro Copy {} #[derive(Copy)] struct Foo; -impl < > $crate::marker::Copy for Foo< > where {}"#]], +impl <> $crate::marker::Copy for Foo< > where {}"#]], ); } @@ -225,14 +225,14 @@ enum Bar { Bar, } -impl < > $crate::default::Default for Foo< > where { +impl <> $crate::default::Default for Foo< > where { fn default() -> Self { Foo { field1: $crate::default::Default::default(), field2: $crate::default::Default::default(), } } } -impl < > $crate::default::Default for Bar< > where { +impl <> $crate::default::Default for Bar< > where { fn default() -> Self { Bar::Bar } @@ -260,7 +260,7 @@ enum Command { Jump, } -impl < > $crate::cmp::PartialEq for Command< > where { +impl <> $crate::cmp::PartialEq for Command< > where { fn eq(&self , other: &Self ) -> bool { match (self , other) { (Command::Move { @@ -273,7 +273,7 @@ impl < > $crate::cmp::PartialEq for Command< > where { } } } -impl < > $crate::cmp::Eq for Command< > where {}"#]], +impl <> $crate::cmp::Eq for Command< > where {}"#]], ); } @@ -298,7 +298,7 @@ enum Command { Jump, } -impl < > $crate::cmp::PartialEq for Command< > where { +impl <> $crate::cmp::PartialEq for Command< > where { fn eq(&self , other: &Self ) -> bool { match (self , other) { (Command::Move { @@ -311,7 +311,7 @@ impl < > $crate::cmp::PartialEq for Command< > where { } } } -impl < > $crate::cmp::Eq for Command< > where {}"#]], +impl <> $crate::cmp::Eq for Command< > where {}"#]], ); } @@ -335,7 +335,7 @@ enum Command { Jump, } -impl < > $crate::cmp::PartialOrd for Command< > where { +impl <> $crate::cmp::PartialOrd for Command< > where { fn partial_cmp(&self , other: &Self ) -> $crate::option::Option::Option<$crate::cmp::Ordering> { match $crate::intrinsics::discriminant_value(self ).partial_cmp(&$crate::intrinsics::discriminant_value(other)) { $crate::option::Option::Some($crate::cmp::Ordering::Equal)=> { @@ -370,7 +370,7 @@ impl < > $crate::cmp::PartialOrd for Command< > where { } } } -impl < > $crate::cmp::Ord for Command< > where { +impl <> $crate::cmp::Ord for Command< > where { fn cmp(&self , other: &Self ) -> $crate::cmp::Ordering { match $crate::intrinsics::discriminant_value(self ).cmp(&$crate::intrinsics::discriminant_value(other)) { $crate::cmp::Ordering::Equal=> { @@ -432,7 +432,7 @@ struct Foo { z: (i32, u64), } -impl < > $crate::hash::Hash for Foo< > where { +impl <> $crate::hash::Hash for Foo< > where { fn hash(&self , ra_expand_state: &mut H) { match self { Foo { @@ -470,7 +470,7 @@ enum Command { Jump, } -impl < > $crate::hash::Hash for Command< > where { +impl <> $crate::hash::Hash for Command< > where { fn hash(&self , ra_expand_state: &mut H) { $crate::mem::discriminant(self ).hash(ra_expand_state); match self { @@ -516,7 +516,7 @@ enum Command { Jump, } -impl < > $crate::fmt::Debug for Command< > where { +impl <> $crate::fmt::Debug for Command< > where { fn fmt(&self , f: &mut $crate::fmt::Formatter) -> $crate::fmt::Result { match self { Command::Move { @@ -578,7 +578,7 @@ enum HideAndShowEnum { } } -impl < > $crate::fmt::Debug for HideAndShow< > where { +impl <> $crate::fmt::Debug for HideAndShow< > where { fn fmt(&self , f: &mut $crate::fmt::Formatter) -> $crate::fmt::Result { match self { HideAndShow { @@ -588,7 +588,7 @@ impl < > $crate::fmt::Debug for HideAndShow< > where { } } } -impl < > $crate::fmt::Debug for HideAndShowEnum< > where { +impl <> $crate::fmt::Debug for HideAndShowEnum< > where { fn fmt(&self , f: &mut $crate::fmt::Formatter) -> $crate::fmt::Result { match self { HideAndShowEnum::AlwaysShow { @@ -640,17 +640,109 @@ enum Bar { Bar, } -impl < > $crate::default::Default for Foo< > where { +impl <> $crate::default::Default for Foo< > where { fn default() -> Self { Foo { field1: $crate::default::Default::default(), field4: $crate::default::Default::default(), } } } -impl < > $crate::default::Default for Bar< > where { +impl <> $crate::default::Default for Bar< > where { fn default() -> Self { Bar::Bar } }"##]], ); } + +#[test] +fn coerce_pointee_expansion() { + check( + r#" +//- minicore: coerce_pointee + +use core::marker::CoercePointee; + +pub trait Trait {} + +#[derive(CoercePointee)] +#[repr(transparent)] +pub struct Foo<'a, T: ?Sized + Trait, #[pointee] U: ?Sized, const N: u32>(T) +where + U: Trait + ToString;"#, + expect![[r#" + +use core::marker::CoercePointee; + +pub trait Trait {} + +#[derive(CoercePointee)] +#[repr(transparent)] +pub struct Foo<'a, T: ?Sized + Trait, #[pointee] U: ?Sized, const N: u32>(T) +where + U: Trait + ToString; +impl $crate::ops::DispatchFromDyn> for Foo where U: Trait +ToString, T: Trait<__S>, __S: ?Sized, __S: Trait<__S> +ToString, U: ::core::marker::Unsize<__S>, T:?Sized+Trait, U:?Sized, {} +impl $crate::ops::CoerceUnsized> for Foo where U: Trait +ToString, T: Trait<__S>, __S: ?Sized, __S: Trait<__S> +ToString, U: ::core::marker::Unsize<__S>, T:?Sized+Trait, U:?Sized, {}"#]], + ); +} + +#[test] +fn coerce_pointee_errors() { + check_errors( + r#" +//- minicore: coerce_pointee + +use core::marker::CoercePointee; + +#[derive(CoercePointee)] +enum Enum {} + +#[derive(CoercePointee)] +struct Struct1; + +#[derive(CoercePointee)] +struct Struct2(); + +#[derive(CoercePointee)] +struct Struct3 {} + +#[derive(CoercePointee)] +struct Struct4(T); + +#[derive(CoercePointee)] +#[repr(transparent)] +struct Struct5(i32); + +#[derive(CoercePointee)] +#[repr(transparent)] +struct Struct6<#[pointee] T: ?Sized, #[pointee] U: ?Sized>(T, U); + +#[derive(CoercePointee)] +#[repr(transparent)] +struct Struct7(T, U); + +#[derive(CoercePointee)] +#[repr(transparent)] +struct Struct8<#[pointee] T, U: ?Sized>(T); + +#[derive(CoercePointee)] +#[repr(transparent)] +struct Struct9(T); + +#[derive(CoercePointee)] +#[repr(transparent)] +struct Struct9<#[pointee] T, U>(T) where T: ?Sized; +"#, + expect![[r#" + 35..72: `CoercePointee` can only be derived on `struct`s + 74..114: `CoercePointee` can only be derived on `struct`s with at least one field + 116..158: `CoercePointee` can only be derived on `struct`s with at least one field + 160..202: `CoercePointee` can only be derived on `struct`s with at least one field + 204..258: `CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]` + 260..326: `CoercePointee` can only be derived on `struct`s that are generic over at least one type + 328..439: only one type parameter can be marked as `#[pointee]` when deriving `CoercePointee` traits + 441..530: exactly one generic type parameter must be marked as `#[pointee]` to derive `CoercePointee` traits + 532..621: `derive(CoercePointee)` requires `T` to be marked `?Sized` + 623..690: `derive(CoercePointee)` requires `T` to be marked `?Sized`"#]], + ); +} diff --git a/src/tools/rust-analyzer/crates/hir-def/src/macro_expansion_tests/mod.rs b/src/tools/rust-analyzer/crates/hir-def/src/macro_expansion_tests/mod.rs index f129358946d34..5b9ffdf37beda 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/macro_expansion_tests/mod.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/macro_expansion_tests/mod.rs @@ -16,14 +16,16 @@ mod proc_macros; use std::{iter, ops::Range, sync}; +use base_db::SourceDatabase; use expect_test::Expect; use hir_expand::{ db::ExpandDatabase, proc_macro::{ProcMacro, ProcMacroExpander, ProcMacroExpansionError, ProcMacroKind}, span_map::SpanMapRef, - InFile, MacroFileId, MacroFileIdExt, + InFile, MacroCallKind, MacroFileId, MacroFileIdExt, }; use intern::Symbol; +use itertools::Itertools; use span::{Edition, Span}; use stdx::{format_to, format_to_acc}; use syntax::{ @@ -44,6 +46,36 @@ use crate::{ AdtId, AsMacroCall, Lookup, ModuleDefId, }; +#[track_caller] +fn check_errors(ra_fixture: &str, expect: Expect) { + let db = TestDB::with_files(ra_fixture); + let krate = db.fetch_test_crate(); + let def_map = db.crate_def_map(krate); + let errors = def_map + .modules() + .flat_map(|module| module.1.scope.all_macro_calls()) + .filter_map(|macro_call| { + let errors = db.parse_macro_expansion_error(macro_call)?; + let errors = errors.err.as_ref()?.render_to_string(&db); + let macro_loc = db.lookup_intern_macro_call(macro_call); + let ast_id = match macro_loc.kind { + MacroCallKind::FnLike { ast_id, .. } => ast_id.map(|it| it.erase()), + MacroCallKind::Derive { ast_id, .. } => ast_id.map(|it| it.erase()), + MacroCallKind::Attr { ast_id, .. } => ast_id.map(|it| it.erase()), + }; + let ast = db + .parse(ast_id.file_id.file_id().expect("macros inside macros are not supported")) + .syntax_node(); + let ast_id_map = db.ast_id_map(ast_id.file_id); + let node = ast_id_map.get_erased(ast_id.value).to_node(&ast); + Some((node.text_range(), errors)) + }) + .sorted_unstable_by_key(|(range, _)| range.start()) + .format_with("\n", |(range, err), format| format(&format_args!("{range:?}: {err}"))) + .to_string(); + expect.assert_eq(&errors); +} + #[track_caller] fn check(ra_fixture: &str, mut expect: Expect) { let extra_proc_macros = vec![( @@ -245,7 +277,9 @@ fn pretty_print_macro_expansion( let mut res = String::new(); let mut prev_kind = EOF; let mut indent_level = 0; - for token in iter::successors(expn.first_token(), |t| t.next_token()) { + for token in iter::successors(expn.first_token(), |t| t.next_token()) + .take_while(|token| token.text_range().start() < expn.text_range().end()) + { let curr_kind = token.kind(); let space = match (prev_kind, curr_kind) { _ if prev_kind.is_trivia() || curr_kind.is_trivia() => "", diff --git a/src/tools/rust-analyzer/crates/hir-expand/src/builtin/derive_macro.rs b/src/tools/rust-analyzer/crates/hir-expand/src/builtin/derive_macro.rs index e083e0ddca038..4510a593af4da 100644 --- a/src/tools/rust-analyzer/crates/hir-expand/src/builtin/derive_macro.rs +++ b/src/tools/rust-analyzer/crates/hir-expand/src/builtin/derive_macro.rs @@ -1,9 +1,10 @@ //! Builtin derives. use intern::sym; -use itertools::izip; +use itertools::{izip, Itertools}; +use parser::SyntaxKind; use rustc_hash::FxHashSet; -use span::{MacroCallId, Span}; +use span::{MacroCallId, Span, SyntaxContextId}; use stdx::never; use syntax_bridge::DocCommentDesugarMode; use tracing::debug; @@ -16,8 +17,12 @@ use crate::{ span_map::ExpansionSpanMap, tt, ExpandError, ExpandResult, }; -use syntax::ast::{ - self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, +use syntax::{ + ast::{ + self, edit_in_place::GenericParamsOwnerEdit, make, AstNode, FieldList, HasAttrs, + HasGenericArgs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, + }, + ted, }; macro_rules! register_builtin { @@ -67,13 +72,15 @@ register_builtin! { Ord => ord_expand, PartialOrd => partial_ord_expand, Eq => eq_expand, - PartialEq => partial_eq_expand + PartialEq => partial_eq_expand, + CoercePointee => coerce_pointee_expand } pub fn find_builtin_derive(ident: &name::Name) -> Option { BuiltinDeriveExpander::find_by_name(ident) } +#[derive(Clone)] enum VariantShape { Struct(Vec), Tuple(usize), @@ -147,6 +154,7 @@ impl VariantShape { } } +#[derive(Clone)] enum AdtShape { Struct(VariantShape), Enum { variants: Vec<(tt::Ident, VariantShape)>, default_variant: Option }, @@ -197,30 +205,38 @@ impl AdtShape { } } +#[derive(Clone)] struct BasicAdtInfo { name: tt::Ident, shape: AdtShape, /// first field is the name, and /// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param. /// third fields is where bounds, if any - param_types: Vec<(tt::TopSubtree, Option, Option)>, + param_types: Vec, where_clause: Vec, associated_types: Vec, } +#[derive(Clone)] +struct AdtParam { + name: tt::TopSubtree, + /// `None` if this is a type parameter. + const_ty: Option, + bounds: Option, +} + +// FIXME: This whole thing needs a refactor. Each derive requires its special values, and the result is a mess. fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result { - let (parsed, tm) = &syntax_bridge::token_tree_to_syntax_node( - tt, - syntax_bridge::TopEntryPoint::MacroItems, - parser::Edition::CURRENT_FIXME, - ); - let macro_items = ast::MacroItems::cast(parsed.syntax_node()) - .ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?; - let item = - macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?; - let adt = &ast::Adt::cast(item.syntax().clone()) - .ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?; - let (name, generic_param_list, where_clause, shape) = match adt { + let (adt, tm) = to_adt_syntax(tt, call_site)?; + parse_adt_from_syntax(&adt, &tm, call_site) +} + +fn parse_adt_from_syntax( + adt: &ast::Adt, + tm: &span::SpanMap, + call_site: Span, +) -> Result { + let (name, generic_param_list, where_clause, shape) = match &adt { ast::Adt::Struct(it) => ( it.name(), it.generic_param_list(), @@ -291,7 +307,7 @@ fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result None, }; - let ty = if let ast::TypeOrConstParam::Const(param) = param { + let const_ty = if let ast::TypeOrConstParam::Const(param) = param { let ty = param .ty() .map(|ty| { @@ -309,7 +325,7 @@ fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result Result Result<(ast::Adt, span::SpanMap), ExpandError> { + let (parsed, tm) = syntax_bridge::token_tree_to_syntax_node( + tt, + syntax_bridge::TopEntryPoint::MacroItems, + parser::Edition::CURRENT_FIXME, + ); + let macro_items = ast::MacroItems::cast(parsed.syntax_node()) + .ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?; + let item = + macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?; + let adt = ast::Adt::cast(item.syntax().clone()) + .ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?; + Ok((adt, tm)) +} + fn name_to_token( call_site: Span, token_map: &ExpansionSpanMap, @@ -426,38 +460,64 @@ fn expand_simple_derive( ) } }; + ExpandResult::ok(expand_simple_derive_with_parsed( + invoc_span, + info, + trait_path, + make_trait_body, + true, + tt::TopSubtree::empty(tt::DelimSpan::from_single(invoc_span)), + )) +} + +fn expand_simple_derive_with_parsed( + invoc_span: Span, + info: BasicAdtInfo, + trait_path: tt::TopSubtree, + make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree, + constrain_to_trait: bool, + extra_impl_params: tt::TopSubtree, +) -> tt::TopSubtree { let trait_body = make_trait_body(&info); let mut where_block: Vec<_> = info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect(); let (params, args): (Vec<_>, Vec<_>) = info .param_types .into_iter() - .map(|(ident, param_ty, bound)| { - let ident_ = ident.clone(); - if let Some(b) = bound { - let ident = ident.clone(); - where_block.push(quote! {invoc_span => #ident : #b , }); - } - if let Some(ty) = param_ty { - (quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident_ , }) + .map(|param| { + let ident = param.name; + if let Some(b) = param.bounds { + let ident2 = ident.clone(); + where_block.push(quote! {invoc_span => #ident2 : #b , }); + } + if let Some(ty) = param.const_ty { + let ident2 = ident.clone(); + (quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident2 , }) } else { let bound = trait_path.clone(); - (quote! {invoc_span => #ident : #bound , }, quote! {invoc_span => #ident_ , }) + let ident2 = ident.clone(); + let param = if constrain_to_trait { + quote! {invoc_span => #ident : #bound , } + } else { + quote! {invoc_span => #ident , } + }; + (param, quote! {invoc_span => #ident2 , }) } }) .unzip(); - where_block.extend(info.associated_types.iter().map(|it| { - let it = it.clone(); - let bound = trait_path.clone(); - quote! {invoc_span => #it : #bound , } - })); + if constrain_to_trait { + where_block.extend(info.associated_types.iter().map(|it| { + let it = it.clone(); + let bound = trait_path.clone(); + quote! {invoc_span => #it : #bound , } + })); + } let name = info.name; - let expanded = quote! {invoc_span => - impl < ##params > #trait_path for #name < ##args > where ##where_block { #trait_body } - }; - ExpandResult::ok(expanded) + quote! {invoc_span => + impl < ##params #extra_impl_params > #trait_path for #name < ##args > where ##where_block { #trait_body } + } } fn copy_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult { @@ -871,3 +931,493 @@ fn partial_ord_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult ExpandResult { + let (adt, _span_map) = match to_adt_syntax(tt, span) { + Ok(it) => it, + Err(err) => { + return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err); + } + }; + let adt = adt.clone_for_update(); + let ast::Adt::Struct(strukt) = &adt else { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + ExpandError::other(span, "`CoercePointee` can only be derived on `struct`s"), + ); + }; + let has_at_least_one_field = strukt + .field_list() + .map(|it| match it { + ast::FieldList::RecordFieldList(it) => it.fields().next().is_some(), + ast::FieldList::TupleFieldList(it) => it.fields().next().is_some(), + }) + .unwrap_or(false); + if !has_at_least_one_field { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + ExpandError::other( + span, + "`CoercePointee` can only be derived on `struct`s with at least one field", + ), + ); + } + let is_repr_transparent = strukt.attrs().any(|attr| { + attr.as_simple_call().is_some_and(|(name, tt)| { + name == "repr" + && tt.syntax().children_with_tokens().any(|it| { + it.into_token().is_some_and(|it| { + it.kind() == SyntaxKind::IDENT && it.text() == "transparent" + }) + }) + }) + }); + if !is_repr_transparent { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + ExpandError::other( + span, + "`CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]`", + ), + ); + } + let type_params = strukt + .generic_param_list() + .into_iter() + .flat_map(|generics| { + generics.generic_params().filter_map(|param| match param { + ast::GenericParam::TypeParam(param) => Some(param), + _ => None, + }) + }) + .collect_vec(); + if type_params.is_empty() { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + ExpandError::other( + span, + "`CoercePointee` can only be derived on `struct`s that are generic over at least one type", + ), + ); + } + let (pointee_param, pointee_param_idx) = if type_params.len() == 1 { + // Regardless of the only type param being designed as `#[pointee]` or not, we can just use it as such. + (type_params[0].clone(), 0) + } else { + let mut pointees = type_params.iter().cloned().enumerate().filter(|(_, param)| { + param.attrs().any(|attr| { + let is_pointee = attr.as_simple_atom().is_some_and(|name| name == "pointee"); + if is_pointee { + // Remove the `#[pointee]` attribute so it won't be present in the generated + // impls (where we cannot resolve it). + ted::remove(attr.syntax()); + } + is_pointee + }) + }); + match (pointees.next(), pointees.next()) { + (Some((pointee_idx, pointee)), None) => (pointee, pointee_idx), + (None, _) => { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + ExpandError::other( + span, + "exactly one generic type parameter must be marked \ + as `#[pointee]` to derive `CoercePointee` traits", + ), + ) + } + (Some(_), Some(_)) => { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + ExpandError::other( + span, + "only one type parameter can be marked as `#[pointee]` \ + when deriving `CoercePointee` traits", + ), + ) + } + } + }; + let (Some(struct_name), Some(pointee_param_name)) = (strukt.name(), pointee_param.name()) + else { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + ExpandError::other(span, "invalid item"), + ); + }; + + { + let mut pointee_has_maybe_sized_bound = false; + if let Some(bounds) = pointee_param.type_bound_list() { + pointee_has_maybe_sized_bound |= bounds.bounds().any(is_maybe_sized_bound); + } + if let Some(where_clause) = strukt.where_clause() { + pointee_has_maybe_sized_bound |= where_clause.predicates().any(|pred| { + let Some(ast::Type::PathType(ty)) = pred.ty() else { return false }; + let is_not_pointee = ty.path().is_none_or(|path| { + let is_pointee = path + .as_single_name_ref() + .is_some_and(|name| name.text() == pointee_param_name.text()); + !is_pointee + }); + if is_not_pointee { + return false; + } + pred.type_bound_list() + .is_some_and(|bounds| bounds.bounds().any(is_maybe_sized_bound)) + }) + } + if !pointee_has_maybe_sized_bound { + return ExpandResult::new( + tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), + ExpandError::other( + span, + format!("`derive(CoercePointee)` requires `{pointee_param_name}` to be marked `?Sized`"), + ), + ); + } + } + + const ADDED_PARAM: &str = "__S"; + + let where_clause = strukt.get_or_create_where_clause(); + + { + let mut new_predicates = Vec::new(); + + // # Rewrite generic parameter bounds + // For each bound `U: ..` in `struct`, make a new bound with `__S` in place of `#[pointee]` + // Example: + // ``` + // struct< + // U: Trait, + // #[pointee] T: Trait + ?Sized, + // V: Trait> ... + // ``` + // ... generates this `impl` generic parameters + // ``` + // impl< + // U: Trait, + // T: Trait + ?Sized, + // V: Trait + // > + // where + // U: Trait<__S>, + // __S: Trait<__S> + ?Sized, + // V: Trait<__S> ... + // ``` + for param in &type_params { + let Some(param_name) = param.name() else { continue }; + if let Some(bounds) = param.type_bound_list() { + // If the target type is the pointee, duplicate the bound as whole. + // Otherwise, duplicate only bounds that mention the pointee. + let is_pointee = param_name.text() == pointee_param_name.text(); + let new_bounds = bounds + .bounds() + .map(|bound| bound.clone_subtree().clone_for_update()) + .filter(|bound| { + bound.ty().is_some_and(|ty| { + substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM) + || is_pointee + }) + }); + let new_bounds_target = if is_pointee { + make::name_ref(ADDED_PARAM) + } else { + make::name_ref(¶m_name.text()) + }; + new_predicates.push( + make::where_pred( + make::ty_path(make::path_from_segments( + [make::path_segment(new_bounds_target)], + false, + )), + new_bounds, + ) + .clone_for_update(), + ); + } + } + + // # Rewrite `where` clauses + // + // Move on to `where` clauses. + // Example: + // ``` + // struct MyPointer<#[pointee] T, ..> + // where + // U: Trait + Trait, + // Companion: Trait, + // T: Trait + ?Sized, + // { .. } + // ``` + // ... will have a impl prelude like so + // ``` + // impl<..> .. + // where + // U: Trait + Trait, + // U: Trait<__S>, + // Companion: Trait, + // Companion<__S>: Trait<__S>, + // T: Trait + ?Sized, + // __S: Trait<__S> + ?Sized, + // ``` + // + // We should also write a few new `where` bounds from `#[pointee] T` to `__S` + // as well as any bound that indirectly involves the `#[pointee] T` type. + for predicate in where_clause.predicates() { + let predicate = predicate.clone_subtree().clone_for_update(); + let Some(pred_target) = predicate.ty() else { continue }; + + // If the target type references the pointee, duplicate the bound as whole. + // Otherwise, duplicate only bounds that mention the pointee. + if substitute_type_in_bound( + pred_target.clone(), + &pointee_param_name.text(), + ADDED_PARAM, + ) { + if let Some(bounds) = predicate.type_bound_list() { + for bound in bounds.bounds() { + if let Some(ty) = bound.ty() { + substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM); + } + } + } + + new_predicates.push(predicate); + } else if let Some(bounds) = predicate.type_bound_list() { + let new_bounds = bounds + .bounds() + .map(|bound| bound.clone_subtree().clone_for_update()) + .filter(|bound| { + bound.ty().is_some_and(|ty| { + substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM) + }) + }); + new_predicates.push(make::where_pred(pred_target, new_bounds).clone_for_update()); + } + } + + for new_predicate in new_predicates { + where_clause.add_predicate(new_predicate); + } + } + + { + // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location + // + // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it. + where_clause.add_predicate( + make::where_pred( + make::ty_path(make::path_from_segments( + [make::path_segment(make::name_ref(&pointee_param_name.text()))], + false, + )), + [make::type_bound(make::ty_path(make::path_from_segments( + [ + make::path_segment(make::name_ref("core")), + make::path_segment(make::name_ref("marker")), + make::generic_ty_path_segment( + make::name_ref("Unsize"), + [make::type_arg(make::ty_path(make::path_from_segments( + [make::path_segment(make::name_ref(ADDED_PARAM))], + false, + ))) + .into()], + ), + ], + true, + )))], + ) + .clone_for_update(), + ); + } + + let self_for_traits = { + // Replace the `#[pointee]` with `__S`. + let mut type_param_idx = 0; + let self_params_for_traits = strukt + .generic_param_list() + .into_iter() + .flat_map(|params| params.generic_params()) + .filter_map(|param| { + Some(match param { + ast::GenericParam::ConstParam(param) => { + ast::GenericArg::ConstArg(make::expr_const_value(¶m.name()?.text())) + } + ast::GenericParam::LifetimeParam(param) => { + make::lifetime_arg(param.lifetime()?).into() + } + ast::GenericParam::TypeParam(param) => { + let name = if pointee_param_idx == type_param_idx { + make::name_ref(ADDED_PARAM) + } else { + make::name_ref(¶m.name()?.text()) + }; + type_param_idx += 1; + make::type_arg(make::ty_path(make::path_from_segments( + [make::path_segment(name)], + false, + ))) + .into() + } + }) + }); + let self_for_traits = make::path_from_segments( + [make::generic_ty_path_segment( + make::name_ref(&struct_name.text()), + self_params_for_traits, + )], + false, + ) + .clone_for_update(); + self_for_traits + }; + + let mut span_map = span::SpanMap::empty(); + // One span for them all. + span_map.push(adt.syntax().text_range().end(), span); + + let self_for_traits = syntax_bridge::syntax_node_to_token_tree( + self_for_traits.syntax(), + &span_map, + span, + DocCommentDesugarMode::ProcMacro, + ); + let info = match parse_adt_from_syntax(&adt, &span_map, span) { + Ok(it) => it, + Err(err) => { + return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err) + } + }; + + let self_for_traits2 = self_for_traits.clone(); + let krate = dollar_crate(span); + let krate2 = krate.clone(); + let dispatch_from_dyn = expand_simple_derive_with_parsed( + span, + info.clone(), + quote! {span => #krate2::ops::DispatchFromDyn<#self_for_traits2> }, + |_adt| quote! {span => }, + false, + quote! {span => __S }, + ); + let coerce_unsized = expand_simple_derive_with_parsed( + span, + info, + quote! {span => #krate::ops::CoerceUnsized<#self_for_traits> }, + |_adt| quote! {span => }, + false, + quote! {span => __S }, + ); + return ExpandResult::ok(quote! {span => #dispatch_from_dyn #coerce_unsized }); + + fn is_maybe_sized_bound(bound: ast::TypeBound) -> bool { + if bound.question_mark_token().is_none() { + return false; + } + let Some(ast::Type::PathType(ty)) = bound.ty() else { + return false; + }; + let Some(path) = ty.path() else { + return false; + }; + return segments_eq(&path, &["Sized"]) + || segments_eq(&path, &["core", "marker", "Sized"]) + || segments_eq(&path, &["std", "marker", "Sized"]); + + fn segments_eq(path: &ast::Path, expected: &[&str]) -> bool { + path.segments().zip_longest(expected.iter().copied()).all(|value| { + value.both().is_some_and(|(segment, expected)| { + segment.name_ref().is_some_and(|name| name.text() == expected) + }) + }) + } + } + + /// Returns true if any substitution was performed. + fn substitute_type_in_bound(ty: ast::Type, param_name: &str, replacement: &str) -> bool { + return match ty { + ast::Type::ArrayType(ty) => { + ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement)) + } + ast::Type::DynTraitType(ty) => go_bounds(ty.type_bound_list(), param_name, replacement), + ast::Type::FnPtrType(ty) => any_long( + ty.param_list() + .into_iter() + .flat_map(|params| params.params().filter_map(|param| param.ty())) + .chain(ty.ret_type().and_then(|it| it.ty())), + |ty| substitute_type_in_bound(ty, param_name, replacement), + ), + ast::Type::ForType(ty) => { + ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement)) + } + ast::Type::ImplTraitType(ty) => { + go_bounds(ty.type_bound_list(), param_name, replacement) + } + ast::Type::ParenType(ty) => { + ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement)) + } + ast::Type::PathType(ty) => ty.path().is_some_and(|path| { + if path.as_single_name_ref().is_some_and(|name| name.text() == param_name) { + ted::replace( + path.syntax(), + make::path_from_segments( + [make::path_segment(make::name_ref(replacement))], + false, + ) + .clone_for_update() + .syntax(), + ); + return true; + } + + any_long( + path.segments() + .filter_map(|segment| segment.generic_arg_list()) + .flat_map(|it| it.generic_args()) + .filter_map(|generic_arg| match generic_arg { + ast::GenericArg::TypeArg(ty) => ty.ty(), + _ => None, + }), + |ty| substitute_type_in_bound(ty, param_name, replacement), + ) + }), + ast::Type::PtrType(ty) => { + ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement)) + } + ast::Type::RefType(ty) => { + ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement)) + } + ast::Type::SliceType(ty) => { + ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement)) + } + ast::Type::TupleType(ty) => { + any_long(ty.fields(), |ty| substitute_type_in_bound(ty, param_name, replacement)) + } + ast::Type::InferType(_) | ast::Type::MacroType(_) | ast::Type::NeverType(_) => false, + }; + + fn go_bounds( + bounds: Option, + param_name: &str, + replacement: &str, + ) -> bool { + bounds.is_some_and(|bounds| { + any_long(bounds.bounds(), |bound| { + bound + .ty() + .is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement)) + }) + }) + } + + /// Like [`Iterator::any()`], but not short-circuiting. + fn any_long bool>(iter: I, mut f: F) -> bool { + let mut result = false; + iter.for_each(|item| result |= f(item)); + result + } + } +} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs index 1dd376ac3fd53..5101d8fa0a9e3 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs @@ -78,7 +78,7 @@ pub(crate) fn move_bounds_to_where_clause( fn build_predicate(param: ast::TypeParam) -> Option { let path = make::ext::ident_path(¶m.name()?.syntax().to_string()); - let predicate = make::where_pred(path, param.type_bound_list()?.bounds()); + let predicate = make::where_pred(make::ty_path(path), param.type_bound_list()?.bounds()); Some(predicate.clone_for_update()) } diff --git a/src/tools/rust-analyzer/crates/ide/src/expand_macro.rs b/src/tools/rust-analyzer/crates/ide/src/expand_macro.rs index f642db6a71ef4..e028c5ff0cb47 100644 --- a/src/tools/rust-analyzer/crates/ide/src/expand_macro.rs +++ b/src/tools/rust-analyzer/crates/ide/src/expand_macro.rs @@ -574,7 +574,7 @@ struct Foo {} "#, expect![[r#" Clone - impl < >core::clone::Clone for Foo< >where { + impl <>core::clone::Clone for Foo< >where { fn clone(&self) -> Self { match self { Foo{} @@ -600,7 +600,7 @@ struct Foo {} "#, expect![[r#" Copy - impl < >core::marker::Copy for Foo< >where{}"#]], + impl <>core::marker::Copy for Foo< >where{}"#]], ); } @@ -615,7 +615,7 @@ struct Foo {} "#, expect![[r#" Copy - impl < >core::marker::Copy for Foo< >where{}"#]], + impl <>core::marker::Copy for Foo< >where{}"#]], ); check( r#" @@ -626,7 +626,7 @@ struct Foo {} "#, expect![[r#" Clone - impl < >core::clone::Clone for Foo< >where { + impl <>core::clone::Clone for Foo< >where { fn clone(&self) -> Self { match self { Foo{} diff --git a/src/tools/rust-analyzer/crates/intern/src/symbol/symbols.rs b/src/tools/rust-analyzer/crates/intern/src/symbol/symbols.rs index c15751e7c680f..66b8900109c2b 100644 --- a/src/tools/rust-analyzer/crates/intern/src/symbol/symbols.rs +++ b/src/tools/rust-analyzer/crates/intern/src/symbol/symbols.rs @@ -361,6 +361,7 @@ define_symbols! { partial_ord, PartialEq, PartialOrd, + CoercePointee, path, Pending, phantom_data, diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs index 76b39c3b73f30..f77ca7ff068a5 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs @@ -336,6 +336,24 @@ pub fn path_segment(name_ref: ast::NameRef) -> ast::PathSegment { ast_from_text(&format!("type __ = {name_ref};")) } +/// Type and expressions/patterns path differ in whether they require `::` before generic arguments. +/// Type paths allow them but they are often omitted, while expression/pattern paths require them. +pub fn generic_ty_path_segment( + name_ref: ast::NameRef, + generic_args: impl IntoIterator, +) -> ast::PathSegment { + let mut generic_args = generic_args.into_iter(); + let first_generic_arg = generic_args.next(); + quote! { + PathSegment { + #name_ref + GenericArgList { + [<] #first_generic_arg #([,] " " #generic_args)* [>] + } + } + } +} + pub fn path_segment_ty(type_ref: ast::Type, trait_ref: Option) -> ast::PathSegment { let text = match trait_ref { Some(trait_ref) => format!("fn f(x: <{type_ref} as {trait_ref}>) {{}}"), @@ -814,7 +832,7 @@ pub fn match_arm_list(arms: impl IntoIterator) -> ast::Mat } pub fn where_pred( - path: ast::Path, + path: ast::Type, bounds: impl IntoIterator, ) -> ast::WherePred { let bounds = bounds.into_iter().join(" + "); diff --git a/src/tools/rust-analyzer/crates/test-utils/src/minicore.rs b/src/tools/rust-analyzer/crates/test-utils/src/minicore.rs index 99dfabe174eeb..4a2346193b491 100644 --- a/src/tools/rust-analyzer/crates/test-utils/src/minicore.rs +++ b/src/tools/rust-analyzer/crates/test-utils/src/minicore.rs @@ -17,6 +17,7 @@ //! builtin_impls: //! cell: copy, drop //! clone: sized +//! coerce_pointee: derive, sized, unsize, coerce_unsized, dispatch_from_dyn //! coerce_unsized: unsize //! concat: //! copy: clone @@ -157,6 +158,14 @@ pub mod marker { type Discriminant; } // endregion:discriminant + + // region:coerce_pointee + #[rustc_builtin_macro(CoercePointee, attributes(pointee))] + #[allow_internal_unstable(dispatch_from_dyn, coerce_unsized, unsize)] + pub macro CoercePointee($item:item) { + /* compiler built-in */ + } + // endregion:coerce_pointee } // region:default