Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama-3-8B f16 fails to compile to vmfb #17226

Open
aviator19941 opened this issue Apr 29, 2024 · 41 comments · Fixed by llvm/torch-mlir#3269
Open

Llama-3-8B f16 fails to compile to vmfb #17226

aviator19941 opened this issue Apr 29, 2024 · 41 comments · Fixed by llvm/torch-mlir#3269
Assignees
Labels
bug 🐞 Something isn't working

Comments

@aviator19941
Copy link
Contributor

aviator19941 commented Apr 29, 2024

batch_llama_3_8B.zip

What happened?

When trying to compile this mlir file, I get the shared memory error below:

failed to translate executables
failed to translate executables
failed to translate executables
result_llama_3_v4.mlir:352:7: error: 'func.func' op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes
      func.func @prefill_bs4$async_dispatch_1_generic_4xDx4096_i64xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x?xi64>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>) {
      ^
result_llama_3_v4.mlir:346:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_1 {
  ^
result_llama_3_v4.mlir:449:14: error: 'iree_linalg_ext.set_encoding' op unhandled tensor operation
        %7 = linalg.batch_matmul_transpose_b ins(%3, %4 : tensor<4x?x4096xf32>, tensor<4x4096x4096xf16>) outs(%6 : tensor<4x?x4096xf32>) -> tensor<4x?x4096xf32>
             ^
result_llama_3_v4.mlir:440:7: error: 'func.func' op failed to create tensor equivalance classes
      func.func @prefill_bs4$async_dispatch_4_batch_matmul_transpose_b_4xDx4096x4096_f32xf16xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<4x?x4096xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x4096x4096xf16>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>) {
      ^
result_llama_3_v4.mlir:434:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_4 {
  ^
result_llama_3_v4.mlir:504:14: error: 'iree_linalg_ext.set_encoding' op unhandled tensor operation
        %7 = linalg.batch_matmul_transpose_b ins(%3, %4 : tensor<4x?x4096xf32>, tensor<4x1024x4096xf16>) outs(%6 : tensor<4x?x1024xf32>) -> tensor<4x?x1024xf32>
             ^
result_llama_3_v4.mlir:495:7: error: 'func.func' op failed to create tensor equivalance classes
      func.func @prefill_bs4$async_dispatch_7_batch_matmul_transpose_b_4xDx1024x4096_f32xf16xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<4x?x4096xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x1024x4096xf16>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x1024xf32>>) {
      ^
result_llama_3_v4.mlir:489:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_7 {
  ^

Steps to reproduce your issue

  1. Cherry pick iree#17182
  2. Cherry pick llvm-project#90141
  3. ../iree-build/tools/iree-compile --mlir-disable-threading --iree-opt-const-eval=false --compile-to=flow ../batch_llama_3_8B.mlir -o result_llama_3.mlir
  4. ../iree-build/tools/iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx940 --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode result_llama_3.mlir -o llama_3.vmfb
  5. Error:
failed to translate executables
failed to translate executables
failed to translate executables
result_llama_3_v4.mlir:352:7: error: 'func.func' op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes
      func.func @prefill_bs4$async_dispatch_1_generic_4xDx4096_i64xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x?xi64>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>) {
      ^
result_llama_3_v4.mlir:346:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_1 {
  ^
result_llama_3_v4.mlir:449:14: error: 'iree_linalg_ext.set_encoding' op unhandled tensor operation
        %7 = linalg.batch_matmul_transpose_b ins(%3, %4 : tensor<4x?x4096xf32>, tensor<4x4096x4096xf16>) outs(%6 : tensor<4x?x4096xf32>) -> tensor<4x?x4096xf32>
             ^
result_llama_3_v4.mlir:440:7: error: 'func.func' op failed to create tensor equivalance classes
      func.func @prefill_bs4$async_dispatch_4_batch_matmul_transpose_b_4xDx4096x4096_f32xf16xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<4x?x4096xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x4096x4096xf16>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>) {
      ^
result_llama_3_v4.mlir:434:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_4 {
  ^
result_llama_3_v4.mlir:504:14: error: 'iree_linalg_ext.set_encoding' op unhandled tensor operation
        %7 = linalg.batch_matmul_transpose_b ins(%3, %4 : tensor<4x?x4096xf32>, tensor<4x1024x4096xf16>) outs(%6 : tensor<4x?x1024xf32>) -> tensor<4x?x1024xf32>
             ^
result_llama_3_v4.mlir:495:7: error: 'func.func' op failed to create tensor equivalance classes
      func.func @prefill_bs4$async_dispatch_7_batch_matmul_transpose_b_4xDx1024x4096_f32xf16xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<4x?x4096xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x1024x4096xf16>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x1024xf32>>) {
      ^
result_llama_3_v4.mlir:489:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_7 {
  ^

What component(s) does this issue relate to?

No response

Version information

f2746b4

Additional context

No response

@aviator19941 aviator19941 added the bug 🐞 Something isn't working label Apr 29, 2024
@benvanik
Copy link
Collaborator

lol I'm guessing something is multiplying by a dynamic dimension (sentinel -1) without checking :P

@benvanik
Copy link
Collaborator

(to reproduce we'll need the batch_llama_3_8B.mlir file, or the entire contents of the @prefill_bs4$async_dispatch_1_generic_4xDx4096_i64xf32 flow.executable/hal.executable op prior to the error )

@aviator19941
Copy link
Contributor Author

Yeah I'll upload it here, accidentally submitted the issue before uploading it here :)

@aviator19941
Copy link
Contributor Author

@benvanik I uploaded a zip that has the batch_llama_3_8B.mlir file

@hanhanW
Copy link
Contributor

hanhanW commented Apr 30, 2024

It looks like it failed in SetEncoding (or related passes). @pashu123 given that you want to get more involved in these tasks, would you like to triage the issue when you're available?

@pashu123
Copy link
Contributor

pashu123 commented Apr 30, 2024

@aviator19941 Do we need to cherry-pick some commit or checkout branch? On main branch I am noticing this

batch_llama_3_8B.mlir:1003:12: error: 'flow.tensor.reshape' op operand #2 must be variadic of index, but got 'i64'
    %339 = torch.aten.view %333, %338 : !torch.vtensor<[4,?,32,128],f32>, !torch.list<int> -> !torch.vtensor<[4,?,32,64,2],f32>
           ^
batch_llama_3_8B.mlir:1003:12: note: see current operation: %352 = "flow.tensor.reshape"(%331, %305, %351) <{operandSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<4x?x4096xf32>, index, i64) -> tensor<4x?x32x64x2xf32>

@pashu123
Copy link
Contributor

pashu123 commented Apr 30, 2024

We need to cherry-pick this #17182 for the 1st command to work.

@pashu123
Copy link
Contributor

@pashu123
Copy link
Contributor

This is failing in the // -----// IR Dump After GPUCheckResourceUsage Failed (iree-codegen-gpu-check-resource-usage) //----- //

@pashu123
Copy link
Contributor

@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs?

@pashu123
Copy link
Contributor

pashu123 commented Apr 30, 2024

A possible optimization that can be thought of is

 %0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
 %1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
    return %1 : !torch.vtensor<[4,?,4096],f32>

can be replaced by

 %0 = torch.aten.embedding %arg1, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16>
    return %1 : !torch.vtensor<[4,?,4096],f16>
 %1 = torch.prims.convert_element_type %0, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32>

i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: func.func op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes.

@hanhanW
Copy link
Contributor

hanhanW commented May 1, 2024

It is not compiled because vector.gather is lowered to a lot of vector ops -- which should be fixed.

The other issue is that we are having two generic ops and they are not fused in TileAndFuse. Because there are no operands dependency between two generic ops. It should be fixed before sending it to codegen. I don't have a good solution so far. Perhaps we should just disable the fusion for this kind of case. @MaheshRavishankar do you have any suggestions?

func.func @decode_bs4$async_dispatch_0_generic_4xDx4096_i64xf32() {
  %c0 = arith.constant 0 : index
  %c32_i64 = arith.constant 32 : i64
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = arith.extui %0 : i32 to i64
  %3 = arith.extui %1 : i32 to i64
  %4 = arith.shli %3, %c32_i64 : i64
  %5 = arith.ori %2, %4 : i64
  %6 = arith.index_castui %5 : i64 to index
  %7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>>
  %8 = flow.dispatch.workload.ordinal %6, 0 : index
  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4x?xi64>>{%8}
  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>{%8}
  %11 = flow.dispatch.tensor.load %7, offsets = [0, 0], sizes = [128256, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>> -> tensor<128256x4096xf16>
  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [4, %8], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x?xi64>>{%8} -> tensor<4x?xi64>
  %13 = tensor.empty(%8) : tensor<4x?x4096xf32>
  %14 = tensor.empty() : tensor<128256x4096xf32>
  %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%11 : tensor<128256x4096xf16>) outs(%14 : tensor<128256x4096xf32>) {
  ^bb0(%in: f16, %out: f32):
    %17 = arith.extf %in : f16 to f32
    linalg.yield %17 : f32
  } -> tensor<128256x4096xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
  ^bb0(%in: i64, %out: f32):
    %17 = arith.index_cast %in : i64 to index
    %18 = linalg.index 2 : index
    %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
    linalg.yield %extracted : f32
  } -> tensor<4x?x4096xf32>
  flow.dispatch.tensor.store %16, %10, offsets = [0, 0, 0], sizes = [4, %8, 4096], strides = [1, 1, 1] : tensor<4x?x4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>{%8}
  return
}

@aviator19941
Copy link
Contributor Author

aviator19941 commented May 1, 2024

@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs?

When I try to set the activation and attention dtypes to fp16 here, I run into

convertScalarToDtype should handle all the types
UNREACHABLE executed at iree/third_party/torch-mlir/lib/Conversion/Utils/Utils.cpp:355!

because it is trying to multiply complex<f16> and complex<f32> (repro). So I think it has to do with some dtype in the model that should be fp16, but is not.

@aviator19941
Copy link
Contributor Author

@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs?

In order to obtain the IR's:

  1. set up sharktank - sharktank
  2. rebase/checkout enable_llama3 branch
  3. clone and build llama.cpp - llama.cpp
  4. run export_paged_llm_v1 example - llama3 IR

@pashu123
Copy link
Contributor

pashu123 commented May 1, 2024

@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs?

When I try to set the activation and attention dtypes to fp16 here, I run into

convertScalarToDtype should handle all the types
UNREACHABLE executed at iree/third_party/torch-mlir/lib/Conversion/Utils/Utils.cpp:355!

because it is trying to multiply complex<f16> and complex<f32> (repro). So I think it has to do with some dtype in the model that should be fp16, but is not.

I think I can add the fix for this. It is required to enable the full Fp16 precision model.

@pashu123
Copy link
Contributor

pashu123 commented May 1, 2024

@aviator19941 You can get the latest fp16 IR from wget https://huggingface.co/prashantk/test_files/resolve/main/batch_llama_v1.mlir?download=true.

It's able to generate the .vmfb setting llvm-cpu backend with the command
iree-compile -iree-input-type=torch --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host batch_llama_v1.mlir -iree-opt-demote-i64-to-i32 -o llama3.vmfb

@pashu123
Copy link
Contributor

pashu123 commented May 1, 2024

You need to cherry-pick #17247

@hanhanW
Copy link
Contributor

hanhanW commented May 1, 2024

I think there are still action items in the issue, the look-up table fusion is scaring me. We should fix that at least. The tile sizes for vector.gather are problematic. They will be fully unrolled, which looks really bad.

@hanhanW hanhanW reopened this May 1, 2024
@pashu123
Copy link
Contributor

pashu123 commented May 1, 2024

I think there are still action items in the issue, the look-up table fusion is scaring me. We should fix that at least. The tile sizes for vector.gather are problematic. They will be fully unrolled, which looks really bad.

I never intended to close the issue; I don't know if it got closed automatically. Yes, for the mixed precision case in which we have activations represented as f32, we still have action items to do.

@hanhanW
Copy link
Contributor

hanhanW commented May 1, 2024

Confirmed that the fusion is not expected. @MaheshRavishankar will fix it.

For the gather codegen issue, @pashu123 could you create a input case for the generic op and see what's happening? I'm expecting that some dimensions would be collapsed, and the next issue could be tile size selection. #17227 could help, but there could other issues remaining on the table.

  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
  ^bb0(%in: i64, %out: f32):
    %17 = arith.index_cast %in : i64 to index
    %18 = linalg.index 2 : index
    %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
    linalg.yield %extracted : f32
  } -> tensor<4x?x4096xf32>

@MaheshRavishankar
Copy link
Contributor

Confirmed that the fusion is not expected. @MaheshRavishankar will fix it.

For the gather codegen issue, @pashu123 could you create a input case for the generic op and see what's happening? I'm expecting that some dimensions would be collapsed, and the next issue could be tile size selection. #17227 could help, but there could other issues remaining on the table.

  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
  ^bb0(%in: i64, %out: f32):
    %17 = arith.index_cast %in : i64 to index
    %18 = linalg.index 2 : index
    %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
    linalg.yield %extracted : f32
  } -> tensor<4x?x4096xf32>

Well, fusion is not expected cause I wasnt looking at it properly. It is expected and I think it is probably what you want at the dispatch level. If we dont fuse this we will materialize a tensor of size 128256x4096x4 bytes which is completely unnecessary.

The real issue though is that the op shouldnt be lowered this way. A better representation of this would be to do

%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<4x?xi64>) outs(%3 : tensor<4x?x4096xf32>) {
    ^bb0(%in: i64, %out: f32):
      %9 = arith.index_cast %in : i64 to index
      %10 = linalg.index 2 : index
      %extracted = tensor.extract %5[%9, %10] : tensor<128256x4096xf16>
      %extracted_f32 = arith.extf %extracted : f16 to f32
      linalg.yield %extracted_f32 : f32
    } -> tensor<4x?x4096xf32>

That should fix one of the issue Hanhan mentioned. If we can fix the front end to do this that would be best. If not, then we should just write an ad-hoc pattern that does this kind of fusion. There is really nothing structured about this to generalize here. This is just a specific pattern which is just a WAR to a front-end lowering issue.

@pashu123
Copy link
Contributor

pashu123 commented May 2, 2024

A possible optimization that can be thought of is

 %0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
 %1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
    return %1 : !torch.vtensor<[4,?,4096],f32>

can be replaced by

 %0 = torch.aten.embedding %arg1, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16>
    return %1 : !torch.vtensor<[4,?,4096],f16>
 %1 = torch.prims.convert_element_type %0, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32>

i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: func.func op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes.

@MaheshRavishankar, does this sound reasonable to add to torch-mlir?

@pashu123
Copy link
Contributor

pashu123 commented May 2, 2024

A possible optimization that can be thought of is

 %0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
 %1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
    return %1 : !torch.vtensor<[4,?,4096],f32>

can be replaced by

 %0 = torch.aten.embedding %arg1, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16>
    return %1 : !torch.vtensor<[4,?,4096],f16>
 %1 = torch.prims.convert_element_type %0, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32>

i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: func.func op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes.

@MaheshRavishankar, does this sound reasonable to add to torch-mlir?

Added here: llvm/torch-mlir#3277

@stellaraccident
Copy link
Collaborator

Not sure why this keeps closing

@stellaraccident
Copy link
Collaborator

A possible optimization that can be thought of is

 %0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
 %1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
    return %1 : !torch.vtensor<[4,?,4096],f32>

can be replaced by

 %0 = torch.aten.embedding %arg1, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16>
    return %1 : !torch.vtensor<[4,?,4096],f16>
 %1 = torch.prims.convert_element_type %0, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32>

i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: func.func op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes.

@MaheshRavishankar, does this sound reasonable to add to torch-mlir?

FYI, if you can make the torch embedding lookup good, that is best. But also I carved this out for a potential special op: it would be trivial to write a custom op at the frontend that expanded to whatever linalg you want.

@benvanik
Copy link
Collaborator

benvanik commented May 3, 2024

Not sure why this keeps closing

@pashu123 put a "fixes" command in a commit message and now anyone who has write access to the repo will close it when they merge in that commit to their forks of whatever :P
aartbik/torch-mlir@8c48135

@stellaraccident
Copy link
Collaborator

With that said, moving the cast across the embedding lookup is a common optimization.

I'm a bit worried that the default path on this generates basically unusable code, though.

@MaheshRavishankar
Copy link
Contributor

With that said, moving the cast across the embedding lookup is a common optimization.

I'm a bit worried that the default path on this generates basically unusable code, though.

That's fair, but we just don't represent gathers well. And if we clone the quantization into all its used dispatches (as we do now under current understanding of best way to handle dequantization) none of the transformations can actually fuse and generate this code. The producer consumer dependency only materializes from within the body of the consumer. Nothing accounts for that and it just falls off the cliff

@qedawkins
Copy link
Contributor

Not sure why this keeps closing

@pashu123 put a "fixes" command in a commit message and now anyone who has write access to the repo will close it when they merge in that commit to their forks of whatever :P aartbik/torch-mlir@8c48135

Why is Github unable to prevent actions on forks from spamming main repos... Seems like a big anti-feature.

@pashu123
Copy link
Contributor

pashu123 commented May 3, 2024

@aviator19941 Do you have instructions on how to run llama3 for the IREE backend?

@MaheshRavishankar
Copy link
Contributor

With that said, moving the cast across the embedding lookup is a common optimization.

I'm a bit worried that the default path on this generates basically unusable code, though.

More I think about this, it might be worth just doing the fusion of

 %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%11 : tensor<128256x4096xf16>) outs(%14 : tensor<128256x4096xf32>) {
  ^bb0(%in: f16, %out: f32):
    %17 = arith.extf %in : f16 to f32
    linalg.yield %17 : f32
  } -> tensor<128256x4096xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
  ^bb0(%in: i64, %out: f32):
    %17 = arith.index_cast %in : i64 to index
    %18 = linalg.index 2 : index
    %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
    linalg.yield %extracted : f32
  } -> tensor<4x?x4096xf32>

to

%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<4x?xi64>) outs(%3 : tensor<4x?x4096xf32>) {
    ^bb0(%in: i64, %out: f32):
      %9 = arith.index_cast %in : i64 to index
      %10 = linalg.index 2 : index
      %extracted = tensor.extract %5[%9, %10] : tensor<128256x4096xf16>
      %extracted_f32 = arith.extf %extracted : f16 to f32
      linalg.yield %extracted_f32 : f32
    } -> tensor<4x?x4096xf32>

as a one-off canonicalization for now to not fall off a cliff. Might be hard to make it future proof, but more examples will help.
@IanWood1 just FYI for something for us to discuss (and for you to pick up as a simple task). Please make sure we chat about this next time we sync.

@benvanik
Copy link
Collaborator

benvanik commented May 3, 2024

Agreed at handling even if not generalized as it's pretty catastrophic to clone embeddings.

I think the more durable fix may be proper propagation: we should sink any exts down/hoist truncs up across memcpy-like ops (such as this gather or a scatter). We may with the current logic be in a better situation but still want to ensure we don't materialize ext/trunc dispatches unless absolutely required.

archana-ramalingam pushed a commit to archana-ramalingam/torch-mlir that referenced this issue May 8, 2024
The conversion of complex type wasn't supported or checked; the support
and required tests were added.

Fixes:
iree-org/iree#17226 (comment)
@zjgarvey
Copy link
Contributor

Noting that this issue also occurs with some other models. In the SHARK-TestSuite, the onnx/models/RAFT_vaiq_int8 also encounters a similar issue. To reproduce, set up the test suite, and run

python run.py --cachedir=/path/to/.cache/ -t onnx/models/RAFT_vaiq_int8/ -m onnx -c /path/to/torch-mlir/build/ -i /path/to/iree-build/ --torchtolinalg

with an up-to-date torch-mlir and iree build.

@c-rhodes
Copy link
Contributor

The other issue is that we are having two generic ops and they are not fused in TileAndFuse. Because there are no operands dependency between two generic ops. It should be fixed before sending it to codegen. I don't have a good solution so far. Perhaps we should just disable the fusion for this kind of case. @MaheshRavishankar do you have any suggestions?

One of our fp16 models is failing to compile from what looks like the same issue.

reproducer:

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
func.func @main_dispatch_0_elementwise_1x30522x128_f16xf32(%arg0: tensor<1x30522x128xf16>, %arg1: tensor<1x128xi32>, %arg2: tensor<128xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
  %c0 = arith.constant 0 : index
  %4 = tensor.empty() : tensor<1x128x128xf32>
  %5 = tensor.empty() : tensor<1x30522x128xf32>
  %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x30522x128xf16>) outs(%5 : tensor<1x30522x128xf32>) {
  ^bb0(%in: f16, %out: f32):
    %8 = arith.extf %in : f16 to f32
    linalg.yield %8 : f32
  } -> tensor<1x30522x128xf32>
  %7 = linalg.generic {indexing_maps = [#map1, #map2, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %arg2, %arg3 : tensor<1x128xi32>, tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<1x128x128xf32>) {
  ^bb0(%in: i32, %in_0: f32, %in_1: f32, %out: f32):
    %8 = linalg.index 2 : index
    %9 = arith.index_cast %in : i32 to index
    %extracted = tensor.extract %6[%c0, %9, %8] : tensor<1x30522x128xf32>
    %10 = arith.mulf %extracted, %in_0 : f32
    %11 = arith.addf %10, %in_1 : f32
    linalg.yield %11 : f32
  } -> tensor<1x128x128xf32>
  return %7 : tensor<1x128x128xf32>
}

invocation: iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=+sve --iree-llvmcpu-enable-scalable-vectorization=true main_dispatch_0_elementwise_1x30522x128_f16xf32.mlir -o /dev/null

fails with:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=+sve --iree-llvmcpu-enable-scalable-vectorization=true main_dispatch_0_elementwise_1x30522x128_f16xf32.mlir -o /dev/null
main_dispatch_0_elementwise_1x30522x128_f16xf32.mlir:13:8: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 15627264 bytes
  %7 = linalg.generic {indexing_maps = [#map1, #map2, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %arg2, %arg3 : tensor<1x128xi32>, tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<1x128x128xf32>) {
       ^

the tensor.empty isn't eliminated byEliminateEmptyTensorsPass hence stack allocation limit is exceeded, but look at print-after-all I noticed the first generic doing the f16 -> f32 extension isn't getting tiled or fused:

// -----// IR Dump Before LLVMCPUTileAndFusePass (iree-llvmcpu-tile-and-fuse) //----- //
func.func @main_dispatch_0_elementwise_1x30522x128_f16xf32_dispatch_0_elementwise_1x30522x128_f16xf32() attributes {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert, {enable_loop_peeling}>} {
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x30522x128xf16>>
  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x128xi32>>
  %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128xf32>>
  %3 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128xf32>>
  %4 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(4) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x128x128xf32>>
  %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [1, 30522, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x30522x128xf16>> -> tensor<1x30522x128xf16>
  %6 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x128xi32>> -> tensor<1x128xi32>
  %7 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [128], strides = [1] : !flow.dispatch.tensor<readonly:tensor<128xf32>> -> tensor<128xf32>
  %8 = flow.dispatch.tensor.load %3, offsets = [0], sizes = [128], strides = [1] : !flow.dispatch.tensor<readonly:tensor<128xf32>> -> tensor<128xf32>
  %9 = tensor.empty() : tensor<1x128x128xf32>
  %10 = tensor.empty() : tensor<1x30522x128xf32>
  %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor<1x30522x128xf16>) outs(%10 : tensor<1x30522x128xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0], [1, 1, 4], [0, 0, 0], [0, 0, 0]]>} {
  ^bb0(%in: f16, %out: f32):
    %13 = arith.extf %in : f16 to f32
    linalg.yield %13 : f32
  } -> tensor<1x30522x128xf32>
  %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %7, %8 : tensor<1x128xi32>, tensor<128xf32>, tensor<128xf32>) outs(%9 : tensor<1x128x128xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0], [1, 1, 4], [0, 0, 0], [0, 0, 0]]>} {
  ^bb0(%in: i32, %in_0: f32, %in_1: f32, %out: f32):
    %13 = linalg.index 2 : index
    %14 = arith.index_cast %in : i32 to index
    %extracted = tensor.extract %11[%c0, %14, %13] : tensor<1x30522x128xf32>
    %15 = arith.mulf %extracted, %in_0 : f32
    %16 = arith.addf %15, %in_1 : f32
    linalg.yield %16 : f32
  } -> tensor<1x128x128xf32>
  flow.dispatch.tensor.store %12, %4, offsets = [0, 0, 0], sizes = [1, 128, 128], strides = [1, 1, 1] : tensor<1x128x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x128x128xf32>>
  return
}

