Skip to content

Commit

Permalink
Basic lowering generic function definitions. (#5016)
Browse files Browse the repository at this point in the history
Resolve the specific type for the callee, to lower the proper specific
function called.
  • Loading branch information
alinas authored Mar 4, 2025
1 parent 92b3e61 commit 4e21c0c
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 21 deletions.
21 changes: 17 additions & 4 deletions toolchain/lower/file_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
// TODO: Consider tracking whether the function has been used, and only
// lowering it if it's needed.

// TODO nit: add is_symbolic() to type_id to forward to
// type_id.AsConstantId().is_symbolic(). Update call below too.
auto get_llvm_type = [&](SemIR::TypeId type_id) -> llvm::Type* {
if (!type_id.has_value()) {
return nullptr;
}
return GetType(SemIR::GetTypeInSpecific(sem_ir(), specific_id, type_id));
};

const auto return_info =
SemIR::ReturnTypeInfo::ForFunction(sem_ir(), function, specific_id);
CARBON_CHECK(return_info.is_valid(), "Should not lower invalid functions.");
Expand All @@ -229,8 +238,7 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
auto param_patterns =
sem_ir().inst_blocks().GetOrEmpty(function.param_patterns_id);

auto* return_type =
return_info.type_id.has_value() ? GetType(return_info.type_id) : nullptr;
auto* return_type = get_llvm_type(return_info.type_id);

llvm::SmallVector<llvm::Type*> param_types;
// TODO: Consider either storing `param_inst_ids` somewhere so that we can
Expand Down Expand Up @@ -259,6 +267,10 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
}
auto param_type_id =
SemIR::GetTypeInSpecific(sem_ir(), specific_id, param_pattern.type_id);
CARBON_CHECK(
!param_type_id.AsConstantId().is_symbolic(),
"Found symbolic type id after resolution when lowering type {0}.",
param_pattern.type_id);
switch (auto value_rep = SemIR::ValueRepr::ForType(sem_ir(), param_type_id);
value_rep.kind) {
case SemIR::ValueRepr::Unknown:
Expand All @@ -268,7 +280,8 @@ auto FileContext::BuildFunctionDecl(SemIR::FunctionId function_id,
case SemIR::ValueRepr::Copy:
case SemIR::ValueRepr::Custom:
case SemIR::ValueRepr::Pointer:
param_types.push_back(GetType(value_rep.type_id));
auto* param_types_to_add = get_llvm_type(value_rep.type_id);
param_types.push_back(param_types_to_add);
param_inst_ids.push_back(param_pattern_id);
break;
}
Expand Down Expand Up @@ -349,7 +362,7 @@ auto FileContext::BuildFunctionBody(SemIR::FunctionId function_id,
CARBON_DCHECK(!body_block_ids.empty(),
"No function body blocks found during lowering.");

FunctionContext function_lowering(*this, llvm_function,
FunctionContext function_lowering(*this, llvm_function, specific_id,
BuildDISubprogram(function, llvm_function),
vlog_stream_);

Expand Down
2 changes: 2 additions & 0 deletions toolchain/lower/function_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ namespace Carbon::Lower {

FunctionContext::FunctionContext(FileContext& file_context,
llvm::Function* function,
SemIR::SpecificId specific_id,
llvm::DISubprogram* di_subprogram,
llvm::raw_ostream* vlog_stream)
: file_context_(&file_context),
function_(function),
specific_id_(specific_id),
builder_(file_context.llvm_context(), llvm::ConstantFolder(),
Inserter(file_context.inst_namer())),
di_subprogram_(di_subprogram),
Expand Down
7 changes: 7 additions & 0 deletions toolchain/lower/function_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace Carbon::Lower {
class FunctionContext {
public:
explicit FunctionContext(FileContext& file_context, llvm::Function* function,
SemIR::SpecificId specific_id,
llvm::DISubprogram* di_subprogram,
llvm::raw_ostream* vlog_stream);

Expand All @@ -45,6 +46,8 @@ class FunctionContext {

// Returns a value for the given instruction.
auto GetValue(SemIR::InstId inst_id) -> llvm::Value* {
// TODO: if(specific_id_.has_value()) may need to update inst_id first.

// All builtins are types, with the same empty lowered value.
if (SemIR::IsSingletonInstId(inst_id)) {
return GetTypeAsValue();
Expand Down Expand Up @@ -130,6 +133,7 @@ class FunctionContext {
}
auto llvm_module() -> llvm::Module& { return file_context_->llvm_module(); }
auto llvm_function() -> llvm::Function& { return *function_; }
auto specific_id() -> SemIR::SpecificId { return specific_id_; }
auto builder() -> llvm::IRBuilderBase& { return builder_; }
auto sem_ir() -> const SemIR::File& { return file_context_->sem_ir(); }

Expand Down Expand Up @@ -174,6 +178,9 @@ class FunctionContext {
// The IR function we're generating.
llvm::Function* function_;

// The specific id, if the function is a specific.
SemIR::SpecificId specific_id_;

// Builder for creating code in this function. The insertion point is held at
// the location of the current SemIR instruction.
llvm::IRBuilder<llvm::ConstantFolder, Inserter> builder_;
Expand Down
9 changes: 6 additions & 3 deletions toolchain/lower/handle_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
llvm::ArrayRef<SemIR::InstId> arg_ids =
context.sem_ir().inst_blocks().Get(inst.args_id);

auto callee_function =
SemIR::GetCalleeFunction(context.sem_ir(), inst.callee_id);
auto callee_function = SemIR::GetCalleeFunction(
context.sem_ir(), inst.callee_id, context.specific_id());
CARBON_CHECK(callee_function.function_id.has_value());

if (auto builtin_kind = context.sem_ir()
Expand All @@ -442,7 +442,10 @@ auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,

std::vector<llvm::Value*> args;

if (SemIR::ReturnTypeInfo::ForType(context.sem_ir(), inst.type_id)
auto inst_type_id = SemIR::GetTypeInSpecific(
context.sem_ir(), callee_function.resolved_specific_id, inst.type_id);

if (SemIR::ReturnTypeInfo::ForType(context.sem_ir(), inst_type_id)
.has_return_slot()) {
args.push_back(context.GetValue(arg_ids.back()));
arg_ids = arg_ids.drop_back();
Expand Down
42 changes: 30 additions & 12 deletions toolchain/lower/testdata/function/generic/call_basic.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ fn H[T:! type](x: T) {
}

fn G[T:! type](x: T) -> T {
// TODO: the call below is crashing because proper type resolution to
// use the G specific, not the G generic is not done yet.
H(x);
// TODO: Call crashes, see TODO in FunctionContext::GetValue()
// H(T);
return x;
}
Expand Down Expand Up @@ -72,17 +72,29 @@ fn M() {
// CHECK:STDOUT:
// CHECK:STDOUT: define i32 @_CG.Main.b88d1103f417c6d4(i32 %x) !dbg !22 {
// CHECK:STDOUT: entry:
// CHECK:STDOUT: ret i32 %x, !dbg !23
// CHECK:STDOUT: call void @_CH.Main.b88d1103f417c6d4(i32 %x), !dbg !23
// CHECK:STDOUT: ret i32 %x, !dbg !24
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define void @_CF.Main.66be507887ceee78(double %x) !dbg !24 {
// CHECK:STDOUT: define void @_CF.Main.66be507887ceee78(double %x) !dbg !25 {
// CHECK:STDOUT: entry:
// CHECK:STDOUT: ret void, !dbg !25
// CHECK:STDOUT: ret void, !dbg !26
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define double @_CG.Main.66be507887ceee78(double %x) !dbg !26 {
// CHECK:STDOUT: define double @_CG.Main.66be507887ceee78(double %x) !dbg !27 {
// CHECK:STDOUT: entry:
// CHECK:STDOUT: ret double %x, !dbg !27
// CHECK:STDOUT: call void @_CH.Main.66be507887ceee78(double %x), !dbg !28
// CHECK:STDOUT: ret double %x, !dbg !29
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define void @_CH.Main.b88d1103f417c6d4(i32 %x) !dbg !30 {
// CHECK:STDOUT: entry:
// CHECK:STDOUT: ret void, !dbg !31
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define void @_CH.Main.66be507887ceee78(double %x) !dbg !32 {
// CHECK:STDOUT: entry:
// CHECK:STDOUT: ret void, !dbg !33
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: ; uselistorder directives
Expand Down Expand Up @@ -116,8 +128,14 @@ fn M() {
// CHECK:STDOUT: !20 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.b88d1103f417c6d4", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !21 = !DILocation(line: 11, column: 1, scope: !20)
// CHECK:STDOUT: !22 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.b88d1103f417c6d4", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !23 = !DILocation(line: 21, column: 3, scope: !22)
// CHECK:STDOUT: !24 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.66be507887ceee78", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !25 = !DILocation(line: 11, column: 1, scope: !24)
// CHECK:STDOUT: !26 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.66be507887ceee78", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !27 = !DILocation(line: 21, column: 3, scope: !26)
// CHECK:STDOUT: !23 = !DILocation(line: 18, column: 3, scope: !22)
// CHECK:STDOUT: !24 = !DILocation(line: 21, column: 3, scope: !22)
// CHECK:STDOUT: !25 = distinct !DISubprogram(name: "F", linkageName: "_CF.Main.66be507887ceee78", scope: null, file: !3, line: 11, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !26 = !DILocation(line: 11, column: 1, scope: !25)
// CHECK:STDOUT: !27 = distinct !DISubprogram(name: "G", linkageName: "_CG.Main.66be507887ceee78", scope: null, file: !3, line: 17, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !28 = !DILocation(line: 18, column: 3, scope: !27)
// CHECK:STDOUT: !29 = !DILocation(line: 21, column: 3, scope: !27)
// CHECK:STDOUT: !30 = distinct !DISubprogram(name: "H", linkageName: "_CH.Main.b88d1103f417c6d4", scope: null, file: !3, line: 14, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !31 = !DILocation(line: 14, column: 1, scope: !30)
// CHECK:STDOUT: !32 = distinct !DISubprogram(name: "H", linkageName: "_CH.Main.66be507887ceee78", scope: null, file: !3, line: 14, type: !5, spFlags: DISPFlagDefinition, unit: !2)
// CHECK:STDOUT: !33 = !DILocation(line: 14, column: 1, scope: !32)
7 changes: 6 additions & 1 deletion toolchain/sem_ir/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@

namespace Carbon::SemIR {

auto GetCalleeFunction(const File& sem_ir, InstId callee_id) -> CalleeFunction {
auto GetCalleeFunction(const File& sem_ir, InstId callee_id,
SpecificId specific_id) -> CalleeFunction {
CalleeFunction result = {.function_id = FunctionId::None,
.enclosing_specific_id = SpecificId::None,
.resolved_specific_id = SpecificId::None,
.self_type_id = InstId::None,
.self_id = InstId::None,
.is_error = false};
if (specific_id.has_value()) {
callee_id = sem_ir.constant_values().GetInstIdIfValid(
GetConstantValueInSpecific(sem_ir, specific_id, callee_id));
}

if (auto specific_function =
sem_ir.insts().TryGetAs<SpecificFunction>(callee_id)) {
Expand Down
4 changes: 3 additions & 1 deletion toolchain/sem_ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ struct CalleeFunction {
};

// Returns information for the function corresponding to callee_id.
auto GetCalleeFunction(const File& sem_ir, InstId callee_id) -> CalleeFunction;
auto GetCalleeFunction(const File& sem_ir, InstId callee_id,
SpecificId specific_id = SpecificId::None)
-> CalleeFunction;

} // namespace Carbon::SemIR

Expand Down

0 comments on commit 4e21c0c

Please sign in to comment.