Skip to content

Commit 1fdd9bf

Browse files
wanda-phiwhitequark
authored andcommitted
lib.enum: add .format() implementation.
1 parent 3c870d6 commit 1fdd9bf

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

amaranth/lib/enum.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import warnings
33
import operator
44

5-
from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning
6-
from ..hdl._repr import *
5+
from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning, Format
6+
from ..hdl._repr import Repr, FormatEnum
77

88

99
__all__ = py_enum.__all__ + ["EnumView", "FlagView"]
@@ -176,6 +176,11 @@ def const(cls, init):
176176
def from_bits(cls, bits):
177177
return cls(bits)
178178

179+
def format(cls, value, format_spec):
180+
if format_spec != "":
181+
raise ValueError(f"Format specifier {format_spec!r} is not supported for enums")
182+
return Format.Enum(value, cls)
183+
179184
def _value_repr(cls, value):
180185
yield Repr(FormatEnum(cls), value)
181186

tests/test_sim.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from amaranth.sim import *
1616
from amaranth.lib.memory import Memory
1717
from amaranth.lib.data import View, StructLayout
18+
from amaranth.lib import enum
1819

1920
from .utils import *
2021
from amaranth._utils import _ignore_deprecated
@@ -1165,7 +1166,7 @@ def process():
11651166
Counter: 009
11661167
"""))
11671168

1168-
def test_print(self):
1169+
def test_print_str(self):
11691170
def enc(s):
11701171
return Cat(
11711172
Const(b, 8)
@@ -1196,6 +1197,38 @@ def process():
11961197
Counter: non-zero
11971198
"""))
11981199

1200+
def test_print_enum(self):
1201+
class MyEnum(enum.Enum, shape=unsigned(2)):
1202+
A = 0
1203+
B = 1
1204+
CDE = 2
1205+
1206+
sig = Signal(MyEnum)
1207+
ctr = Signal(2)
1208+
m = Module()
1209+
m.d.comb += sig.eq(ctr)
1210+
m.d.sync += [
1211+
Print(sig),
1212+
ctr.eq(ctr + 1),
1213+
]
1214+
output = StringIO()
1215+
with redirect_stdout(output):
1216+
with self.assertSimulation(m) as sim:
1217+
sim.add_clock(1e-6, domain="sync")
1218+
def process():
1219+
yield Tick()
1220+
yield Tick()
1221+
yield Tick()
1222+
yield Tick()
1223+
sim.add_testbench(process)
1224+
self.assertEqual(output.getvalue(), dedent("""\
1225+
A
1226+
B
1227+
CDE
1228+
[unknown]
1229+
"""))
1230+
1231+
11991232
def test_assert(self):
12001233
m = Module()
12011234
ctr = Signal(16)

0 commit comments

Comments
 (0)