// -----// IR Dump After LLVMCPUTileAndFusePass (iree-llvmcpu-tile-and-fuse) //----- //
func.func @main_dispatch_0_elementwise_1x30522x128_f16xf32_dispatch_0_elementwise_1x30522x128_f16xf32() attributes {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert, {enable_loop_peeling}>} {
  %c4 = arith.constant 4 : index
  %c128 = arith.constant 128 : index
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x30522x128xf16>>
  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x128xi32>>
  %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128xf32>>
  %3 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128xf32>>
  %4 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(4) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x128x128xf32>>
  %5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [1, 30522, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x30522x128xf16>> -> tensor<1x30522x128xf16>
  %6 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x128xi32>> -> tensor<1x128xi32>
  %7 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [128], strides = [1] : !flow.dispatch.tensor<readonly:tensor<128xf32>> -> tensor<128xf32>
  %8 = flow.dispatch.tensor.load %3, offsets = [0], sizes = [128], strides = [1] : !flow.dispatch.tensor<readonly:tensor<128xf32>> -> tensor<128xf32>
  %9 = tensor.empty() : tensor<1x128x128xf32>
  %10 = tensor.empty() : tensor<1x30522x128xf32>
  %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor<1x30522x128xf16>) outs(%10 : tensor<1x30522x128xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0], [1, 1, 4], [0, 0, 0], [0, 0, 0]]>} {
  ^bb0(%in: f16, %out: f32):
    %13 = arith.extf %in : f16 to f32
    linalg.yield %13 : f32
  } -> tensor<1x30522x128xf32>
  %12 = scf.for %arg0 = %c0 to %c128 step %c1 iter_args(%arg1 = %9) -> (tensor<1x128x128xf32>) {
    %13 = scf.for %arg2 = %c0 to %c128 step %c4 iter_args(%arg3 = %arg1) -> (tensor<1x128x128xf32>) {
      %extracted_slice = tensor.extract_slice %6[0, %arg0] [1, 1] [1, 1] : tensor<1x128xi32> to tensor<1x1xi32>
      %extracted_slice_0 = tensor.extract_slice %7[%arg2] [4] [1] : tensor<128xf32> to tensor<4xf32>
      %extracted_slice_1 = tensor.extract_slice %8[%arg2] [4] [1] : tensor<128xf32> to tensor<4xf32>
      %extracted_slice_2 = tensor.extract_slice %arg3[0, %arg0, %arg2] [1, 1, 4] [1, 1, 1] : tensor<1x128x128xf32> to tensor<1x1x4xf32>
      %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_0, %extracted_slice_1 : tensor<1x1xi32>, tensor<4xf32>, tensor<4xf32>) outs(%extracted_slice_2 : tensor<1x1x4xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0], [1, 1, 4], [0, 0, 0], [0, 0, 0]]>} {
      ^bb0(%in: i32, %in_3: f32, %in_4: f32, %out: f32):
        %15 = linalg.index 2 : index
        %16 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%15, %arg2)
        %17 = arith.index_cast %in : i32 to index
        %extracted = tensor.extract %11[%c0, %17, %16] : tensor<1x30522x128xf32>
        %18 = arith.mulf %extracted, %in_3 : f32
        %19 = arith.addf %18, %in_4 : f32
        linalg.yield %19 : f32
      } -> tensor<1x1x4xf32>
      %inserted_slice = tensor.insert_slice %14 into %arg3[0, %arg0, %arg2] [1, 1, 4] [1, 1, 1] : tensor<1x1x4xf32> into tensor<1x128x128xf32>
      scf.yield %inserted_slice : tensor<1x128x128xf32>
    }
    scf.yield %13 : tensor<1x128x128xf32>
  }
  flow.dispatch.tensor.store %12, %4, offsets = [0, 0, 0], sizes = [1, 128, 128], strides = [1, 1, 1] : tensor<1x128x128xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x128x128xf32>>
  return
}

