Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap inline PTX as utility for Hopper matmul #3860

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open

Wrap inline PTX as utility for Hopper matmul #3860

wants to merge 20 commits into from

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Feb 10, 2025

This PR is a followup of #3844 that further wraps all instructions used by Hopper matmul as utility functions. The trick here is some PTX instructions requires some operands to be immediate numbers, so it must be a template parameter instead of a function parameter, so I have to extend our current implementation of codegen for kir::Asm to support it.

The kir::Asm::AsmOptions is extended with a new option specifying which input is always an immediate.

Example kernel:
__device__ __inline__ void wgmmaFence() {
  asm volatile("wgmma.fence.sync.aligned;\n");
}
__device__ __inline__ void fenceAsyncProxy() {
  asm volatile("fence.proxy.async;\n");
}
template <int in3, int in4, int in5, int in6>
__device__ __inline__ void wgmmaM64N256K16Half(Array<float, 128, 1>& out0, uint64_t in0, uint64_t in1, bool in2) {
  asm volatile(
    "{\n"
    "  .reg .pred p0; \n"
    "  setp.ne.b32 p0, %130, 0;\n"
    "  wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, %128, %129, p0, %131, %132, %133, %134;\n"
    "}\n"
    :"+f"(out0[0]),
     "+f"(out0[1]),
     "+f"(out0[2]),
     "+f"(out0[3]),
     "+f"(out0[4]),
     "+f"(out0[5]),
     "+f"(out0[6]),
     "+f"(out0[7]),
     "+f"(out0[8]),
     "+f"(out0[9]),
     "+f"(out0[10]),
     "+f"(out0[11]),
     "+f"(out0[12]),
     "+f"(out0[13]),
     "+f"(out0[14]),
     "+f"(out0[15]),
     "+f"(out0[16]),
     "+f"(out0[17]),
     "+f"(out0[18]),
     "+f"(out0[19]),
     "+f"(out0[20]),
     "+f"(out0[21]),
     "+f"(out0[22]),
     "+f"(out0[23]),
     "+f"(out0[24]),
     "+f"(out0[25]),
     "+f"(out0[26]),
     "+f"(out0[27]),
     "+f"(out0[28]),
     "+f"(out0[29]),
     "+f"(out0[30]),
     "+f"(out0[31]),
     "+f"(out0[32]),
     "+f"(out0[33]),
     "+f"(out0[34]),
     "+f"(out0[35]),
     "+f"(out0[36]),
     "+f"(out0[37]),
     "+f"(out0[38]),
     "+f"(out0[39]),
     "+f"(out0[40]),
     "+f"(out0[41]),
     "+f"(out0[42]),
     "+f"(out0[43]),
     "+f"(out0[44]),
     "+f"(out0[45]),
     "+f"(out0[46]),
     "+f"(out0[47]),
     "+f"(out0[48]),
     "+f"(out0[49]),
     "+f"(out0[50]),
     "+f"(out0[51]),
     "+f"(out0[52]),
     "+f"(out0[53]),
     "+f"(out0[54]),
     "+f"(out0[55]),
     "+f"(out0[56]),
     "+f"(out0[57]),
     "+f"(out0[58]),
     "+f"(out0[59]),
     "+f"(out0[60]),
     "+f"(out0[61]),
     "+f"(out0[62]),
     "+f"(out0[63]),
     "+f"(out0[64]),
     "+f"(out0[65]),
     "+f"(out0[66]),
     "+f"(out0[67]),
     "+f"(out0[68]),
     "+f"(out0[69]),
     "+f"(out0[70]),
     "+f"(out0[71]),
     "+f"(out0[72]),
     "+f"(out0[73]),
     "+f"(out0[74]),
     "+f"(out0[75]),
     "+f"(out0[76]),
     "+f"(out0[77]),
     "+f"(out0[78]),
     "+f"(out0[79]),
     "+f"(out0[80]),
     "+f"(out0[81]),
     "+f"(out0[82]),
     "+f"(out0[83]),
     "+f"(out0[84]),
     "+f"(out0[85]),
     "+f"(out0[86]),
     "+f"(out0[87]),
     "+f"(out0[88]),
     "+f"(out0[89]),
     "+f"(out0[90]),
     "+f"(out0[91]),
     "+f"(out0[92]),
     "+f"(out0[93]),
     "+f"(out0[94]),
     "+f"(out0[95]),
     "+f"(out0[96]),
     "+f"(out0[97]),
     "+f"(out0[98]),
     "+f"(out0[99]),
     "+f"(out0[100]),
     "+f"(out0[101]),
     "+f"(out0[102]),
     "+f"(out0[103]),
     "+f"(out0[104]),
     "+f"(out0[105]),
     "+f"(out0[106]),
     "+f"(out0[107]),
     "+f"(out0[108]),
     "+f"(out0[109]),
     "+f"(out0[110]),
     "+f"(out0[111]),
     "+f"(out0[112]),
     "+f"(out0[113]),
     "+f"(out0[114]),
     "+f"(out0[115]),
     "+f"(out0[116]),
     "+f"(out0[117]),
     "+f"(out0[118]),
     "+f"(out0[119]),
     "+f"(out0[120]),
     "+f"(out0[121]),
     "+f"(out0[122]),
     "+f"(out0[123]),
     "+f"(out0[124]),
     "+f"(out0[125]),
     "+f"(out0[126]),
     "+f"(out0[127])
    :"l"(in0),
     "l"(in1),
     "r"((uint32_t)(in2)),
     "n"(in3),
     "n"(in4),
     "n"(in5),
     "n"(in6)
  );
}
__device__ __inline__ void wgmmaCommit() {
  asm volatile("wgmma.commit_group.sync.aligned;\n");
}
template <int64_t in0>
__device__ __inline__ void wgmmaWait() {
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(in0):"memory");
}
__device__ __inline__ void stmatrix4(uint32_t in0, Array<uint32_t, 4, 1> in1) {
  asm volatile(
    "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
    :
    :"r"(in0),
     "r"(in1[0]),
     "r"(in1[1]),
     "r"(in1[2]),
     "r"(in1[3])
  );
}
__device__ __inline__ void cpAsyncBulkCommitGroup() {
  asm volatile("cp.async.bulk.commit_group;\n");
}
template <int64_t in0>
__device__ __inline__ void cpAsyncBulkWaitGroup() {
  asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(in0):"memory");
}
__global__ void __cluster_dims__(2, 1, 1) nvfuser_none_f0_c0_r0_g0(Tensor<__half, 3, 3> T0, Tensor<__half, 3, 3> T1, const __grid_constant__ TensorMap var0, const __grid_constant__ TensorMap var1, const __grid_constant__ TensorMap var2, Tensor<__half, 2, 2> T3) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  nvfuser_index_t i3;
  i3 = ceilDiv(T0.logical_size[0LL], 32);
  nvfuser_index_t i4;
  i4 = -3 + i3;
  const TensorMap* ptr5;
  ptr5 = &var0;
  nvfuser_index_t i6;
  i6 = 256 * ((nvfuser_index_t)blockIdx.x);
  __half* T5 = reinterpret_cast<__half*>(array + smem_offset + 32896);
  uint32_t i7;
  i7 = toSmem(T5);
  const TensorMap* ptr8;
  ptr8 = &var1;
  nvfuser_index_t i9;
  i9 = 128 * ((nvfuser_index_t)blockIdx.y);
  __half* T4 = reinterpret_cast<__half*>(array + smem_offset + 128);
  uint32_t i10;
  i10 = toSmem(T4);
  uint32_t i11;
  i11 = i10 + (4096 * ((nvfuser_index_t)threadIdx.y));
  nvfuser_index_t i12;
  i12 = ((((nvfuser_index_t)threadIdx.x) / 32) * 16) + ((((nvfuser_index_t)threadIdx.x) % 32) % 16);
  __half* T7 = reinterpret_cast<__half*>(array + smem_offset + 0);
  uint32_t i13;
  i13 = toSmem(T7) + (32768 * ((nvfuser_index_t)threadIdx.y));
  const TensorMap* ptr14;
  ptr14 = &var2;
  nvfuser_index_t i15;
  i15 = (64 * ((nvfuser_index_t)threadIdx.y)) + i9;
  bool b16;
  b16 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
  bool b17;
  b17 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
  Array<float, 128, 1> T2;
  ((*reinterpret_cast<Array<float, 128, 1>*>(&T2[0]))).set(0);
  wgmmaFence();
  fenceAsyncProxy();
  uint64_t* T8 = reinterpret_cast<uint64_t*>(array + smem_offset + 0);
  #pragma unroll
  for(nvfuser_index_t i18 = 0; i18 < 4; ++i18) {
    if (((Hopper::electSync(4294967295U) && b16) && b17)) {
      mbarrier::init(toSmem((&T8[i18])), 2U);
    }
  }
  __syncthreads();
  #pragma unroll 3
  for(nvfuser_index_t i19 = 0; i19 < 3; ++i19) {
    nvfuser_index_t i20;
    i20 = 32 * i19;
    uint32_t i21;
    i21 = i7 + (16384 * i19);
    uint32_t i22;
    i22 = i10 + (8192 * i19);
    if (((Hopper::electSync(4294967295U) && b16) && b17)) {
      mbarrier::arriveExpectTX(toSmem((&T8[i19])), 16384U);
      #pragma unroll
      for(nvfuser_index_t i23 = 0; i23 < 4; ++i23) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, (Array<nvfuser_index_t, 2, 1>{(i6 + (64 * i23)), i20}), toSmem((&T8[i19])) }), (i21 + (4096 * i23)));
      }
      mbarrier::arriveExpectTX(toSmem((&T8[i19])), 8192U);
      #pragma unroll
      for(nvfuser_index_t i24 = 0; i24 < 2; ++i24) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, (Array<nvfuser_index_t, 2, 1>{(i9 + (64 * i24)), i20}), toSmem((&T8[i19])) }), (i22 + (4096 * i24)));
      }
    }
  }
  #pragma unroll 3
  for(nvfuser_index_t i25 = 0; i25 < i4; ++i25) {
    nvfuser_index_t i26;
    i26 = 96 + (32 * i25);
    nvfuser_index_t i27;
    i27 = (3 + i25) % 4;
    uint32_t i28;
    i28 = i7 + (16384 * i27);
    uint32_t i29;
    i29 = i10 + (8192 * i27);
    nvfuser_index_t i30;
    i30 = i25 % 4;
    uint32_t i31;
    i31 = i11 + (8192 * i30);
    uint32_t i32;
    i32 = i7 + (16384 * i30);
    if (((Hopper::electSync(4294967295U) && b16) && b17)) {
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i25) % 4)])), 16384U);
      #pragma unroll
      for(nvfuser_index_t i23 = 0; i23 < 4; ++i23) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, (Array<nvfuser_index_t, 2, 1>{(i6 + (64 * i23)), i26}), toSmem((&T8[((3LL + i25) % 4)])) }), (i28 + (4096 * i23)));
      }
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i25) % 4)])), 8192U);
      #pragma unroll
      for(nvfuser_index_t i24 = 0; i24 < 2; ++i24) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, (Array<nvfuser_index_t, 2, 1>{(i9 + (64 * i24)), i26}), toSmem((&T8[((3LL + i25) % 4)])) }), (i29 + (4096 * i24)));
      }
    }
    mbarrier::waitParity(toSmem((&T8[(i25 % 4)])), (uint32_t)(((i25 / 4) % 2)));
    #pragma unroll
    for(nvfuser_index_t i33 = 0; i33 < 2; ++i33) {
      nvfuser_index_t i34;
      i34 = 2048 * i33;
      uint32_t i35;
      i35 = i31 + i34;
      uint32_t i36;
      i36 = i32 + i34;
      wgmmaFence();
      wgmmaM64N256K16Half<1, 1, 1, 1>((*reinterpret_cast<Array<float, 128, 1>*>(&T2[0])), (4611686293305294848ULL | ((262143ULL & (uint64_t)(i35)) >> 4ULL)), (4611686293322072064ULL | ((262143ULL & (uint64_t)(i36)) >> 4ULL)), true);
    }
    __syncthreads();
    wgmmaCommit();
    wgmmaWait<0LL>();
  }
  #pragma unroll 3
  for(nvfuser_index_t i37 = (i3 - 3); i37 < i3; ++i37) {
    nvfuser_index_t i38;
    i38 = i37 % 4;
    uint32_t i39;
    i39 = i11 + (8192 * i38);
    uint32_t i40;
    i40 = i7 + (16384 * i38);
    mbarrier::waitParity(toSmem((&T8[(i37 % 4)])), (uint32_t)(((i37 / 4) % 2)));
    #pragma unroll
    for(nvfuser_index_t i33 = 0; i33 < 2; ++i33) {
      nvfuser_index_t i41;
      i41 = 2048 * i33;
      uint32_t i42;
      i42 = i39 + i41;
      uint32_t i43;
      i43 = i40 + i41;
      wgmmaFence();
      wgmmaM64N256K16Half<1, 1, 1, 1>((*reinterpret_cast<Array<float, 128, 1>*>(&T2[0])), (4611686293305294848ULL | ((262143ULL & (uint64_t)(i42)) >> 4ULL)), (4611686293322072064ULL | ((262143ULL & (uint64_t)(i43)) >> 4ULL)), true);
    }
    __syncthreads();
  }
  #pragma unroll
  for(nvfuser_index_t i44 = 0; i44 < 4; ++i44) {
    if (((Hopper::electSync(4294967295U) && b16) && b17)) {
      mbarrier::inval(toSmem((&T8[i44])));
    }
  }
  wgmmaWait<0LL>();
  Array<__half, 128, 8> T6;
  #pragma unroll
  for(nvfuser_index_t i45 = 0; i45 < 32; ++i45) {
    nvfuser_index_t i46;
    i46 = 4 * i45;
    #pragma unroll
    for(nvfuser_index_t i47 = 0; i47 < 2; ++i47) {
      nvfuser_index_t i48;
      i48 = i46 + (2 * i47);
      #pragma unroll
      for(nvfuser_index_t i49 = 0; i49 < 2; ++i49) {
        nvfuser_index_t i50;
        i50 = i48 + i49;
        T6[i50]
           = __float2half(T2[i50]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i51 = 0; i51 < 16; ++i51) {
    stmatrix4((uint32_t)((toSmem(T7) + ((((nvfuser_index_t)threadIdx.y) * 32768) + (((i51 / 4) * 8192) + ((i12 * 128) + (((((((nvfuser_index_t)threadIdx.x) % 32) / 16) + ((i51 % 4) * 2)) ^ (i12 % 8)) * 16)))))), (*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T6[(8 * i51)])));
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i52 = 0; i52 < 4; ++i52) {
    fenceAsyncProxy();
    if ((Hopper::electSync(4294967295U) && b16)) {
      Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr14, (Array<nvfuser_index_t, 2, 1>{(i6 + (64 * i52)), i15}) }), (i13 + (8192 * i52)));
    }
  }
  cpAsyncBulkCommitGroup();
  cpAsyncBulkWaitGroup<0LL>();
}

