From 4e21c0c1fc81519506add1e3ba68c514d21c4a0a Mon Sep 17 00:00:00 2001 From: Alina Sbirlea Date: Tue, 4 Mar 2025 10:41:29 -0800 Subject: [PATCH] Basic lowering generic function definitions. (#5016) Resolve the specific type for the callee, to lower the proper specific function called. --- toolchain/lower/file_context.cpp | 21 ++++++++-- toolchain/lower/function_context.cpp | 2 + toolchain/lower/function_context.h | 7 ++++ toolchain/lower/handle_call.cpp | 9 ++-- .../function/generic/call_basic.carbon | 42 +++++++++++++------ toolchain/sem_ir/function.cpp | 7 +++- toolchain/sem_ir/function.h | 4 +- 7 files changed, 71 insertions(+), 21 deletions(-) diff --git a/toolchain/lower/file_context.cpp b/toolchain/lower/file_context.cpp index 5156ec729402c..a085c29ee2cf1 100644 --- a/toolchain/lower/file_context.cpp +++ b/toolchain/lower/file_context.cpp @@ -219,6 +219,15 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id, // TODO: Consider tracking whether the function has been used, and only // lowering it if it's needed. + // TODO nit: add is_symbolic() to type_id to forward to + // type_id.AsConstantId().is_symbolic(). Update call below too. + auto get_llvm_type = [&](SemIR::TypeId type_id) -> llvm::Type* { + if (!type_id.has_value()) { + return nullptr; + } + return GetType(SemIR::GetTypeInSpecific(sem_ir(), specific_id, type_id)); + }; + const auto return_info = SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id); CARBON_CHECK(return_info.is_valid(), "Should not lower invalid functions."); @@ -229,8 +238,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id, auto param_patterns = sem_ir().inst_blocks().GetOrEmpty(function.param_patterns_id); - auto* return_type = - return_info.type_id.has_value() ? GetType(return_info.type_id) : nullptr; + auto* return_type = get_llvm_type(return_info.type_id); llvm::SmallVector param_types; // TODO: Consider either storing `param_inst_ids` somewhere so that we can @@ -259,6 +267,10 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id, } auto param_type_id = SemIR::GetTypeInSpecific(sem_ir(), specific_id, param_pattern.type_id); + CARBON_CHECK( + !param_type_id.AsConstantId().is_symbolic(), + "Found symbolic type id after resolution when lowering type {0}.", + param_pattern.type_id); switch (auto value_rep = SemIR::ValueRepr::ForType(sem_ir(), param_type_id); value_rep.kind) { case SemIR::ValueRepr::Unknown: @@ -268,7 +280,8 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id, case SemIR::ValueRepr::Copy: case SemIR::ValueRepr::Custom: case SemIR::ValueRepr::Pointer: - param_types.push_back(GetType(value_rep.type_id)); + auto* param_types_to_add = get_llvm_type(value_rep.type_id); + param_types.push_back(param_types_to_add); param_inst_ids.push_back(param_pattern_id); break; } @@ -349,7 +362,7 @@ auto FileContext::BuildFunctionBody(SemIR::FunctionId function_id, CARBON_DCHECK(!body_block_ids.empty(), "No function body blocks found during lowering."); - FunctionContext function_lowering(*this, llvm_function, + FunctionContext function_lowering(*this, llvm_function, specific_id, BuildDISubprogram(function, llvm_function), vlog_stream_); diff --git a/toolchain/lower/function_context.cpp b/toolchain/lower/function_context.cpp index 2ae307314e806..cb5aa4b7e537a 100644 --- a/toolchain/lower/function_context.cpp +++ b/toolchain/lower/function_context.cpp @@ -12,10 +12,12 @@ namespace Carbon::Lower { FunctionContext::FunctionContext(FileContext& file_context, llvm::Function* function, + SemIR::SpecificId specific_id, llvm::DISubprogram* di_subprogram, llvm::raw_ostream* vlog_stream) : file_context_(&file_context), function_(function), + specific_id_(specific_id), builder_(file_context.llvm_context(), llvm::ConstantFolder(), Inserter(file_context.inst_namer())), di_subprogram_(di_subprogram), diff --git a/toolchain/lower/function_context.h b/toolchain/lower/function_context.h index d0ec940d5e77f..2780fa81cc3b4 100644 --- a/toolchain/lower/function_context.h +++ b/toolchain/lower/function_context.h @@ -19,6 +19,7 @@ namespace Carbon::Lower { class FunctionContext { public: explicit FunctionContext(FileContext& file_context, llvm::Function* function, + SemIR::SpecificId specific_id, llvm::DISubprogram* di_subprogram, llvm::raw_ostream* vlog_stream); @@ -45,6 +46,8 @@ class FunctionContext { // Returns a value for the given instruction. auto GetValue(SemIR::InstId inst_id) -> llvm::Value* { + // TODO: if(specific_id_.has_value()) may need to update inst_id first. + // All builtins are types, with the same empty lowered value. if (SemIR::IsSingletonInstId(inst_id)) { return GetTypeAsValue(); @@ -130,6 +133,7 @@ class FunctionContext { } auto llvm_module() -> llvm::Module& { return file_context_->llvm_module(); } auto llvm_function() -> llvm::Function& { return *function_; } + auto specific_id() -> SemIR::SpecificId { return specific_id_; } auto builder() -> llvm::IRBuilderBase& { return builder_; } auto sem_ir() -> const SemIR::File& { return file_context_->sem_ir(); } @@ -174,6 +178,9 @@ class FunctionContext { // The IR function we're generating. llvm::Function* function_; + // The specific id, if the function is a specific. + SemIR::SpecificId specific_id_; + // Builder for creating code in this function. The insertion point is held at // the location of the current SemIR instruction. llvm::IRBuilder builder_; diff --git a/toolchain/lower/handle_call.cpp b/toolchain/lower/handle_call.cpp index 31b0913f8e40a..c76e6ea4df07d 100644 --- a/toolchain/lower/handle_call.cpp +++ b/toolchain/lower/handle_call.cpp @@ -424,8 +424,8 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id, llvm::ArrayRef arg_ids = context.sem_ir().inst_blocks().Get(inst.args_id); - auto callee_function = - SemIR::GetCalleeFunction(context.sem_ir(), inst.callee_id); + auto callee_function = SemIR::GetCalleeFunction( + context.sem_ir(), inst.callee_id, context.specific_id()); CARBON_CHECK(callee_function.function_id.has_value()); if (auto builtin_kind = context.sem_ir() @@ -442,7 +442,10 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id, std::vector args; - if (SemIR::ReturnTypeInfo::ForType(context.sem_ir(), inst.type_id) + auto inst_type_id = SemIR::GetTypeInSpecific( + context.sem_ir(), callee_function.resolved_specific_id, inst.type_id); + + if (SemIR::ReturnTypeInfo::ForType(context.sem_ir(), inst_type_id) .has_return_slot()) { args.push_back(context.GetValue(arg_ids.back())); arg_ids = arg_ids.drop_back(); diff --git a/toolchain/lower/testdata/function/generic/call_basic.carbon b/toolchain/lower/testdata/function/generic/call_basic.carbon index 5cc998a769aba..21fd9e71cccbb 100644 --- a/toolchain/lower/testdata/function/generic/call_basic.carbon +++ b/toolchain/lower/testdata/function/generic/call_basic.carbon @@ -15,8 +15,8 @@ fn H[T:! type](x: T) { } fn G[T:! type](x: T) -> T { - // TODO: the call below is crashing because proper type resolution to - // use the G specific, not the G generic is not done yet. + H(x); + // TODO: Call crashes, see TODO in FunctionContext::GetValue() // H(T); return x; } @@ -72,17 +72,29 @@ fn M() { // CHECK:STDOUT: // CHECK:STDOUT: define i32 @_CG.Main.b88d1103f417c6d4(i32 %x) !dbg !22 { // CHECK:STDOUT: entry: -// CHECK:STDOUT: ret i32 %x, !dbg !23 +// CHECK:STDOUT: call void @_CH.Main.b88d1103f417c6d4(i32 %x), !dbg !23 +// CHECK:STDOUT: ret i32 %x, !dbg !24 // CHECK:STDOUT: } // CHECK:STDOUT: -// CHECK:STDOUT: define void @_CF.Main.66be507887ceee78(double %x) !dbg !24 { +// CHECK:STDOUT: define void @_CF.Main.66be507887ceee78(double %x) !dbg !25 { // CHECK:STDOUT: entry: -// CHECK:STDOUT: ret void, !dbg !25 +// CHECK:STDOUT: ret void, !dbg !26 // CHECK:STDOUT: } // CHECK:STDOUT: -// CHECK:STDOUT: define double @_CG.Main.66be507887ceee78(double %x) !dbg !26 { +// CHECK:STDOUT: define double @_CG.Main.66be507887ceee78(double %x) !dbg !27 { // CHECK:STDOUT: entry: -// CHECK:STDOUT: ret double %x, !dbg !27 +// CHECK:STDOUT: call void @_CH.Main.66be507887ceee78(double %x), !dbg !28 +// CHECK:STDOUT: ret double %x, !dbg !29 +// CHECK:STDOUT: } +// CHECK:STDOUT: +// CHECK:STDOUT: define void @_CH.Main.b88d1103f417c6d4(i32 %x) !dbg !30 { +// CHECK:STDOUT: entry: +// CHECK:STDOUT: ret void, !dbg !31 +// CHECK:STDOUT: } +// CHECK:STDOUT: +// CHECK:STDOUT: define void @_CH.Main.66be507887ceee78(double %x) !dbg !32 { +// CHECK:STDOUT: entry: +// CHECK:STDOUT: ret void, !dbg !33 // CHECK:STDOUT: } // CHECK:STDOUT: // CHECK:STDOUT: ; uselistorder directives @@ -116,8 +128,14 @@ fn M() { // CHECK:STDOUT: !20 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.b88d1103f417c6d4", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2) // CHECK:STDOUT: !21 = !DILocation(line: 11, column: 1, scope: !20) // CHECK:STDOUT: !22 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.b88d1103f417c6d4", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2) -// CHECK:STDOUT: !23 = !DILocation(line: 21, column: 3, scope: !22) -// CHECK:STDOUT: !24 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.66be507887ceee78", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2) -// CHECK:STDOUT: !25 = !DILocation(line: 11, column: 1, scope: !24) -// CHECK:STDOUT: !26 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.66be507887ceee78", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2) -// CHECK:STDOUT: !27 = !DILocation(line: 21, column: 3, scope: !26) +// CHECK:STDOUT: !23 = !DILocation(line: 18, column: 3, scope: !22) +// CHECK:STDOUT: !24 = !DILocation(line: 21, column: 3, scope: !22) +// CHECK:STDOUT: !25 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.66be507887ceee78", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2) +// CHECK:STDOUT: !26 = !DILocation(line: 11, column: 1, scope: !25) +// CHECK:STDOUT: !27 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.66be507887ceee78", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2) +// CHECK:STDOUT: !28 = !DILocation(line: 18, column: 3, scope: !27) +// CHECK:STDOUT: !29 = !DILocation(line: 21, column: 3, scope: !27) +// CHECK:STDOUT: !30 = distinct !DISubprogram(name: "H", linkageName: "_CH.Main.b88d1103f417c6d4", scope: null, file: !3, line: 14, type: !5, spFlags: DISPFlagDefinition, unit: !2) +// CHECK:STDOUT: !31 = !DILocation(line: 14, column: 1, scope: !30) +// CHECK:STDOUT: !32 = distinct !DISubprogram(name: "H", linkageName: "_CH.Main.66be507887ceee78", scope: null, file: !3, line: 14, type: !5, spFlags: DISPFlagDefinition, unit: !2) +// CHECK:STDOUT: !33 = !DILocation(line: 14, column: 1, scope: !32) diff --git a/toolchain/sem_ir/function.cpp b/toolchain/sem_ir/function.cpp index ecca6e18cea32..872de836bf341 100644 --- a/toolchain/sem_ir/function.cpp +++ b/toolchain/sem_ir/function.cpp @@ -10,13 +10,18 @@ namespace Carbon::SemIR { -auto GetCalleeFunction(const File& sem_ir, InstId callee_id) -> CalleeFunction { +auto GetCalleeFunction(const File& sem_ir, InstId callee_id, + SpecificId specific_id) -> CalleeFunction { CalleeFunction result = {.function_id = FunctionId::None, .enclosing_specific_id = SpecificId::None, .resolved_specific_id = SpecificId::None, .self_type_id = InstId::None, .self_id = InstId::None, .is_error = false}; + if (specific_id.has_value()) { + callee_id = sem_ir.constant_values().GetInstIdIfValid( + GetConstantValueInSpecific(sem_ir, specific_id, callee_id)); + } if (auto specific_function = sem_ir.insts().TryGetAs(callee_id)) { diff --git a/toolchain/sem_ir/function.h b/toolchain/sem_ir/function.h index 1133031956928..b3e7c32952a23 100644 --- a/toolchain/sem_ir/function.h +++ b/toolchain/sem_ir/function.h @@ -115,7 +115,9 @@ struct CalleeFunction { }; // Returns information for the function corresponding to callee_id. -auto GetCalleeFunction(const File& sem_ir, InstId callee_id) -> CalleeFunction; +auto GetCalleeFunction(const File& sem_ir, InstId callee_id, + SpecificId specific_id = SpecificId::None) + -> CalleeFunction; } // namespace Carbon::SemIR