Skip to content

Commit

Permalink
Use the .Self symbolic binding as 1st arg of WhereExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
josh11b committed Sep 26, 2024
1 parent 044ad96 commit c55078f
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 23 deletions.
3 changes: 2 additions & 1 deletion toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,8 @@ auto TryEvalInstInContext(EvalContext& eval_context, SemIR::InstId inst_id,
case CARBON_KIND(SemIR::WhereExpr typed_inst): {
// TODO: This currently ignores the requirements and just produces the
// left-hand type argument to the `where`.
return eval_context.GetConstantValue(typed_inst.lhs_id);
return eval_context.GetConstantValue(
eval_context.insts().Get(typed_inst.period_self_id).type_id());
}

// `not true` -> `false`, `not false` -> `true`.
Expand Down
11 changes: 6 additions & 5 deletions toolchain/check/handle_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ auto HandleParseNode(Context& context, Parse::WhereOperandId node_id) -> bool {
// TODO: Validate that `self_type_id` represents a facet type. Only facet
// types may have `where` restrictions.

// FIXME: Should we save the self_id here instead?
context.node_stack().Push(node_id, self_type_id);

// Introduce a name scope so that we can remove the `.Self` entry we are
// adding to name lookup at the end of the `where` expression.
context.scope_stack().Push();
Expand All @@ -45,6 +42,10 @@ auto HandleParseNode(Context& context, Parse::WhereOperandId node_id) -> bool {
// Shouldn't have any names in newly created scope.
CARBON_CHECK(!existing.is_valid());

// Save the `.Self` symbolic binding on the node stack. It will become the
// first argument to the `WhereExpr` instruction.
context.node_stack().Push(node_id, inst_id);

// Going to put each requirement on `args_type_info_stack`, so we can have an
// inst block with the varying number of requirements but keeping other
// instructions on the current inst block from the `inst_block_stack()`.
Expand Down Expand Up @@ -104,12 +105,12 @@ auto HandleParseNode(Context& context, Parse::WhereExprId node_id) -> bool {
// Remove `PeriodSelf` from name lookup, undoing the `Push` done for the
// `WhereOperand`.
context.scope_stack().Pop();
SemIR::TypeId lhs_type_id =
SemIR::InstId period_self_id =
context.node_stack().Pop<Parse::NodeKind::WhereOperand>();
SemIR::InstBlockId requirements_id = context.args_type_info_stack().Pop();
context.AddInstAndPush<SemIR::WhereExpr>(
node_id, {.type_id = SemIR::TypeId::TypeType,
.lhs_id = lhs_type_id,
.period_self_id = period_self_id,
.requirements_id = requirements_id});
return true;
}
Expand Down
3 changes: 1 addition & 2 deletions toolchain/check/node_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ class NodeStack {
case Parse::NodeKind::ShortCircuitOperandOr:
case Parse::NodeKind::StructField:
case Parse::NodeKind::StructTypeField:
case Parse::NodeKind::WhereOperand:
return Id::KindFor<SemIR::InstId>();
case Parse::NodeKind::IfCondition:
case Parse::NodeKind::IfExprIf:
Expand All @@ -449,8 +450,6 @@ class NodeStack {
case Parse::NodeKind::DefaultLibrary:
case Parse::NodeKind::LibraryName:
return Id::KindFor<SemIR::LibraryNameId>();
case Parse::NodeKind::WhereOperand:
return Id::KindFor<SemIR::TypeId>();
case Parse::NodeKind::ArrayExprSemi:
case Parse::NodeKind::BuiltinName:
case Parse::NodeKind::ClassIntroducer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl bool as I where .T = bool {}
// CHECK:STDOUT: %.Self.ref: %.1 = name_ref .Self, %.Self [symbolic = constants.%.Self]
// CHECK:STDOUT: %T.ref: %.2 = name_ref T, @I.%.loc11 [template = constants.%.3]
// CHECK:STDOUT: %bool.make_type.loc16_27: init type = call constants.%Bool() [template = bool]
// CHECK:STDOUT: %.loc16_16: type = where_expr %.1 [template = constants.%.1] {
// CHECK:STDOUT: %.loc16_16: type = where_expr %.Self [template = constants.%.1] {
// CHECK:STDOUT: requirement_rewrite %T.ref, %bool.make_type.loc16_27
// CHECK:STDOUT: }
// CHECK:STDOUT: }
Expand Down
12 changes: 6 additions & 6 deletions toolchain/check/testdata/where_expr/constraints.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ fn NotEmptyStruct() {
// CHECK:STDOUT: %.Self.ref: %.2 = name_ref .Self, %.Self [symbolic = constants.%.Self.1]
// CHECK:STDOUT: %Member.ref: %.3 = name_ref Member, @I.%.loc7 [template = constants.%.4]
// CHECK:STDOUT: %.loc11_33: %.8 = struct_literal ()
// CHECK:STDOUT: %.loc11_16: type = where_expr %.2 [template = constants.%.2] {
// CHECK:STDOUT: %.loc11_16: type = where_expr %.Self [template = constants.%.2] {
// CHECK:STDOUT: requirement_rewrite %Member.ref, %.loc11_33
// CHECK:STDOUT: }
// CHECK:STDOUT: %T.param: %.2 = param T, runtime_param<invalid>
Expand All @@ -155,7 +155,7 @@ fn NotEmptyStruct() {
// CHECK:STDOUT: %.Self: %.2 = bind_symbolic_name .Self 0 [symbolic = constants.%.Self.1]
// CHECK:STDOUT: %.Self.ref: %.2 = name_ref .Self, %.Self [symbolic = constants.%.Self.1]
// CHECK:STDOUT: %.loc13_37: %.5 = tuple_literal ()
// CHECK:STDOUT: %.loc13_21: type = where_expr %.2 [template = constants.%.2] {
// CHECK:STDOUT: %.loc13_21: type = where_expr %.Self [template = constants.%.2] {
// CHECK:STDOUT: requirement_equivalent %.Self.ref, %.loc13_37
// CHECK:STDOUT: }
// CHECK:STDOUT: %U.param: %.2 = param U, runtime_param<invalid>
Expand All @@ -168,7 +168,7 @@ fn NotEmptyStruct() {
// CHECK:STDOUT: %.Self: %.1 = bind_symbolic_name .Self 0 [symbolic = constants.%.Self.2]
// CHECK:STDOUT: %.Self.ref: %.1 = name_ref .Self, %.Self [symbolic = constants.%.Self.2]
// CHECK:STDOUT: %I.ref: type = name_ref I, file.%I.decl [template = constants.%.2]
// CHECK:STDOUT: %.loc15: type = where_expr %.1 [template = constants.%.1] {
// CHECK:STDOUT: %.loc15: type = where_expr %.Self [template = constants.%.1] {
// CHECK:STDOUT: requirement_impls %.Self.ref, %I.ref
// CHECK:STDOUT: }
// CHECK:STDOUT: %V.param: %.1 = param V, runtime_param<invalid>
Expand All @@ -186,7 +186,7 @@ fn NotEmptyStruct() {
// CHECK:STDOUT: %Member.ref: %.3 = name_ref Member, @I.%.loc7 [template = constants.%.4]
// CHECK:STDOUT: %.Self.ref.loc17_50: %.2 = name_ref .Self, %.Self [symbolic = constants.%.Self.1]
// CHECK:STDOUT: %Second.ref.loc17_50: %.6 = name_ref Second, @I.%.loc8 [template = constants.%.7]
// CHECK:STDOUT: %.loc17: type = where_expr %.2 [template = constants.%.2] {
// CHECK:STDOUT: %.loc17: type = where_expr %.Self [template = constants.%.2] {
// CHECK:STDOUT: requirement_impls %Second.ref.loc17_20, %I.ref.loc17_34
// CHECK:STDOUT: requirement_rewrite %Member.ref, %Second.ref.loc17_50
// CHECK:STDOUT: }
Expand Down Expand Up @@ -319,7 +319,7 @@ fn NotEmptyStruct() {
// CHECK:STDOUT: %.Self.ref: %.1 = name_ref .Self, %.Self [symbolic = constants.%.Self]
// CHECK:STDOUT: %Member.ref: %.3 = name_ref Member, imports.%import_ref.8 [template = constants.%.4]
// CHECK:STDOUT: %.loc8_39: i32 = int_literal 2 [template = constants.%.5]
// CHECK:STDOUT: %.loc8_23: type = where_expr %.1 [template = constants.%.1] {
// CHECK:STDOUT: %.loc8_23: type = where_expr %.Self [template = constants.%.1] {
// CHECK:STDOUT: requirement_rewrite %Member.ref, %.loc8_39
// CHECK:STDOUT: }
// CHECK:STDOUT: %X.param: %.1 = param X, runtime_param<invalid>
Expand Down Expand Up @@ -440,7 +440,7 @@ fn NotEmptyStruct() {
// CHECK:STDOUT: %.Self: %.3 = bind_symbolic_name .Self 0 [symbolic = constants.%.Self]
// CHECK:STDOUT: %.Self.ref: %.3 = name_ref .Self, %.Self [symbolic = constants.%.Self]
// CHECK:STDOUT: %.loc26_38: %.1 = struct_literal ()
// CHECK:STDOUT: %.loc26_22: type = where_expr %.3 [template = constants.%.3] {
// CHECK:STDOUT: %.loc26_22: type = where_expr %.Self [template = constants.%.3] {
// CHECK:STDOUT: requirement_equivalent %.Self.ref, %.loc26_38
// CHECK:STDOUT: }
// CHECK:STDOUT: %Y.param: %.3 = param Y, runtime_param<invalid>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class D {
// CHECK:STDOUT: %.Self: %.1 = bind_symbolic_name .Self 0 [symbolic = constants.%.Self.1]
// CHECK:STDOUT: %.Self.ref: %.1 = name_ref .Self, %.Self [symbolic = constants.%.Self.1]
// CHECK:STDOUT: %.loc8_37: %.4 = tuple_literal ()
// CHECK:STDOUT: %.loc8_21: type = where_expr %.1 [template = constants.%.1] {
// CHECK:STDOUT: %.loc8_21: type = where_expr %.Self [template = constants.%.1] {
// CHECK:STDOUT: requirement_equivalent %.Self.ref, %.loc8_37
// CHECK:STDOUT: }
// CHECK:STDOUT: %T.param: %.1 = param T, runtime_param<invalid>
Expand All @@ -137,7 +137,7 @@ class D {
// CHECK:STDOUT: %.Self.ref: %.1 = name_ref .Self, %.Self [symbolic = constants.%.Self.1]
// CHECK:STDOUT: %Member.ref: %.2 = name_ref Member, @I.%.loc5 [template = constants.%.3]
// CHECK:STDOUT: %.loc10_40: %.5 = struct_literal ()
// CHECK:STDOUT: %.loc10_23: type = where_expr %.1 [template = constants.%.1] {
// CHECK:STDOUT: %.loc10_23: type = where_expr %.Self [template = constants.%.1] {
// CHECK:STDOUT: requirement_rewrite %Member.ref, %.loc10_40
// CHECK:STDOUT: }
// CHECK:STDOUT: %U.param: %.1 = param U, runtime_param<invalid>
Expand All @@ -149,7 +149,7 @@ class D {
// CHECK:STDOUT: %.Self: type = bind_symbolic_name .Self 0 [symbolic = constants.%.Self.2]
// CHECK:STDOUT: %.Self.ref: type = name_ref .Self, %.Self [symbolic = constants.%.Self.2]
// CHECK:STDOUT: %I.ref: type = name_ref I, file.%I.decl [template = constants.%.1]
// CHECK:STDOUT: %.loc12: type = where_expr type [template = type] {
// CHECK:STDOUT: %.loc12: type = where_expr %.Self [template = type] {
// CHECK:STDOUT: requirement_impls %.Self.ref, %I.ref
// CHECK:STDOUT: }
// CHECK:STDOUT: %V.param: type = param V, runtime_param<invalid>
Expand Down Expand Up @@ -227,7 +227,7 @@ class D {
// CHECK:STDOUT: %.Self.ref: %.1 = name_ref .Self, %.Self [symbolic = constants.%.Self]
// CHECK:STDOUT: %Mismatch.ref: <error> = name_ref Mismatch, <error> [template = <error>]
// CHECK:STDOUT: %.loc12_44: %.5 = struct_literal ()
// CHECK:STDOUT: %.loc12_25: type = where_expr %.1 [template = constants.%.1] {
// CHECK:STDOUT: %.loc12_25: type = where_expr %.Self [template = constants.%.1] {
// CHECK:STDOUT: requirement_rewrite %Mismatch.ref, %.loc12_44
// CHECK:STDOUT: }
// CHECK:STDOUT: %W.param: %.1 = param W, runtime_param<invalid>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn NotGenericF(U: I where .T = {}) {}
// CHECK:STDOUT: %.Self.ref: %.1 = name_ref .Self, %.Self [symbolic = constants.%.Self]
// CHECK:STDOUT: %T.ref: %.2 = name_ref T, @I.%.loc11 [template = constants.%.3]
// CHECK:STDOUT: %.loc14_33: %.5 = struct_literal ()
// CHECK:STDOUT: %.loc14_21: type = where_expr %.1 [template = constants.%.1] {
// CHECK:STDOUT: %.loc14_21: type = where_expr %.Self [template = constants.%.1] {
// CHECK:STDOUT: requirement_rewrite %T.ref, %.loc14_33
// CHECK:STDOUT: }
// CHECK:STDOUT: %U.param: %.1 = param U, runtime_param0
Expand Down
3 changes: 2 additions & 1 deletion toolchain/sem_ir/file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ static auto StringifyTypeExprImpl(const SemIR::File& outer_sem_ir,
if (step.index == 0) {
out << "<where restriction on ";
steps.push_back(step.Next());
push_inst_id(sem_ir.types().GetInstId(inst.lhs_id));
TypeId type_id = sem_ir.insts().Get(inst.period_self_id).type_id();
push_inst_id(sem_ir.types().GetInstId(type_id));
// TODO: also output restrictions from the inst block
// inst.requirements_id
} else {
Expand Down
2 changes: 1 addition & 1 deletion toolchain/sem_ir/formatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ class FormatterImpl {
}

auto FormatInstRHS(WhereExpr inst) -> void {
FormatArgs(inst.lhs_id);
FormatArgs(inst.period_self_id);
FormatTrailingBlock(inst.requirements_id);
}

Expand Down
4 changes: 3 additions & 1 deletion toolchain/sem_ir/typed_insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,9 @@ struct WhereExpr {
.constant_kind = InstConstantKind::Conditional});

TypeId type_id;
TypeId lhs_id;
// This is the `.Self` symbolic binding. Its type matches the left type
// argument of the `where`.
InstId period_self_id;
InstBlockId requirements_id;
};

Expand Down

0 comments on commit c55078f

Please sign in to comment.