@benvanik
Copy link
Collaborator

(oof, yeah, never want exts/truncs/etc on their own! good catches and something that's should be in the top 10 things-to-triage checklist for any model before doing deeper performance/memory work)

@IanWood1
Copy link
Contributor

IanWood1 commented Oct 24, 2024

Here's the IR dump before GatherFusionPattern is applied in FusionPreprocessingPass (from this PR #17341). This is before any reshape propagation, so %expanded is getting in the way. The reshape ops get moved to the edges a bit later on, so maybe this pattern could get added there? That should fuse these 2 linalg.generic ops. But maybe there is a more robust solution.

// -----// IR Dump Before FusionPreprocessingPass (iree-dispatch-creation-fusion-preprocessing) //----- //
util.func public @main_dispatch_0_elementwise_1x30522x128_f16xf32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @main_dispatch_0_elementwise_1x30522x128_f16xf32(%input0: tensor<1x30522x128xf16>, %input1: tensor<1x128xi32>, %input2: tensor<128xf32>, %input3: tensor<128xf32>) -> (%output0: tensor<1x128x128xf32>)"}} {
  %c0 = arith.constant 0 : index
  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<1x30522x128xf16>
  %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<1x128xi32>
  %2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<128xf32>
  %3 = hal.tensor.import %arg3 "input3" : !hal.buffer_view -> tensor<128xf32>
  %collapsed = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<1x30522x128xf16> into tensor<30522x128xf16>
  %4 = tensor.empty() : tensor<30522x128xf32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<30522x128xf16>) outs(%4 : tensor<30522x128xf32>) {
  ^bb0(%in: f16, %out: f32):
    %9 = arith.extf %in : f16 to f32
    linalg.yield %9 : f32
  } -> tensor<30522x128xf32>
  %expanded = tensor.expand_shape %5 [[0, 1], [2]] output_shape [1, 30522, 128] : tensor<30522x128xf32> into tensor<1x30522x128xf32>
  %collapsed_0 = tensor.collapse_shape %1 [[0, 1]] : tensor<1x128xi32> into tensor<128xi32>
  %6 = tensor.empty() : tensor<128x128xf32>
  %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_0, %2, %3 : tensor<128xi32>, tensor<128xf32>, tensor<128xf32>) outs(%6 : tensor<128x128xf32>) {
  ^bb0(%in: i32, %in_2: f32, %in_3: f32, %out: f32):
    %9 = linalg.index 1 : index
    %10 = arith.index_cast %in : i32 to index
    %extracted = tensor.extract %expanded[%c0, %10, %9] : tensor<1x30522x128xf32>
    %11 = arith.mulf %extracted, %in_2 : f32
    %12 = arith.addf %11, %in_3 : f32
    linalg.yield %12 : f32
  } -> tensor<128x128xf32>
  %expanded_1 = tensor.expand_shape %7 [[0, 1], [2]] output_shape [1, 128, 128] : tensor<128x128xf32> into tensor<1x128x128xf32>
  %8 = hal.tensor.export %expanded_1 "output0" : tensor<1x128x128xf32> -> !hal.buffer_view
  util.return %8 : !hal.buffer_view
}

