This repository contains the RTL to build an 8-bit integer GEMM (General Matrix Multiply) accelerator to integrate into the SNAX core. This repository contains the chisel sources and unit tests for the GEMM accelerator. It is written in CHISEL 5.0.0 and connected to the SNAX accelerator RISC-V manager core through SystemVerilog.
The microarchitecture of the GEMM accelerator is shown below. The GEMM array has meshRow row and meshCol column tiles. Each tile implements a tileSize vector dot product.
This repository contains three GEMM versions: Base GEMM, Block GEMM, and Batch Stride GEMM.
Base GEMM implements General Matrix Multiply: C = (A -a) * (B - b), where a and b are scalar values. Base GEMM is parameterized and its parameters are listed below. These parameters can be configured in the main/scala/gemm/Parameter.scala
Parameter | Meaning |
dataWidthA | Input matrix data width (integer) |
dataWidthC | Output matrix data width (integer) |
dataWidthAccum | Accumulator data width (integer) |
meshRow | The row number of the GEMM array |
meshCol | The column number of the GEMM array |
tileSize | The tile size of each tile |
The size of matrix A is (meshRow, tileSize) and the size of matrix B is (tileSize, meshCol). The size of result matrix C is (meshRow, meshCol). To get the right results, matrix A should be arranged in row-major order
and matrix B should be arranged in column-major order
. The result matrix C is arranged in row-major order
by Base Gemm. The data layout of the three matrices is shown below. The data in each small block indicates the address. How the address increases shows the data arrangement in memory.
Each row of the GEMM array takes in the corresponding row of matrix A, i.e., all the tiles in the ith row take in the ith row of matrix A. And each column of the GEMM array takes in the corresponding column of matrix B, i.e., all the tiles in the jth column take in the jth column of matrix B. This process is shown in the figure below in detail.
The base GEMM also has an accumulation option. If the signal accumulate
is one, then the accumulator register will retain it's current result. If the signal accumulate
is zero, then the accumulator register will clear itself.
This option is useful for using the base GEMM for software-managed block matrix multiplication.
The Base GEMM function definition pseudocode is shown below.
bool gemm(int8_t *A, int8_t *B, int32_t * C, bool accumulate)
The Block GEMM is built with the Base GEMM. The GEMM accelerator uses the Block matrix multiplication method to implement matrix multiplication in which the matrix sizes are larger than the physical GEMM array.
It takes in the M, K, and N configurations as the size of the block gemm.
In this case, the size of matrix A is (M' = M * meshRow
, K' = K * tileSize
) and the size of matrix B is (K' = K * tileSize
, N' = N * meshCol
). The size of result matrix C is (M' = M * meshRow, N' = N * meshCol
) as shown below.
Take a Block Gemm with hardware parameters: meshRow == tileSize == meshCol == 8
as an example, the M, K, and N
should be configured as 4, 2, and 3
respectively if we want to do the block matrix multiplication with C(32,24) = A(32,16) * B(16,24), M’ = 32, K’ = 16, N’ = 24
. The block matrix multiplication is illustrated in the picture below. With:
- Mu (M unrolling) = meshRow = 8
- Ku (K unrolling) = tileSize = 8
- Nu (N unrolling) = meshCol = 8
The pseudocode is :
for (m1 = 0 to M’/Mu - 1);
for (n1 = 0 to N’/Nu - 1);
for (k1 = 0 to K’/Ku - 1);
parallel_for(m2 = 0 to Mu)
parallel_for(n2 = 0 to Nu)
parallel_for(k2 = 0 to Ku)
o[m1*Mu+m2][n1*Nu+n2] += (i[m1*Mu+m2][k1*Ku+k2] - a) * (w[k1*Ku+k2][n1*Nu+n2] - b);
To get the right results, the input matrices should have the right layout in memory. The address of matrix A, B, and C should also be given when configuring the Block Gemm. In the current version, the data of A and B should be stored in memory with a stride of 1.
Take a Block Gemm with hardware parameters: meshRow == tileSize == meshCol == 8
as an example, the M, K, and N
should be configured as 3, 4, and 5
respectively if we want to do the block matrix multiplication with C(24,40) = A(24,32) * B(32,40), M’ = 24, K’ = 32, N’ = 40
as shown in below.
To get the right results for M = 3
, K=4
and N=5
, the matrix A should be arranged as M3 K4 M8 K8. B should be arranged as N5 K4 N8 K8. C should be arranged as M3 N5 M8 N8. Take matrix A for example, the innermost is 8 elements in the K dimension, then is the 8 rows in the M dimension, then is 4 big columns in the K dimension, and the outermost is 3 big rows in the M dimension.
To illustrate it more for clarification, another example is for Block Gemm with hardware parameters: meshRow == tileSize == meshCol == 2
and with a block matrix multiplication for M = 2, K = 2, and N = 2
. The two figures below illustrate the detailed data layout in memory. The continuous liens in the first picture are the conceptual addresses in memory. The concrete layout in memory for each element is shown in the second picture.
The Block GEMM function c wrapper pseudocode is shown below.
bool blockGemm(int M, int K, int N, int8_t *A, int8_t *B, int32_t * C)
Signals | Signal name in generated SV | Width | Dir | Description |
M_i | i_io_ctrl_M_i | 8 | In | The M size setting |
K_i | i_io_ctrl_K_i | 8 | In | The K size setting |
N_i | i_io_ctrl_N_i | 8 | In | The N size setting |
ptr_addr_a_i | io_ctrl_ptr_addr_a_i | 32 | In | The address of the matrix A. It points to the first data element of matrix A. |
ptr_addr_b_i | io_ctrl_ptr_addr_b_i | 32 | In | The address of the matrix B. It points to the first data element of matrix B. |
ptr_addr_c_i | io_ctrl_ptr_addr_c_i | 32 | In | The address of the matrix C. It points to the first data element of matrix C. |
start_do_i | io_ctrl_start_do_i | 1 | In | Start do signal. When this start_do_i signal asserts, all the configuration signals (M_i, K_i, N_i, ptr_addr_a_i, ptr_addr_b_i and ptr_addr_c_i) should be ready |
data_valid_i | io_ctrl_data_valid_i | 1 | In | Input data valid signal |
addr_a_o | io_ctrl_addr_a_o | 32 | Out | The address of the sub-matrix of A. It points to the first data element of the sub-matrix of A. |
addr_b_o | io_ctrl_addr_b_o | 32 | Out | The address of the sub-matrix of B. It points to the first data element of the sub-matrix of B |
addr_c_o | io_ctrl_addr_c_o | 32 | Out | The address of the sub-matrix of C. It points to the first data element of the sub-matrix of C. |
busy_o | io_ctrl_busy_o | 1 | Out | Indicating if Block Gemm is busy. When busy_o is asserted, the Block Gemm doesn't take in any new start_do_i signal or configuration. |
a_i | io_data_a_i | meshRow tileSize input | In | A big bus containing all the data elements of the input (sub-)matrix A |
b_i | io_data_b_i | tileSize meshCol input | In | A big bus containing all the data elements of the input (sub-)matrix B |
c_o | io_data_c_o | meshRow meshCol output | Out | A big bus containing all the data elements of the input (sub-)matrix C |
subtraction_a_i | io_ctrl_subtraction_a_i | 8 | In | The scalar value for A to subtract |
subtraction_b_i | io_ctrl_subtraction_b_i | 8 | In | The scalar value for B to subtract |
The Batch Stride GEMM is built with the Block GEMM. Basically, the Batch Stride GEMM does multiple block matrix multiplication. The pseudocode is:
for (batch = 0; batch < Batch; batch++) {
C[batch] = (A[batch] - a) * (B[batch]- b) // Each is a blockGemm
Besides M, K and N, it takes in an extra Batch configuration (B). It also takes in nine extra strides configuration, eg., strideInnermostA, strideInnermostB, strideInnermostC, ldA, ldB, ldC, strideA, strideB, and strideC, for computing the address for each sub-matrix in the block matrix multiplication and the start matrix for each batch.
The illustration of strideInnermostA, strideInnermostB, strideInnermostC, ldA, ldB, ldC, strideA, strideB, and strideC is listed below.
Stride configuration | Meaning |
strideInnermostA | the incremental address for a new submatrix of each block row of A |
strideInnermostB | the incremental address for a new submatrix of each block column of B |
strideInnermostC | the incremental address for a new submatrix of each block row of C |
ldA | the incremental address for a new block row of A |
ldB | the incremental address for a new block column of B |
ldC | the incremental address for a new block row of C |
strideA | the incremental address of matrix A for each batch |
strideB | the incremental address of matrix B for each batch |
strideC | the incremental address of matrix C for each batch |
Batch Stride GEMM does Batch number of the same size Block Gemm. The computation flow is shown as the following picture.
The pseudocode is :
for (b = 0; b < B; b++) {
for (i = 0; i < M; i++) {
for (j = 0; j < N; j++) {
for (k = 0; k < K; k++) {
addr_a = start_addr_a + b *strideA + m* ldA + k *strideInnermostA
addr_b = start_addr_b + b* strideB + n *ldB + k* strideInnermostB
addr_c = start_addr_c + b *strideC + m* ldC + n * strideInnermostC
C[addr_c] = baseGemm(addr_a,addr_b)
To better illustrate the meaning of the strideInnermostA, strideInnermostB, strideInnermostC, ldA, ldB, ldC, strideA, strideB, and strideC, an address table with an auxiliary picture is shown below.
Computation process | Address of submatrix of A generation using strides | Address of submatrix of A generation using strides | Address of submatrix of A generation using strides |
b=0, m=0, k=0, n=0 | start address of A11 | start address of B11 | start address of C11 |
b=0, m=0, k=1, n=0 | start address of A12 (ptr_addr_a_i + k * strideInnermostA) | start address of B21 (ptr_addr_b_i + k * strideInnermostB) | start address of C11 |
b=0, m=0, k=0, n=1 | start address of A11 | start address of B12 (ptr_addr_b_i + ldB, ldB = the incremental address for a new block column of B) | start address of C12 (ptr_addr_c_i + n * strideInnermostC) |
b=0, m=0, k=1, n=1 | start address of A12(ptr_addr_a_i + k * strideInnermostA) | start address of B12 (ptr_addr_b_i + ldB + strideInnermostB) | start address of C12 (ptr_addr_c_i + n * strideInnermostC) |
... | ... | ... | ... |
b=1, m=0, k=0, n=0 | start address of A11' for batch 1 (ptr_addr_a_i + stride_A, stride_A = the incremental address of matrix A for each batch) | start address of B11' for batch 1 (ptr_addr_b_i + stride_B, stride_B = the incremental address of matrix B for each batch) | start address of C11' for batch 1 (ptr_addr_c_i + stride_C, stride_C = the incremental address of matrix C for each batch) |
... | ... | ... | ... |
The data layout of Batch Stride GEMM is the same as Block GEMM except for there is another outmost dimension which is B.
int Batch, int M, int N, int K,
const T*A, int strideInnermostA, int ldA, int strideA,
const T* B, int strideInnermostB, int ldB, int strideB,
const T* C, int strideInnermostC, int ldC, int strideC,
Except the ports already listed in Block GEMM, Batch Stride GEMM has some extra ports which is listed below.
Signals | Signal name in generated SV | Width | Dir | Description |
B_i | io_ctrl_B_i | 8 | In | The Batch size setting |
strideInnermostA_i | io_ctrl_strideInnermostA_i | 32 | In | the address stride for a new submatrix of each block row of A |
strideInnermostB_i | io_ctrl_strideInnermostB_i | 32 | In | the address stride for a new submatrix of each block column of B |
strideInnermostC_i | io_ctrl_strideInnermostC_i | 32 | In | the address stride for a new submatrix of each block row of C |
ldA_i | io_ctrl_ldA_i | 32 | In | the address stride for a new block row of matrix A |
ldB_i | io_ctrl_ldB_i | 32 | In | the address stride for a new block column of matrix B |
ldC_i | io_ctrl_ldC_i | 32 | In | the address stride for a new block row of matrix C |
strideA_i | io_ctrl_strideA_i | 32 | In | the address stride for matrix A in different batch |
strideB_i | io_ctrl_strideB_i | 32 | In | the address stride for matrix B in different batch |
strideC_i | io_ctrl_strideC_i | 32 | In | the address stride for matrix C in different batch |
This module adds support for 8, 16 32 TCDM ports with 64 bits width for each TCDM ports. There is no functionality difference from software side. In hardware, the Gemm will stall the input and the computation until the output is finished.
Besides the ports already in Batch Stride GEMM, this module add another multi cycle output port.
Signals | Signal name in generated SV | Width | Dir | Description |
multi_stage_c_o | io_data_multi_stage_c_o | TCDMWritePorts * GemmConstant.TCDMDataWidth | Out | multi cycle output port |
This wrapper module connect the GEMM Accelerator with SNAX platform. It has two data register to temporarily store input data for GEMM Accelerator when the TCDM memory is not ready because of the data contention.
Signals | Width | Direction | Description |
clk_i | 1 | Input | Clock signal |
rst_ni | 1 | Input | Reset signal, low valid |
snax_qvalid_i | 1 | Input | Input valid signal for SNAX accelerator CSR write(configuration) / read port |
snax_qready_o | 1 | Output | Output ready signal for SNAX accelerator CSR write(configuration) / read port |
snax_req_i | acc_req_t(decided by SNAX parameters) | Input | SNAX accelerator CSR write(configuration) / read port |
snax_resp_o | acc_rsp_t(decided by SNAX parameters) | Output | Output response signal for SNAX CSR read |
snax_pvalid_o | 1 | Output | Output valid signal for SNAX CSR read |
snax_pready_i | 1 | Input | Input ready signal for SNAX CSR read |
snax_tcdm_req_o | tcdm_req_t[SnaxTcdmPorts-1:0](decided by SNAX parameters) | Output | Output request signal for tcdm |
snax_tcdm_rsp_i | tcdm_rsp_t[SnaxTcdmPorts-1:0](decided by SNAX parameters) | Input | Input response signal for tcdm |
More detailed illustration about the interface between SNAX and GEMM Accelerator can be found at SNAX platform. For acc_req_t and acc_rsp_t, please refer to reqrsp_interface. For tcdm_req_t and tcdm_rsp_t, please refer to mem_interface .
Address | CSR name | Notes |
0x3c0 | MatrixSizeCsr | Matrix size configuration, MatrixSizeCsr[31:24] = Batch, MatrixSizeCsr[23:16] = M, MatrixSizeCsr[15:8] = K MatrixSizeCsr[7:0] = N |
0x3c1 | AddrACsr | Start address of matrix A |
0x3c2 | AddrBCsr | Start address of matrix B |
0x3c3 | AddrCCsr | Start address of matrix C |
0x3c4 | StrideInnermostACsr | The incremental address for a new submatrix of each block row of A |
0x3c5 | StrideInnermostBCsr | The incremental address for a new submatrix of each block row of B |
0x3c6 | StrideInnermostCCsr | The incremental address for a new submatrix of each block row of C |
0x3c7 | LdACsr | The incremental address for a new block row of A |
0x3c8 | LdBCsr | The incremental address for a new block column of B |
0x3c9 | LdCCsr | The incremental address for a new block column of C |
0x3ca | StrideACsr | The incremental address of matrix A for each batch |
0x3cb | StrideBCsr | The incremental address of matrix B for each batch |
0x3cc | StrideCCsr | The incremental address of matrix C for each batch |
0x3cd | PerfCounterCsr | Performance counter for GEMM busy cycles |
0x3ce | SubtractionCsr | The two subtraction value, SubtractionCsr[7:0] = a, SubtractionCsr[15:8] = b |
0x3cf | StateCsr | State CSR indicating GEMM start and busy status |
This repository also contains some unit tests for each version of GEMM. The unit tests are also written in Chisel. Firstly, random input matrices and random size configurations are generated. Then these matrices and configurations are fed into the GEMM accelerator. After the computation, the output result of the GEMM accelerator will be compared with the golden model in Scala.
There is also a corner case test for Block Gemm and Batch Stride Gemm to see if the Block Gemm and Batch Stride Gemm works well in extreme configurations, such as M == K == N == 1
for Block Gemm.
In the current unit test, we only test the GEMM with input datatype as int8 and output data type as int32. The GEMM array size is also fixed.
Follow this quickstart to set up the Chisel environment and run the unit tests of the GEMM accelerator.
Run following script:
To run all the Gemm accelerator tests, use:
sbt test -java-home /Library/Java/JavaVirtualMachines/openjdk-11.jdk/Contents/Home
To run specific test, use:
sbt "testOnly gemm.${chisel_test_name}" -java-home /Library/Java/JavaVirtualMachines/openjdk-11.jdk/Contents/Home
where chisel_test_name
is the class name of the specific test. For instance, use:
sbt "testOnly gemm.GemmArrayRandomTest" -java-home /Library/Java/JavaVirtualMachines/openjdk-11.jdk/Contents/Home
to only run the random data and random matrix size test for Base Gemm.
sbt "testOnly gemm.BlockGemmRandomTest" -java-home /Library/Java/JavaVirtualMachines/openjdk-11.jdk/Contents/Home
to only run the random data and random matrix size(M,K,N) for Block Gemm.
Note: run the test for the Chisel module also generate the corresponding system verilog file to test if the system verilog file can be generated correctly.
To generate the corresponding system verilog file for specific Chisel module, use:
sbt "runMain gemm.${chisel_module_name}" -java-home /Library/Java/JavaVirtualMachines/openjdk-11.jdk/Contents/Home
For instance, to generate the system verilog file for Base Gemm, use:
sbt "runMain gemm.GemmArray" -java-home /Library/Java/JavaVirtualMachines/openjdk-11.jdk/Contents/Home
We learned a lot from Huawei DaVinci when developing the Gemm Accelerator.