Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Bug: Undistributed chained gemm doesn't write half of output #381

Open
GMNGeoffrey opened this issue Jan 10, 2025 · 0 comments
Open

Comments

@GMNGeoffrey
Copy link
Contributor

This looks kind of similar to #374, but I think it's actually not. If you do a chained gemm into a tensor with a dimension that isn't spread over workgroups or tiles, only half the output gets written. The following test (available in GMNGeoffrey@0fef747413), reduced from the flash attention 2 backward pass, fails:

python test
def testReproWriteAlongUnconstrainedDimension():
    shape = (16, 32, 16, 16)
    q_seq_len, v_head_dim, qk_head_dim, kv_seq_len = shape
    mfma_variant = MMAType.F32_16x16x16_F16
    M = tkl.sym.M
    N = tkl.sym.N
    K1 = tkl.sym.K1
    K2 = tkl.sym.K2

    LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
    STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

    constraints: list[tkw.Constraint] = [
        tkw.WorkgroupConstraint(K2, K2, 0),
        tkw.HardwareConstraint(
            threads_per_wave=64,
            waves_per_block=(1, 1, 1),
            mma_type=mfma_variant,
        )
    ]

    @tkw.wave(constraints)
    def attention_bwd(
        q: tkl.Memory[M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
        k: tkl.Memory[K2, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
        do: tkl.Memory[N, M, GLOBAL_ADDRESS_SPACE, tkl.f16],
        dv: tkl.Memory[K2, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
    ):
        dv_init = tkl.Register[K2, N, tkl.f32](0.0)

        k_j = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
        q_i = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)

        s_acc = tkl.Register[M, K2, tkl.f32](0.0)
        s_ij = tkw.mma(q_i, k_j, s_acc)
        s_ij = tkw.permute(s_ij, [K2, M])

        do_i = tkw.read(do, elements_per_thread=LOAD_ELEMS_PER_THREAD)
        dv_j = tkw.mma(tkw.cast(s_ij, tkl.f16), do_i, dv_init)
            
        tkw.write(dv_j, dv, elements_per_thread=STORE_ELEMS_PER_THREAD)

    hyperparams = {
        LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant),
        STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant),
        M: q_seq_len,
        N: v_head_dim,
        K1: qk_head_dim,
        K2: kv_seq_len,
    }

    hyperparams.update(get_default_scheduling_params())
    config = get_default_run_config()
    # config["print_ir_after_all"] = True
    compile_config = {
        "waves_per_eu": 2,
        "denorm_fp_math_f32": "preserve-sign",
        "print_ir_after": ["first", "last", "set_node_indices", "expand_graph"],
        # "print_ir_before": ["first", "set_node_indices"],
        "print_signature": True,
        "print_pretty_mlir": True,
        "print_indices": True,
    }

    with tk.gen.TestLaunchContext(
        hyperparams,
        canonicalize=True,
        run=True,
        run_bench=False,
        run_config=config,
        compile_config=compile_config,
        schedule=False,
        use_scheduling_barriers=enable_scheduling_barriers,
    ):
        
        torch.manual_seed(0)
        q = torch.full((q_seq_len, qk_head_dim), 0.1, device=get_default_device(), dtype=torch.float16)
        k = torch.full((kv_seq_len, qk_head_dim), 0.2, device=get_default_device(), dtype=torch.float16)
        do = torch.full((q_seq_len, v_head_dim), 0.3, device=get_default_device(), dtype=torch.float16)

        s_ref = torch.matmul(q, k.transpose(-1, -2))
        dv_ref = torch.matmul(s_ref.transpose(-1, -2), do)

        dv = device_zeros(kv_seq_len, v_head_dim)
        mb_bwd = attention_bwd(
            q,
            k,
            do.transpose(-1, -2),
            dv,
        )
            
        if dump_generated_mlir:
            filename = f"wave_chained_gemm_undistributed_{'x'.join(map(str, shape))}.mlir"
            with open(filename, "w") as f:
                f.write(mb_bwd.module_op.get_asm())
            print(f"IR dumped to {filename}")

        assert_close(dv, dv_ref.to(torch.float32), atol=1e-3, rtol=1e-4)

produces this fx trace:

