diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 9fc0c20..7b806f5 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -62,4 +62,7 @@ jobs: pytest ../cgra/translate/VectorCGRAKingMeshRTL_test.py -xvs --tb=short --test-verilog --dump-vtb --dump-vcd pytest ../tile/translate/TileRTL_test.py -xvs --tb=short --test-verilog --dump-vtb --dump-vcd pytest ../cgra/translate/CGRASeparateCrossbarRTL_test.py -xvs --tb=short --test-verilog --dump-vtb --dump-vcd + pytest ../cgra/translate/CGRAMemBottomRTL_matmul_2x2_test.py -xvs --tb=short --test-verilog --dump-vtb --dump-vcd + pytest ../cgra/translate/CGRAMemRightAndBottomRTL_matmul_2x2_test.py -xvs --tb=short + pytest ../cgra/translate/CGRAMemRightAndBottomRTL_matmul_2x2_test.py -xvs --tb=short --test-verilog --dump-vtb --dump-vcd diff --git a/cgra/CGRAMemBottomRTL.py b/cgra/CGRAMemBottomRTL.py new file mode 100644 index 0000000..1130e8f --- /dev/null +++ b/cgra/CGRAMemBottomRTL.py @@ -0,0 +1,116 @@ +""" +========================================================================= +CGRAMemBottomRTL.py +========================================================================= +The scrachpad memory is connected to the bottom (first row) tiles. + +Author : Cheng Tan + Date : Nov 18, 2024 +""" + +from pymtl3 import * +from ..lib.ifcs import SendIfcRTL, RecvIfcRTL +from ..noc.CrossbarRTL import CrossbarRTL +from ..noc.ChannelRTL import ChannelRTL +from ..tile.TileRTL import TileRTL +from ..lib.opt_type import * +from ..lib.common import * +from ..mem.data.DataMemRTL import DataMemRTL +from ..mem.data.DataMemCL import DataMemCL +from ..fu.single.MemUnitRTL import MemUnitRTL +from ..fu.single.AdderRTL import AdderRTL +from ..fu.flexible.FlexibleFuRTL import FlexibleFuRTL + +class CGRAMemBottomRTL(Component): + + def construct(s, DataType, PredicateType, CtrlType, width, height, + ctrl_mem_size, data_mem_size, num_ctrl, total_steps, + FunctionUnit, FuList, preload_data = None, + preload_const = None): + + s.num_tiles = width * height + s.num_mesh_ports = 4 + AddrType = mk_bits(clog2(ctrl_mem_size)) + + # Interfaces + s.recv_waddr = [RecvIfcRTL(AddrType) for _ in range(s.num_tiles)] + s.recv_wopt = [RecvIfcRTL(CtrlType) for _ in range(s.num_tiles)] + + # Components + if preload_const == None: + preload_const = [[DataType(0, 0)] for _ in range(width * height)] + s.tile = [TileRTL(DataType, PredicateType, CtrlType, + ctrl_mem_size, data_mem_size, num_ctrl, + total_steps, 4, 2, s.num_mesh_ports, + s.num_mesh_ports, Fu = FunctionUnit, + FuList = FuList, const_list = preload_const[i]) + for i in range(s.num_tiles)] + s.data_mem = DataMemRTL(DataType, data_mem_size, height, height, + preload_data) + + s.send_data = [SendIfcRTL(DataType) for _ in range (height - 1)] + + # Connections + for i in range(s.num_tiles): + s.recv_waddr[i] //= s.tile[i].recv_waddr + s.recv_wopt[i] //= s.tile[i].recv_wopt + + if i // width > 0: + s.tile[i].send_data[PORT_SOUTH] //= s.tile[i-width].recv_data[PORT_NORTH] + + if i // width < height - 1: + s.tile[i].send_data[PORT_NORTH] //= s.tile[i+width].recv_data[PORT_SOUTH] + + if i % width > 0: + s.tile[i].send_data[PORT_WEST] //= s.tile[i-1].recv_data[PORT_EAST] + + if i % width < width - 1: + s.tile[i].send_data[PORT_EAST] //= s.tile[i+1].recv_data[PORT_WEST] + + if i // width == 0: + s.tile[i].send_data[PORT_SOUTH].rdy //= 0 + s.tile[i].recv_data[PORT_SOUTH].en //= 0 + s.tile[i].recv_data[PORT_SOUTH].msg //= DataType(0, 0) + + if i // width == height - 1: + s.tile[i].send_data[PORT_NORTH].rdy //= 0 + s.tile[i].recv_data[PORT_NORTH].en //= 0 + s.tile[i].recv_data[PORT_NORTH].msg //= DataType(0, 0) + + if i % width == 0: + s.tile[i].send_data[PORT_WEST].rdy //= 0 + s.tile[i].recv_data[PORT_WEST].en //= 0 + s.tile[i].recv_data[PORT_WEST].msg //= DataType(0, 0) + + if i % width == width - 1: + if i // width != 0: + # Connects the send ports to the right-most tiles (except the + # ones on the first row). + s.tile[i].send_data[PORT_EAST] //= s.send_data[i // width - 1] + s.tile[i].recv_data[PORT_EAST].en //= 0 + s.tile[i].recv_data[PORT_EAST].msg //= DataType(0, 0) + else: + s.tile[i].send_data[PORT_EAST].rdy //= 0 + s.tile[i].recv_data[PORT_EAST].en //= 0 + s.tile[i].recv_data[PORT_EAST].msg //= DataType(0, 0) + + if i // width == 0: + s.tile[i].to_mem_raddr //= s.data_mem.recv_raddr[i % width] + s.tile[i].from_mem_rdata //= s.data_mem.send_rdata[i % width] + s.tile[i].to_mem_waddr //= s.data_mem.recv_waddr[i % width] + s.tile[i].to_mem_wdata //= s.data_mem.recv_wdata[i % width] + else: + s.tile[i].to_mem_raddr.rdy //= 0 + s.tile[i].from_mem_rdata.en //= 0 + s.tile[i].from_mem_rdata.msg //= DataType(0, 0) + s.tile[i].to_mem_waddr.rdy //= 0 + s.tile[i].to_mem_wdata.rdy //= 0 + + # Line trace + def line_trace(s): + # str = "||".join([ x.element.line_trace() for x in s.tile ]) + # str += " :: [" + s.data_mem.line_trace() + "]" + res = "||\n".join([(("[tile" + str(i) + "]: ") + x.line_trace() + x.ctrl_mem.line_trace()) + for (i,x) in enumerate(s.tile)]) + res += "\n :: Mem [" + s.data_mem.line_trace() + "] \n" + return res diff --git a/cgra/CGRAMemRightAndBottomRTL.py b/cgra/CGRAMemRightAndBottomRTL.py new file mode 100644 index 0000000..5b600e5 --- /dev/null +++ b/cgra/CGRAMemRightAndBottomRTL.py @@ -0,0 +1,135 @@ +""" +========================================================================= +CGRAMemRightAndBottomRTL.py +========================================================================= +Two scrachpad memories are connected to the bottom (first row) and the +last column (except the one on the first row) tiles. For example, in a +3x3 CGRA, the bottom 3 tiles are connected to the south SPM while right- +most 2 tlies (from top to bottom) are connected to the east SPM. + +Author : Cheng Tan + Date : Nov 19, 2024 +""" + +from pymtl3 import * +from ..lib.ifcs import SendIfcRTL, RecvIfcRTL +from ..noc.CrossbarRTL import CrossbarRTL +from ..noc.ChannelRTL import ChannelRTL +from ..tile.TileRTL import TileRTL +from ..lib.opt_type import * +from ..lib.common import * +from ..mem.data.DataMemRTL import DataMemRTL +from ..mem.data.DataMemCL import DataMemCL +from ..fu.single.MemUnitRTL import MemUnitRTL +from ..fu.single.AdderRTL import AdderRTL +from ..fu.flexible.FlexibleFuRTL import FlexibleFuRTL + +class CGRAMemRightAndBottomRTL(Component): + + def construct(s, DataType, PredicateType, CtrlType, width, height, + ctrl_mem_size, data_mem_size, num_ctrl, total_steps, + FunctionUnit, FuList, preload_data = None, + preload_const = None): + + s.num_tiles = width * height + s.num_mesh_ports = 4 + AddrType = mk_bits(clog2(ctrl_mem_size)) + + # Interfaces + s.recv_waddr = [RecvIfcRTL(AddrType) for _ in range(s.num_tiles)] + s.recv_wopt = [RecvIfcRTL(CtrlType) for _ in range(s.num_tiles)] + + # Components + if preload_const == None: + preload_const = [[DataType(0, 0)] for _ in range(width * height)] + s.tile = [TileRTL(DataType, PredicateType, CtrlType, + ctrl_mem_size, data_mem_size, num_ctrl, + total_steps, 4, 2, s.num_mesh_ports, + s.num_mesh_ports, Fu = FunctionUnit, + FuList = FuList, const_list = preload_const[i]) + for i in range(s.num_tiles)] + + s.data_mem_south = DataMemRTL(DataType, data_mem_size, + rd_ports = width, wr_ports = width, + preload_data = preload_data) + + s.data_mem_east = DataMemRTL(DataType, data_mem_size, + rd_ports = height - 1, + wr_ports = height - 1, + preload_data = None) + + # s.send_data = [SendIfcRTL(DataType) for _ in range (height - 1)] + + # Connections + for i in range(s.num_tiles): + s.recv_waddr[i] //= s.tile[i].recv_waddr + s.recv_wopt[i] //= s.tile[i].recv_wopt + + if i // width > 0: + s.tile[i].send_data[PORT_SOUTH] //= s.tile[i-width].recv_data[PORT_NORTH] + + if i // width < height - 1: + s.tile[i].send_data[PORT_NORTH] //= s.tile[i+width].recv_data[PORT_SOUTH] + + if i % width > 0: + s.tile[i].send_data[PORT_WEST] //= s.tile[i-1].recv_data[PORT_EAST] + + if i % width < width - 1: + s.tile[i].send_data[PORT_EAST] //= s.tile[i+1].recv_data[PORT_WEST] + + if i // width == 0: + s.tile[i].send_data[PORT_SOUTH].rdy //= 0 + s.tile[i].recv_data[PORT_SOUTH].en //= 0 + s.tile[i].recv_data[PORT_SOUTH].msg //= DataType(0, 0) + + if i // width == height - 1: + s.tile[i].send_data[PORT_NORTH].rdy //= 0 + s.tile[i].recv_data[PORT_NORTH].en //= 0 + s.tile[i].recv_data[PORT_NORTH].msg //= DataType(0, 0) + + if i % width == 0: + s.tile[i].send_data[PORT_WEST].rdy //= 0 + s.tile[i].recv_data[PORT_WEST].en //= 0 + s.tile[i].recv_data[PORT_WEST].msg //= DataType(0, 0) + + if i % width == width - 1: + # if i // width != 0: + # # Connects the send ports to the right-most tiles (except the + # # ones on the first row). + # s.tile[i].send_data[PORT_EAST] //= s.send_data[i // width - 1] + # s.tile[i].recv_data[PORT_EAST].en //= 0 + # s.tile[i].recv_data[PORT_EAST].msg //= DataType(0, 0) + # else: + s.tile[i].send_data[PORT_EAST].rdy //= 0 + s.tile[i].recv_data[PORT_EAST].en //= 0 + s.tile[i].recv_data[PORT_EAST].msg //= DataType(0, 0) + + if i // width == 0: + # Connects the bottom tiles to the south SPM. + s.tile[i].to_mem_raddr //= s.data_mem_south.recv_raddr[i % width] + s.tile[i].from_mem_rdata //= s.data_mem_south.send_rdata[i % width] + s.tile[i].to_mem_waddr //= s.data_mem_south.recv_waddr[i % width] + s.tile[i].to_mem_wdata //= s.data_mem_south.recv_wdata[i % width] + elif i // width != 0 and i % width == width - 1: + # Connects the right-most tiles (except the bottom ones) to the east + # SPM. + s.tile[i].to_mem_raddr //= s.data_mem_east.recv_raddr[i // width - 1] + s.tile[i].from_mem_rdata //= s.data_mem_east.send_rdata[i // width - 1] + s.tile[i].to_mem_waddr //= s.data_mem_east.recv_waddr[i // width - 1] + s.tile[i].to_mem_wdata //= s.data_mem_east.recv_wdata[i // width - 1] + else: + s.tile[i].to_mem_raddr.rdy //= 0 + s.tile[i].from_mem_rdata.en //= 0 + s.tile[i].from_mem_rdata.msg //= DataType(0, 0) + s.tile[i].to_mem_waddr.rdy //= 0 + s.tile[i].to_mem_wdata.rdy //= 0 + + # Line trace + def line_trace(s): + # str = "||".join([ x.element.line_trace() for x in s.tile ]) + # str += " :: [" + s.data_mem.line_trace() + "]" + res = "||\n".join([(("[tile" + str(i) + "]: ") + x.line_trace() + x.ctrl_mem.line_trace()) + for (i,x) in enumerate(s.tile)]) + res += "\n :: SouthMem [" + s.data_mem_south.line_trace() + "] \n" + res += "\n :: EastMem [" + s.data_mem_east.line_trace() + "] \n" + return res diff --git a/cgra/CGRARTL.py b/cgra/CGRARTL.py index f3508a8..989bfd8 100644 --- a/cgra/CGRARTL.py +++ b/cgra/CGRARTL.py @@ -40,7 +40,10 @@ def construct( s, DataType, PredicateType, CtrlType, width, height, s.tile = [ TileRTL( DataType, PredicateType, CtrlType, ctrl_mem_size, data_mem_size, num_ctrl, total_steps, 4, 2, s.num_mesh_ports, - s.num_mesh_ports, const_list = preload_const[i] ) + s.num_mesh_ports, + Fu = FunctionUnit, + FuList = FuList, + const_list = preload_const[i] ) for i in range( s.num_tiles ) ] s.data_mem = DataMemRTL( DataType, data_mem_size, height, height, preload_data ) diff --git a/cgra/translate/CGRAMemBottomRTL_matmul_2x2_test.py b/cgra/translate/CGRAMemBottomRTL_matmul_2x2_test.py new file mode 100644 index 0000000..a1bc3d1 --- /dev/null +++ b/cgra/translate/CGRAMemBottomRTL_matmul_2x2_test.py @@ -0,0 +1,310 @@ +""" +========================================================================== +CGRAMemBottomRTL_matmul_2x2_test.py +========================================================================== +Translation for 3x2 CGRA. The provided test is only used for a 2x2 matmul. + +Author : Cheng Tan + Date : Oct 14, 2024 +""" + +from pymtl3 import * +from pymtl3.passes.backends.verilog import VerilogTranslationPass +from pymtl3.stdlib.test_utils import (run_sim, + config_model_with_cmdline_opts) + +from ...lib.test_srcs import TestSrcRTL +from ...lib.test_sinks import TestSinkRTL +from ...lib.opt_type import * +from ...lib.messages import * +from ...fu.flexible.FlexibleFuRTL import FlexibleFuRTL +from ...fu.single.AdderRTL import AdderRTL +from ...fu.single.MemUnitRTL import MemUnitRTL +from ...fu.single.MulRTL import MulRTL +from ...fu.single.SelRTL import SelRTL +from ...fu.single.ShifterRTL import ShifterRTL +from ...fu.single.LogicRTL import LogicRTL +from ...fu.single.PhiRTL import PhiRTL +from ...fu.single.CompRTL import CompRTL +from ...fu.double.SeqMulAdderRTL import SeqMulAdderRTL +from ...fu.single.BranchRTL import BranchRTL +from ..CGRAMemBottomRTL import CGRAMemBottomRTL + +#------------------------------------------------------------------------- +# Test harness +#------------------------------------------------------------------------- + +kMaxCycles = 20 + +class TestHarness(Component): + + def construct(s, DUT, FunctionUnit, fu_list, DataType, PredicateType, + CtrlType, width, height, ctrl_mem_size, data_mem_size, + src_opt, ctrl_waddr, preload_data, preload_const, + sink_out): + + s.height = height + s.num_tiles = width * height + AddrType = mk_bits(clog2(ctrl_mem_size)) + + s.src_opt = [TestSrcRTL(CtrlType, src_opt[i]) + for i in range(s.num_tiles)] + s.ctrl_waddr = [TestSrcRTL(AddrType, ctrl_waddr[i]) + for i in range(s.num_tiles)] + + s.dut = DUT(DataType, PredicateType, CtrlType, width, height, + ctrl_mem_size, data_mem_size, kMaxCycles, + kMaxCycles, FunctionUnit, fu_list, preload_data, + preload_const) + + s.sink_out = [TestSinkRTL(DataType, sink_out[i]) + for i in range(height - 1)] + + for i in range(height - 1): + connect(s.dut.send_data[i], s.sink_out[i].recv) + + for i in range(s.num_tiles): + connect(s.src_opt[i].send, s.dut.recv_wopt[i]) + connect(s.ctrl_waddr[i].send, s.dut.recv_waddr[i]) + + def done(s): + for i in range(s.height - 1): + if not s.sink_out[i].done(): + return False + return True + + def line_trace(s): + return s.dut.line_trace() + +def run_sim( test_harness, max_cycles = kMaxCycles ): + # test_harness.elaborate() + test_harness.apply( DefaultPassGroup() ) + + # Run simulation + ncycles = 0 + print() + print("{}:{}".format( ncycles, test_harness.line_trace())) + while not test_harness.done(): + test_harness.sim_tick() + ncycles += 1 + print("----------------------------------------------------") + print("{}:{}".format( ncycles, test_harness.line_trace())) + + # Check timeout + assert ncycles < max_cycles + + test_harness.sim_tick() + test_harness.sim_tick() + test_harness.sim_tick() + +def test_CGRA_systolic(cmdline_opts): + num_tile_inports = 4 + num_tile_outports = 4 + num_xbar_inports = 6 + num_xbar_outports = 8 + ctrl_mem_size = 10 + width = 2 + height = 3 + RouteType = mk_bits(clog2(num_xbar_inports + 1)) + AddrType = mk_bits(clog2(ctrl_mem_size)) + num_tiles = width * height + num_fu_in = 4 + DUT = CGRAMemBottomRTL + FunctionUnit = FlexibleFuRTL + FuList = [SeqMulAdderRTL, AdderRTL, MulRTL, LogicRTL, ShifterRTL, PhiRTL, CompRTL, BranchRTL, MemUnitRTL] + DataType = mk_data(32, 1) + PredicateType = mk_predicate(1, 1) + # FuList = [ SeqMulAdderRTL, AdderRTL, MulRTL, LogicRTL, ShifterRTL, PhiRTL, CompRTL, BranchRTL, MemUnitRTL ] + # DataType = mk_data(16, 1) + CtrlType = mk_ctrl(num_fu_in, num_xbar_inports, num_xbar_outports) + FuInType = mk_bits(clog2( num_fu_in + 1)) + pickRegister = [FuInType(x + 1) for x in range(num_fu_in)] + + src_opt = [ + # On tile 0 ([0, 0]). + [CtrlType(OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + # On tile 1 ([0, 1]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + # On tile 2 ([1, 0]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + # On tile 3 ([1, 1]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_MUL_CONST_ADD, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_MUL_CONST_ADD, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + ], + # On tile 4 ([2, 0]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + ], + # On tile 5 ([2, 1]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_MUL_CONST_ADD, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_MUL_CONST_ADD, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + ] + ] + + preload_mem = [DataType(1, 1), DataType(2, 1), DataType(3, 1), + DataType(4, 1)] + preload_const = [ + # The offset address used for loading input activation. + # We use a shared data memory here, indicating global address + # space. Users can make each tile has its own address space. + + # The last one is not useful for the first colum, which is just + # to make the length aligned. + [DataType(0, 1), DataType(1, 1), DataType(0, 0)], + # The first one is not useful for the second colum, which is just + # to make the length aligned. + [DataType(0, 0), DataType(2, 1), DataType(3, 1)], + + # Preloads weights. 3 items to align with the above const length. + # Duplication exists as the iter of the const queue automatically + # increment. + [DataType(2, 1), DataType(2, 1), DataType(2, 1)], + [DataType(4, 1), DataType(4, 1), DataType(4, 1)], + [DataType(6, 1), DataType(6, 1), DataType(6, 1)], + [DataType(8, 1), DataType(8, 1), DataType(8, 1)]] + + data_mem_size = len(preload_mem) + + """ + 1 3 2 6 14 20 + x = + 2 4 4 8 30 44 + """ + sink_out = [[DataType(14, 1), DataType(20, 1)], [DataType(30, 1), + DataType(44, 1)]] + + # When the max iterations are larger than the number of control signals, + # enough ctrl_waddr needs to be provided to make execution (i.e., ctrl + # read) continue. + ctrl_waddr = [[AddrType(0), AddrType(1), AddrType(2), AddrType(3), + AddrType(4), AddrType(5)] for _ in range(num_tiles)] + + th = TestHarness(DUT, FunctionUnit, FuList, DataType, PredicateType, + CtrlType, width, height, ctrl_mem_size, data_mem_size, + src_opt, ctrl_waddr, preload_mem, preload_const, + sink_out) + + th.elaborate() + th.dut.set_metadata(VerilogTranslationPass.explicit_module_name, + f'CGRARTL') + # th.dut.set_metadata( VerilogVerilatorImportPass.vl_Wno_list, + # ['UNSIGNED', 'UNOPTFLAT', 'WIDTH', 'WIDTHCONCAT', + # 'ALWCOMBORDER'] ) + th = config_model_with_cmdline_opts(th, cmdline_opts, duts=['dut']) + + run_sim(th) + diff --git a/cgra/translate/CGRAMemRightAndBottomRTL_matmul_2x2_test.py b/cgra/translate/CGRAMemRightAndBottomRTL_matmul_2x2_test.py new file mode 100644 index 0000000..2a96527 --- /dev/null +++ b/cgra/translate/CGRAMemRightAndBottomRTL_matmul_2x2_test.py @@ -0,0 +1,415 @@ +""" +========================================================================== +CGRARightAndBottomRTL_matmul_2x2_test.py +========================================================================== +Translation for 3x3 CGRA. The provided test is only used for a 2x2 matmul. + +Author : Cheng Tan + Date : Nov 19, 2024 +""" + +from pymtl3 import * +from pymtl3.passes.backends.verilog import VerilogTranslationPass +from pymtl3.stdlib.test_utils import (run_sim, + config_model_with_cmdline_opts) + +from ...lib.test_srcs import TestSrcRTL +from ...lib.test_sinks import TestSinkRTL +from ...lib.opt_type import * +from ...lib.messages import * +from ...fu.flexible.FlexibleFuRTL import FlexibleFuRTL +from ...fu.single.AdderRTL import AdderRTL +from ...fu.single.MemUnitRTL import MemUnitRTL +from ...fu.single.MulRTL import MulRTL +from ...fu.single.SelRTL import SelRTL +from ...fu.single.ShifterRTL import ShifterRTL +from ...fu.single.LogicRTL import LogicRTL +from ...fu.single.PhiRTL import PhiRTL +from ...fu.single.CompRTL import CompRTL +from ...fu.double.SeqMulAdderRTL import SeqMulAdderRTL +from ...fu.single.BranchRTL import BranchRTL +from ..CGRAMemRightAndBottomRTL import CGRAMemRightAndBottomRTL + +#------------------------------------------------------------------------- +# Test harness +#------------------------------------------------------------------------- + +kMaxCycles = 12 + +class TestHarness(Component): + + def construct(s, DUT, FunctionUnit, fu_list, DataType, PredicateType, + CtrlType, width, height, ctrl_mem_size, data_mem_size, + src_opt, ctrl_waddr, preload_data, preload_const, + expected_out): + + s.DataType = DataType + s.expected_out = expected_out + s.num_tiles = width * height + AddrType = mk_bits(clog2(ctrl_mem_size)) + + s.src_opt = [TestSrcRTL(CtrlType, src_opt[i]) + for i in range(s.num_tiles)] + s.ctrl_waddr = [TestSrcRTL(AddrType, ctrl_waddr[i]) + for i in range(s.num_tiles)] + + s.dut = DUT(DataType, PredicateType, CtrlType, width, height, + ctrl_mem_size, data_mem_size, kMaxCycles, + kMaxCycles, FunctionUnit, fu_list, preload_data, + preload_const) + + # s.sink_out = [TestSinkRTL(DataType, sink_out[i]) + # for i in range(height - 1)] + + # for i in range(height - 1): + # connect(s.dut.send_data[i], s.sink_out[i].recv) + + for i in range(s.num_tiles): + connect(s.src_opt[i].send, s.dut.recv_wopt[i]) + connect(s.ctrl_waddr[i].send, s.dut.recv_waddr[i]) + + # Simulation terminates if the output memory contains + # not less than the expected number of outputs. + def done(s): + num_valid_out = 0 + for data in s.dut.data_mem_east.reg_file.regs: + if data != s.DataType(0, 0): + num_valid_out += 1 + if num_valid_out >= len(s.expected_out): + return True + return False + + # Checks the output parity. + def check_parity(s): + for i in range(len(s.expected_out)): + if s.expected_out[i] != s.dut.data_mem_east.reg_file.regs[i]: + return False + return True + + def line_trace(s): + return s.dut.line_trace() + +def run_sim(test_harness, enable_verification_pymtl, + max_cycles = kMaxCycles): + # test_harness.elaborate() + test_harness.apply( DefaultPassGroup() ) + + # Run simulation + ncycles = 0 + print() + print("{}:{}".format( ncycles, test_harness.line_trace())) + if enable_verification_pymtl: + while not test_harness.done(): + test_harness.sim_tick() + ncycles += 1 + print("----------------------------------------------------") + print("{}:{}".format( ncycles, test_harness.line_trace())) + + # Checks the output parity. + assert test_harness.check_parity() + + # Checks timeout. + assert ncycles < max_cycles + else: + while ncycles < max_cycles: + test_harness.sim_tick() + ncycles += 1 + print("----------------------------------------------------") + print("{}:{}".format( ncycles, test_harness.line_trace())) + + test_harness.sim_tick() + test_harness.sim_tick() + test_harness.sim_tick() + +def test_CGRA_systolic(cmdline_opts): + num_tile_inports = 4 + num_tile_outports = 4 + num_xbar_inports = 6 + num_xbar_outports = 8 + ctrl_mem_size = 8 + width = 3 + height = 3 + RouteType = mk_bits(clog2(num_xbar_inports + 1)) + AddrType = mk_bits(clog2(ctrl_mem_size)) + num_tiles = width * height + num_fu_in = 4 + DUT = CGRAMemRightAndBottomRTL + FunctionUnit = FlexibleFuRTL + FuList = [SeqMulAdderRTL, AdderRTL, MulRTL, LogicRTL, ShifterRTL, PhiRTL, CompRTL, BranchRTL, MemUnitRTL] + DataType = mk_data(32, 1) + PredicateType = mk_predicate(1, 1) + # FuList = [ SeqMulAdderRTL, AdderRTL, MulRTL, LogicRTL, ShifterRTL, PhiRTL, CompRTL, BranchRTL, MemUnitRTL ] + # DataType = mk_data(16, 1) + CtrlType = mk_ctrl(num_fu_in, num_xbar_inports, num_xbar_outports) + FuInType = mk_bits(clog2( num_fu_in + 1)) + pickRegister = [FuInType(x + 1) for x in range(num_fu_in)] + + src_opt = [ + # On tile 0 ([0, 0]). + [CtrlType(OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + # On tile 1 ([0, 1]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_LD_CONST, b1(0), pickRegister, [ + RouteType(5), RouteType(0), RouteType(0), RouteType(0), + RouteType(0), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + + # On tile 2 ([0, 2]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + + # On tile 3 ([1, 0]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + # On tile 4 ([1, 1]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_MUL_CONST_ADD, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_MUL_CONST_ADD, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + ], + + # On tile 5 ([1, 2]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(3), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_STR_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(3), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_STR_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + + # On tile 6 ([2, 0]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_MUL_CONST, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + + ], + + # On tile 7 ([2, 1]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_MUL_CONST_ADD, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + CtrlType( OPT_MUL_CONST_ADD, b1(0), pickRegister, [ + RouteType(0), RouteType(0), RouteType(0), RouteType(5), + RouteType(2), RouteType(0), RouteType(3), RouteType(0)]), + + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + + # On tile 8 ([2, 2]). + [CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_NAH, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(3), RouteType(0), RouteType(0), RouteType(0)]), + + CtrlType( OPT_STR_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(3), RouteType(0), RouteType(0), RouteType(0)]), + CtrlType( OPT_STR_CONST, b1(0), pickRegister, [ + RouteType(2), RouteType(0), RouteType(0), RouteType(0), + RouteType(2), RouteType(0), RouteType(0), RouteType(0)]), + ], + ] + + preload_mem = [DataType(1, 1), DataType(2, 1), DataType(3, 1), + DataType(4, 1)] + preload_const = [ + # The offset address used for loading input activation. + # We use a shared data memory here, indicating global address + # space. Users can make each tile has its own address space. + + # The last one is not useful for the first colum, which is just + # to make the length aligned. + [DataType(0, 1), DataType(1, 1), DataType(0, 0)], + # The first one is not useful for the second colum, which is just + # to make the length aligned. + [DataType(0, 0), DataType(2, 1), DataType(3, 1)], + # The third column is not actually necessary to perform activation + # loading nor storing parameters. + [DataType(0, 0), DataType(0, 0), DataType(0, 0)], + + # Preloads weights. 3 items to align with the above const length. + # Duplication exists as the iter of the const queue automatically + # increment. + [DataType(2, 1), DataType(2, 1), DataType(2, 1)], + [DataType(4, 1), DataType(4, 1), DataType(4, 1)], + # The third column (except the bottom one) is used to store the + # accumulated results. + [DataType(0, 1), DataType(2, 1), DataType(0, 0)], + + [DataType(6, 1), DataType(6, 1), DataType(6, 1)], + [DataType(8, 1), DataType(8, 1), DataType(8, 1)], + # The third column (except the bottom one) is used to store the + # accumulated results. + [DataType(1, 1), DataType(3, 1), DataType(0, 0)]] + + data_mem_size = len(preload_mem) + + """ + 1 3 2 6 14 20 + x = + 2 4 4 8 30 44 + """ + expected_out = [DataType(14, 1), DataType(30, 1), DataType(20, 1), + DataType(44, 1)] + + # When the max iterations are larger than the number of control signals, + # enough ctrl_waddr needs to be provided to make execution (i.e., ctrl + # read) continue. + ctrl_waddr = [[AddrType(0), AddrType(1), AddrType(2), AddrType(3), + AddrType(4), AddrType(5)] for _ in range(num_tiles)] + + th = TestHarness(DUT, FunctionUnit, FuList, DataType, PredicateType, + CtrlType, width, height, ctrl_mem_size, data_mem_size, + src_opt, ctrl_waddr, preload_mem, preload_const, + expected_out) + + th.elaborate() + th.dut.set_metadata(VerilogTranslationPass.explicit_module_name, + f'CGRAMemRightAndBottomRTL') + # th.dut.set_metadata( VerilogVerilatorImportPass.vl_Wno_list, + # ['UNSIGNED', 'UNOPTFLAT', 'WIDTH', 'WIDTHCONCAT', + # 'ALWCOMBORDER'] ) + th = config_model_with_cmdline_opts(th, cmdline_opts, duts=['dut']) + + enable_verification_pymtl = not (cmdline_opts['test_verilog'] or \ + cmdline_opts['dump_vcd'] or \ + cmdline_opts['dump_vtb']) + run_sim(th, enable_verification_pymtl) + diff --git a/cgra/translate/CGRARTL_test.py b/cgra/translate/CGRARTL_test.py index 7360904..0723450 100644 --- a/cgra/translate/CGRARTL_test.py +++ b/cgra/translate/CGRARTL_test.py @@ -37,17 +37,20 @@ num_xbar_inports = 6 num_xbar_outports = 8 ctrl_mem_size = 6 -width = 2 -height = 2 +width = 4 +height = 4 RouteType = mk_bits( clog2( num_xbar_inports + 1 ) ) AddrType = mk_bits( clog2( ctrl_mem_size ) ) num_tiles = width * height -data_mem_size = 8 +data_mem_size = 2 num_fu_in = 4 DUT = CGRARTL FunctionUnit = FlexibleFuRTL -FuList = [ SeqMulAdderRTL, MemUnitRTL ]#AdderRTL, MulRTL, LogicRTL, ShifterRTL, PhiRTL, CompRTL, BranchRTL, MemUnitRTL ] -DataType = mk_data( 32, 1 ) +FuList = [ AdderRTL, MemUnitRTL ]# SeqMulAdderRTL, AdderRTL, MulRTL, LogicRTL, ShifterRTL, PhiRTL, CompRTL, BranchRTL, MemUnitRTL ] +# Parameterizes channel bandwidth as around 2-byte per cycle. +payload_nbits = 16 # 32 +predicate_nbits = 1 +DataType = mk_data( payload_nbits, predicate_nbits ) PredicateType = mk_predicate( 1, 1 ) # FuList = [ SeqMulAdderRTL, AdderRTL, MulRTL, LogicRTL, ShifterRTL, PhiRTL, CompRTL, BranchRTL, MemUnitRTL ] # DataType = mk_data( 16, 1 ) diff --git a/cgra/translate/CGRATemplateRTL_test.py b/cgra/translate/CGRATemplateRTL_test.py index f900f86..f634f65 100644 --- a/cgra/translate/CGRATemplateRTL_test.py +++ b/cgra/translate/CGRATemplateRTL_test.py @@ -187,7 +187,8 @@ def test_cgra_universal( cmdline_opts, paramCGRA = None): DUT = CGRATemplateRTL FunctionUnit = FlexibleFuRTL # FuList = [ SeqMulAdderRTL, MemUnitRTL ]#AdderRTL, MulRTL, LogicRTL, ShifterRTL, PhiRTL, CompRTL, BranchRTL, MemUnitRTL ] - FuList = [ PhiRTL, AdderRTL, ShifterRTL, MemUnitRTL, SelRTL, CompRTL, SeqMulAdderRTL, RetRTL, MulRTL, LogicRTL, BranchRTL ] + # FuList = [ PhiRTL, AdderRTL, ShifterRTL, MemUnitRTL, SelRTL, CompRTL, SeqMulAdderRTL, RetRTL, MulRTL, LogicRTL, BranchRTL ] + FuList = [ PhiRTL, AdderRTL, ShifterRTL, MemUnitRTL, SelRTL, CompRTL, SeqMulAdderRTL, MulRTL, LogicRTL, BranchRTL ] DataType = mk_data( 32, 1 ) PredicateType = mk_predicate( 1, 1 ) # DataType = mk_data( 16, 1 ) diff --git a/fu/single/MemUnitRTL.py b/fu/single/MemUnitRTL.py index cf088ee..8bf3bf5 100644 --- a/fu/single/MemUnitRTL.py +++ b/fu/single/MemUnitRTL.py @@ -118,6 +118,7 @@ def comb_logic(): s.send_out[0].en @= s.recv_opt.en s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate + # LD_CONST indicates the address is a const. elif s.recv_opt.msg.ctrl == OPT_LD_CONST: for i in range( num_inports): s.recv_in[i].rdy @= b1( 0 ) @@ -130,9 +131,8 @@ def comb_logic(): # Const's predicate will always be true. s.send_out[0].msg.predicate @= b1( 1 ) - # TODO: and -> & elif s.recv_opt.msg.ctrl == OPT_STR: - s.send_out[0].en @= s.from_mem_rdata.en & s.recv_in[s.in0_idx].en & s.recv_in[s.in1_idx].en + # s.send_out[0].en @= s.from_mem_rdata.en & s.recv_in[s.in0_idx].en & s.recv_in[s.in1_idx].en s.recv_in[s.in0_idx].rdy @= s.to_mem_waddr.rdy s.recv_in[s.in1_idx].rdy @= s.to_mem_wdata.rdy # s.to_mem_waddr.msg @= AddrType( s.recv_in[0].msg.payload ) @@ -140,15 +140,36 @@ def comb_logic(): s.to_mem_waddr.en @= s.recv_in[s.in0_idx].en s.to_mem_wdata.msg @= s.recv_in[s.in1_idx].msg s.to_mem_wdata.en @= s.recv_in[s.in1_idx].en + + # `send_out` is meaningless for store operation. s.send_out[0].en @= b1( 0 ) - s.send_out[0].msg @= s.from_mem_rdata.msg - s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \ - s.recv_in[s.in1_idx].msg.predicate + s.send_out[0].msg @= s.to_mem_wdata.msg + # s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \ + # s.recv_in[s.in1_idx].msg.predicate + s.send_out[0].msg.predicate @= b1(0) if s.recv_opt.en & ( (s.recv_in_count[s.in0_idx] == 0) | \ (s.recv_in_count[s.in1_idx] == 0) ): s.recv_in[s.in0_idx].rdy @= b1( 0 ) s.recv_in[s.in1_idx].rdy @= b1( 0 ) - s.send_out[0].msg.predicate @= b1( 0 ) + # s.send_out[0].msg.predicate @= b1( 0 ) + + # STR_CONST indicates the address is a const. + elif s.recv_opt.msg.ctrl == OPT_STR_CONST: + s.recv_const.rdy @= s.to_mem_waddr.rdy + + # Only needs one input register to indicate the storing data. + s.recv_in[s.in0_idx].rdy @= s.to_mem_wdata.rdy + s.to_mem_waddr.msg @= AddrType( s.recv_const.msg.payload[0:AddrType.nbits] ) + s.to_mem_waddr.en @= s.recv_const.en + s.to_mem_wdata.msg @= s.recv_in[s.in0_idx].msg + s.to_mem_wdata.en @= s.recv_in[s.in0_idx].en + + # `send_out` is meaningless for store operation. + s.send_out[0].en @= b1( 0 ) + s.send_out[0].msg @= s.to_mem_wdata.msg + s.send_out[0].msg.predicate @= b1(0) + if s.recv_opt.en & (s.recv_in_count[s.in0_idx] == 0): + s.recv_in[s.in0_idx].rdy @= b1( 0 ) else: for j in range( num_outports ): diff --git a/lib/opt_type.py b/lib/opt_type.py index d2b3c3a..80f2e27 100644 --- a/lib/opt_type.py +++ b/lib/opt_type.py @@ -39,6 +39,7 @@ OPT_PHI_CONST = Bits6( 32 ) OPT_SEL = Bits6( 27 ) OPT_LD_CONST = Bits6( 28 ) +OPT_STR_CONST = Bits6( 58 ) OPT_MUL_ADD = Bits6( 18 ) OPT_MUL_CONST = Bits6( 29 ) OPT_MUL_CONST_ADD = Bits6( 30 ) @@ -72,7 +73,7 @@ OPT_SYMBOL_DICT = { OPT_START : "(start)", - OPT_NAH : "( )", + OPT_NAH : "(NAH)", OPT_PAS : "(->)", OPT_ADD : "(+)", OPT_ADD_CONST : "(+')", @@ -97,6 +98,7 @@ OPT_PHI_CONST : "(ph')", OPT_SEL : "(sel)", OPT_LD_CONST : "(ldcst)", + OPT_STR_CONST : "(strcst)", OPT_MUL_ADD : "(* +)", OPT_MUL_CONST : "(*')", OPT_MUL_CONST_ADD : "(*' +)", diff --git a/mem/ctrl/CtrlMemCL.py b/mem/ctrl/CtrlMemCL.py index 3ac5a46..2d7a965 100644 --- a/mem/ctrl/CtrlMemCL.py +++ b/mem/ctrl/CtrlMemCL.py @@ -52,8 +52,10 @@ def update_signal(): @update_ff def update_raddr(): + if s.times < TimeType( total_ctrl_steps ): s.times <<= s.times + TimeType( 1 ) + if s.send_ctrl.rdy: if zext(s.cur + 1, PCType) == PCType( ctrl_count_per_iter ): s.cur <<= AddrType( 0 ) @@ -63,5 +65,5 @@ def update_raddr(): def line_trace( s ): out_str = "||".join([ str(data) for data in s.sram ]) - return f'[{out_str}] : {s.send_ctrl.msg}' + return f'[{out_str}] : {OPT_SYMBOL_DICT[s.send_ctrl.msg.ctrl]}' diff --git a/mem/data/DataMemRTL.py b/mem/data/DataMemRTL.py index 8600105..9929cea 100644 --- a/mem/data/DataMemRTL.py +++ b/mem/data/DataMemRTL.py @@ -50,9 +50,9 @@ def update_read_without_init(): s.reg_file.wen[i] @= s.recv_wdata[i].en & s.recv_waddr[i].en else: - s.preloadData = [ DataType( 0 ) for _ in range( data_mem_size ) ] + s.preloadData = [ Wire( DataType ) for _ in range( data_mem_size ) ] for i in range( len( preload_data ) ): - s.preloadData[ i ] = preload_data[i] + s.preloadData[ i ] //= preload_data[i] @update def update_read_with_init(): @@ -72,7 +72,7 @@ def update_read_with_init(): if s.recv_waddr[i].en == b1(1): s.reg_file.waddr[i] @= s.recv_waddr[i].msg s.reg_file.wdata[i] @= s.recv_wdata[i].msg - s.reg_file.wen[i] @= s.recv_wdata[i].en and s.recv_waddr[i].en + s.reg_file.wen[i] @= s.recv_wdata[i].en & s.recv_waddr[i].en # Connections @@ -96,13 +96,15 @@ def update_signal(): s.recv_waddr[i].rdy @= Bits1( 1 ) s.recv_wdata[i].rdy @= Bits1( 1 ) - def line_trace( s ): - recv_str = "|".join([ str(data.msg) for data in s.recv_wdata ]) - out_str = "|".join([ str(data) for data in s.reg_file.regs ]) - send_str = "|".join([ str(data.msg) for data in s.send_rdata ]) - # return f'{recv_str} : [{out_str}] : {send_str} initWrites: {s.initWrites}' + def line_trace(s): + recv_raddr_str = "recv_read_addr: " + "|".join([str(data.msg) for data in s.recv_raddr]) + recv_waddr_str = "recv_write_addr: " + "|".join([str(data.msg) for data in s.recv_waddr]) + recv_wdata_str = "recv_write_data: " + "|".join([str(data.msg) for data in s.recv_wdata]) + content_str = "content: " + "|".join([str(data) for data in s.reg_file.regs]) + send_rdata_str = "send_read_data: " + "|".join([str(data.msg) for data in s.send_rdata]) + return f'{recv_raddr_str} || {recv_waddr_str} || {recv_wdata_str} || [{content_str}] || {send_rdata_str}' + # return f'DataMem: {recv_str} : [{out_str}] : {send_str} initWrites: {s.initWrites}' # return s.reg_file.line_trace() # return f'<{s.reg_file.wen[0]}>{s.reg_file.waddr[0]}:{s.reg_file.wdata[0]}|{s.reg_file.raddr[0]}:{s.reg_file.rdata[0]}' - rf_trace = f'<{s.reg_file.wen[0]}>{s.reg_file.waddr[0]}:{s.reg_file.wdata[0]}|{s.reg_file.raddr[0]}:{s.reg_file.rdata[0]}' - - return f'[{s.recv_wdata[0].en & s.recv_waddr[0].en}]{s.recv_waddr[0]}<{s.recv_wdata[0]}({rf_trace}){s.recv_raddr[0]}>{s.send_rdata[0]}' + # rf_trace = f'<{s.reg_file.wen[0]}>{s.reg_file.waddr[0]}:{s.reg_file.wdata[0]}|{s.reg_file.raddr[0]}:{s.reg_file.rdata[0]}' + # return f'[{s.recv_wdata[0].en & s.recv_waddr[0].en}]{s.recv_waddr[0]}<{s.recv_wdata[0]}({rf_trace}){s.recv_raddr[0]}>{s.send_rdata[0]}' diff --git a/systolic/test/SystolicCL_test.py b/systolic/test/SystolicCL_test.py index b951fc1..366b52e 100644 --- a/systolic/test/SystolicCL_test.py +++ b/systolic/test/SystolicCL_test.py @@ -25,6 +25,8 @@ import os +kMaxCycles = 6 + #------------------------------------------------------------------------- # Test harness #------------------------------------------------------------------------- @@ -43,7 +45,7 @@ def construct( s, DUT, FunctionUnit, FuList, DataType, PredicateType, s.dut = DUT( FunctionUnit, FuList, DataType, PredicateType, CtrlType, width, height, ctrl_mem_size, data_mem_size, - len( src_opt[0] ), 0, src_opt, + len( src_opt[0] ), kMaxCycles, src_opt, preload_data, preload_const ) for i in range( height-1 ): @@ -52,10 +54,9 @@ def construct( s, DUT, FunctionUnit, FuList, DataType, PredicateType, def line_trace( s ): return s.dut.line_trace() -def run_sim( test_harness, max_cycles=6 ): +def run_sim( test_harness, max_cycles = kMaxCycles ): test_harness.elaborate() test_harness.apply( DefaultPassGroup() ) - test_harness.sim_reset() # Run simulation ncycles = 0 @@ -106,7 +107,9 @@ def test_systolic_2x2(): FuInType = mk_bits( clog2( num_fu_in + 1 ) ) pickRegister = [ FuInType( x+1 ) for x in range( num_fu_in ) ] - src_opt = [[CtrlType( OPT_LD_CONST, b1( 0 ), pickRegister, [ + src_opt = [ + # Tile 0: + [CtrlType( OPT_LD_CONST, b1( 0 ), pickRegister, [ RouteType(5), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0)] ), CtrlType( OPT_LD_CONST, b1( 0 ), pickRegister, [ @@ -119,6 +122,7 @@ def test_systolic_2x2(): RouteType(5), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0)] ), ], + # Tile 1: [CtrlType( OPT_NAH, b1( 0 ), pickRegister, [ RouteType(5), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0)] ), @@ -135,6 +139,7 @@ def test_systolic_2x2(): RouteType(5), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(0)] ), ], + # Tile 2: [CtrlType( OPT_NAH, b1( 0 ), pickRegister, [ RouteType(2), RouteType(0), RouteType(0), RouteType(0), RouteType(2), RouteType(0), RouteType(0), RouteType(0)] ), @@ -148,6 +153,7 @@ def test_systolic_2x2(): RouteType(2), RouteType(0), RouteType(0), RouteType(5), RouteType(2), RouteType(0), RouteType(0), RouteType(0)] ), ], + # Tile 3: [CtrlType( OPT_NAH, b1( 0 ), pickRegister, [ RouteType(2), RouteType(0), RouteType(0), RouteType(0), RouteType(2), RouteType(0), RouteType(0), RouteType(0)] ), @@ -161,6 +167,7 @@ def test_systolic_2x2(): RouteType(2), RouteType(0), RouteType(0), RouteType(5), RouteType(2), RouteType(0), RouteType(3), RouteType(0)] ), ], + # Tile 4: [CtrlType( OPT_NAH, b1( 0 ), pickRegister, [ RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(2), RouteType(0), RouteType(0), RouteType(0)] ), @@ -174,6 +181,7 @@ def test_systolic_2x2(): RouteType(0), RouteType(0), RouteType(0), RouteType(5), RouteType(2), RouteType(0), RouteType(0), RouteType(0)] ), ], + # Tile 5: [CtrlType( OPT_NAH, b1( 0 ), pickRegister, [ RouteType(0), RouteType(0), RouteType(0), RouteType(0), RouteType(2), RouteType(0), RouteType(3), RouteType(0)] ),