From 9e9400194c236bb1dd931f3aa64faf3db3d99eb7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 21 Sep 2022 11:47:43 -0700 Subject: [PATCH] add tf32 type printing --- src/target/source/codegen_cuda.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 5c9aa266a91f..3cc7acba877c 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1240,10 +1240,10 @@ void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOL void CodeGenCUDA::PrintWmmaScope(const runtime::StorageScope& scope, DataType t, const VarNode* variable, std::ostream& os) { std::stringstream type; - PrintType(t, type); std::string shape_str = fragment_shapes.at(variable); - if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { - type.str(std::string()); + if (scope.tag == ".tf32") { + type << "nvcuda::wmma::precision::tf32"; + } else if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { if (t.is_int()) { if (t.bits() == 4) { type << "nvcuda::wmma::experimental::precision::s4"; @@ -1259,6 +1259,8 @@ void CodeGenCUDA::PrintWmmaScope(const runtime::StorageScope& scope, DataType t, LOG(FATAL) << "Unhandled integer type for wmma fragment!"; } } + } else { + PrintType(t, type); } if (scope.rank == runtime::StorageRank::kWMMAMatrixA) { need_mma_h_ = true;