@c-rhodes
Copy link
Contributor

Here's the IR dump before GatherFusionPattern is applied in FusionPreprocessingPass (from this PR #17341). This is before any reshape propagation, so %expanded is getting in the way. The reshape ops get moved to the edges a bit later on, so maybe this pattern could get added there? That should fuse these 2 linalg.generic ops. But maybe there is a more robust solution.

thanks for the pointers! Adding iree-dispatch-creation-preprocessing-pipeline before iree-dispatch-creation-pipeline fixes the issue. Is there a reason this is only used for testing and is this an acceptable fix?

diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.h b/compiler/src/iree/compiler/DispatchCreation/Passes.h
index e129fe654d..9020465fe7 100644
--- a/compiler/src/iree/compiler/DispatchCreation/Passes.h
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.h
@@ -27,6 +27,8 @@ struct TransformOptions : public PassPipelineOptions<TransformOptions> {};

 void buildDispatchCreationPassPipeline(
     OpPassManager &passManager, const TransformOptions &transformOptions);
+void addDispatchRegionCreationPreprocessingPasses(
+    OpPassManager &passManager);

 //===----------------------------------------------------------------------===//
 // Register all Passes
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
index 77319e26f7..952b9eb0a9 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp
@@ -277,6 +277,8 @@ void buildIREEVMTransformPassPipeline(
       IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "DispatchCreation");
       if (hooks.beforePhase)
         hooks.beforePhase(IREEVMPipelinePhase::DispatchCreation, passManager);