Copy link

github-actions bot commented Feb 10, 2025

Review updated until commit fa55b99

Description

  • Wrap inline PTX as utility for Hopper matmul

  • Add support for immediate inputs in PTX utilities

  • Enhance utility name generation for wgmma instructions

  • Update PTX utility generation in multiple files


Changes walkthrough 📝

Relevant files
Enhancement
codegen.cpp
Enhance PTX utility generation with immediate inputs         

csrc/codegen.cpp

  • Added lambda to determine type or index type for PTX utility
    generation
  • Updated utility function signature generation to handle immediate
    inputs
  • Modified PTX code generation to include immediate inputs in function
    calls
  • +47/-13 
    inline_ptx.cpp
    Add immediate inputs to PTX options                                           

    csrc/device_lower/pass/inline_ptx.cpp

  • Updated PTX options to include immediate inputs for various operations
  • +17/-4   
    kernel_ir.cpp
    Improve utility name generation and immediate input handling

    csrc/kernel_ir.cpp

  • Updated constraintsAndInputs to handle immediate inputs
  • Enhanced utility name generation for wgmma instructions using regex
  • +46/-4   
    kernel_ir.h
    Add immediate_inputs to AsmOptions                                             

    csrc/kernel_ir.h

    • Added immediate_inputs to AsmOptions struct
    +1/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The getTypeOrIndexType function is defined twice, once inside handle and once outside. This could lead to confusion and potential bugs.

    auto getTypeOrIndexType = [](Val* value) {
      if (auto ti = dynamic_cast<kir::TensorIndex*>(value)) {
        if (isPointerType(ti->index()->dtype())) {
          return ti->index()->dtype();
        }
      }
      return value->dtype();
    };
    Regex Complexity

    The regex patterns used to match PTX instructions for utility names are quite specific and may not cover all cases. Consider adding more comprehensive tests to ensure all valid PTX instructions are matched correctly.

      // Half
      std::regex pattern(
          R"(wgmma\.mma_async\.sync\.aligned\.(m\d+n\d+k\d+)\.f32\.f16\.f16)");
      std::smatch match;
      if (std::regex_match(code, match, pattern)) {
        std::string extracted = match[1];
        std::transform(
            extracted.begin(), extracted.end(), extracted.begin(), ::toupper);
        return "wgmma" + extracted + "Half";
      }
    }
    {
      // BFloat16
      std::regex pattern(
          R"(wgmma\.mma_async\.sync\.aligned\.(m\d+n\d+k\d+)\.f32\.bf16\.bf16)");
      std::smatch match;
      if (std::regex_match(code, match, pattern)) {
        std::string extracted = match[1];
        std::transform(
            extracted.begin(), extracted.end(), extracted.begin(), ::toupper);
        return "wgmma" + extracted + "BF16";
      }
    }
    Documentation

    The new immediate_inputs field in AsmOptions should be documented to explain its purpose and usage.

    std::unordered_set<int64_t> immediate_inputs = {};

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm zasdfgbnm marked this pull request as ready for review February 10, 2025 22:15
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant