diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index f884be56b..9aaa6c802 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -831,7 +831,11 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): if isinstance(fragment, mem.MemoryInstance): memory = fragment.memory - init = "".join(format(ast.Const(elem, ast.unsigned(memory.width)).value, f"0{memory.width}b") for elem in reversed(memory.init)) + if isinstance(memory.shape, ast.ShapeCastable): + cast = lambda elem: ast.Value.cast(memory.shape.const(elem)).value + else: + cast = lambda elem: elem + init = "".join(format(ast.Const(cast(elem), ast.unsigned(memory.width)).value, f"0{memory.width}b") for elem in reversed(memory.init)) init = ast.Const(int(init or "0", 2), memory.depth * memory.width) rd_clk = [] rd_clk_enable = 0 diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index 76643a0b2..5db503964 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -1929,6 +1929,8 @@ class ValueSet(_MappedKeySet): class SignalKey: def __init__(self, signal): + if isinstance(signal, ValueCastable): + signal = Value.cast(signal) self.signal = signal if isinstance(signal, Signal): self._intern = (0, signal.duid) diff --git a/amaranth/hdl/mem.py b/amaranth/hdl/mem.py index 61b8bc0f6..7e8ea4b2c 100644 --- a/amaranth/hdl/mem.py +++ b/amaranth/hdl/mem.py @@ -1,4 +1,5 @@ import operator +import warnings from collections import OrderedDict from .. import tracer @@ -14,53 +15,84 @@ class Memory(Elaboratable): Parameters ---------- + shape : ShapeLike + Shape of each storage element of this memory. width : int Access granularity. Each storage element of this memory is ``width`` bits in size. + Deprecated in favor of ``shape``. depth : int Word count. This memory contains ``depth`` storage elements. - init : list of int + init : list Initial values. At power on, each storage element in this memory is initialized to the corresponding element of ``init``, if any, or to zero otherwise. Uninitialized memories are not currently supported. name : str Name hint for this memory. If ``None`` (default) the name is inferred from the variable - name this ``Signal`` is assigned to. + name this ``Memory`` is assigned to. attrs : dict Dictionary of synthesis attributes. Attributes ---------- + shape : ShapeLike width : int depth : int - init : list of int + init : list attrs : dict """ - def __init__(self, *, width, depth, init=None, name=None, attrs=None, simulate=True): - if not isinstance(width, int) or width < 0: - raise TypeError("Memory width must be a non-negative integer, not {!r}" - .format(width)) + def __init__(self, *, width=None, shape=None, depth, init=None, name=None, attrs=None, simulate=True): + if shape is None and width is None: + raise TypeError("Memory shape must be specified") + if shape is not None and width is not None: + raise TypeError("Memory shape and memory width cannot both be specified") + # TODO(amaranth-0.5): remove + if width is not None: + if not isinstance(width, int) or width < 0: + raise TypeError("Memory width must be a non-negative integer, not {!r}" + .format(width)) + warnings.warn("The `width` argument is deprecated; use `shape` instead", + DeprecationWarning) + shape = width + if not isinstance(shape, ShapeLike): + raise TypeError("Memory shape must be shape-castable, not {!r}" + .format(shape)) if not isinstance(depth, int) or depth < 0: raise TypeError("Memory depth must be a non-negative integer, not {!r}" .format(depth)) + if not isinstance(shape, ShapeCastable): + shape = Shape.cast(shape) + self.name = name or tracer.get_var_name(depth=2, default="$memory") self.src_loc = tracer.get_src_loc() - self.width = width - self.depth = depth + self._shape = shape + self._depth = depth self.attrs = OrderedDict(() if attrs is None else attrs) # Array of signals for simulation. self._array = Array() if simulate: for addr in range(self.depth): - self._array.append(Signal(self.width, name="{}({})" + self._array.append(Signal(self.shape, name="{}({})" .format(name or "memory", addr))) self.init = init self._read_ports = [] self._write_ports = [] + @property + def shape(self): + return self._shape + + @property + def width(self): + return Shape.cast(self.shape).width + + @property + def depth(self): + return self._depth + @property def init(self): return self._init @@ -75,7 +107,10 @@ def init(self, new_init): try: for addr in range(len(self._array)): if addr < len(self._init): - self._array[addr].reset = operator.index(self._init[addr]) + if isinstance(self.shape, ShapeCastable): + self._array[addr].reset = self.shape.const(self._init[addr]) + else: + self._array[addr].reset = operator.index(self._init[addr]) else: self._array[addr].reset = 0 except TypeError as e: @@ -186,7 +221,7 @@ class ReadPort(Elaboratable): transparent : bool addr : Signal(range(memory.depth)), in Read address. - data : Signal(memory.width), out + data : Signal(memory.shape), out Read data. en : Signal or Const, in Read enable. If asserted, ``data`` is updated with the word stored at ``addr``. @@ -205,7 +240,7 @@ def __init__(self, memory, *, domain="sync", transparent=True, src_loc_at=0): self.addr = Signal(range(memory.depth), name=f"{memory.name}_r_addr", src_loc_at=1 + src_loc_at) - self.data = Signal(memory.width, + self.data = Signal(memory.shape, name=f"{memory.name}_r_data", src_loc_at=1 + src_loc_at) if self.domain != "comb": self.en = Signal(name=f"{memory.name}_r_en", reset=1, @@ -242,7 +277,7 @@ class WritePort(Elaboratable): granularity : int addr : Signal(range(memory.depth)), in Write address. - data : Signal(memory.width), in + data : Signal(memory.shape), in Write data. en : Signal(memory.width // granularity), in Write enable. Each bit selects a non-overlapping chunk of ``granularity`` bits on the @@ -254,6 +289,8 @@ class WritePort(Elaboratable): divide memory width evenly. """ def __init__(self, memory, *, domain="sync", granularity=None, src_loc_at=0): + if granularity is not None and isinstance(memory.shape, ShapeCastable) or memory.shape.signed: + raise TypeError("Write port granularity can only be specified when the memory shape is an unsigned Shape") if granularity is None: granularity = memory.width if not isinstance(granularity, int) or granularity < 0: @@ -272,7 +309,7 @@ def __init__(self, memory, *, domain="sync", granularity=None, src_loc_at=0): self.addr = Signal(range(memory.depth), name=f"{memory.name}_w_addr", src_loc_at=1 + src_loc_at) - self.data = Signal(memory.width, + self.data = Signal(memory.shape, name=f"{memory.name}_w_data", src_loc_at=1 + src_loc_at) self.en = Signal(memory.width // granularity, name=f"{memory.name}_w_en", src_loc_at=1 + src_loc_at) @@ -293,9 +330,15 @@ class DummyPort: It does not include any read/write port specific attributes, i.e. none besides ``"domain"``; any such attributes may be set manually. """ - def __init__(self, *, data_width, addr_width, domain="sync", name=None, granularity=None): + def __init__(self, *, data_width=None, data_shape=None, addr_width, domain="sync", name=None, granularity=None): self.domain = domain + # TODO(amaranth-0.5): remove + if data_width is not None: + warnings.warn("The `data_width` argument is deprecated; use `data_shape` instead", + DeprecationWarning) + data_shape = data_width + data_width = Shape.cast(data_shape).width if granularity is None: granularity = data_width if name is None: @@ -303,7 +346,7 @@ def __init__(self, *, data_width, addr_width, domain="sync", name=None, granular self.addr = Signal(addr_width, name=f"{name}_addr", src_loc_at=1) - self.data = Signal(data_width, + self.data = Signal(data_shape, name=f"{name}_data", src_loc_at=1) self.en = Signal(data_width // granularity, name=f"{name}_en", src_loc_at=1) diff --git a/amaranth/hdl/xfrm.py b/amaranth/hdl/xfrm.py index 02c640380..673fa11e7 100644 --- a/amaranth/hdl/xfrm.py +++ b/amaranth/hdl/xfrm.py @@ -276,11 +276,11 @@ def map_memory_ports(self, fragment, new_fragment): for port in new_fragment.read_ports: port.en = self.on_value(port.en) port.addr = self.on_value(port.addr) - port.data = self.on_value(port.data) + port.data = self.on_value(Value.cast(port.data)) for port in new_fragment.write_ports: port.en = self.on_value(port.en) port.addr = self.on_value(port.addr) - port.data = self.on_value(port.data) + port.data = self.on_value(Value.cast(port.data)) def on_fragment(self, fragment): if isinstance(fragment, MemoryInstance): @@ -408,13 +408,13 @@ def on_fragment(self, fragment): if isinstance(fragment, MemoryInstance): for port in fragment.read_ports: self.on_value(port.addr) - self.on_value(port.data) + self.on_value(Value.cast(port.data)) self.on_value(port.en) if port.domain != "comb": self._add_used_domain(port.domain) for port in fragment.write_ports: self.on_value(port.addr) - self.on_value(port.data) + self.on_value(Value.cast(port.data)) self.on_value(port.en) self._add_used_domain(port.domain)