Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tf32 type printing #9

Merged
merged 1 commit into from
Sep 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1240,10 +1240,10 @@ void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOL
void CodeGenCUDA::PrintWmmaScope(const runtime::StorageScope& scope, DataType t,
const VarNode* variable, std::ostream& os) {
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes.at(variable);
if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
type.str(std::string());
if (scope.tag == ".tf32") {
type << "nvcuda::wmma::precision::tf32";
} else if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
if (t.is_int()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::s4";
Expand All @@ -1259,6 +1259,8 @@ void CodeGenCUDA::PrintWmmaScope(const runtime::StorageScope& scope, DataType t,
LOG(FATAL) << "Unhandled integer type for wmma fragment!";
}
}
} else {
PrintType(t, type);
}
if (scope.rank == runtime::StorageRank::kWMMAMatrixA) {
need_mma_h_ = true;
Expand Down