fx trace
After set_node_indices:
region_0 [root]:
graph():
    %q :  [num_users=1] = placeholder[target=q]
    %k :  [num_users=1] = placeholder[target=k]
    %do :  [num_users=1] = placeholder[target=do]
    %dv :  [num_users=1] = placeholder[target=dv]
    %register :  [num_users=1] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %read :  [num_users=1] = [read](args = (%k, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_1 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %register_1 :  [num_users=1] = [register](args = ((M, K2), f32, 0.0), kwargs = {})
    %mma :  [num_users=1] = [mma](args = (%read_1, %read, %register_1, None), kwargs = {})
    %permute :  [num_users=1] = [permute](args = (%mma, [K2, M]), kwargs = {})
    %read_2 :  [num_users=1] = [read](args = (%do, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %cast :  [num_users=1] = [cast](args = (%permute, f16), kwargs = {})
    %mma_1 :  [num_users=1] = [mma](args = (%cast, %read_2, %register, None), kwargs = {})
    %write :  [num_users=0] = [write](args = (%mma_1, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    return None


q: None
k: None
do: None
dv: None
register: {K2: $WG0*K2 + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: Mod($T0, 16) : 1 : 1}
read: {K2: $WG0*K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}
read_1: {M: Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}
register_1: {M: 4*floor((Mod($T0, 64))/16) : 4 : 16, K2: $WG0*K2 + Mod($T0, 16) : 1 : 1}
mma: {M: Piecewise((Mod($T0, 16), ~$MMA_ACC), (4*floor((Mod($T0, 64))/16), $MMA_ACC)) : Piecewise((1, ~$MMA_ACC), (4, $MMA_ACC)) : Piecewise((1, ~$MMA_ACC), (16, $MMA_ACC)), K1: 4*floor((Mod($T0, 64))/16) : 4 : 1, K2: $WG0*K2 + Mod($T0, 16) : 1 : 1}
permute: {K2: $WG0*K2 + Mod($T0, 16) : 1 : 16, M: 4*floor((Mod($T0, 64))/16) : 4 : 1}
read_2: {N: Mod($T0, 16) : 1 : 1, M: 4*floor((Mod($T0, 64))/16) : 4 : 1}
cast: {K2: $WG0*K2 + Mod($T0, 16) : 1 : 16, M: 4*floor((Mod($T0, 64))/16) : 4 : 1}
mma_1: {K2: $WG0*K2 + Piecewise((Mod($T0, 16), ~$MMA_ACC), (4*floor((Mod($T0, 64))/16), $MMA_ACC)) : Piecewise((1, ~$MMA_ACC), (4, $MMA_ACC)) : Piecewise((1, ~$MMA_ACC), (16, $MMA_ACC)), M: 4*floor((Mod($T0, 64))/16) : 4 : 1, N: Mod($T0, 16) : 1 : 1}
write: {K2: $WG0*K2 + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: Mod($T0, 16) : 1 : 1}
output: {}


After expand_graph:
region_0 [root]:
graph():
    %q :  [num_users=1] = placeholder[target=q]
    %k :  [num_users=1] = placeholder[target=k]
    %do :  [num_users=1] = placeholder[target=do]
    %dv :  [num_users=1] = placeholder[target=dv]
    %register_K2:0_M:0 :  [num_users=1] = [register](args = ((K2, N), f32, 0.0), kwargs = {})
    %read_K2:0_M:0_K1:0 :  [num_users=1] = [read](args = (%k, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %read_K2:0_M:0_K1:0 :  [num_users=1] = [read](args = (%q, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %register_K2:0_M:0_K1:0 :  [num_users=1] = [register](args = ((M, K2), f32, 0.0), kwargs = {})
    %mma_K2:0_M:0_K1:0 :  [num_users=1] = [mma](args = (%read_K2:0_M:0_K1:0, %read_K2:0_M:0_K1:0, %register_K2:0_M:0_K1:0, None), kwargs = {})
    %permute_K2:0_M:0 :  [num_users=1] = [permute](args = (%mma_K2:0_M:0_K1:0, [K2, M]), kwargs = {})
    %read_K2:0_M:0 :  [num_users=1] = [read](args = (%do, LOAD_ELEMS_PER_THREAD, None, (), None), kwargs = {})
    %cast_K2:0_M:0 :  [num_users=1] = [cast](args = (%permute_K2:0_M:0, f16), kwargs = {})
    %mma_K2:0_M:0 :  [num_users=1] = [mma](args = (%cast_K2:0_M:0, %read_K2:0_M:0, %register_K2:0_M:0, None), kwargs = {})
    %write_K2:0 :  [num_users=0] = [write](args = (%mma_K2:0_M:0, %dv, STORE_ELEMS_PER_THREAD, None, ()), kwargs = {})
    return None


q: None
k: None
do: None
dv: None
register_K2:0_M:0: {K2: $WG0*K2 + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: Mod($T0, 16) : 1 : 1}
read_K2:0_M:0_K1:0: {K2: $WG0*K2 + Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}
read_K2:0_M:0_K1:0: {M: Mod($T0, 16) : 1 : 1, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}
register_K2:0_M:0_K1:0: {M: 4*floor((Mod($T0, 64))/16) : 4 : 16, K2: $WG0*K2 + Mod($T0, 16) : 1 : 1}
mma_K2:0_M:0_K1:0: {M: Piecewise((Mod($T0, 16), ~$MMA_ACC), (4*floor((Mod($T0, 64))/16), $MMA_ACC)) : Piecewise((1, ~$MMA_ACC), (4, $MMA_ACC)) : Piecewise((1, ~$MMA_ACC), (16, $MMA_ACC)), K1: 4*floor((Mod($T0, 64))/16) : 4 : 1, K2: $WG0*K2 + Mod($T0, 16) : 1 : 1}
permute_K2:0_M:0: {K2: $WG0*K2 + Mod($T0, 16) : 1 : 16, M: 4*floor((Mod($T0, 64))/16) : 4 : 1}
read_K2:0_M:0: {N: Mod($T0, 16) : 1 : 1, M: 4*floor((Mod($T0, 64))/16) : 4 : 1}
cast_K2:0_M:0: {K2: $WG0*K2 + Mod($T0, 16) : 1 : 16, M: 4*floor((Mod($T0, 64))/16) : 4 : 1}
mma_K2:0_M:0: {K2: $WG0*K2 + Piecewise((Mod($T0, 16), ~$MMA_ACC), (4*floor((Mod($T0, 64))/16), $MMA_ACC)) : Piecewise((1, ~$MMA_ACC), (4, $MMA_ACC)) : Piecewise((1, ~$MMA_ACC), (16, $MMA_ACC)), M: 4*floor((Mod($T0, 64))/16) : 4 : 1, N: Mod($T0, 16) : 1 : 1}
write_K2:0: {K2: $WG0*K2 + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: Mod($T0, 16) : 1 : 1}
output: {}

and this IR (some ssa variables renamed for readability)

mlir
 #translation = #iree_codegen.translation_info<pipeline = None workgroup_size = [64, 1, 1] subgroup_size = 64, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "2", "denormal-fp-math-f32" = "preserve-sign"}}>
module attributes {transform.with_named_sequence} {
  stream.executable private @attention_bwd {
    stream.executable.export public @attention_bwd workgroups() -> (index, index, index) {
      %c1 = arith.constant 1 : index
      stream.return %c1, %c1, %c1 : index, index, index
    }
    builtin.module {
      func.func @attention_bwd(%arg_q: !stream.binding, %arg_k: !stream.binding, %arg_do: !stream.binding, %arg_dv: !stream.binding) attributes {translation_info = #translation} {
        %c3 = arith.constant 3 : index
        %c2 = arith.constant 2 : index
        %c64 = arith.constant 64 : index
        %c1 = arith.constant 1 : index
        %c4 = arith.constant 4 : index
        %c16 = arith.constant 16 : index
        %c0 = arith.constant 0 : index
        c0_4vf32 = arith.constant dense<0.000000e+00> : vector<4xf32>
        %workgroup_id_0 = stream.dispatch.workgroup.id[0] : index
        %thread_id_x = gpu.thread_id  x
        %k = stream.binding.subspan %arg_k[%c0] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>>
        %1 = arith.muli %workgroup_id_0, %c16 overflow<nsw, nuw> : index
        %2 = arith.remsi %thread_id_x, %c16 : index
        %3 = arith.addi %2, %1 overflow<nsw, nuw> : index
        %4 = arith.remsi %thread_id_x, %c64 : index
        %5 = arith.divsi %4, %c16 : index
        %6 = arith.muli %5, %c4 overflow<nsw, nuw> : index
        %7 = vector.load %k[%3, %6] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
        %q = stream.binding.subspan %arg_q[%c0] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>>
        %9 = vector.load %q[%2, %6] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
        %10 = amdgpu.mfma %9 * %7 + c0_4vf32 {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
        %do = stream.binding.subspan %arg_do[%c0] : !stream.binding -> memref<32x16xf16, strided<[16, 1], offset: ?>>
        %12 = vector.load %do[%2, %6] : memref<32x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
        %13 = arith.truncf %10 : vector<4xf32> to vector<4xf16>
        %14 = amdgpu.mfma %13 * %12 + c0_4vf32 {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
        %15 = vector.extract_strided_slice %14 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
        %dv = stream.binding.subspan %arg_dv[%c0] : !stream.binding -> memref<16x32xf32, strided<[32, 1], offset: ?>>
        %17 = arith.addi %1, %6 overflow<nsw, nuw> : index
        vector.store %15, %dv[%17, %2] : memref<16x32xf32, strided<[32, 1], offset: ?>>, vector<1xf32>
        %18 = vector.extract_strided_slice %14 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
        %19 = arith.addi %17, %c1 overflow<nsw, nuw> : index
        vector.store %18, %dv[%19, %2] : memref<16x32xf32, strided<[32, 1], offset: ?>>, vector<1xf32>
        %20 = vector.extract_strided_slice %14 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
        %21 = arith.addi %17, %c2 overflow<nsw, nuw> : index
        vector.store %20, %dv[%21, %2] : memref<16x32xf32, strided<[32, 1], offset: ?>>, vector<1xf32>
        %22 = vector.extract_strided_slice %14 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
        %23 = arith.addi %17, %c3 overflow<nsw, nuw> : index
        vector.store %22, %dv[%23, %2] : memref<16x32xf32, strided<[32, 1], offset: ?>>, vector<1xf32>
        return
      }
    }
  }
}

There's only one block, so the writes to dv are for [(x/16)*4:(x/16)*4+4, x%16], but dv is K2xN = 16x32 and the second index maxes out at 15, so half of the N dimension never gets written. Looking at it another way, there are 16x32=512 elements of dv, but each thread is only writing 4, so it's impossible for them all to get filled in. Interestingly, if I add a no-op distribution over the N dimension (tkw.WorkgroupConstraint(N, N, 1)), then we get the correct number of writes and the test passes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant