From 22d237d38c342ccd64b96826d70700cf3fbd46c1 Mon Sep 17 00:00:00 2001 From: Wanda Date: Thu, 11 Apr 2024 12:48:21 +0200 Subject: [PATCH] lib.enum: add `.format()` implementation. --- amaranth/lib/enum.py | 9 +++++++-- tests/test_sim.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index f9572220d..80267603e 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -2,8 +2,8 @@ import warnings import operator -from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning -from ..hdl._repr import * +from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning, Format +from ..hdl._repr import Repr, FormatEnum __all__ = py_enum.__all__ + ["EnumView", "FlagView"] @@ -176,6 +176,11 @@ def const(cls, init): def from_bits(cls, bits): return cls(bits) + def format(cls, value, format_spec): + if format_spec != "": + raise ValueError(f"Format specifier {format_spec!r} is not supported for enums") + return Format.Enum(value, cls) + def _value_repr(cls, value): yield Repr(FormatEnum(cls), value) diff --git a/tests/test_sim.py b/tests/test_sim.py index 6b9affdc2..3bb52b6b8 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -15,6 +15,7 @@ from amaranth.sim import * from amaranth.lib.memory import Memory from amaranth.lib.data import View, StructLayout +from amaranth.lib import enum from .utils import * from amaranth._utils import _ignore_deprecated @@ -1165,7 +1166,7 @@ def process(): Counter: 009 """)) - def test_print(self): + def test_print_str(self): def enc(s): return Cat( Const(b, 8) @@ -1196,6 +1197,38 @@ def process(): Counter: non-zero """)) + def test_print_enum(self): + class MyEnum(enum.Enum, shape=unsigned(2)): + A = 0 + B = 1 + CDE = 2 + + sig = Signal(MyEnum) + ctr = Signal(2) + m = Module() + m.d.comb += sig.eq(ctr) + m.d.sync += [ + Print(sig), + ctr.eq(ctr + 1), + ] + output = StringIO() + with redirect_stdout(output): + with self.assertSimulation(m) as sim: + sim.add_clock(1e-6, domain="sync") + def process(): + yield Tick() + yield Tick() + yield Tick() + yield Tick() + sim.add_testbench(process) + self.assertEqual(output.getvalue(), dedent("""\ + A + B + CDE + [unknown] + """)) + + def test_assert(self): m = Module() ctr = Signal(16)