Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using named consts in pattern types #136284

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,12 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
});
}

fn visit_pattern_type_pattern(&mut self, p: &'hir hir::Pat<'hir>) {
self.visit_pat(p)
fn visit_pattern_type_pattern(&mut self, pat: &'hir hir::TyPat<'hir>) {
self.insert(pat.span, pat.hir_id, Node::TyPat(pat));

self.with_parent(pat.hir_id, |this| {
intravisit::walk_ty_pat(this, pat);
});
}

fn visit_precise_capturing_arg(
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}
}
}
TyKind::Pat(ty, pat) => hir::TyKind::Pat(self.lower_ty(ty, itctx), self.lower_pat(pat)),
TyKind::Pat(ty, pat) => {
hir::TyKind::Pat(self.lower_ty(ty, itctx), self.lower_ty_pat(pat))
}
TyKind::MacCall(_) => {
span_bug!(t.span, "`TyKind::MacCall` should have been expanded by now")
}
Expand Down
80 changes: 78 additions & 2 deletions compiler/rustc_ast_lowering/src/pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use rustc_ast::ptr::P;
use rustc_ast::*;
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir as hir;
use rustc_hir::def::Res;
use rustc_hir::def::{DefKind, Res};
use rustc_middle::span_bug;
use rustc_span::source_map::{Spanned, respan};
use rustc_span::{Ident, Span};
use rustc_span::{Ident, Span, kw};

use super::errors::{
ArbitraryExpressionInPattern, ExtraDoubleDot, MisplacedDoubleDot, SubTupleBinding,
Expand Down Expand Up @@ -429,4 +429,80 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
};
self.arena.alloc(hir::PatExpr { hir_id: self.lower_node_id(expr.id), span, kind })
}

pub(crate) fn lower_ty_pat(&mut self, pattern: &Pat) -> &'hir hir::TyPat<'hir> {
self.arena.alloc(self.lower_ty_pat_mut(pattern))
}

fn lower_ty_pat_mut(&mut self, mut pattern: &Pat) -> hir::TyPat<'hir> {
// loop here to avoid recursion
let pat_hir_id = self.lower_node_id(pattern.id);
let node = loop {
match &pattern.kind {
PatKind::Range(e1, e2, Spanned { node: end, .. }) => {
let mut lower_expr = |e: &Expr| -> &_ {
let kind = if let ExprKind::Path(qself, path) = &e.kind {
hir::ConstArgKind::Path(self.lower_qpath(
e.id,
qself,
path,
ParamMode::Optional,
AllowReturnTypeNotation::No,
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
))
} else {
let node_id = self.next_node_id();
let def_id = self.create_def(
self.current_hir_id_owner.def_id,
node_id,
kw::Empty,
DefKind::AnonConst,
e.span,
);
let hir_id = self.lower_node_id(node_id);
let ac = self.arena.alloc(hir::AnonConst {
def_id,
hir_id,
body: self.lower_const_body(pattern.span, Some(e)),
span: self.lower_span(pattern.span),
});
hir::ConstArgKind::Anon(ac)
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty much mgce? 🤔 as in, this also adds consts in the type system.

Can you delegate this to lower_const_arg or whatever instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but at that point I should modify the ast, and I'd prefer to do that in other PRs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a FIXME then 😁

self.arena.alloc(hir::ConstArg { hir_id: self.next_id(), kind })
};
break hir::TyPatKind::Range(
e1.as_deref().map(|e| lower_expr(e)),
e2.as_deref().map(|e| lower_expr(e)),
self.lower_range_end(end, e2.is_some()),
);
}
// return inner to be processed in next loop
PatKind::Paren(inner) => pattern = inner,
PatKind::MacCall(_) => panic!("{:?} shouldn't exist here", pattern.span),
PatKind::Err(guar) => break hir::TyPatKind::Err(*guar),
PatKind::Deref(..)
| PatKind::Box(..)
| PatKind::Or(..)
| PatKind::Struct(..)
| PatKind::TupleStruct(..)
| PatKind::Tuple(..)
| PatKind::Ref(..)
| PatKind::Expr(..)
| PatKind::Guard(..)
| PatKind::Slice(_)
| PatKind::Ident(..)
| PatKind::Path(..)
| PatKind::Wild
| PatKind::Never
| PatKind::Rest => {
break hir::TyPatKind::Err(
self.dcx().span_err(pattern.span, "pattern not supported in pattern types"),
);
}
}
};

hir::TyPat { hir_id: pat_hir_id, kind: node, span: self.lower_span(pattern.span) }
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_driver_impl/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ impl<'tcx> pprust_hir::PpAnn for HirIdentifiedAnn<'tcx> {
s.s.space();
s.synth_comment(format!("pat hir_id: {}", pat.hir_id));
}
pprust_hir::AnnNode::TyPat(pat) => {
s.s.space();
s.synth_comment(format!("ty pat hir_id: {}", pat.hir_id));
}
pprust_hir::AnnNode::Arm(arm) => {
s.s.space();
s.synth_comment(format!("arm hir_id: {}", arm.hir_id));
Expand Down
21 changes: 20 additions & 1 deletion compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,14 @@ impl<'hir> Block<'hir> {
}
}

#[derive(Debug, Clone, Copy, HashStable_Generic)]
pub struct TyPat<'hir> {
#[stable_hasher(ignore)]
pub hir_id: HirId,
pub kind: TyPatKind<'hir>,
pub span: Span,
}

