Skip to content

Commit

Permalink
[Hexagon] Add hexagon user DMA intrins for tensorization.
Browse files Browse the repository at this point in the history
  • Loading branch information
nverke committed Jan 6, 2023
1 parent 123f1f5 commit 7515de3
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 28 deletions.
123 changes: 100 additions & 23 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,53 @@
from .. import TensorIntrin


def generate_dot_product_32x4_u8u8i32(mem_scope="global"):
def generate_dma_load_intrin(
size: int,
dtype: str,
):
"""Generator of dma_load intrins"""

@T.prim_func
def dma_load_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (size), dtype, offset_factor=1, scope="global")
C = T.match_buffer(c, (size), dtype, offset_factor=1, scope="global.vtcm")
with T.block("root"):
T.reads(A[0:size])
T.writes(C[0:size])
for i in T.serial(size):
with T.block("load"):
vii = T.axis.remap("S", [i])
C[vii] = A[vii]

@T.prim_func
def dma_load_impl(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (size), dtype, offset_factor=1, scope="global")
C = T.match_buffer(c, (size), dtype, offset_factor=1, scope="global.vtcm")
with T.block("root"):
T.reads(A[0:size])
T.writes(C[0:size])
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_copy",
0,
T.address_of(C[0], dtype="handle"),
T.address_of(A[0], dtype="handle"),
size,
0,
dtype="int32",
)
)
T.evaluate(T.tvm_call_packed("device_api.hexagon.dma_wait", 0, 0, dtype="int32"))

return dma_load_desc, dma_load_impl


def generate_dot_product_32x4_u8u8i32(mem_scopes={"reads": "global", "write": "global"}):
@T.prim_func
def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
Expand All @@ -37,9 +78,9 @@ def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None

@T.prim_func
def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
Expand All @@ -62,12 +103,12 @@ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non
return dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy


def generate_dot_product_32x4_u8i8i32(mem_scope="global"):
def generate_dot_product_32x4_u8i8i32(mem_scopes={"reads": "global", "write": "global"}):
@T.prim_func
def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
Expand All @@ -79,9 +120,9 @@ def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None

@T.prim_func
def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
Expand All @@ -104,12 +145,12 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non
return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy


def generate_dot_product_32x2_i16i16i32(mem_scope="global"):
def generate_dot_product_32x2_i16i16i32(mem_scopes={"reads": "global", "write": "global"}):
@T.prim_func
def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
with T.block("root"):
T.reads(C[0:32], A[0:2], B[0:32, 0:2])
T.writes(C[0:32])
Expand All @@ -121,9 +162,9 @@ def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> No

@T.prim_func
def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
with T.block("root"):
T.reads(C[0:32], A[0:2], B[0:32, 0:2])
T.writes(C[0:32])
Expand Down Expand Up @@ -159,7 +200,43 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N
TensorIntrin.register(VDMPY_i16i16i32_INTRIN, *generate_dot_product_32x2_i16i16i32())

VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy"
TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8u8i32("global.vtcm"))
TensorIntrin.register(
VRMPY_u8u8i32_VTCM_INTRIN,
*generate_dot_product_32x4_u8u8i32({"reads": "global.vtcm", "write": "global.vtcm"}),
)

VRMPY_u8u8i32_VTCM_READS_INTRIN = "dot_32x4_u8u8i32_vtcm_reads_vrmpy"
TensorIntrin.register(
VRMPY_u8u8i32_VTCM_READS_INTRIN,
*generate_dot_product_32x4_u8u8i32({"reads": "global.vtcm", "write": "global"}),
)

VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy"
TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8i8i32("global.vtcm"))
TensorIntrin.register(
VRMPY_u8i8i32_VTCM_INTRIN,
*generate_dot_product_32x4_u8i8i32({"reads": "global.vtcm", "write": "global.vtcm"}),
)

DMA_READ_1_u8 = "dma_read_1_u8"
TensorIntrin.register(DMA_READ_1_u8, *generate_dma_load_intrin(1, "uint8"))

