Skip to content

Commit 1444a9e

Browse files
committed
hdl._nir, back.rtlil: use Format.* to emit enum attributes and wires for fields.
1 parent 4cb2dde commit 1444a9e

File tree

7 files changed

+203
-9
lines changed

7 files changed

+203
-9
lines changed

amaranth/back/rtlil.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io
44

55
from ..utils import bits_for
6+
from .._utils import to_binary
67
from ..lib import wiring
78
from ..hdl import _repr, _ast, _ir, _nir
89

@@ -421,6 +422,7 @@ def emit(self):
421422
self.emit_cell_wires()
422423
self.emit_submodule_wires()
423424
self.emit_connects()
425+
self.emit_signal_fields()
424426
self.emit_submodules()
425427
self.emit_cells()
426428

@@ -491,12 +493,12 @@ def emit_signal_wires(self):
491493
attrs.update(signal.attrs)
492494
self.value_src_loc[value] = signal.src_loc
493495

494-
for repr in signal._value_repr:
495-
if repr.path == () and isinstance(repr.format, _repr.FormatEnum):
496-
enum = repr.format.enum
497-
attrs["enum_base_type"] = enum.__name__
498-
for enum_value in enum:
499-
attrs["enum_value_{:0{}b}".format(enum_value.value, len(signal))] = enum_value.name
496+
field = self.netlist.signal_fields[signal][()]
497+
if field.enum_name is not None:
498+
attrs["enum_base_type"] = field.enum_name
499+
if field.enum_variants is not None:
500+
for var_val, var_name in field.enum_variants.items():
501+
attrs["enum_value_" + to_binary(var_val, len(signal))] = var_name
500502

501503
if name in self.module.ports:
502504
port_value, _flow = self.module.ports[name]
@@ -666,6 +668,30 @@ def emit_connects(self):
666668
if name not in self.driven_sigports:
667669
self.builder.connect(wire.name, self.sigspec(value))
668670

671+
def emit_signal_fields(self):
672+
for signal, name in self.module.signal_names.items():
673+
fields = self.netlist.signal_fields[signal]
674+
for path, field in fields.items():
675+
if path == ():
676+
continue
677+
name_parts = [name]
678+
for component in path:
679+
if isinstance(component, str):
680+
name_parts.append(f".{component}")
681+
elif isinstance(component, int):
682+
name_parts.append(f"[{component}]")
683+
else:
684+
assert False # :nocov:
685+
attrs = {}
686+
if field.enum_name is not None:
687+
attrs["enum_base_type"] = field.enum_name
688+
if field.enum_variants is not None:
689+
for var_val, var_name in field.enum_variants.items():
690+
attrs["enum_value_" + to_binary(var_val, len(field.value))] = var_name
691+
wire = self.builder.wire(width=len(field.value), signed=field.signed, attrs=attrs,
692+
name="".join(name_parts), src_loc=signal.src_loc)
693+
self.builder.connect(wire.name, self.sigspec(field.value))
694+
669695
def emit_submodules(self):
670696
for submodule_idx in self.module.submodules:
671697
submodule = self.netlist.modules[submodule_idx]

amaranth/hdl/_ast.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,6 +2095,11 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F
20952095

20962096
self._attrs = OrderedDict(() if attrs is None else attrs)
20972097

2098+
if isinstance(orig_shape, ShapeCastable):
2099+
self._format = orig_shape.format(orig_shape(self), "")
2100+
else:
2101+
self._format = Format("{}", self)
2102+
20982103
if decoder is not None:
20992104
# The value representation is specified explicitly. Since we do not expose `hdl._repr`,
21002105
# this is the only way to add a custom filter to the signal right now.

amaranth/hdl/_ir.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ def emit_value(self, builder):
690690

691691

692692
class NetlistEmitter:
693-
def __init__(self, netlist: _nir.Netlist, design, *, all_undef_to_ff=False):
693+
def __init__(self, netlist: _nir.Netlist, design: Design, *, all_undef_to_ff=False):
694694
self.netlist = netlist
695695
self.design = design
696696
self.all_undef_to_ff = all_undef_to_ff
@@ -776,7 +776,7 @@ def extend(self, value: _nir.Value, signed: bool, width: int):
776776
def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src_loc):
777777
op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc)
778778
return self.netlist.add_value_cell(op.width, op)
779-
779+
780780
def emit_matches(self, module_idx: int, value: _nir.Value, patterns, *, src_loc):
781781
key = module_idx, value, patterns, src_loc
782782
try:
@@ -1334,6 +1334,42 @@ def emit_top_ports(self, fragment: _ir.Fragment):
13341334
else:
13351335
raise ValueError(f"Invalid port direction {dir!r}")
13361336

