Skip to content

Commit

Permalink
revert changes not needed
Browse files Browse the repository at this point in the history
  • Loading branch information
nverke committed Jan 7, 2023
1 parent 7515de3 commit dcc2b1a
Showing 1 changed file with 23 additions and 35 deletions.
58 changes: 23 additions & 35 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def dma_load_impl(a: T.handle, c: T.handle) -> None:
return dma_load_desc, dma_load_impl


def generate_dot_product_32x4_u8u8i32(mem_scopes={"reads": "global", "write": "global"}):
def generate_dot_product_32x4_u8u8i32(mem_scope="global"):
@T.prim_func
def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
Expand All @@ -78,9 +78,9 @@ def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None

@T.prim_func
def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
Expand All @@ -103,12 +103,12 @@ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non
return dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy


def generate_dot_product_32x4_u8i8i32(mem_scopes={"reads": "global", "write": "global"}):
def generate_dot_product_32x4_u8i8i32(mem_scope="global"):
@T.prim_func
def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
Expand All @@ -120,9 +120,9 @@ def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None

@T.prim_func
def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
Expand All @@ -145,12 +145,12 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non
return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy


def generate_dot_product_32x2_i16i16i32(mem_scopes={"reads": "global", "write": "global"}):
def generate_dot_product_32x2_i16i16i32(mem_scope="global"):
@T.prim_func
def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:2], B[0:32, 0:2])
T.writes(C[0:32])
Expand All @@ -162,9 +162,9 @@ def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> No

@T.prim_func
def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scopes["reads"])
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scopes["reads"])
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scopes["write"])
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:2], B[0:32, 0:2])
T.writes(C[0:32])
Expand Down Expand Up @@ -200,22 +200,10 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N
TensorIntrin.register(VDMPY_i16i16i32_INTRIN, *generate_dot_product_32x2_i16i16i32())

VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy"
TensorIntrin.register(
VRMPY_u8u8i32_VTCM_INTRIN,
*generate_dot_product_32x4_u8u8i32({"reads": "global.vtcm", "write": "global.vtcm"}),
)

VRMPY_u8u8i32_VTCM_READS_INTRIN = "dot_32x4_u8u8i32_vtcm_reads_vrmpy"
TensorIntrin.register(
VRMPY_u8u8i32_VTCM_READS_INTRIN,
*generate_dot_product_32x4_u8u8i32({"reads": "global.vtcm", "write": "global"}),
)
TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8u8i32("global.vtcm"))

VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy"
TensorIntrin.register(
VRMPY_u8i8i32_VTCM_INTRIN,
*generate_dot_product_32x4_u8i8i32({"reads": "global.vtcm", "write": "global.vtcm"}),
)
TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8i8i32("global.vtcm"))

DMA_READ_1_u8 = "dma_read_1_u8"
TensorIntrin.register(DMA_READ_1_u8, *generate_dma_load_intrin(1, "uint8"))
Expand Down

0 comments on commit dcc2b1a

Please sign in to comment.