+      DispatchCreation::addDispatchRegionCreationPreprocessingPasses(
+          passManager);
       DispatchCreation::buildDispatchCreationPassPipeline(
           passManager, dispatchCreationOptions);
       if (hooks.afterPhase)

@MaheshRavishankar
Copy link
Contributor

(oof, yeah, never want exts/truncs/etc on their own! good catches and something that's should be in the top 10 things-to-triage checklist for any model before doing deeper performance/memory work)

They arent on their own. Its within the dispatch.

@IanWood1
Copy link
Contributor

thanks for the pointers! Adding iree-dispatch-creation-preprocessing-pipeline before iree-dispatch-creation-pipeline fixes the issue. Is there a reason this is only used for testing and is this an acceptable fix?

@c-rhodes addDispatchRegionCreationPreprocessingPasses is run as a part of the DispatchCreation pipeline. Also, FusionPreprocessingPass is poorly named and for some reason not a part of addDispatchRegionCreationPreprocessingPasses.

I think the problem is that the pattern is very sensitive to the state of the IR, so small changes will cause the pattern to succeed/fail (e.g. running DispatchRegionCreationPreprocessingPasses twice). Would you be able to link to the input IR?

@c-rhodes
Copy link
Contributor

thanks for the pointers! Adding iree-dispatch-creation-preprocessing-pipeline before iree-dispatch-creation-pipeline fixes the issue. Is there a reason this is only used for testing and is this an acceptable fix?

@c-rhodes addDispatchRegionCreationPreprocessingPasses is run as a part of the DispatchCreation pipeline. Also, FusionPreprocessingPass is poorly named and for some reason not a part of addDispatchRegionCreationPreprocessingPasses.

Ah ok, I didn't realise it was already being run.

I think the problem is that the pattern is very sensitive to the state of the IR, so small changes will cause the pattern to succeed/fail (e.g. running DispatchRegionCreationPreprocessingPasses twice). Would you be able to link to the input IR?

this snippet is the relevant input

func.func @main(%arg0: tensor<1x128xi32>) -> (tensor<1x2xf32>) {
  ...
  %38 = "tosa.const"() <{value = dense_resource<__elided__> : tensor<128xf16>}> : () -> tensor<128xf16>
  %39 = "tosa.const"() <{value = dense_resource<__elided__> : tensor<128xf16>}> : () -> tensor<128xf16>
  %40 = "tosa.const"() <{value = dense_resource<__elided__> : tensor<30522x128xf16>}> : () -> tensor<30522x128xf16>
  %41 = tosa.cast %40 : (tensor<30522x128xf16>) -> tensor<30522x128xf32>
  %42 = tosa.reshape %41 {new_shape = array<i64: 1, 30522, 128>} : (tensor<30522x128xf32>) -> tensor<1x30522x128xf32>
  %43 = tosa.gather %42, %arg0 : (tensor<1x30522x128xf32>, tensor<1x128xi32>) -> tensor<1x128x128xf32>
  %44 = tosa.cast %39 : (tensor<128xf16>) -> tensor<128xf32>
  %45 = tosa.reshape %44 {new_shape = array<i64: 1, 1, 128>} : (tensor<128xf32>) -> tensor<1x1x128xf32>
  %46 = tosa.mul %43, %45 {shift = 0 : i8} : (tensor<1x128x128xf32>, tensor<1x1x128xf32>) -> tensor<1x128x128xf32>
  %47 = tosa.cast %38 : (tensor<128xf16>) -> tensor<128xf32>
  %48 = tosa.reshape %47 {new_shape = array<i64: 1, 1, 128>} : (tensor<128xf32>) -> tensor<1x1x128xf32>
  %49 = tosa.add %46, %48 : (tensor<1x128x128xf32>, tensor<1x1x128xf32>) -> tensor<1x128x128xf32>
  ...
}

@c-rhodes
Copy link
Contributor

I forgot to mention the model is Bert tiny (fp16). I've been going wider on models in our testing recently and have hit quite a few issues that I'm currently working through, particularly with fp16 models

c-rhodes pushed a commit that referenced this issue Oct 29, 2024
I think it makes sense to run `FusionPreprocessingPass` before
`ElementwiseOpFusionPass` because it helps put the IR in a better state
for fusion (e.g. interchanging `linalg.generic` indexing maps). But
also, reshapes have been propagated to the edges of the program, which
allows the `GatherFusionPattern` to be more effective.


Fixes compilation error from
#17226 (comment).

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Eliasj42 pushed a commit that referenced this issue Oct 31, 2024
I think it makes sense to run `FusionPreprocessingPass` before
`ElementwiseOpFusionPass` because it helps put the IR in a better state
for fusion (e.g. interchanging `linalg.generic` indexing maps). But
also, reshapes have been propagated to the edges of the program, which
allows the `GatherFusionPattern` to be more effective.

Fixes compilation error from
#17226 (comment).

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Signed-off-by: Elias Joseph <eljoseph@amd.com>
giacs-epic pushed a commit to giacs-epic/iree that referenced this issue Dec 4, 2024
…org#18920)

I think it makes sense to run `FusionPreprocessingPass` before
`ElementwiseOpFusionPass` because it helps put the IR in a better state
for fusion (e.g. interchanging `linalg.generic` indexing maps). But
also, reshapes have been propagated to the edges of the program, which
allows the `GatherFusionPattern` to be more effective.

Fixes compilation error from
iree-org#17226 (comment).

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants