Skip to content

Commit

Permalink
replace hipblasLtComputeType_t with hipblasComputeType_t
Browse files Browse the repository at this point in the history
  • Loading branch information
jichangjichang committed Jan 15, 2024
1 parent 95131d6 commit 3aad0d8
Show file tree
Hide file tree
Showing 33 changed files with 176 additions and 232 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Full documentation for hipBLASLt is available at [rocm.docs.amd.com/projects/hip
### Changes

* Replaced `hipblasDatatype_t` with `hipDataType`
* Replaced `hipblasLtComputeType_t` with `hipblasComputeType_t`

### Removals

Expand Down
10 changes: 5 additions & 5 deletions clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,11 +751,11 @@ try
if(arg.d_type == HIPBLASLT_DATATYPE_INVALID)
throw std::invalid_argument("Invalid value for --d_type " + d_type);

bool is_f16 = arg.a_type == HIP_R_16F || arg.a_type == HIP_R_16BF;
bool is_f32 = arg.a_type == HIP_R_32F;
arg.compute_type = compute_type == "" ? (HIPBLASLT_COMPUTE_F32)
: string_to_hipblaslt_computetype(compute_type);
if(arg.compute_type == static_cast<hipblasLtComputeType_t>(0))
bool is_f16 = arg.a_type == HIP_R_16F || arg.a_type == HIP_R_16BF;
bool is_f32 = arg.a_type == HIP_R_32F;
arg.compute_type
= compute_type == "" ? (HIPBLAS_COMPUTE_32F) : string_to_hipblas_computetype(compute_type);
if(arg.compute_type == static_cast<hipblasComputeType_t>(0))
throw std::invalid_argument("Invalid value for --compute_type " + compute_type);

if(string_to_hip_datatype(bias_type) == HIPBLASLT_DATATYPE_INVALID && bias_type != ""
Expand Down
6 changes: 3 additions & 3 deletions clients/benchmarks/client_groupedgemm_fixed_mk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ int test_hipblaslt(hipDataType in_datatype,
in_datatype,
out_datatype,
out_datatype,
HIPBLASLT_COMPUTE_F32,
HIPBLAS_COMPUTE_32F,
heuristicResult));

std::vector<int> validIdx;
Expand All @@ -795,7 +795,7 @@ int test_hipblaslt(hipDataType in_datatype,
in_datatype,
out_datatype,
out_datatype,
HIPBLASLT_COMPUTE_F32);
HIPBLAS_COMPUTE_32F);

std::cout << "index, transAB, M, N, K, lda, ldb, ldc, stride_a, stride_b, "
"stride_c, batch_count, alpha, beta, bias, activationType"
Expand Down Expand Up @@ -844,7 +844,7 @@ int test_hipblaslt(hipDataType in_datatype,
in_datatype,
out_datatype,
out_datatype,
HIPBLASLT_COMPUTE_F32};
HIPBLAS_COMPUTE_32F};

