diff --git a/src/main/scala/lltriscv/core/Core.scala b/src/main/scala/lltriscv/core/Core.scala index cd05e91..79037b5 100644 --- a/src/main/scala/lltriscv/core/Core.scala +++ b/src/main/scala/lltriscv/core/Core.scala @@ -186,7 +186,7 @@ class CoreFrontend(config: CoreConfig) extends Module { iCache.io.flush <> io.iCacheFlush // Fetch - fetch.io.satp := io.satp + fetch.io.asid := io.satp(30, 22) fetch.io.update <> io.predictorUpdate fetch.io.itlb <> itlb.io.request fetch.io.icache <> iCache.io.request diff --git a/src/main/scala/lltriscv/core/DataType.scala b/src/main/scala/lltriscv/core/DataType.scala index 720286b..3c2a037 100644 --- a/src/main/scala/lltriscv/core/DataType.scala +++ b/src/main/scala/lltriscv/core/DataType.scala @@ -27,3 +27,9 @@ object DataType { def asid = UInt(9.W) // 9-bits address space ID def aByte = UInt(8.W) // A Byte } + +object CoreConstant { + val XLEN = 32 + val instructionLength = 4 + val compressInstructionLength = 2 +} diff --git a/src/main/scala/lltriscv/core/execute/Memory.scala b/src/main/scala/lltriscv/core/execute/Memory.scala index 85bdf87..445c4de 100644 --- a/src/main/scala/lltriscv/core/execute/Memory.scala +++ b/src/main/scala/lltriscv/core/execute/Memory.scala @@ -187,6 +187,7 @@ class MemoryExecuteStage extends Module { ) { io.out.bits.error := MemoryErrorCode.misaligned } + chisel3.util.BitPat io.out.bits.rd := inReg.rd io.out.bits.pc := inReg.pc diff --git a/src/main/scala/lltriscv/core/fetch/BranchPredictor.scala b/src/main/scala/lltriscv/core/fetch/BranchPredictor.scala index ece3b9d..0598e28 100644 --- a/src/main/scala/lltriscv/core/fetch/BranchPredictor.scala +++ b/src/main/scala/lltriscv/core/fetch/BranchPredictor.scala @@ -11,7 +11,7 @@ abstract class BranchPredictor(depth: Int) extends Module { val io = IO(new Bundle { val asid = Input(DataType.asid) - val resuest = Flipped(new BranchPredictorRequestIO()) + val request = Flipped(new BranchPredictorRequestIO()) val update = Flipped(new BranchPredictorUpdateIO()) }) } @@ -35,25 +35,25 @@ class TwoBitsBranchPredictor(depth: Int) extends BranchPredictor(depth) { // Output for (i <- 0 until 2) { // Default - io.resuest.out(i) := Mux(io.resuest.in(i).compress, io.resuest.in(i).pc + 2.U, io.resuest.in(i).pc + 4.U) + io.request.out(i) := Mux(io.request.in(i).compress, io.request.in(i).pc + 2.U, io.request.in(i).pc + 4.U) for ( j <- 0 until depth; k <- 0 until 2 ) { - when(io.resuest.in(i).pc === table(j)(k).pc && io.asid === table(j)(k).asid) { + when(io.request.in(i).pc === table(j)(k).pc && io.asid === table(j)(k).asid) { // Output logic table switch(table(j)(k).history) { is(History.NN) { // N - io.resuest.out(i) := Mux(io.resuest.in(i).compress, io.resuest.in(i).pc + 2.U, io.resuest.in(i).pc + 4.U) + io.request.out(i) := Mux(io.request.in(i).compress, io.request.in(i).pc + 2.U, io.request.in(i).pc + 4.U) } is(History.NT) { // T - io.resuest.out(i) := table(j)(k).address + io.request.out(i) := table(j)(k).address } is(History.TN) { // T - io.resuest.out(i) := table(j)(k).address + io.request.out(i) := table(j)(k).address } is(History.TT) { // T - io.resuest.out(i) := table(j)(k).address + io.request.out(i) := table(j)(k).address } } } diff --git a/src/main/scala/lltriscv/core/fetch/Fetch.scala b/src/main/scala/lltriscv/core/fetch/Fetch.scala index 231a766..0712464 100644 --- a/src/main/scala/lltriscv/core/fetch/Fetch.scala +++ b/src/main/scala/lltriscv/core/fetch/Fetch.scala @@ -2,23 +2,29 @@ package lltriscv.core.fetch import chisel3._ import chisel3.util._ -import lltriscv.core.DataType + +import lltriscv.core._ import lltriscv.core.decode.DecodeStageEntry -import lltriscv.utils.ChiselUtils._ import lltriscv.core.record.TLBRequestIO -import lltriscv.cache.ICacheLineRequestIO -import lltriscv.utils.CoreUtils import lltriscv.core.execute.MemoryErrorCode +import lltriscv.cache.ICacheLineRequestIO + +import lltriscv.utils.CoreUtils._ +import lltriscv.utils.ChiselUtils._ +import lltriscv.utils.Sv32 + /* * Instruction fetch * - * Instruction fetch is located behind the instruction cache and is mainly used to control PC logic, instruction validation, and branch prediction + * Instruction fetch is located behind the instruction cache and is mainly used to control PC logic, instruction validation and branch prediction. * * Copyright (C) 2024-2025 LoveLonelyTime */ /** Fetch components + * + * InstructionFetcher -> SpeculationStage(InstructionExtender, InstructionPredictor) -> InstructionQueue * * @param cacheLineDepth * Instruction cache line depth @@ -33,14 +39,15 @@ class Fetch(cacheLineDepth: Int, queueDepth: Int, predictorDepth: Int, pcInit: I val io = IO(new Bundle { val itlb = new TLBRequestIO() val icache = new ICacheLineRequestIO(cacheLineDepth) - val out = DecoupledIO(Vec(2, new DecodeStageEntry())) - // Prediction failure and correction PC + val out = DecoupledIO(Vec2(new DecodeStageEntry())) + // Address space ID + val asid = Input(DataType.asid) + // Predictor update interface + val update = Flipped(new BranchPredictorUpdateIO()) + // Correction PC val correctPC = Input(DataType.address) // Recovery interface val recover = Input(Bool()) - - val satp = Input(DataType.operation) - val update = Flipped(new BranchPredictorUpdateIO()) }) private val instructionFetcher = Module(new InstructionFetcher(cacheLineDepth)) @@ -56,7 +63,7 @@ class Fetch(cacheLineDepth: Int, queueDepth: Int, predictorDepth: Int, pcInit: I instructionQueue.io.enq <> speculationStage.io.out instructionQueue.io.deq <> io.out - speculationStage.io.satp := io.satp + speculationStage.io.asid := io.asid speculationStage.io.update <> io.update speculationStage.io.correctPC := io.correctPC @@ -67,239 +74,255 @@ class Fetch(cacheLineDepth: Int, queueDepth: Int, predictorDepth: Int, pcInit: I /** Instruction fetcher * - * Caching two lines at once can perform cross line instructions concatenation. + * Caching two lines at once to connect cross line instructions. + * + * Composed of two independent working state machines: ITLB state machine, ICache state machine. * - * Composed of two independent working state machines:ITLB state machine, ICache state machine. + * ITLB state machine is responsible for caching the PTE corresponding to the current and next cache line. * - * ITLB state machine is responsible for caching the PTE corresponding to the current CacheLine and the PTE corresponding to the next CacheLine. + * ICache state machine is responsible for caching the current and next cache line. * - * ICache state machine is responsible for caching the current and next CacheLine. + * Maximum throughput: 2 32Bits/16Bits instructions / cycle * * @param cacheLineDepth - * cache line depth + * Instruction cache line depth */ class InstructionFetcher(cacheLineDepth: Int) extends Module { require(cacheLineDepth % 2 == 0, "Instruction Cache line depth must be a multiple of 2") require(cacheLineDepth >= 4, "Instruction Cache line depth must be greater than or equal 4") val io = IO(new Bundle { - val out = Output(Vec(2, new RawInstructionEntry())) + val out = Output(Vec2(new RawInstructionEntry())) + // The PC of instructions fetching val pc = Input(DataType.address) + // Instruction TLB request interface val itlb = new TLBRequestIO() + // Instruction cache request interface val icache = new ICacheLineRequestIO(cacheLineDepth) val recover = Input(Bool()) }) - // Buffer - private val itlbWorkReg = RegInit(Vec(2, new ITLBWorkEntry()).zero) - private val lineWorkReg = RegInit(Vec(2, new ICacheLineWorkEntry()).zero) + // Buffers + private val itlbWorkReg = RegInit(Vec2(new ITLBWorkEntry()).zero) + private val cacheLineWorkReg = RegInit(Vec2(new ICacheLineWorkEntry()).zero) + + /* The struction of cache address: + CacheLineAddress | CacheLineOffset | byte offset + Remainder | log2Ceil(cacheLineDepth) | 1 + */ - // CacheLineAddress | CacheLineOffset | byte offset - // Remainder | log2Ceil(cacheLineDepth) | 1 - private def getCacheLineAddress(address: UInt) = address(31, log2Ceil(cacheLineDepth) + 1) ## 0.U((log2Ceil(cacheLineDepth) + 1).W) - private def getNextCacheLineAddress(address: UInt) = (address(31, log2Ceil(cacheLineDepth) + 1) + 1.U) ## 0.U((log2Ceil(cacheLineDepth) + 1).W) - private def getCacheLineOffset(address: UInt) = address(log2Ceil(cacheLineDepth), 1) + // Address helper functions + private def getCacheLineAddress(address: UInt) = + address(CoreConstant.XLEN - 1, log2Ceil(cacheLineDepth) + 1) ## 0.U((log2Ceil(cacheLineDepth) + 1).W) - private val lineMatch = VecInit.fill(2)(false.B) - for (i <- 0 until 2) lineMatch(i) := lineWorkReg(i).valid && lineWorkReg(i).pc === getCacheLineAddress(io.pc) - private val nextLineMatch = VecInit.fill(2)(false.B) - for (i <- 0 until 2) nextLineMatch(i) := lineWorkReg(i).valid && lineWorkReg(i).pc === getNextCacheLineAddress(io.pc) + private def getNextCacheLineAddress(address: UInt) = + (address(CoreConstant.XLEN - 1, log2Ceil(cacheLineDepth) + 1) + 1.U) ## 0.U((log2Ceil(cacheLineDepth) + 1).W) - private def isInBoundary(address: UInt) = getCacheLineOffset(address) === (cacheLineDepth - 1).U + private def getCacheLineOffset(address: UInt) = + address(log2Ceil(cacheLineDepth), 1) - // TLB logic - private val itlbMatch = VecInit.fill(2)(false.B) - for (i <- 0 until 2) itlbMatch(i) := itlbWorkReg(i).valid && itlbWorkReg(i).vpn === getCacheLineAddress(io.pc)(31, 12) - private val nextITLBMatch = VecInit.fill(2)(false.B) - for (i <- 0 until 2) nextITLBMatch(i) := itlbWorkReg(i).valid && itlbWorkReg(i).vpn === getNextCacheLineAddress(io.pc)(31, 12) + private def isInBoundary(address: UInt) = + getCacheLineOffset(address) === (cacheLineDepth - 1).U - private object ITLBStatus extends ChiselEnum { - val idle, req = Value + // FSM status + private object Status extends ChiselEnum { + val idle, request = Value } - private val itlbStatusReg = RegInit(ITLBStatus.idle) - private val itlbQueryAddressReg = RegInit(DataType.address.zeroAsUInt) - private val itlbVictim = RegInit(0.U) - private val itlbTaskFlag = RegInit(false.B) - - when(itlbStatusReg === ITLBStatus.idle) { - when(!itlbMatch.reduceTree(_ || _)) { // ITLB missing - itlbQueryAddressReg := getCacheLineAddress(io.pc) - itlbVictim := 0.U - itlbTaskFlag := true.B - itlbStatusReg := ITLBStatus.req - }.elsewhen(!nextITLBMatch.reduceTree(_ || _)) { // Next ITLB missing - itlbQueryAddressReg := getNextCacheLineAddress(io.pc) - itlbVictim := itlbMatch(0) // Victim is another one - itlbTaskFlag := true.B - itlbStatusReg := ITLBStatus.req - } + /* ---------------- ITLB FSM start ---------------- */ + + // Match current ITLB + private val itlbMatch = VecInit2(false.B) + itlbMatch.zipWithIndex.foreach { case (item, i) => + item := itlbWorkReg(i).valid && itlbWorkReg(i).vpn === Sv32.getVPN(getCacheLineAddress(io.pc)) + } + + // Match next ITLB + private val nextITLBMatch = VecInit2(false.B) + nextITLBMatch.zipWithIndex.foreach { case (item, i) => + item := itlbWorkReg(i).valid && itlbWorkReg(i).vpn === Sv32.getVPN(getNextCacheLineAddress(io.pc)) } - io.itlb.valid := false.B - io.itlb.vaddress := 0.U - io.itlb.write := false.B - when(itlbStatusReg === ITLBStatus.req) { - io.itlb.vaddress := itlbQueryAddressReg - io.itlb.valid := true.B - when(io.itlb.ready) { - when(itlbTaskFlag) { - // Not consider 4MiB pages - itlbWorkReg(itlbVictim).vpn := itlbQueryAddressReg(31, 12) - itlbWorkReg(itlbVictim).ppn := io.itlb.paddress(31, 12) - itlbWorkReg(itlbVictim).error := io.itlb.error - itlbWorkReg(itlbVictim).valid := true.B + private val itlbStatusReg = RegInit(Status.idle) + private val itlbQueryAddressReg = RegInit(DataType.address.zeroAsUInt) + private val itlbVictim = RegInit(0.U) // Victim + private val itlbTaskFlag = RegInit(false.B) // Undo + + io.itlb <> new TLBRequestIO().zero + + switch(itlbStatusReg) { + is(Status.idle) { + when(!itlbMatch.reduceTree(_ || _)) { // ITLB missing + itlbQueryAddressReg := getCacheLineAddress(io.pc) + itlbVictim := 0.U + itlbTaskFlag := true.B + itlbStatusReg := Status.request // Request + }.elsewhen(!nextITLBMatch.reduceTree(_ || _)) { // Next ITLB missing + itlbQueryAddressReg := getNextCacheLineAddress(io.pc) + itlbVictim := itlbMatch(0) // Victim is another one + itlbTaskFlag := true.B + itlbStatusReg := Status.request // Request } + } + + is(Status.request) { + io.itlb.vaddress := itlbQueryAddressReg + io.itlb.valid := true.B + when(io.itlb.ready) { + when(itlbTaskFlag) { + // Not consider 4MiB pages + itlbWorkReg(itlbVictim).vpn := Sv32.getVPN(itlbQueryAddressReg) + itlbWorkReg(itlbVictim).ppn := Sv32.get32PPN(io.itlb.paddress) + itlbWorkReg(itlbVictim).error := io.itlb.error + itlbWorkReg(itlbVictim).valid := true.B + } - itlbStatusReg := ITLBStatus.idle // Return + itlbStatusReg := Status.idle // Return + } } } - // Cache logic - private object ICacheStatus extends ChiselEnum { - val idle, req = Value + /* ---------------- ITLB FSM end ---------------- */ + + /* ---------------- ICache FSM start ---------------- */ + + // Match current cache line + private val cacheLineMatch = VecInit2(false.B) + cacheLineMatch.zipWithIndex.foreach { case (item, i) => + item := cacheLineWorkReg(i).valid && cacheLineWorkReg(i).address === getCacheLineAddress(io.pc) + } + + // Match next cache line + private val nextCacheLineMatch = VecInit2(false.B) + nextCacheLineMatch.zipWithIndex.foreach { case (item, i) => + item := cacheLineWorkReg(i).valid && cacheLineWorkReg(i).valid && cacheLineWorkReg(i).address === getNextCacheLineAddress(io.pc) } - private val iCacheStatusReg = RegInit(ICacheStatus.idle) + + private val iCacheStatusReg = RegInit(Status.idle) private val iCacheQueryVAddressReg = RegInit(DataType.address.zeroAsUInt) private val iCacheQueryPAddressReg = RegInit(DataType.address.zeroAsUInt) private val iCacheVictim = RegInit(0.U) private val iCacheTaskFlag = RegInit(false.B) - when(iCacheStatusReg === ICacheStatus.idle) { - when(!lineMatch.reduceTree(_ || _)) { // Cache line missing - for (i <- 0 until 2) { - when(itlbMatch(i)) { // ITLB exists - when(itlbWorkReg(i).error === MemoryErrorCode.none) { // OK - iCacheQueryVAddressReg := getCacheLineAddress(io.pc) - iCacheQueryPAddressReg := itlbWorkReg(i).ppn(19, 0) ## getCacheLineAddress(io.pc)(11, 0) - iCacheVictim := 0.U - iCacheTaskFlag := true.B - iCacheStatusReg := ICacheStatus.req - }.otherwise { // Error - lineWorkReg(0).pc := getCacheLineAddress(io.pc) - lineWorkReg(0).error := itlbWorkReg(i).error - lineWorkReg(0).valid := true.B + io.icache <> new ICacheLineRequestIO(cacheLineDepth).zero + + switch(iCacheStatusReg) { + is(Status.idle) { + when(!cacheLineMatch.reduceTree(_ || _)) { // Cache line missing + itlbMatch.zip(itlbWorkReg).foreach { case (matched, item) => + when(matched) { // ITLB exists + when(item.error === MemoryErrorCode.none) { // OK + iCacheQueryVAddressReg := getCacheLineAddress(io.pc) + iCacheQueryPAddressReg := item.ppn ## Sv32.getOffset(getCacheLineAddress(io.pc)) + iCacheVictim := 0.U + iCacheTaskFlag := true.B + iCacheStatusReg := Status.request // Request + }.otherwise { // ITLB error + cacheLineWorkReg(0).address := getCacheLineAddress(io.pc) + cacheLineWorkReg(0).error := item.error + cacheLineWorkReg(0).valid := true.B + } } } - } - }.elsewhen(!nextLineMatch.reduceTree(_ || _)) { // Next ITLB missing - val nextVictim = lineMatch(0) // Victim is another one - for (i <- 0 until 2) { - when(nextITLBMatch(i)) { // Next ITLB exists - when(itlbWorkReg(i).error === MemoryErrorCode.none) { // OK - iCacheQueryVAddressReg := getNextCacheLineAddress(io.pc) - iCacheQueryPAddressReg := itlbWorkReg(i).ppn(19, 0) ## getNextCacheLineAddress(io.pc)(11, 0) - iCacheVictim := nextVictim - iCacheTaskFlag := true.B - iCacheStatusReg := ICacheStatus.req - }.otherwise { - lineWorkReg(nextVictim).pc := getNextCacheLineAddress(io.pc) - lineWorkReg(nextVictim).error := itlbWorkReg(i).error - lineWorkReg(nextVictim).valid := true.B + }.elsewhen(!nextCacheLineMatch.reduceTree(_ || _)) { // Next ITLB missing + val nextVictim = cacheLineMatch(0) // Victim is another one + nextITLBMatch.zip(itlbWorkReg).foreach { case (matched, item) => + when(matched) { // Next ITLB exists + when(item.error === MemoryErrorCode.none) { // OK + iCacheQueryVAddressReg := getNextCacheLineAddress(io.pc) + iCacheQueryPAddressReg := item.ppn ## Sv32.getOffset(getNextCacheLineAddress(io.pc)) + iCacheVictim := nextVictim + iCacheTaskFlag := true.B + iCacheStatusReg := Status.request // Request + }.otherwise { // ITLB error + cacheLineWorkReg(nextVictim).address := getNextCacheLineAddress(io.pc) + cacheLineWorkReg(nextVictim).error := item.error + cacheLineWorkReg(nextVictim).valid := true.B + } } } } } - } - io.icache.valid := false.B - io.icache.address := 0.U - when(iCacheStatusReg === ICacheStatus.req) { - io.icache.address := iCacheQueryPAddressReg - io.icache.valid := true.B - - when(io.icache.ready) { - when(iCacheTaskFlag) { - lineWorkReg(iCacheVictim).pc := iCacheQueryVAddressReg - lineWorkReg(iCacheVictim).content := io.icache.data - lineWorkReg(iCacheVictim).error := Mux(io.icache.error, MemoryErrorCode.memoryFault, MemoryErrorCode.none) - lineWorkReg(iCacheVictim).valid := true.B - } + is(Status.request) { + io.icache.address := iCacheQueryPAddressReg + io.icache.valid := true.B + + when(io.icache.ready) { + when(iCacheTaskFlag) { + cacheLineWorkReg(iCacheVictim).address := iCacheQueryVAddressReg + cacheLineWorkReg(iCacheVictim).content := io.icache.data + cacheLineWorkReg(iCacheVictim).error := Mux(io.icache.error, MemoryErrorCode.memoryFault, MemoryErrorCode.none) + cacheLineWorkReg(iCacheVictim).valid := true.B + } - iCacheStatusReg := ICacheStatus.idle // Return + iCacheStatusReg := Status.idle // Return + } } } - io.out.foreach(_ := new RawInstructionEntry().zero) - - // Merge logic - for (i <- 0 until 2) { - when(lineMatch(i)) { // i match PC - val pcValues = Wire(Vec(2, DataType.address)) - // Get 0 pc - pcValues(0) := io.pc - - // Get offset - val offset = Wire(Vec(2, UInt(cacheLineDepth.W))) - for (i <- 0 until 2) - offset(i) := getCacheLineOffset(pcValues(i)) - - val compress = VecInit(false.B, false.B) - // Get 0 compress - compress(0) := lineWorkReg(i).content(offset(0))(1, 0) =/= "b11".U - // Get 1 pc - pcValues(1) := Mux(compress(0), pcValues(0) + 2.U, pcValues(0) + 4.U) - - // Get next i - val nextI = Mux(getCacheLineAddress(pcValues(0)) === getCacheLineAddress(pcValues(1)), i.U, (1 - i).U) - - // Get 0 compress - compress(1) := lineWorkReg(nextI).content(offset(1))(1, 0) =/= "b11".U - - // Output - when(lineWorkReg(i).error =/= MemoryErrorCode.none) { // Error - io.out(0).error := lineWorkReg(i).error - io.out(0).valid := true.B - }.elsewhen(compress(0)) { // 16-bits OK - io.out(0).instruction := lineWorkReg(i).content(offset(0)) - io.out(0).compress := true.B - io.out(0).valid := true.B - }.elsewhen(!isInBoundary(pcValues(0))) { // 32-bits OK - io.out(0).instruction := lineWorkReg(i).content(offset(0) + 1.U) ## lineWorkReg(i).content(offset(0)) - io.out(0).compress := false.B - io.out(0).valid := true.B - }.elsewhen(nextLineMatch(1 - i)) { // 32-bits, crossing - when(lineWorkReg(1 - i).error =/= MemoryErrorCode.none) { // Error - io.out(0).error := lineWorkReg(1 - i).error - io.out(0).valid := true.B - }.otherwise { // OK - io.out(0).instruction := lineWorkReg(1 - i).content(0) ## lineWorkReg(i).content(offset(0)) - io.out(0).compress := false.B - io.out(0).valid := true.B + /* ---------------- ICache FSM end ---------------- */ + + // Merge output logic + io.out := Vec2(new RawInstructionEntry()).zero + + cacheLineMatch + .zip(nextCacheLineMatch.reverse) + .zip(cacheLineWorkReg.zip(cacheLineWorkReg.reverse)) + .foreach { case ((matched, oppositeMatched), (item, opposite)) => + when(matched) { // Cache line exists + // Collect the location of two instructions + val cachelines = Wire(Vec2(new ICacheLineWorkEntry())) + val pcValues = Wire(Vec2(DataType.address)) + val offsetValues = pcValues.map(getCacheLineOffset(_)) + val compress = cachelines.zip(offsetValues).map { case (cacheLine, offset) => + isCompressInstruction(cacheLine.content(offset)) } - } - when(lineWorkReg(i).error =/= MemoryErrorCode.none || lineWorkReg(nextI).pc =/= getCacheLineAddress(pcValues(1))) { // Skip - io.out(1).valid := false.B - }.elsewhen(lineWorkReg(nextI).error =/= MemoryErrorCode.none) { // Error - io.out(1).error := lineWorkReg(nextI).error - io.out(1).valid := true.B - }.elsewhen(compress(1)) { // 16-bits OK - io.out(1).instruction := lineWorkReg(nextI).content(offset(1)) - io.out(1).compress := true.B - io.out(1).valid := true.B - }.elsewhen(!isInBoundary(pcValues(1))) { // 32-bits OK - io.out(1).instruction := lineWorkReg(nextI).content(offset(1) + 1.U) ## lineWorkReg(nextI).content(offset(1)) - io.out(1).compress := false.B - io.out(1).valid := true.B - }.elsewhen(nextLineMatch(1 - i)) { // 32-bits, crossing - when(lineWorkReg(1 - i).error =/= MemoryErrorCode.none) { // Error - io.out(1).error := lineWorkReg(1 - i).error - io.out(1).valid := true.B - }.otherwise { // OK - io.out(1).instruction := lineWorkReg(1 - i).content(0) ## lineWorkReg(i).content(offset(1)) - io.out(1).compress := false.B - io.out(1).valid := true.B + pcValues(0) := io.pc + cachelines(0) := item + // The PC of next instruction + pcValues(1) := pcValues(0) + Mux(compress(0), CoreConstant.compressInstructionLength.U, CoreConstant.instructionLength.U) + cachelines(1) := Mux(getCacheLineAddress(pcValues(0)) === getCacheLineAddress(pcValues(1)), item, opposite) + + // Output + io.out.zipWithIndex.foreach { case (out, i) => + when(!cachelines(i).valid || cachelines(i).address =/= getCacheLineAddress(pcValues(i))) { // Skip + out.valid := false.B + }.elsewhen(cachelines(i).error =/= MemoryErrorCode.none) { // Error + out.error := cachelines(i).error + out.valid := true.B + }.elsewhen(compress(i)) { // 16-bits OK + out.instruction := cachelines(i).content(offsetValues(i)) + out.compress := true.B + out.valid := true.B + }.elsewhen(!isInBoundary(pcValues(i))) { // 32-bits OK + out.instruction := cachelines(i).content(offsetValues(i) + 1.U) ## cachelines(i).content(offsetValues(i)) + out.compress := false.B + out.valid := true.B + }.elsewhen(oppositeMatched) { // 32-bits, crossing cache line + when(opposite.error =/= MemoryErrorCode.none) { // Error + out.error := opposite.error + out.valid := true.B + }.otherwise { // OK + out.instruction := opposite.content(0) ## cachelines(i).content(offsetValues(i)) + out.compress := false.B + out.valid := true.B + } + } + } + // When the first instruction is error, the second instruction cannot be valid + when(cachelines(0).error =/= MemoryErrorCode.none) { + io.out(1).valid := false.B } } } - } // Recover logic when(io.recover) { // Clear buffers itlbWorkReg.foreach(_.valid := false.B) - lineWorkReg.foreach(_.valid := false.B) + cacheLineWorkReg.foreach(_.valid := false.B) // Undo cache tasks itlbTaskFlag := false.B iCacheTaskFlag := false.B @@ -308,58 +331,56 @@ class InstructionFetcher(cacheLineDepth: Int) extends Module { /** Instruction extender * - * Expanding 16 bits instructions to 32 bits - * - * C Extensions + * Expanding 16 bits(C Extensions) instructions to 32 bits */ class InstructionExtender extends Module { val io = IO(new Bundle { - val in = Input(Vec(2, new RawInstructionEntry())) - val out = Output(Vec(2, new RawInstructionEntry())) + val in = Input(Vec2(new RawInstructionEntry())) + val out = Output(Vec2(new RawInstructionEntry())) }) - def extend(instructionIn: UInt) = { - + // Extending function + private def extend(instructionIn: UInt) = { val instructionOut = WireInit(instructionIn) - val imm = WireInit(0.U(32.W)) - val eimm = WireInit(0.U(32.W)) - val rs1 = WireInit(0.U(5.W)) - val rs2 = WireInit(0.U(5.W)) - val rd = WireInit(0.U(5.W)) + val imm = WireInit(DataType.immediate.zeroAsUInt) + val simm = WireInit(DataType.immediate.zeroAsUInt) + val rs1 = WireInit(DataType.register.zeroAsUInt) + val rs2 = WireInit(DataType.register.zeroAsUInt) + val rd = WireInit(DataType.register.zeroAsUInt) switch(instructionIn(1, 0)) { is("b01".U) { switch(instructionIn(15, 13)) { is("b000".U) { // c.nop, c.addi imm := instructionIn(12) ## instructionIn(6, 2) - eimm := CoreUtils.signExtended(imm, 5) - instructionOut := eimm(11, 0) ## instructionIn(11, 7) ## "b000".U(3.W) ## instructionIn(11, 7) ## "b0010011".U(7.W) + simm := signExtended(imm, 5) + instructionOut := simm(11, 0) ## instructionIn(11, 7) ## "b000".U(3.W) ## instructionIn(11, 7) ## "b0010011".U(7.W) } is("b001".U) { // c.jal imm := instructionIn(12) ## instructionIn(8) ## instructionIn(10, 9) ## instructionIn(6) ## instructionIn(7) ## instructionIn(2) ## instructionIn(11) ## instructionIn(5, 3) ## 0.U - eimm := CoreUtils.signExtended(imm, 11) - instructionOut := eimm(20) ## eimm(10, 1) ## eimm(11) ## eimm(19, 12) ## "b00001".U(5.W) ## "b1101111".U(7.W) + simm := signExtended(imm, 11) + instructionOut := simm(20) ## simm(10, 1) ## simm(11) ## simm(19, 12) ## "b00001".U(5.W) ## "b1101111".U(7.W) } is("b010".U) { // c.li imm := instructionIn(12) ## instructionIn(6, 2) - eimm := CoreUtils.signExtended(imm, 5) - instructionOut := eimm(11, 0) ## "b00000".U(5.W) ## "b000".U(3.W) ## instructionIn(11, 7) ## "b0010011".U(7.W) + simm := signExtended(imm, 5) + instructionOut := simm(11, 0) ## "b00000".U(5.W) ## "b000".U(3.W) ## instructionIn(11, 7) ## "b0010011".U(7.W) } is("b011".U) { // c.addi16sp, c.lui when(instructionIn(11, 7) === "b00010".U) { // c.addi16sp imm := instructionIn(12) ## instructionIn(4, 3) ## instructionIn(5) ## instructionIn(2) ## instructionIn(6) ## "b0000".U(4.W) - eimm := CoreUtils.signExtended(imm, 9) - instructionOut := eimm(11, 0) ## instructionIn(11, 7) ## "b000".U(3.W) ## instructionIn(11, 7) ## "b0010011".U(7.W) + simm := signExtended(imm, 9) + instructionOut := simm(11, 0) ## instructionIn(11, 7) ## "b000".U(3.W) ## instructionIn(11, 7) ## "b0010011".U(7.W) }.otherwise { // c.lui imm := instructionIn(12) ## instructionIn(6, 2) ## "b000000000000".U(12.W) - eimm := CoreUtils.signExtended(imm, 17) - instructionOut := eimm(31, 12) ## instructionIn(11, 7) ## "b0110111".U(7.W) + simm := signExtended(imm, 17) + instructionOut := simm(31, 12) ## instructionIn(11, 7) ## "b0110111".U(7.W) } } is("b100".U) { // c.srli, c.srai, c.andi, c.sub, c.xor, c.or, c.and imm := instructionIn(12) ## instructionIn(6, 2) - eimm := CoreUtils.signExtended(imm, 5) + simm := signExtended(imm, 5) rs1 := "b01".U ## instructionIn(9, 7) rs2 := "b01".U ## instructionIn(4, 2) switch(instructionIn(11, 10)) { @@ -370,7 +391,7 @@ class InstructionExtender extends Module { instructionOut := "b0100000".U(7.W) ## imm(4, 0) ## rs1 ## "b101".U ## rs1 ## "b0010011".U(7.W) } is("b10".U) { // c.andi - instructionOut := eimm(11, 0) ## rs1 ## "b111".U ## rs1 ## "b0010011".U(7.W) + instructionOut := simm(11, 0) ## rs1 ## "b111".U ## rs1 ## "b0010011".U(7.W) } is("b11".U) { // c.sub, c.xor, c.or, c.and switch(instructionIn(6, 5)) { @@ -393,22 +414,22 @@ class InstructionExtender extends Module { is("b101".U) { // c.j imm := instructionIn(12) ## instructionIn(8) ## instructionIn(10, 9) ## instructionIn(6) ## instructionIn(7) ## instructionIn(2) ## instructionIn(11) ## instructionIn(5, 3) ## 0.U - eimm := CoreUtils.signExtended(imm, 11) - instructionOut := eimm(20) ## eimm(10, 1) ## eimm(11) ## eimm(19, 12) ## "b00000".U(5.W) ## "b1101111".U(7.W) + simm := signExtended(imm, 11) + instructionOut := simm(20) ## simm(10, 1) ## simm(11) ## simm(19, 12) ## "b00000".U(5.W) ## "b1101111".U(7.W) } is("b110".U) { // c.beqz rs1 := "b01".U ## instructionIn(9, 7) imm := instructionIn(12) ## instructionIn(6, 5) ## instructionIn(2) ## instructionIn(11, 10) ## instructionIn(4, 3) ## 0.U - eimm := CoreUtils.signExtended(imm, 8) - instructionOut := eimm(12) ## eimm(10, 5) ## "b00000".U(5.W) ## rs1 ## "b000".U(3.W) ## eimm(4, 1) ## eimm(11) ## "b1100011".U(7.W) + simm := signExtended(imm, 8) + instructionOut := simm(12) ## simm(10, 5) ## "b00000".U(5.W) ## rs1 ## "b000".U(3.W) ## simm(4, 1) ## simm(11) ## "b1100011".U(7.W) } is("b111".U) { // c.bnez rs1 := "b01".U ## instructionIn(9, 7) imm := instructionIn(12) ## instructionIn(6, 5) ## instructionIn(2) ## instructionIn(11, 10) ## instructionIn(4, 3) ## 0.U - eimm := CoreUtils.signExtended(imm, 8) - instructionOut := eimm(12) ## eimm(10, 5) ## "b00000".U(5.W) ## rs1 ## "b001".U(3.W) ## eimm(4, 1) ## eimm(11) ## "b1100011".U(7.W) + simm := signExtended(imm, 8) + instructionOut := simm(12) ## simm(10, 5) ## "b00000".U(5.W) ## rs1 ## "b001".U(3.W) ## simm(4, 1) ## simm(11) ## "b1100011".U(7.W) } } } @@ -479,71 +500,93 @@ class InstructionExtender extends Module { instructionOut // Return } - for (i <- 0 until 2) { - io.out(i).error := io.in(i).error - io.out(i).compress := io.in(i).compress - io.out(i).instruction := Mux(io.in(i).compress, extend(io.in(i).instruction), io.in(i).instruction) - io.out(i).valid := io.in(i).valid + // Output + io.out.zip(io.in).foreach { case (out, in) => + out.error := in.error + out.compress := in.compress + out.instruction := Mux(in.compress, extend(in.instruction), in.instruction) + out.valid := in.valid } } /** Instruction predictor + * + * Using predictor FSM for instruction prediction + * * @param depth * Branch predictor table depth */ class InstructionPredictor(depth: Int) extends Module { val io = IO(new Bundle { - val in = Input(Vec(2, new RawInstructionEntry())) + val in = Input(Vec2(new RawInstructionEntry())) + val out = Output(Vec2(new SpeculativeEntry())) + // The PC of instructions to be predicted val pc = Input(DataType.address) - val out = Output(Vec(2, new SpeculativeEntry())) - val nextPC = Output(DataType.address) - + // Address space ID val asid = Input(DataType.asid) + // Predicted next PC + val nextPC = Output(DataType.address) + // Predictor update interface val update = Flipped(new BranchPredictorUpdateIO()) }) - private val pcValues = Wire(Vec(2, DataType.address)) - pcValues(0) := io.pc - pcValues(1) := Mux(io.in(0).compress, io.pc + 2.U, io.pc + 4.U) - private val nextPCValues = Wire(Vec(2, DataType.address)) - nextPCValues(0) := pcValues(1) - nextPCValues(1) := Mux(io.in(1).compress, pcValues(1) + 2.U, pcValues(1) + 4.U) + // Instruction address + private val pcValues = VecInit( + io.pc, + io.pc + Mux(io.in(0).compress, CoreConstant.compressInstructionLength.U, CoreConstant.instructionLength.U) + ) - // Prediction + // Next instruction address + private val nextPCValues = VecInit( + pcValues(1), + pcValues(1) + Mux(io.in(1).compress, CoreConstant.compressInstructionLength.U, CoreConstant.instructionLength.U) + ) + + // Predictor FSM + // Here, You can choose different implementations private val branchPredictor = Module(new TwoBitsBranchPredictor(depth)) - private val specValues = Wire(Vec(2, DataType.address)) - for (i <- 0 until 2) { - branchPredictor.io.resuest.in(i).pc := pcValues(i) - branchPredictor.io.resuest.in(i).compress := io.in(i).compress - specValues(i) := branchPredictor.io.resuest.out(i) - } branchPredictor.io.asid := io.asid branchPredictor.io.update <> io.update - for (i <- 0 until 2) { - io.out(i).instruction := io.in(i).instruction - io.out(i).pc := pcValues(i) - io.out(i).spec := specValues(i) - io.out(i).next := nextPCValues(i) - io.out(i).error := io.in(i).error - io.out(i).valid := io.in(i).valid + private val specValues = Wire(Vec2(DataType.address)) + specValues.zipWithIndex.foreach { case (spec, i) => + branchPredictor.io.request.in(i).pc := pcValues(i) + branchPredictor.io.request.in(i).compress := io.in(i).compress + spec := branchPredictor.io.request.out(i) + } + + // Output + io.out.zipWithIndex.foreach { case (out, i) => + out.instruction := io.in(i).instruction + out.pc := pcValues(i) + out.spec := specValues(i) + out.next := nextPCValues(i) + out.error := io.in(i).error + out.valid := io.in(i).valid + } + + // Next PC + private val grant = VecInit2(false.B) + grant(0) := io.in(0).valid + for (i <- 1 until grant.length) { + grant(i) := grant(i - 1) && io.in(i).valid && specValues(i - 1) === pcValues(i) + // Not in the prediction chain + when(!grant(i)) { io.out(i).valid := false.B } } - when(io.in(0).valid) { - when(io.in(1).valid && specValues(0) === pcValues(1)) { // 0 -> 1 -> spec - io.nextPC := specValues(1) - }.otherwise { // 0 -> spec, 1 is masked - io.nextPC := specValues(0) - io.out(1).valid := false.B + io.nextPC := io.pc + grant.zip(specValues).foreach { case (granted, spec) => + when(granted) { + io.nextPC := spec } - }.otherwise { - io.nextPC := io.pc } } /** Speculation stage * - * Extend instructions and predict + * Sample, extend and predict instructions + * + * Single cycle stage * * @param depth * Branch predictor table depth @@ -552,15 +595,15 @@ class InstructionPredictor(depth: Int) extends Module { */ class SpeculationStage(depth: Int, pcInit: Int) extends Module { val io = IO(new Bundle { - // Pipeline interface - val in = Input(Vec(2, new RawInstructionEntry())) - val out = DecoupledIO(Vec(2, new SpeculativeEntry())) - - val satp = Input(DataType.operation) + val in = Input(Vec2(new RawInstructionEntry())) + val out = DecoupledIO(Vec2(new SpeculativeEntry())) + // Address space ID + val asid = Input(DataType.asid) + // Predictor update interface val update = Flipped(new BranchPredictorUpdateIO()) // Current PC val pc = Output(DataType.address) - // Prediction failure and correction PC + // Correction PC val correctPC = Input(DataType.address) // Recovery interface val recover = Input(Bool()) @@ -570,14 +613,11 @@ class SpeculationStage(depth: Int, pcInit: Int) extends Module { private val pcReg = RegInit(pcInit.U(DataType.address.getWidth.W)) // Pipeline logic - private val inReg = RegInit(Vec(2, new RawInstructionEntry()).zero) - val extender = Module(new InstructionExtender()) - extender.io.in <> inReg - - val predictor = Module(new InstructionPredictor(depth)) - predictor.io.in <> extender.io.out + private val inReg = RegInit(Vec2(new RawInstructionEntry()).zero) + private val extender = Module(new InstructionExtender()) + private val predictor = Module(new InstructionPredictor(depth)) predictor.io.pc := pcReg - predictor.io.asid := io.satp(30, 22) + predictor.io.asid := io.asid predictor.io.update <> io.update when(io.out.fire) { // Sample @@ -585,6 +625,9 @@ class SpeculationStage(depth: Int, pcInit: Int) extends Module { pcReg := predictor.io.nextPC } + inReg <> extender.io.in + extender.io.out <> predictor.io.in + // Next PC io.pc := predictor.io.nextPC @@ -595,11 +638,14 @@ class SpeculationStage(depth: Int, pcInit: Int) extends Module { // Recovery logic when(io.recover) { inReg.foreach(_.valid := false.B) + // Correct PC pcReg := io.correctPC } } /** Instruction queue + * + * Buffering instructions, implementing with loop pointer * * @param depth * Instruction queue depth @@ -608,17 +654,17 @@ class InstructionQueue(depth: Int) extends Module { require(depth > 0, "Instruction queue depth must be greater than 0") val io = IO(new Bundle { - val enq = Flipped(DecoupledIO(Vec(2, new SpeculativeEntry()))) - val deq = DecoupledIO(Vec(2, new DecodeStageEntry())) + val enq = Flipped(DecoupledIO(Vec2(new SpeculativeEntry()))) + val deq = DecoupledIO(Vec2(new DecodeStageEntry())) val recover = Input(Bool()) }) - private val queue = RegInit(Vec(depth, Vec(2, new SpeculativeEntry())).zero) + private val queue = RegInit(Vec(depth, Vec2(new SpeculativeEntry())).zero) private val incrRead = WireInit(false.B) private val incrWrite = WireInit(false.B) - private val (readPtr, nextRead) = CoreUtils.pointer(depth, incrRead) - private val (writePtr, nextWrite) = CoreUtils.pointer(depth, incrWrite) + private val (readPtr, nextRead) = pointer(depth, incrRead) + private val (writePtr, nextWrite) = pointer(depth, incrWrite) private val emptyReg = RegInit(true.B) private val fullReg = RegInit(false.B) @@ -626,8 +672,9 @@ class InstructionQueue(depth: Int) extends Module { io.enq.ready := !fullReg io.deq.valid := !emptyReg - // Queue logic - private val op = (io.enq.valid && io.enq.ready) ## (io.deq.valid && io.deq.ready) + // Prevent invalid instructions from blocking queues to accelerate fetching instructions + private val validEnqueue = VecInit(io.enq.bits.map(_.valid)).reduceTree(_ || _) + private val op = (io.enq.fire && validEnqueue) ## (io.deq.fire) private val doWrite = WireDefault(false.B) switch(op) { @@ -650,7 +697,7 @@ class InstructionQueue(depth: Int) extends Module { } } - for (i <- 0 until 2) io.deq.bits <> queue(readPtr) + io.deq.bits <> queue(readPtr) // Write logic when(doWrite) { diff --git a/src/main/scala/lltriscv/core/fetch/FetchEntry.scala b/src/main/scala/lltriscv/core/fetch/FetchEntry.scala index a5161c2..a0d282c 100644 --- a/src/main/scala/lltriscv/core/fetch/FetchEntry.scala +++ b/src/main/scala/lltriscv/core/fetch/FetchEntry.scala @@ -21,7 +21,7 @@ class ITLBWorkEntry extends Bundle { class ICacheLineWorkEntry extends Bundle { val content = Vec(8, UInt(16.W)) - val pc = DataType.address + val address = DataType.address val error = MemoryErrorCode() val valid = Bool() } diff --git a/src/main/scala/lltriscv/utils/CoreUtils.scala b/src/main/scala/lltriscv/utils/CoreUtils.scala index 8f274be..567fb13 100644 --- a/src/main/scala/lltriscv/utils/CoreUtils.scala +++ b/src/main/scala/lltriscv/utils/CoreUtils.scala @@ -68,4 +68,15 @@ object CoreUtils { output.receipt := source.data } } + + def Vec2[T <: Data](data: T) = Vec(2, data) + def VecInit2[T <: Data](data: T) = VecInit.fill(2)(data) + + def isCompressInstruction(instruction: UInt) = instruction(1, 0) =/= "b11".U +} + +object Sv32 { + def getVPN(vaddr: UInt) = vaddr(31, 12) + def getOffset(vaddr: UInt) = vaddr(11, 0) + def get32PPN(paddr: UInt) = paddr(31, 12) } diff --git a/src/test/scala/lltriscv/test/riscvtests/RV32PTest.scala b/src/test/scala/lltriscv/test/riscvtests/RV32PTest.scala index 0b5caae..384e0ec 100644 --- a/src/test/scala/lltriscv/test/riscvtests/RV32PTest.scala +++ b/src/test/scala/lltriscv/test/riscvtests/RV32PTest.scala @@ -41,7 +41,6 @@ class RV32PTest extends AnyFlatSpec with ChiselScalatestTester { private val ucTests = new File("riscv-tests/isa").listFiles().filter(_.getName().matches(raw"rv32uc-p-.*\.bin")) private val uaTests = new File("riscv-tests/isa").listFiles().filter(_.getName().matches(raw"rv32ua-p-.*\.bin")) private val umTests = new File("riscv-tests/isa").listFiles().filter(_.getName().matches(raw"rv32um-p-.*\.bin")) - private val needToTest = uiTests ++ ucTests ++ uaTests ++ umTests private def expectPass(memory: MemoryMock) = {