#[derive(Debug, Clone, Copy, HashStable_Generic)]
pub struct Pat<'hir> {
#[stable_hasher(ignore)]
Expand Down Expand Up @@ -1591,6 +1599,15 @@ pub enum PatExprKind<'hir> {
Path(QPath<'hir>),
}

#[derive(Debug, Clone, Copy, HashStable_Generic)]
pub enum TyPatKind<'hir> {
/// A range pattern (e.g., `1..=2` or `1..2`).
Range(Option<&'hir ConstArg<'hir>>, Option<&'hir ConstArg<'hir>>, RangeEnd),

/// A placeholder for a pattern that wasn't well formed in some way.
Err(ErrorGuaranteed),
}

#[derive(Debug, Clone, Copy, HashStable_Generic)]
pub enum PatKind<'hir> {
/// Represents a wildcard pattern (i.e., `_`).
Expand Down Expand Up @@ -3345,7 +3362,7 @@ pub enum TyKind<'hir, Unambig = ()> {
/// Placeholder for a type that has failed to be defined.
Err(rustc_span::ErrorGuaranteed),
/// Pattern types (`pattern_type!(u32 is 1..)`)
Pat(&'hir Ty<'hir>, &'hir Pat<'hir>),
Pat(&'hir Ty<'hir>, &'hir TyPat<'hir>),
/// `TyKind::Infer` means the type should be inferred instead of it having been
/// specified. This can appear anywhere in a type.
///
Expand Down Expand Up @@ -4331,6 +4348,7 @@ pub enum Node<'hir> {
AssocItemConstraint(&'hir AssocItemConstraint<'hir>),
TraitRef(&'hir TraitRef<'hir>),
OpaqueTy(&'hir OpaqueTy<'hir>),
TyPat(&'hir TyPat<'hir>),
Pat(&'hir Pat<'hir>),
PatField(&'hir PatField<'hir>),
/// Needed as its own node with its own HirId for tracking
Expand Down Expand Up @@ -4393,6 +4411,7 @@ impl<'hir> Node<'hir> {
| Node::Block(..)
| Node::Ctor(..)
| Node::Pat(..)
| Node::TyPat(..)
| Node::PatExpr(..)
| Node::Arm(..)
| Node::LetStmt(..)
Expand Down
18 changes: 14 additions & 4 deletions compiler/rustc_hir/src/intravisit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,8 @@ pub trait Visitor<'v>: Sized {
fn visit_expr_field(&mut self, field: &'v ExprField<'v>) -> Self::Result {
walk_expr_field(self, field)
}
fn visit_pattern_type_pattern(&mut self, _p: &'v Pat<'v>) {
// Do nothing. Only a few visitors need to know the details of the pattern type,
// and they opt into it. All other visitors will just choke on our fake patterns
// because they aren't in a body.
fn visit_pattern_type_pattern(&mut self, p: &'v TyPat<'v>) -> Self::Result {
walk_ty_pat(self, p)
}
fn visit_generic_param(&mut self, p: &'v GenericParam<'v>) -> Self::Result {
walk_generic_param(self, p)
Expand Down Expand Up @@ -702,6 +700,18 @@ pub fn walk_arm<'v, V: Visitor<'v>>(visitor: &mut V, arm: &'v Arm<'v>) -> V::Res
visitor.visit_expr(arm.body)
}

pub fn walk_ty_pat<'v, V: Visitor<'v>>(visitor: &mut V, pattern: &'v TyPat<'v>) -> V::Result {
try_visit!(visitor.visit_id(pattern.hir_id));
match pattern.kind {
TyPatKind::Range(lower_bound, upper_bound, _) => {
visit_opt!(visitor, visit_const_arg_unambig, lower_bound);
visit_opt!(visitor, visit_const_arg_unambig, upper_bound);
}
TyPatKind::Err(_) => (),
}
V::Result::output()
}

pub fn walk_pat<'v, V: Visitor<'v>>(visitor: &mut V, pattern: &'v Pat<'v>) -> V::Result {
try_visit!(visitor.visit_id(pattern.hir_id));
match pattern.kind {
Expand Down
3 changes: 0 additions & 3 deletions compiler/rustc_hir_analysis/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,6 @@ hir_analysis_paren_sugar_attribute = the `#[rustc_paren_sugar]` attribute is a t
hir_analysis_parenthesized_fn_trait_expansion =
parenthesized trait syntax expands to `{$expanded_type}`
hir_analysis_pattern_type_non_const_range = range patterns must have constant range start and end
hir_analysis_pattern_type_wild_pat = wildcard patterns are not permitted for pattern types
.label = this type is the same as the inner type without a pattern
hir_analysis_placeholder_not_allowed_item_signatures = the placeholder `_` is not allowed within types on item signatures for {$kind}
.label = not allowed in type signatures
hir_analysis_precise_capture_self_alias = `Self` can't be captured in `use<...>` precise captures list, since it is an alias
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir_analysis/src/collect/resolve_bound_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,8 @@ impl<'a, 'tcx> Visitor<'tcx> for BoundVarContext<'a, 'tcx> {
}

#[instrument(level = "debug", skip(self))]
fn visit_pattern_type_pattern(&mut self, p: &'tcx hir::Pat<'tcx>) {
intravisit::walk_pat(self, p)
fn visit_pattern_type_pattern(&mut self, p: &'tcx hir::TyPat<'tcx>) {
intravisit::walk_ty_pat(self, p)
}

#[instrument(level = "debug", skip(self))]
Expand Down
21 changes: 16 additions & 5 deletions compiler/rustc_hir_analysis/src/collect/type_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ use crate::hir_ty_lowering::HirTyLowerer;

mod opaque;

fn anon_const_type_of<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> Ty<'tcx> {
fn anon_const_type_of<'tcx>(icx: &ItemCtxt<'tcx>, def_id: LocalDefId) -> Ty<'tcx> {
use hir::*;
use rustc_middle::ty::Ty;
let tcx = icx.tcx;
let hir_id = tcx.local_def_id_to_hir_id(def_id);

let node = tcx.hir_node(hir_id);
Expand Down Expand Up @@ -54,7 +55,7 @@ fn anon_const_type_of<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> Ty<'tcx> {
hir_id: arg_hir_id,
kind: ConstArgKind::Anon(&AnonConst { hir_id: anon_hir_id, .. }),
..
}) if anon_hir_id == hir_id => const_arg_anon_type_of(tcx, arg_hir_id, span),
}) if anon_hir_id == hir_id => const_arg_anon_type_of(icx, arg_hir_id, span),

