Skip to content

Commit

Permalink
Fixing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe Rossini committed Nov 20, 2020
1 parent 40664b0 commit d54e73a
Showing 1 changed file with 75 additions and 29 deletions.
104 changes: 75 additions & 29 deletions python/tvm/topi/arm_cpu/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@

def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type):
"""
Use integer ARM v8 instructions in order to produce a block c of 4x4 elements
given two 4xK blocks a and b' (where b' is a Kx4 block transposed). The final
result is c = a*b (where '*' indicates the matrix product)
Int8 4x4 matrix multiplication and accumulation using a sequence of
umull -> uadalp -> umull2 -> uadalp instructions. This function
takes two arrays of int8 data type A[4][K] and B[4][K], and produces
a 4x4 matrix which is equal to A*B'.
Every row of the matrix c is obtained (for uint8) by a sequence of
The pseudo code is as follows.
umull -> uadalp -> umull2 -> uadalp
.. code-block:: c
The block size is constrained by the number of registers available in arvm8. This
function returns a TensorIntrin that can be used to tensorize
a schedule.
void gemm_4x4_int8_int8_int32(int8 A[4][K], int8 B[4][K], int32 C[4][4]){
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
for (int k = 0; k < K; k++){
C[i][j] += A[i][k] * B[j][k]
}
}
}
Notes:
* The tiling strategy is picked to maximize register usage.
Parameters
----------
Expand Down Expand Up @@ -115,7 +124,7 @@ def uadalp(a, b):
)

def umull(a, b):
"""Multiply long (lower part)
"""Multiply long (higher part)
Parameters:
----------
Expand All @@ -130,15 +139,15 @@ def umull(a, b):
----------
c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7)
"""
a_low = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a)
b_low = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b)
a_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a)
b_high = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b)
c = tvm.tir.call_llvm_pure_intrin(
"int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low
"int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high
)
return c

def umull2(a, b):
"""Multiply long (uppoer part)
"""Multiply long (lower part)
Parameters:
----------
Expand All @@ -153,10 +162,10 @@ def umull2(a, b):
----------
c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15)
"""
a_high = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a)
b_high = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b)
a_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a)
b_low = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b)
c = tvm.tir.call_llvm_pure_intrin(
"int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high
"int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low
)
return c

Expand All @@ -180,28 +189,60 @@ def addp(a, b):
"int32x4", addp_intrin, tvm.tir.const(2, "uint32"), a, b
)

def accumulation_loop(M, N, ins, acc, i):
a0 = ins[0].vload([i, 0, 0], dtype_vec)
def accumulation_loop(M, N, ins, acc, tile_idx):
"""Internal tile accumulation. This function
takes two arrays of int8 data type A[tile_idx][4][16] and B[tile_idx][4][16], produces
a 4x4 matrix which is equal to A*B' and accumulates into C[4][4]
The pseudo code is as follows.
.. code-block:: c
void gemm_4x4_int8_int8_int32(int8 A[tile_idx][4][K],
int8 B[tile_idx][4][K],
int32 C[4][4]){
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
for (int k = 0; k < 16; k++){
C[i][j] += A[tile_idx][i][k] * B[tile_idx][j][k]
}
}
}
Notes:
* The tiling strategy is picked to maximize register usage.
Parameters:
----------
M: number of total rows of the output matrix
N: number of total rows of the output matrix
ins: input buffers
acc: bank of register accumulators
tiled_idx: index of a sub-tile of A and B in A[tile_idx][:][:] and B[tile_idx][:][:].
Please note that 0 <= tile_idx <= K//16
"""
a0 = ins[0].vload([tile_idx, 0, 0], dtype_vec)
a1 = tvm.tir.const(0, "int8x16")
if M > 1:
a1 = ins[0].vload([i, 1, 0], dtype_vec)
a1 = ins[0].vload([tile_idx, 1, 0], dtype_vec)
a2 = tvm.tir.const(0, "int8x16")
if M > 2:
a2 = ins[0].vload([i, 2, 0], dtype_vec)
a2 = ins[0].vload([tile_idx, 2, 0], dtype_vec)
a3 = tvm.tir.const(0, "int8x16")
if M > 3:
a3 = ins[0].vload([i, 3, 0], dtype_vec)
a3 = ins[0].vload([tile_idx, 3, 0], dtype_vec)

b0 = ins[1].vload([i, 0, 0], dtype_vec)
b0 = ins[1].vload([tile_idx, 0, 0], dtype_vec)
b1 = tvm.tir.const(0, "int8x16")
if N > 1:
b1 = ins[1].vload([i, 1, 0], dtype_vec)
b1 = ins[1].vload([tile_idx, 1, 0], dtype_vec)
b2 = tvm.tir.const(0, "int8x16")
if N > 2:
b2 = ins[1].vload([i, 2, 0], dtype_vec)
b2 = ins[1].vload([tile_idx, 2, 0], dtype_vec)
b3 = tvm.tir.const(0, "int8x16")
if N > 3:
b3 = ins[1].vload([i, 3, 0], dtype_vec)
b3 = ins[1].vload([tile_idx, 3, 0], dtype_vec)

# First half
# Lower part of a0 * {b0,b1,b2,b3}
Expand Down Expand Up @@ -312,10 +353,15 @@ def _instr():
accumulation_loop(M, N, ins, acc, i)

# Final accumulations
# acc[i] contains the partial sums of a[i, 0:K].*b[0,0:K], let's call them (a,b,c,d)
# acc[i+1] contains the partial sums of a[i, 0:K].*b[1,0:K], let's call them (e,f,g,h)
# acc[i+2] contains the partial sums of a[i, 0:K].*b[2,0:K], let's call them (i,j,k,l)
# acc[i+3] contains the partial sums of a[i, 0:K].*b[3,0:K], let's call them (m,n,o,p)
# acc[4*r + c] contains the partial accumulations of element C[r][c]
#
# In particular:
# acc[4*r] contains the partial sums of a[r,0:K].*b[0,0:K] -> (a,b,c,d)
# acc[4*r+1] contains the partial sums of a[r, 0:K].*b[1,0:K] -> (e,f,g,h)
# acc[4*r+2] contains the partial sums of a[r, 0:K].*b[2,0:K] -> (i,j,k,l)
# acc[4*r+3] contains the partial sums of a[r, 0:K].*b[3,0:K] -> (m,n,o,p)
#
# Please note that 0<= r, c < 4

acc[0] = addp(acc[0], acc[1]) # (a+b, c+d, e+f, g+h)
acc[1] = addp(acc[2], acc[3]) # (i+j, k+l, m+n, o+p)
Expand Down

0 comments on commit d54e73a

Please sign in to comment.