diff --git a/vta/include/vta/hw_spec.h b/vta/include/vta/hw_spec.h index 5d105ab6d940e..4fc33e2047275 100644 --- a/vta/include/vta/hw_spec.h +++ b/vta/include/vta/hw_spec.h @@ -35,6 +35,8 @@ extern "C" { #define VTA_OUT_WIDTH (1 << VTA_LOG_OUT_WIDTH) /*! Accumulator data type width */ #define VTA_ACC_WIDTH (1 << VTA_LOG_ACC_WIDTH) +/*! Accumulator truncation bits */ +#define VTA_ACC_TRUC_BITS 24 /*! log2 of ALU data type width */ #define VTA_LOG_ALU_WIDTH (VTA_LOG_ACC_WIDTH - 1) /*! ALU data type width */ diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 4e1a676f95dd1..05345a96f0de5 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -19,6 +19,9 @@ ['batch', 'height', 'width', 'in_filter', 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) +_SCHEDULE_STR_MAP = {} + + def find_schedules(layer, vt_only=False, best_only=False): """ Returns a schedule for a given a layer. @@ -414,6 +417,11 @@ def _traverse(op): else: pad_data = None wrkld = _get_workload(data, pad_data, kernel, output) + + if wrkld in _SCHEDULE_STR_MAP and planStr is None: + planStr = _SCHEDULE_STR_MAP[wrkld] + logging.info("Apply pre-cached schedule for %s->%s", str(wrkld) , planStr) + if planStr: matchObj = re.match( r'b(\d+)_oc(\d+)_ic(\d+)_h(\d+)_w(\d+)_oct(\d+)_ht(\d+)', planStr) b_factor = int(matchObj.group(1)) diff --git a/vta/src/sim/sim_driver.cc b/vta/src/sim/sim_driver.cc index b9cbd22c8adbd..0a9d2c2318d19 100644 --- a/vta/src/sim/sim_driver.cc +++ b/vta/src/sim/sim_driver.cc @@ -423,6 +423,14 @@ class Device { } } + int32_t IntTrunc(int32_t value, int32_t bits) { + if (bits >= 32) return value; + int leftbits = (32 - bits); + value = value & ((1 << bits) -1); + value = (value << leftbits) >> leftbits; + return value; + } + void RunGEMM(const VTAGemInsn* op) { if (!op->reset_reg) { prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn); @@ -452,6 +460,7 @@ class Device { sum += inp.GetSigned(i * VTA_BLOCK_IN + k) * wgt.GetSigned(j * VTA_BLOCK_IN + k); + sum = IntTrunc(sum, VTA_ACC_TRUC_BITS); } acc.SetSigned(acc_offset, sum); } @@ -540,11 +549,13 @@ class Device { BitPacker dst(acc_.BeginPtr(dst_index)); BitPacker src(acc_.BeginPtr(src_index)); for (int k = 0; k < VTA_BLOCK_OUT; ++k) { + int32_t value; if (use_imm) { - dst.SetSigned(k, func(dst.GetSigned(k), op->imm)); + value = func(dst.GetSigned(k), op->imm); } else { - dst.SetSigned(k, func(dst.GetSigned(k), src.GetSigned(k))); + value = func(dst.GetSigned(k), src.GetSigned(k)); } + dst.SetSigned(k, IntTrunc(value, VTA_ACC_TRUC_BITS)); } } } diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py index 91ac6b5f6f264..b98d6696a46f4 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py @@ -238,7 +238,9 @@ def get_ref_data(): return a_np, w_np, b_np def verify(s, check_correctness): - mod = vta.build(s, [data, kernel_arg, bias, coeff, res], "ext_dev", + mod = vta.build(s, + [data, kernel_arg, bias, coeff, res], + "ext_dev", env.target_host, name="conv2d") temp = util.tempdir()