from amaranth import C, Module, Shape, Signal, unsigned from amaranth.utils import exact_log2 from amaranth.lib.wiring import Component, Out, In, connect, flipped, Signature from amaranth.lib.data import StructLayout, View from amaranth.lib.memory import Memory from amaranth_soc import wishbone def cache_addr_layout(index_len, tag_len): return StructLayout({ "index": unsigned(index_len), "tag": unsigned(tag_len) }) def tag_data_layout(tag_len, num_ways): if num_ways == 2: return StructLayout({ "tag": unsigned(tag_len), "last_way": unsigned(1), "valid": unsigned(1) }) else: return StructLayout({ "tag": unsigned(tag_len), "valid": unsigned(1) }) class WishboneMinimalICache(Component): def __init__(self, *, addr_width, data_width, granularity, num_entries_per_way, num_ways=1): if num_ways not in (1, 2): raise ValueError("Number of ways must be 1 or 2.") self.addr_width = addr_width self.data_width = data_width self.granularity = granularity self.num_ways = num_ways self.index_len = exact_log2(num_entries_per_way) self.tag_len = self.addr_width - self.index_len self.sel_mask = C(-1, Shape(self.data_width // self.granularity, signed=False)) self.data_0 = Memory(shape=self.data_width, depth=2**self.index_len, init=[]) self.tags_0 = Memory(shape=tag_data_layout(self.tag_len, num_ways), depth=2**self.index_len, init=[]) if self.num_ways == 2: self.data_1 = Memory(shape=self.data_width, depth=2**self.index_len, init=[]) self.tags_1 = Memory(shape=tag_data_layout(self.tag_len, num_ways), depth=2**self.index_len, init=[]) sig = { "cpu": In(wishbone.Signature(addr_width=addr_width, data_width=data_width, granularity=granularity)), "en": In(1), "inval": In(Signature({ "req": Out(1), "resp": In(1) })), "system": Out(wishbone.Signature(addr_width=addr_width, data_width=data_width, granularity=granularity)) } super().__init__(sig) def elaborate(self, plat): m = Module() m.submodules.data_0 = self.data_0 m.submodules.tags_0 = self.tags_0 data_0_port_w = self.data_0.write_port() tags_0_port_w = self.tags_0.write_port() data_0_port_r = self.data_0.read_port(transparent_for=(data_0_port_w,)) tags_0_port_r = self.tags_0.read_port(transparent_for=(tags_0_port_w,)) if self.num_ways == 2: m.submodules.data_1 = self.data_1 m.submodules.tags_1 = self.tags_1 data_1_port_w = self.data_1.write_port() data_1_port_r = self.data_1.read_port( transparent_for=(data_1_port_w,)) tags_1_port_w = self.tags_1.write_port() tags_1_port_r = self.tags_1.read_port( transparent_for=(tags_1_port_w,)) cpu_uncached_path = wishbone.Interface(addr_width=self.addr_width, data_width=self.data_width, granularity=self.granularity) cpu_cached_path = wishbone.Interface(addr_width=self.addr_width, data_width=self.data_width, granularity=self.granularity) connect(m, flipped(cpu_uncached_path), flipped(self.cpu)) connect(m, flipped(cpu_cached_path), flipped(self.cpu)) with m.If(self.en): m.d.comb += [ cpu_uncached_path.cyc.eq(0), cpu_uncached_path.stb.eq(0), cpu_cached_path.cyc.eq(self.cpu.cyc), cpu_cached_path.stb.eq(self.cpu.stb), self.cpu.ack.eq(cpu_cached_path.ack), self.cpu.dat_r.eq(cpu_cached_path.dat_r), ] with m.Else(): m.d.comb += [ cpu_uncached_path.cyc.eq(self.cpu.cyc), cpu_uncached_path.stb.eq(self.cpu.stb), cpu_cached_path.cyc.eq(0), cpu_cached_path.stb.eq(0), self.cpu.ack.eq(cpu_uncached_path.ack), self.cpu.dat_r.eq(cpu_uncached_path.dat_r), ] connect(m, flipped(self.system), cpu_uncached_path) cache_in = View(cache_addr_layout(self.index_len, self.tag_len), cpu_cached_path.adr) tag_0_in = View(tag_data_layout(self.tag_len, self.num_ways), tags_0_port_w.data) tag_0_out = View(tag_data_layout(self.tag_len, self.num_ways), tags_0_port_r.data) if self.num_ways == 2: tag_1_in = View(tag_data_layout(self.tag_len, self.num_ways), tags_1_port_w.data) tag_1_out = View(tag_data_layout(self.tag_len, self.num_ways), tags_1_port_r.data) # Optimization for one-way case that also works for two-way case. # Hopefully the synthesizer optimizes it well in the two-way case. m.d.comb += cpu_cached_path.dat_r.eq(data_0_port_r.data) m.d.comb += [ data_0_port_r.en.eq(0), tags_0_port_r.en.eq(0), ] if self.num_ways == 2: m.d.comb += [ data_1_port_r.en.eq(0), tags_1_port_r.en.eq(0), ] curr_line = Signal(self.index_len) if self.num_ways == 2: curr_way = Signal(1) with m.FSM(init="FLUSH"): with m.State("IDLE"): with m.If(self.inval.req == 1): m.next = "FLUSH" with m.Elif(cpu_cached_path.cyc & cpu_cached_path.stb & ~cpu_cached_path.we): m.d.comb += [ data_0_port_r.addr.eq(cache_in.index), tags_0_port_r.addr.eq(cache_in.index), data_0_port_r.en.eq(1), tags_0_port_r.en.eq(1), ] if self.num_ways == 2: m.d.comb += [ data_1_port_r.addr.eq(cache_in.index), tags_1_port_r.addr.eq(cache_in.index), data_1_port_r.en.eq(1), tags_1_port_r.en.eq(1), ] m.next = "CHECK" with m.State("CHECK"): with m.If((tag_0_out.tag == cache_in.tag) & tag_0_out.valid): m.d.comb += cpu_cached_path.ack.eq(1) if self.num_ways == 2: m.d.comb += [ # Mark way 0 as last used. Use the previously-read # data to set/clear the last_way bit. tags_0_port_w.addr.eq(cache_in.index), tags_1_port_w.addr.eq(cache_in.index), tags_0_port_w.data.eq(tags_0_port_r.data), tags_1_port_w.data.eq(tags_1_port_r.data), tag_0_in.last_way.eq(1), tag_1_in.last_way.eq(0), tags_0_port_w.en.eq(1), tags_1_port_w.en.eq(1), ] m.next = "IDLE" if self.num_ways == 2: with m.Elif((tag_1_out.tag == cache_in.tag) & tag_1_out.valid): m.d.comb += [ # Override the default "just always present way 0 # on output." cpu_cached_path.dat_r.eq(data_1_port_r.data), cpu_cached_path.ack.eq(1), # Mark way 1 as last used. tags_0_port_w.addr.eq(cache_in.index), tags_1_port_w.addr.eq(cache_in.index), tags_0_port_w.data.eq(tags_0_port_r.data), tags_1_port_w.data.eq(tags_1_port_r.data), tag_0_in.last_way.eq(0), tag_1_in.last_way.eq(1), tags_0_port_w.en.eq(1), tags_1_port_w.en.eq(1), ] m.next = "IDLE" with m.Else(): # Get a head-start on fetching. connect(m, flipped(self.system), cpu_cached_path) # And figure out which way to fill. if self.num_ways == 2: # If both are valid with m.If(tag_0_out.valid & tag_1_out.valid): with m.If(~tag_0_out.last_way): m.d.sync += curr_way.eq(0) with m.Elif(~tag_1_out.last_way): m.d.sync += curr_way.eq(1) with m.Else(): # Just do something... this shouldn't happen. m.d.sync += curr_way.eq(~curr_way) with m.Elif(~tag_0_out.valid): m.d.sync += curr_way.eq(0) with m.Else(): m.d.sync += curr_way.eq(1) m.next = "FETCH" with m.State("FETCH"): connect(m, flipped(self.system), cpu_cached_path) with m.If(cpu_cached_path.cyc & cpu_cached_path.stb & cpu_cached_path.ack): # For simplicity of write path, we can only cache full # words. with m.If(cpu_cached_path.sel == self.sel_mask): if self.num_ways == 1: m.d.comb += [ data_0_port_w.addr.eq(cache_in.index), tags_0_port_w.addr.eq(cache_in.index), data_0_port_w.en.eq(1), tags_0_port_w.en.eq(1), tag_0_in.valid.eq(1), tag_0_in.tag.eq(cache_in.tag), data_0_port_w.data.eq(cpu_cached_path.dat_r) ] elif self.num_ways == 2: with m.If(curr_way == 0): m.d.comb += [ data_0_port_w.addr.eq(cache_in.index), tags_0_port_w.addr.eq(cache_in.index), tags_1_port_w.addr.eq(cache_in.index), data_0_port_w.en.eq(1), tags_0_port_w.en.eq(1), tags_1_port_w.en.eq(1), tag_0_in.valid.eq(1), tag_0_in.last_way.eq(1), tag_0_in.tag.eq(cache_in.tag), data_0_port_w.data.eq( cpu_cached_path.dat_r), tags_1_port_w.data.eq(tags_1_port_r.data), tag_1_in.last_way.eq(0), ] with m.Elif(curr_way == 1): m.d.comb += [ data_1_port_w.addr.eq(cache_in.index), tags_0_port_w.addr.eq(cache_in.index), tags_1_port_w.addr.eq(cache_in.index), data_1_port_w.en.eq(1), tags_0_port_w.en.eq(1), tags_1_port_w.en.eq(1), tag_1_in.valid.eq(1), tag_1_in.last_way.eq(1), tag_1_in.tag.eq(cache_in.tag), data_1_port_w.data.eq( cpu_cached_path.dat_r), tags_0_port_w.data.eq(tags_0_port_r.data), tag_0_in.last_way.eq(0), ] # Even if we couldn't cache it, go back to idle. m.next = "IDLE" with m.State("FLUSH"): m.d.sync += curr_line.eq(curr_line + 1) m.d.comb += [ tag_0_in.valid.eq(0), tags_0_port_w.addr.eq(curr_line), tags_0_port_w.en.eq(1), ] if self.num_ways == 2: m.d.comb += [ tag_1_in.valid.eq(0), tags_1_port_w.addr.eq(curr_line), tags_1_port_w.en.eq(1), ] with m.If((curr_line + 1)[0:self.index_len] == 0): m.d.comb += self.inval.resp.eq(1) m.next = "IDLE" return m if __name__ == "__main__": from amaranth.back import verilog print(verilog.convert(WishboneMinimalICache(addr_width=30, data_width=32, granularity=8, num_entries_per_way=256, num_ways=2))) # print(verilog.convert(WishboneMinimalICache(addr_width=15, # data_width=16, # granularity=8, # num_entries_per_way=256)))