DMA_READ_1_i8 = "dma_read_1_i8"
TensorIntrin.register(DMA_READ_1_i8, *generate_dma_load_intrin(1, "int8"))

DMA_READ_128_u8 = "dma_read_128_u8"
TensorIntrin.register(DMA_READ_128_u8, *generate_dma_load_intrin(128, "uint8"))

DMA_READ_128_i8 = "dma_read_128_i8"
TensorIntrin.register(DMA_READ_128_i8, *generate_dma_load_intrin(128, "int8"))

DMA_READ_1024_u8 = "dma_read_1024_u8"
TensorIntrin.register(DMA_READ_1024_u8, *generate_dma_load_intrin(1024, "uint8"))

DMA_READ_1024_i8 = "dma_read_1024_i8"
TensorIntrin.register(DMA_READ_1024_i8, *generate_dma_load_intrin(1024, "int8"))

DMA_READ_4096_u8 = "dma_read_4096_u8"
TensorIntrin.register(DMA_READ_4096_u8, *generate_dma_load_intrin(4096, "uint8"))

DMA_READ_4096_i8 = "dma_read_4096_i8"
TensorIntrin.register(DMA_READ_4096_i8, *generate_dma_load_intrin(4096, "int8"))
24 changes: 19 additions & 5 deletions tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import tvm
from tvm.script import tir as T
from tvm.tir.tensor_intrin.hexagon import DMA_READ_128_i8

from .infrastructure import get_hexagon_target

Expand All @@ -30,6 +31,7 @@
"Test bandwidth with buffer size {}MB... \n"
" -Base: {} GBps \n -Vectorized: {} GBps\n"
" -Vectorized and Parallelized: {} GBps\n"
" -Sync DMA: {} GBps\n"
" -Single DMA Copy: {} GBps\n"
)

Expand Down Expand Up @@ -104,8 +106,8 @@ def evaluate(hexagon_session, sch, size):
)

# These are reduced for CI but number=100 and repeat=10 does a good job of removing noise.
number = 1
repeat = 1
number = 10
repeat = 10

timer = module.time_evaluator(
"__tvm_main__", hexagon_session.device, number=number, repeat=repeat
Expand All @@ -123,15 +125,18 @@ class TestMatMulVec:

# Removed most of these to speedup CI.
size = tvm.testing.parameter(
# 10 * KB,
128,
256,
1024,
10 * KB,
# 20 * KB,
# 40 * KB,
# 80 * KB,
# 160 * KB,
# 320 * KB,
640 * KB,
# MB,
# 2 * MB,
2 * MB,
# 3 * MB,
# 4 * MB,
# 8 * MB, # Only works on 8gen1 HDKs
Expand Down Expand Up @@ -169,14 +174,23 @@ def test_bandwidth(self, hexagon_session, size, outer_split, unroll_split, vecto
sch.parallel(vbo_a)
parallel_gbps = evaluate(hexagon_session, sch, size)

# Run with some basic unroll and vectorize scheduling and parallelization.
sch = tvm.tir.Schedule(memcopy_operator(size))
block = sch.get_block("A_global.vtcm")
loops = sch.get_loops(block)
_, inner = sch.split(loops[0], [None, 128])
sch.tensorize(inner, DMA_READ_128_i8)
# print(sch.mod.script())
sync_dma_gbps = evaluate(hexagon_session, sch, size)

# Run using a single dma copy to transfer the data.
sch = tvm.tir.Schedule(single_dma_operator(size))
single_dma_gbps = evaluate(hexagon_session, sch, size)

mbs = round(size / MB, 2)
print(
TEST_OUTPUT_TEMPLATE.format(
mbs, base_gpbs, vectorize_gbps, parallel_gbps, single_dma_gbps
mbs, base_gpbs, vectorize_gbps, parallel_gbps, sync_dma_gbps, single_dma_gbps
)
)

Expand Down

0 comments on commit 7515de3

Please sign in to comment.