1337+
def emit_signal_fields(self):
1338+
for signal, fragment in self.design.signal_lca.items():
1339+
module_idx = self.fragment_module_idx[fragment]
1340+
fields = {}
1341+
def emit_format(path, fmt):
1342+
if isinstance(fmt, _ast.Format):
1343+
specs = [
1344+
chunk[0]
1345+
for chunk in fmt._chunks
1346+
if not isinstance(chunk, str)
1347+
]
1348+
if len(specs) != 1:
1349+
return
1350+
val, signed = self.emit_rhs(module_idx, specs[0])
1351+
fields[path] = _nir.SignalField(val, signed=signed)
1352+
elif isinstance(fmt, _ast.Format.Enum):
1353+
val, signed = self.emit_rhs(module_idx, fmt._value)
1354+
fields[path] = _nir.SignalField(val, signed=signed,
1355+
enum_name=fmt._name,
1356+
enum_variants=fmt._variants)
1357+
elif isinstance(fmt, _ast.Format.Struct):
1358+
val, signed = self.emit_rhs(module_idx, fmt._value)
1359+
fields[path] = _nir.SignalField(val, signed=signed)
1360+
for name, subfmt in fmt._fields.items():
1361+
emit_format(path + (name,), subfmt)
1362+
elif isinstance(fmt, _ast.Format.Array):
1363+
val, signed = self.emit_rhs(module_idx, fmt._value)
1364+
fields[path] = _nir.SignalField(val, signed=signed)
1365+
for idx, subfmt in enumerate(fmt._fields):
1366+
emit_format(path + (idx,), subfmt)
1367+
emit_format((), signal._format)
1368+
val, signed = self.emit_rhs(module_idx, signal)
1369+
if () not in fields or fields[()].value != val:
1370+
fields[()] = _nir.SignalField(val, signed=signed)
1371+
self.netlist.signal_fields[signal] = fields
1372+
13371373
def emit_drivers(self):
13381374
for driver in self.drivers.values():
13391375
if (driver.domain is not None and
@@ -1452,6 +1488,7 @@ def emit_fragment(self, fragment: _ir.Fragment, parent_module_idx: 'int | None',
14521488
for subfragment, _name, sub_src_loc in fragment.subfragments:
14531489
self.emit_fragment(subfragment, module_idx, cell_src_loc=sub_src_loc)
14541490
if parent_module_idx is None:
1491+
self.emit_signal_fields()
14551492
self.emit_drivers()
14561493
self.emit_top_ports(fragment)
14571494
if self.all_undef_to_ff:

amaranth/hdl/_nir.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,22 @@ def resolve_nets(self, netlist: "Netlist"):
289289
chunk.value = netlist.resolve_value(chunk.value)
290290

291291

292+
class SignalField:
293+
"""Describes a single field of a signal."""
294+
def __init__(self, value, *, signed, enum_name=None, enum_variants=None):
295+
self.value = Value(value)
296+
self.signed = bool(signed)
297+
self.enum_name = enum_name
298+
self.enum_variants = enum_variants
299+
300+
def __eq__(self, other):
301+
return (type(self) is type(other) and
302+
self.value == other.value and
303+
self.signed == other.signed and
304+
self.enum_name == other.enum_name and
305+
self.enum_variants == other.enum_variants)
306+
307+
292308
class Netlist:
293309
"""A fine netlist. Consists of:
294310
@@ -321,6 +337,7 @@ class Netlist:
321337
connections : dict of (negative) int to int
322338
io_ports : list of ``IOPort``
323339
signals : dict of Signal to ``Value``
340+
signal_fields: dict of Signal to dict of tuple[str | int] to SignalField
324341
last_late_net: int
325342
"""
326343
def __init__(self):
@@ -329,6 +346,7 @@ def __init__(self):
329346
self.connections: dict[Net, Net] = {}
330347
self.io_ports: list[_ast.IOPort] = []
331348
self.signals = SignalDict()
349+
self.signal_fields = SignalDict()
332350
self.last_late_net = 0
333351

334352
def resolve_net(self, net: Net):
@@ -345,6 +363,9 @@ def resolve_all_nets(self):
345363
cell.resolve_nets(self)
346364
for sig in self.signals:
347365
self.signals[sig] = self.resolve_value(self.signals[sig])
366+
for fields in self.signal_fields.values():
367+
for field in fields.values():
368+
field.value = self.resolve_value(field.value)
348369

349370
def __repr__(self):
350371
result = ["("]

tests/test_back_rtlil.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from amaranth.back import rtlil
55
from amaranth.hdl import *
66
from amaranth.hdl._ast import *
7-
from amaranth.lib import memory, wiring
7+
from amaranth.lib import memory, wiring, data, enum
88

99
from .utils import *
1010

@@ -2010,6 +2010,67 @@ def test_print_align(self):
20102010
""")
20112011

20122012

2013+
class DetailTestCase(RTLILTestCase):
2014+
def test_enum(self):
2015+
class MyEnum(enum.Enum, shape=unsigned(2)):
2016+
A = 0
2017+
B = 1
2018+
C = 2
2019+
2020+
sig = Signal(MyEnum)
2021+
m = Module()
2022+
m.d.comb += sig.eq(MyEnum.A)
2023+
self.assertRTLIL(m, [sig.as_value()], R"""
2024+
attribute \generator "Amaranth"
2025+
attribute \top 1
2026+
module \top
2027+
attribute \enum_base_type "MyEnum"
2028+
attribute \enum_value_00 "A"
2029+
attribute \enum_value_01 "B"
2030+
attribute \enum_value_10 "C"
2031+
wire width 2 output 0 \sig
2032+
connect \sig 2'00
2033+
end
2034+
""")
2035+
2036+
def test_struct(self):
2037+
class MyEnum(enum.Enum, shape=unsigned(2)):
2038+
A = 0
2039+
B = 1
2040+
C = 2
2041+
2042+
class Meow(data.Struct):
2043+
a: MyEnum
2044+
b: 3
2045+
c: signed(4)
2046+
d: data.ArrayLayout(2, 2)
2047+
2048+
sig = Signal(Meow)
2049+
m = Module()
2050+
self.assertRTLIL(m, [sig.as_value()], R"""
2051+
attribute \generator "Amaranth"
2052+
attribute \top 1
2053+
module \top
2054+
wire width 13 input 0 \sig
2055+
attribute \enum_base_type "MyEnum"
2056+
attribute \enum_value_00 "A"
2057+
attribute \enum_value_01 "B"
2058+
attribute \enum_value_10 "C"
2059+
wire width 2 \sig.a
2060+
wire width 3 \sig.b
2061+
wire width 4 signed \sig.c
2062+
wire width 4 \sig.d
2063+
wire width 2 \sig.d[0]
2064+
wire width 2 \sig.d[1]
2065+
connect \sig.a \sig [1:0]
2066+
connect \sig.b \sig [4:2]
2067+
connect \sig.c \sig [8:5]
2068+
connect \sig.d \sig [12:9]
2069+
connect \sig.d[0] \sig [10:9]
2070+
connect \sig.d[1] \sig [12:11]
2071+
end
2072+
""")
2073+
20132074
class ComponentTestCase(RTLILTestCase):
20142075
def test_component(self):
20152076
class MyComponent(wiring.Component):

tests/test_hdl_ir.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from amaranth.hdl._dsl import *
88
from amaranth.hdl._ir import *
99
from amaranth.hdl._mem import *
10+
from amaranth.hdl._nir import SignalField
11+
12+
from amaranth.lib import enum, data
1013

1114
from .utils import *
1215

@@ -3501,3 +3504,41 @@ def test_undef_to_ff_partial(self):
35013504
(cell 3 0 (flipflop 3.0:5 10 pos 0 0))
35023505
)
35033506
""")
3507+
3508+
3509+
class FieldsTestCase(FHDLTestCase):
3510+
def test_fields(self):
3511+
class MyEnum(enum.Enum, shape=unsigned(2)):
3512+
A = 0
3513+
B = 1
3514+
C = 2
3515+
l = data.StructLayout({"a": MyEnum, "b": signed(3)})
3516+
s1 = Signal(l)
3517+
s2 = Signal(MyEnum)
3518+
s3 = Signal(signed(3))
3519+
s4 = Signal(unsigned(4))
3520+
nl = build_netlist(Fragment.get(Module(), None), [
3521+
s1.as_value(), s2.as_value(), s3, s4,
3522+
])
3523+
self.assertEqual(nl.signal_fields[s1.as_value()], {
3524+
(): SignalField(nl.signals[s1.as_value()], signed=False),
3525+
('a',): SignalField(nl.signals[s1.as_value()][0:2], signed=False, enum_name="MyEnum", enum_variants={
3526+
0: "A",
3527+
1: "B",
3528+
2: "C",
3529+
}),
3530+
('b',): SignalField(nl.signals[s1.as_value()][2:5], signed=True)
3531+
})
3532+
self.assertEqual(nl.signal_fields[s2.as_value()], {
3533+
(): SignalField(nl.signals[s2.as_value()], signed=False, enum_name="MyEnum", enum_variants={
3534+
0: "A",
3535+
1: "B",
3536+
2: "C",
3537+
}),
3538+
})
3539+
self.assertEqual(nl.signal_fields[s3], {
3540+
(): SignalField(nl.signals[s3], signed=True),
3541+
})
3542+
self.assertEqual(nl.signal_fields[s4], {
3543+
(): SignalField(nl.signals[s4], signed=False),
3544+
})

tests/test_lib_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,9 @@ def const(self, init):
647647
def from_bits(self, bits):
648648
return bits
649649

650+
def format(self, value, spec):
651+
return Format("")
652+
650653
v = Signal(data.StructLayout({
651654
"f": WrongCastable()
652655
}))

0 commit comments

Comments
 (0)