// Anon consts outside the type system.
Node::Expr(&Expr { kind: ExprKind::InlineAsm(asm), .. })
Expand Down Expand Up @@ -138,18 +139,28 @@ fn anon_const_type_of<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> Ty<'tcx> {
}
}

fn const_arg_anon_type_of<'tcx>(tcx: TyCtxt<'tcx>, arg_hir_id: HirId, span: Span) -> Ty<'tcx> {
fn const_arg_anon_type_of<'tcx>(icx: &ItemCtxt<'tcx>, arg_hir_id: HirId, span: Span) -> Ty<'tcx> {
use hir::*;
use rustc_middle::ty::Ty;

let tcx = icx.tcx;

match tcx.parent_hir_node(arg_hir_id) {
// Array length const arguments do not have `type_of` fed as there is never a corresponding
// generic parameter definition.
Node::Ty(&hir::Ty { kind: TyKind::Array(_, ref constant), .. })
| Node::Expr(&Expr { kind: ExprKind::Repeat(_, ref constant), .. })
if constant.hir_id == arg_hir_id =>
{
return tcx.types.usize;
tcx.types.usize
}

Node::TyPat(pat) => {
let hir::TyKind::Pat(ty, p) = tcx.parent_hir_node(pat.hir_id).expect_ty().kind else {
bug!()
};
assert_eq!(p.hir_id, pat.hir_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we use the same HirId for the TyKind::Pat and its pattern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't. This assertion is just x == child(parent(x)) and mostly a leftover from getting the PR to this state

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

icx.lower_ty(ty)
}

// This is not a `bug!` as const arguments in path segments that did not resolve to anything
Expand Down Expand Up @@ -344,7 +355,7 @@ pub(super) fn type_of(tcx: TyCtxt<'_>, def_id: LocalDefId) -> ty::EarlyBinder<'_
tcx.typeck(def_id).node_type(hir_id)
}

Node::AnonConst(_) => anon_const_type_of(tcx, def_id),
Node::AnonConst(_) => anon_const_type_of(&icx, def_id),

Node::ConstBlock(_) => {
let args = ty::GenericArgs::identity_for_item(tcx, def_id.to_def_id());
Expand Down
7 changes: 0 additions & 7 deletions compiler/rustc_hir_analysis/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1605,13 +1605,6 @@ pub(crate) struct OpaqueCapturesHigherRankedLifetime {
pub bad_place: &'static str,
}

#[derive(Diagnostic)]
#[diag(hir_analysis_pattern_type_non_const_range)]
pub(crate) struct NonConstRange {
#[primary_span]
pub span: Span,
}

#[derive(Subdiagnostic)]
pub(crate) enum InvalidReceiverTyHint {
#[note(hir_analysis_invalid_receiver_ty_help_weak_note)]
Expand Down
7 changes: 0 additions & 7 deletions compiler/rustc_hir_analysis/src/errors/pattern_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@ use rustc_macros::Diagnostic;
use rustc_middle::ty::Ty;
use rustc_span::Span;

#[derive(Diagnostic)]
#[diag(hir_analysis_pattern_type_wild_pat)]
pub(crate) struct WildPatTy {
#[primary_span]
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(hir_analysis_invalid_base_type)]
pub(crate) struct InvalidBaseType<'tcx> {
Expand Down
Loading