// step 1: set problem to {Ms, {sum of N, 1, 1, 1, ...}, Ks}
CHECK_HIPBLASLT_ERROR(groupedGemm.setProblem(m,
Expand Down
2 changes: 1 addition & 1 deletion clients/common/hipblaslt_arguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void Arguments::init()
b_type = HIP_R_16F;
c_type = HIP_R_16F;
d_type = HIP_R_16F;
compute_type = HIPBLASLT_COMPUTE_F32;
compute_type = HIPBLAS_COMPUTE_32F;
scale_type = HIP_R_32F;

initialization = hipblaslt_initialization::hpl;
Expand Down
2 changes: 1 addition & 1 deletion clients/gtest/matmul_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ namespace
{
name << hip_datatype_to_string(arg.a_type) << hip_datatype_to_string(arg.b_type)
<< hip_datatype_to_string(arg.c_type) << hip_datatype_to_string(arg.d_type)
<< hipblaslt_computetype_to_string(arg.compute_type);
<< hipblas_computetype_to_string(arg.compute_type);

if(arg.activation_type != hipblaslt_activation_type::none)
{
Expand Down
12 changes: 6 additions & 6 deletions clients/include/hipblaslt_arguments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ struct Arguments
int32_t solution_index;
int32_t requested_solution_num;

hipDataType a_type;
hipDataType b_type;
hipDataType c_type;
hipDataType d_type;
hipblasLtComputeType_t compute_type;
hipDataType scale_type;
hipDataType a_type;
hipDataType b_type;
hipDataType c_type;
hipDataType d_type;
hipblasComputeType_t compute_type;
hipDataType scale_type;

hipblaslt_initialization initialization;

Expand Down
14 changes: 7 additions & 7 deletions clients/include/hipblaslt_common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ Datatypes:
bf16_r: 14
f8_r: 1000
bf8_r: 1001
- hipblasLtComputeType_t:
- hipblasComputeType_t:
bases: [ c_int ]
attr:
c_f32_r: 300
c_xf32_r: 301
c_f64_r: 302
c_i32_r: 303
c_f32_fast_f16_r: 304
c_f32_r: 2
c_f32_fast_f16_r: 4
c_xf32_r: 6
c_f64_r: 7
c_i32_r: 9
- { half: f16_r }
- hipblaslt_initialization:
bases: [ c_int ]
Expand Down Expand Up @@ -172,7 +172,7 @@ Arguments:
- b_type: hipDataType
- c_type: hipDataType
- d_type: hipDataType
- compute_type: hipblasLtComputeType_t
- compute_type: hipblasComputeType_t
- scale_type: hipDataType
- initialization: hipblaslt_initialization
- gpu_arch: c_char*4
Expand Down
2 changes: 1 addition & 1 deletion clients/include/hipblaslt_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ struct hipblaslt_test_invalid
<< " b: " << hip_datatype_to_string(arg.b_type)
<< " c: " << hip_datatype_to_string(arg.c_type)
<< " d: " << hip_datatype_to_string(arg.d_type)
<< " compute:" << hipblaslt_computetype_to_string(arg.compute_type)
<< " compute:" << hipblas_computetype_to_string(arg.compute_type)
<< std::endl;
hipblaslt_abort();
#endif
Expand Down
2 changes: 1 addition & 1 deletion clients/include/testing_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1782,7 +1782,7 @@ void testing_matmul(const Arguments& arg)
// For the xf32 xdl math op, cast type of A/B from float to xfloat32 .
if constexpr(std::is_same<TiA, float>{} && std::is_same<TiB, float>{}
&& std::is_same<To, float>{} && std::is_same<Tc, float>{})
if(arg.compute_type == HIPBLASLT_COMPUTE_F32_FAST_XF32)
if(arg.compute_type == HIPBLAS_COMPUTE_32F_FAST_TF32)
{
for(int i = 0; i < gemm_count; i++)
{
Expand Down
56 changes: 28 additions & 28 deletions clients/include/type_dispatch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,85 +97,85 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)

if(arg.d_type == To)
{
if(TiA == To && TiB == To && To == HIP_R_16F && Tc == HIPBLASLT_COMPUTE_F32)
if(TiA == To && TiB == To && To == HIP_R_16F && Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblasLtHalf, hipblasLtHalf, hipblasLtHalf, float>{}(arg);
}
else if(TiA == To && TiB == To && To == HIP_R_16BF && Tc == HIPBLASLT_COMPUTE_F32)
else if(TiA == To && TiB == To && To == HIP_R_16BF && Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hip_bfloat16, hip_bfloat16, hip_bfloat16, float>{}(arg);
}
else if(TiA == To && TiB == To && To == HIP_R_32F
&& (Tc == HIPBLASLT_COMPUTE_F32 || Tc == HIPBLASLT_COMPUTE_F32_FAST_XF32))
&& (Tc == HIPBLAS_COMPUTE_32F || Tc == HIPBLAS_COMPUTE_32F_FAST_TF32))
{
return TEST<float, float, float, float>{}(arg);
}
else if(TiA == To && TiB == To && To == HIP_R_64F && (Tc == HIPBLASLT_COMPUTE_F64))
else if(TiA == To && TiB == To && To == HIP_R_64F && (Tc == HIPBLAS_COMPUTE_64F))
{
return TEST<double, double, double, double>{}(arg);
}
else if(TiA == HIP_R_16F && TiB == HIP_R_16F && To == HIP_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblasLtHalf, hipblasLtHalf, float, float>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, float, float>{}(arg);
}
else if(TiA == HIP_R_8F_E5M2_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_bf8_fnuz, hipblaslt_f8_fnuz, float, float>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E5M2_FNUZ && To == HIP_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz, hipblaslt_bf8_fnuz, float, float>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, hipblasLtHalf, float>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16BF
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, hipblasLtBfloat16, float>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_8F_E4M3_FNUZ
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, float>{}(arg);
}
else if(TiA == HIP_R_8F_E5M2_FNUZ && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_bf8_fnuz, hipblaslt_f8_fnuz, hipblasLtHalf, float>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E5M2_FNUZ && To == HIP_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz, hipblaslt_bf8_fnuz, hipblasLtHalf, float>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_8F_E5M2_FNUZ && To == HIP_R_16BF
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz, hipblaslt_bf8_fnuz, hipblasLtBfloat16, float>{}(arg);
}
/*
else if(Ti == HIP_R_8I && To == HIP_R_8I && Tc == HIPBLASLT_COMPUTE_I32)
else if(Ti == HIP_R_8I && To == HIP_R_8I && Tc == HIPBLAS_COMPUTE_32I)
{
return TEST<hipblasLtInt8, hipblasLtInt8, int32_t>{}(arg);
}
*/
else if(TiA == HIP_R_8I && To == HIP_R_32I && Tc == HIPBLASLT_COMPUTE_I32)
else if(TiA == HIP_R_8I && To == HIP_R_32I && Tc == HIPBLAS_COMPUTE_32I)
{
return TEST<hipblasLtInt8, hipblasLtInt8, int32_t, int32_t>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_8F_E4M3_FNUZ
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
{
return TEST<hipblaslt_f8_fnuz,
hipblasLtHalf,
Expand All @@ -184,7 +184,7 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
hipblasLtHalf>{}(arg);
}
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_8F_E4M3_FNUZ
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
{
return TEST<hipblasLtHalf,
hipblaslt_f8_fnuz,
Expand All @@ -193,29 +193,29 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
hipblasLtHalf>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
{
return TEST<hipblaslt_f8_fnuz, hipblasLtHalf, hipblasLtHalf, float, hipblasLtHalf>{}(
arg);
}
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
{
return TEST<hipblasLtHalf, hipblaslt_f8_fnuz, hipblasLtHalf, float, hipblasLtHalf>{}(
arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
{
return TEST<hipblaslt_f8_fnuz, hipblasLtHalf, float, float, hipblasLtHalf>{}(arg);
}
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32_FAST_F16)
&& Tc == HIPBLAS_COMPUTE_32F_FAST_16F)
{
return TEST<hipblasLtHalf, hipblaslt_f8_fnuz, float, float, hipblasLtHalf>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_8F_E4M3_FNUZ
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz,
hipblasLtHalf,
Expand All @@ -224,7 +224,7 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
hipblaslt_f8_fnuz>{}(arg);
}
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_8F_E4M3_FNUZ
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblasLtHalf,
hipblaslt_f8_fnuz,
Expand All @@ -233,7 +233,7 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
hipblaslt_f8_fnuz>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz,
hipblasLtHalf,
Expand All @@ -242,7 +242,7 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
hipblaslt_f8_fnuz>{}(arg);
}
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_16F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblasLtHalf,
hipblaslt_f8_fnuz,
Expand All @@ -251,12 +251,12 @@ auto hipblaslt_matmul_dispatch(const Arguments& arg)
hipblaslt_f8_fnuz>{}(arg);
}
else if(TiA == HIP_R_8F_E4M3_FNUZ && TiB == HIP_R_16F && To == HIP_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblaslt_f8_fnuz, hipblasLtHalf, float, float, hipblaslt_f8_fnuz>{}(arg);
}
else if(TiA == HIP_R_16F && TiB == HIP_R_8F_E4M3_FNUZ && To == HIP_R_32F
&& Tc == HIPBLASLT_COMPUTE_F32)
&& Tc == HIPBLAS_COMPUTE_32F)
{
return TEST<hipblasLtHalf, hipblaslt_f8_fnuz, float, float, hipblaslt_f8_fnuz>{}(arg);
}
Expand Down
8 changes: 4 additions & 4 deletions clients/include/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ class hipblaslt_local_matmul_descr
hipblasStatus_t m_status = HIPBLAS_STATUS_NOT_INITIALIZED;

public:
hipblaslt_local_matmul_descr(hipblasOperation_t opA,
hipblasOperation_t opB,
hipblasLtComputeType_t compute_type,
hipDataType scale_type)
hipblaslt_local_matmul_descr(hipblasOperation_t opA,
hipblasOperation_t opB,
hipblasComputeType_t compute_type,
hipDataType scale_type)
{
this->m_status = hipblasLtMatmulDescCreate(&this->m_descr, compute_type, scale_type);

Expand Down
2 changes: 1 addition & 1 deletion clients/samples/gemm/sample_hipblaslt_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void simpleGemm(hipblasLtHandle_t handle,
}

hipblasLtMatmulDesc_t matmul;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIP_R_32F));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
Expand Down
10 changes: 2 additions & 8 deletions clients/samples/gemm/sample_hipblaslt_gemm_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,8 @@ void simpleGemmExt(hipblasLtHandle_t handle,
{
hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(max_workspace_size);
hipblaslt_ext::Gemm gemm(handle,
trans_a,
trans_b,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLASLT_COMPUTE_F32);
hipblaslt_ext::Gemm gemm(
handle, trans_a, trans_b, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIPBLAS_COMPUTE_32F);

hipblaslt_ext::GemmEpilogue
epilogue; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,8 @@ void simpleGemmAlphaVecExt(hipblasLtHandle_t handle,
{
hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(max_workspace_size);
hipblaslt_ext::Gemm gemm(handle,
trans_a,
trans_b,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLASLT_COMPUTE_F32);
hipblaslt_ext::Gemm gemm(
handle, trans_a, trans_b, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIP_R_16F, HIPBLAS_COMPUTE_32F);

hipblaslt_ext::GemmEpilogue
epilogue; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
Expand Down
Loading

0 comments on commit 3aad0d8

Please sign in to comment.