-
Notifications
You must be signed in to change notification settings - Fork 646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama-3-8B f16 fails to compile to vmfb #17226
Comments
lol I'm guessing something is multiplying by a dynamic dimension (sentinel -1) without checking :P |
(to reproduce we'll need the batch_llama_3_8B.mlir file, or the entire contents of the |
Yeah I'll upload it here, accidentally submitted the issue before uploading it here :) |
@benvanik I uploaded a zip that has the batch_llama_3_8B.mlir file |
It looks like it failed in SetEncoding (or related passes). @pashu123 given that you want to get more involved in these tasks, would you like to triage the issue when you're available? |
@aviator19941 Do we need to cherry-pick some commit or checkout branch? On main branch I am noticing this
|
We need to cherry-pick this #17182 for the 1st command to work. |
Here's the minimal repro https://gist.github.com/pashu123/45fe64caa21cfdfa9890698660184a44 |
This is failing in the |
@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs? |
A possible optimization that can be thought of is
can be replaced by
i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: |
It is not compiled because vector.gather is lowered to a lot of vector ops -- which should be fixed. The other issue is that we are having two generic ops and they are not fused in TileAndFuse. Because there are no operands dependency between two generic ops. It should be fixed before sending it to codegen. I don't have a good solution so far. Perhaps we should just disable the fusion for this kind of case. @MaheshRavishankar do you have any suggestions? func.func @decode_bs4$async_dispatch_0_generic_4xDx4096_i64xf32() {
%c0 = arith.constant 0 : index
%c32_i64 = arith.constant 32 : i64
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = arith.extui %0 : i32 to i64
%3 = arith.extui %1 : i32 to i64
%4 = arith.shli %3, %c32_i64 : i64
%5 = arith.ori %2, %4 : i64
%6 = arith.index_castui %5 : i64 to index
%7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>>
%8 = flow.dispatch.workload.ordinal %6, 0 : index
%9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4x?xi64>>{%8}
%10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>{%8}
%11 = flow.dispatch.tensor.load %7, offsets = [0, 0], sizes = [128256, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>> -> tensor<128256x4096xf16>
%12 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [4, %8], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x?xi64>>{%8} -> tensor<4x?xi64>
%13 = tensor.empty(%8) : tensor<4x?x4096xf32>
%14 = tensor.empty() : tensor<128256x4096xf32>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%11 : tensor<128256x4096xf16>) outs(%14 : tensor<128256x4096xf32>) {
^bb0(%in: f16, %out: f32):
%17 = arith.extf %in : f16 to f32
linalg.yield %17 : f32
} -> tensor<128256x4096xf32>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
^bb0(%in: i64, %out: f32):
%17 = arith.index_cast %in : i64 to index
%18 = linalg.index 2 : index
%extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
linalg.yield %extracted : f32
} -> tensor<4x?x4096xf32>
flow.dispatch.tensor.store %16, %10, offsets = [0, 0, 0], sizes = [4, %8, 4096], strides = [1, 1, 1] : tensor<4x?x4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>{%8}
return
} |
When I try to set the activation and attention dtypes to fp16 here, I run into
because it is trying to multiply |
In order to obtain the IR's: |
I think I can add the fix for this. It is required to enable the full Fp16 precision model. |
@aviator19941 You can get the latest fp16 IR from It's able to generate the .vmfb setting llvm-cpu backend with the command |
You need to cherry-pick #17247 |
I think there are still action items in the issue, the look-up table fusion is scaring me. We should fix that at least. The tile sizes for vector.gather are problematic. They will be fully unrolled, which looks really bad. |
I never intended to close the issue; I don't know if it got closed automatically. Yes, for the mixed precision case in which we have activations represented as f32, we still have action items to do. |
Confirmed that the fusion is not expected. @MaheshRavishankar will fix it. For the gather codegen issue, @pashu123 could you create a input case for the generic op and see what's happening? I'm expecting that some dimensions would be collapsed, and the next issue could be tile size selection. #17227 could help, but there could other issues remaining on the table.
|
Well, fusion is not expected cause I wasnt looking at it properly. It is expected and I think it is probably what you want at the dispatch level. If we dont fuse this we will materialize a tensor of size The real issue though is that the op shouldnt be lowered this way. A better representation of this would be to do
That should fix one of the issue Hanhan mentioned. If we can fix the front end to do this that would be best. If not, then we should just write an ad-hoc pattern that does this kind of fusion. There is really nothing structured about this to generalize here. This is just a specific pattern which is just a WAR to a front-end lowering issue. |
@MaheshRavishankar, does this sound reasonable to add to torch-mlir? |
Added here: llvm/torch-mlir#3277 |
Not sure why this keeps closing |
FYI, if you can make the torch embedding lookup good, that is best. But also I carved this out for a potential special op: it would be trivial to write a custom op at the frontend that expanded to whatever linalg you want. |
@pashu123 put a "fixes" command in a commit message and now anyone who has write access to the repo will close it when they merge in that commit to their forks of whatever :P |
With that said, moving the cast across the embedding lookup is a common optimization. I'm a bit worried that the default path on this generates basically unusable code, though. |
That's fair, but we just don't represent gathers well. And if we clone the quantization into all its used dispatches (as we do now under current understanding of best way to handle dequantization) none of the transformations can actually fuse and generate this code. The producer consumer dependency only materializes from within the body of the consumer. Nothing accounts for that and it just falls off the cliff |
Why is Github unable to prevent actions on forks from spamming main repos... Seems like a big anti-feature. |
@aviator19941 Do you have instructions on how to run llama3 for the IREE backend? |
More I think about this, it might be worth just doing the fusion of
to
as a one-off canonicalization for now to not fall off a cliff. Might be hard to make it future proof, but more examples will help. |
Agreed at handling even if not generalized as it's pretty catastrophic to clone embeddings. I think the more durable fix may be proper propagation: we should sink any exts down/hoist truncs up across memcpy-like ops (such as this gather or a scatter). We may with the current logic be in a better situation but still want to ensure we don't materialize ext/trunc dispatches unless absolutely required. |
The conversion of complex type wasn't supported or checked; the support and required tests were added. Fixes: iree-org/iree#17226 (comment)
Noting that this issue also occurs with some other models. In the SHARK-TestSuite, the onnx/models/RAFT_vaiq_int8 also encounters a similar issue. To reproduce, set up the test suite, and run
with an up-to-date torch-mlir and iree build. |
One of our fp16 models is failing to compile from what looks like the same issue. reproducer:
invocation: fails with:
the
|
(oof, yeah, never want exts/truncs/etc on their own! good catches and something that's should be in the top 10 things-to-triage checklist for any model before doing deeper performance/memory work) |
Here's the IR dump before // -----// IR Dump Before FusionPreprocessingPass (iree-dispatch-creation-fusion-preprocessing) //----- //
util.func public @main_dispatch_0_elementwise_1x30522x128_f16xf32(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @main_dispatch_0_elementwise_1x30522x128_f16xf32(%input0: tensor<1x30522x128xf16>, %input1: tensor<1x128xi32>, %input2: tensor<128xf32>, %input3: tensor<128xf32>) -> (%output0: tensor<1x128x128xf32>)"}} {
%c0 = arith.constant 0 : index
%0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<1x30522x128xf16>
%1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<1x128xi32>
%2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<128xf32>
%3 = hal.tensor.import %arg3 "input3" : !hal.buffer_view -> tensor<128xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<1x30522x128xf16> into tensor<30522x128xf16>
%4 = tensor.empty() : tensor<30522x128xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<30522x128xf16>) outs(%4 : tensor<30522x128xf32>) {
^bb0(%in: f16, %out: f32):
%9 = arith.extf %in : f16 to f32
linalg.yield %9 : f32
} -> tensor<30522x128xf32>
%expanded = tensor.expand_shape %5 [[0, 1], [2]] output_shape [1, 30522, 128] : tensor<30522x128xf32> into tensor<1x30522x128xf32>
%collapsed_0 = tensor.collapse_shape %1 [[0, 1]] : tensor<1x128xi32> into tensor<128xi32>
%6 = tensor.empty() : tensor<128x128xf32>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_0, %2, %3 : tensor<128xi32>, tensor<128xf32>, tensor<128xf32>) outs(%6 : tensor<128x128xf32>) {
^bb0(%in: i32, %in_2: f32, %in_3: f32, %out: f32):
%9 = linalg.index 1 : index
%10 = arith.index_cast %in : i32 to index
%extracted = tensor.extract %expanded[%c0, %10, %9] : tensor<1x30522x128xf32>
%11 = arith.mulf %extracted, %in_2 : f32
%12 = arith.addf %11, %in_3 : f32
linalg.yield %12 : f32
} -> tensor<128x128xf32>
%expanded_1 = tensor.expand_shape %7 [[0, 1], [2]] output_shape [1, 128, 128] : tensor<128x128xf32> into tensor<1x128x128xf32>
%8 = hal.tensor.export %expanded_1 "output0" : tensor<1x128x128xf32> -> !hal.buffer_view
util.return %8 : !hal.buffer_view
} |
thanks for the pointers! Adding
|
They arent on their own. Its within the dispatch. |
@c-rhodes I think the problem is that the pattern is very sensitive to the state of the IR, so small changes will cause the pattern to succeed/fail (e.g. running |
Ah ok, I didn't realise it was already being run.
this snippet is the relevant input
|
I forgot to mention the model is Bert tiny (fp16). I've been going wider on models in our testing recently and have hit quite a few issues that I'm currently working through, particularly with fp16 models |
I think it makes sense to run `FusionPreprocessingPass` before `ElementwiseOpFusionPass` because it helps put the IR in a better state for fusion (e.g. interchanging `linalg.generic` indexing maps). But also, reshapes have been propagated to the edges of the program, which allows the `GatherFusionPattern` to be more effective. Fixes compilation error from #17226 (comment). --------- Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
I think it makes sense to run `FusionPreprocessingPass` before `ElementwiseOpFusionPass` because it helps put the IR in a better state for fusion (e.g. interchanging `linalg.generic` indexing maps). But also, reshapes have been propagated to the edges of the program, which allows the `GatherFusionPattern` to be more effective. Fixes compilation error from #17226 (comment). --------- Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu> Signed-off-by: Elias Joseph <eljoseph@amd.com>
…org#18920) I think it makes sense to run `FusionPreprocessingPass` before `ElementwiseOpFusionPass` because it helps put the IR in a better state for fusion (e.g. interchanging `linalg.generic` indexing maps). But also, reshapes have been propagated to the edges of the program, which allows the `GatherFusionPattern` to be more effective. Fixes compilation error from iree-org#17226 (comment). --------- Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu> Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
batch_llama_3_8B.zip
What happened?
When trying to compile this mlir file, I get the shared memory error below:
Steps to reproduce your issue
What component(s) does this issue relate to?
No response
Version information
f2746b4
Additional context
No response
The text was updated successfully, but these errors were encountered: