forked from taichi-dev/taichi
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[llvm] [aot] CUDA-AOT PR taichi-dev#2: Implemented AOT Module Loader …
…for LLVM-CUDA backend
- Loading branch information
1 parent
41c9736
commit f77d75c
Showing
4 changed files
with
94 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#include "taichi/backends/cuda/aot_module_loader_impl.h" | ||
#include "taichi/llvm/llvm_aot_module_loader.h" | ||
|
||
#include "taichi/llvm/llvm_offline_cache.h" | ||
#include "taichi/llvm/llvm_program.h" | ||
#include "taichi/backends/cuda/codegen_cuda.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
namespace { | ||
|
||
class AotModuleImpl : public LlvmAotModule { | ||
public: | ||
explicit AotModuleImpl(const cpu::AotModuleParams ¶ms) | ||
: LlvmAotModule(params.module_path, params.program) { | ||
} | ||
|
||
private: | ||
FunctionType convert_module_to_function( | ||
const std::string &name, | ||
LlvmOfflineCache::KernelCacheData &&loaded) override { | ||
Arch arch = program_->config->arch; | ||
TI_ASSERT(arch == Arch::cuda); | ||
auto *tlctx = program_->get_llvm_context(arch); | ||
|
||
const auto &tasks = loaded.offloaded_task_list; | ||
std::vector<OffloadedTask> offloaded_tasks; | ||
offloaded_tasks.reserve(tasks.size()); | ||
for (const auto &t : tasks) { | ||
OffloadedTask ot{/*codegen=*/nullptr}; | ||
ot.name = t.name; | ||
ot.block_dim = t.block_dim; | ||
ot.grid_dim = t.grid_dim; | ||
offloaded_tasks.push_back(std::move(ot)); | ||
} | ||
|
||
CUDAModuleToFunctionConverter converter{tlctx, program_}; | ||
return converter.convert(name, loaded.args, std::move(loaded.owned_module), | ||
std::move(offloaded_tasks)); | ||
} | ||
|
||
std::unique_ptr<aot::KernelTemplate> make_new_kernel_template( | ||
const std::string &name) override { | ||
TI_NOT_IMPLEMENTED; | ||
return nullptr; | ||
} | ||
|
||
std::unique_ptr<aot::Field> make_new_field(const std::string &name) override { | ||
TI_NOT_IMPLEMENTED; | ||
return nullptr; | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace cuda { | ||
|
||
std::unique_ptr<aot::Module> make_aot_module(std::any mod_params) { | ||
auto mod = std::make_unique<AotModuleImpl>( | ||
std::any_cast<const AotModuleParams &>(mod_params)); | ||
return mod; | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
|
||
#include "taichi/aot/module_loader.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
|
||
class LlvmProgramImpl; | ||
|
||
namespace cuda { | ||
|
||
struct TI_DLL_EXPORT AotModuleParams { | ||
std::string module_path; | ||
LlvmProgramImpl *program{nullptr}; | ||
}; | ||
|
||
TI_DLL_EXPORT std::unique_ptr<aot::Module> make_aot_module(std::any mod_params); | ||
|
||
} // namespace cuda | ||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters