Skip to content

Commit

Permalink
Update spirv tests for changes to sub-byte lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysz00 committed Jan 17, 2025
1 parent a745991 commit ea624f8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
19 changes: 9 additions & 10 deletions compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matvec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ hal.executable @i4_dequant_unit_matmul_f16 {

// CHECK-DAG: %[[CSTVEC4XI32_255:.+]] = spirv.Constant dense<255> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC4XI32_0:.+]] = spirv.Constant dense<0> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_4:.+]] = spirv.Constant dense<4> : vector<2xi32>
// CHECK-DAG: %[[CSTVEC2XI32_15:.+]] = spirv.Constant dense<15> : vector<2xi32>
// CHECK-DAG: %[[CSTVEC4XI32_0_4:.+]] = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC4XI32_15__16:.+]] = spirv.Constant dense<[15, -16, 15, -16]> : vector<4xi32>

// CHECK: spirv.mlir.loop

// Load the quantized weight and get 8xi4 out of it.
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xi32>
// CHECK: %[[SHUF01:.+]] = spirv.VectorShuffle [0 : i32, 1 : i32] %[[LOAD]], %[[LOAD]] : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK: %[[LOW4:.+]] = spirv.BitwiseAnd %[[SHUF01]], %[[CSTVEC2XI32_15]] : vector<2xi32>
// CHECK: %[[HIGH4:.+]] = spirv.ShiftRightLogical %[[SHUF01]], %[[CSTVEC2XI32_4]] : vector<2xi32>, vector<2xi32>
// CHECK: %[[LOW4HIGH4:.+]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[LOW4]], %[[HIGH4]] : vector<2xi32>, {{.*}} -> vector<4xi32>
// CHECK: %[[LOW4HIGH4_ZEROUPPER:.+]] = spirv.BitwiseAnd %[[LOW4HIGH4]], %[[CSTVEC4XI32_255]] : vector<4xi32>
// CHECK: %[[SHUF0011:.+]] = spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %[[SHUF01]], %[[SHUF01]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[SHUF0011]], %[[CSTVEC4XI32_15__16]] : vector<4xi32>
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[MASKED]], %[[CSTVEC4XI32_0_4]] : vector<4xi32>, vector<4xi32>
// CHECK: %[[LOW4HIGH4_ZEROUPPER:.+]] = spirv.BitwiseAnd %[[SHIFTED]], %[[CSTVEC4XI32_255]] : vector<4xi32>

// CHECK: %[[SHUF23:.+]] = spirv.VectorShuffle [2 : i32, 3 : i32] %[[LOAD:.+]], %[[LOAD:.+]] : vector<4xi32>, vector<4xi32> -> vector<2xi32>

Expand Down Expand Up @@ -186,8 +186,8 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
// CHECK-DAG: %[[C0:.+]] = spirv.Constant 0 : i32
// CHECK-DAG: %[[CSTVEC4XF16_1:.+]] = spirv.Constant dense<1.000000e+00> : vector<4xf16>
// CHECK-DAG: %[[CSTVEC4XI32_255:.+]] = spirv.Constant dense<255> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_4:.+]] = spirv.Constant dense<4> : vector<2xi32>
// CHECK-DAG: %[[CSTVEC2XI32_15:.+]] = spirv.Constant dense<15> : vector<2xi32>
// CHECK-DAG: %[[CSTVEC2XI32_1:.+]] = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_2:.+]] = spirv.Constant dense<[15, -16, 15, -16]> : vector<4xi32>

// CHECK: %[[WIDX:.+]] = spirv.CompositeExtract %{{.*}}[0 : i32] : vector<3xi32>
// CHECK: %[[PCPTR:.+]] = spirv.AccessChain %{{.*}}[{{.*}}, %[[C0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
Expand All @@ -209,8 +209,7 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
// CHECK: %[[ACCESS:.+]] = spirv.AccessChain %[[RADDR]][{{.*}}, %[[OFFSET]]] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: spirv.Load "StorageBuffer" %[[ACCESS]] : i32

// CHECK: spirv.ShiftRightLogical %{{.*}}, %[[CSTVEC2XI32_4]] : vector<2xi32>, vector<2xi32>
// CHECK: spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %{{.*}} : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: spirv.ShiftRightLogical %{{.*}}, %[[CSTVEC2XI32_1]] : vector<4xi32>, vector<4xi32>
// CHECK: spirv.BitwiseAnd %{{.*}}, %[[CSTVEC4XI32_255]] : vector<4xi32>

// CHECK: spirv.ConvertUToF %{{.+}} : vector<4xi32> to vector<4xf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ hal.executable @i4_dequant {

// CHECK-LABEL: spirv.func @i4_dequant()

// CHECK: spirv.VectorShuffle [0 : i32, 1 : i32] {{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK: spirv.BitwiseAnd
// CHECK: spirv.ShiftRightLogical
// CHECK: spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32]
// CHECK: spirv.BitwiseAnd
// CHECK: %[[BYTE1:.+]] = spirv.VectorShuffle [0 : i32, 1 : i32] {{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK: %[[COPIED:.+]] = spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %[[BYTE1]], %[[BYTE1]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[COPIED]]
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[MASKED]]
// CHECK: %[[ZEROUPPER:.+]] = spirv.BitwiseAnd %[[SHIFTED]]
// CHECK: spirv.VectorShuffle [2 : i32, 3 : i32] {{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK-COUNT-3: spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32]
// CHECK-COUNT-3: spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32]

// CHECK-COUNT-4: spirv.ConvertUToF {{.+}} : vector<4xi32> to vector<4xf32>
// CHECK-COUNT-4: spirv.FSub
Expand Down

0 comments on commit ea624f8

Please sign in to comment.