From 6807d3932a24e43eb1093cb0910f81ed1d110854 Mon Sep 17 00:00:00 2001 From: Wanda Date: Thu, 11 Apr 2024 05:58:04 +0200 Subject: [PATCH] hdl._ast: add `Format.Enum`, `Format.Struct`, `Format.Array`. --- amaranth/hdl/_ast.py | 136 +++++++++++++++++++++++++++++++++++------- tests/test_hdl_ast.py | 91 ++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+), 21 deletions(-) diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 77b3a05bb..4b56d079c 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -2559,8 +2559,27 @@ def __repr__(self): return "(initial)" +class _FormatLike: + def _as_format(self) -> "Format": + raise NotImplementedError # :nocov: + + def __add__(self, other): + if not isinstance(other, _FormatLike): + return NotImplemented + return Format._from_chunks(self._as_format()._chunks + other._as_format()._chunks) + + def __format__(self, format_desc): + """Forbidden formatting. + + ``Format`` objects cannot be directly formatted for the same reason as the ``Value``s + they contain. + """ + raise TypeError(f"Format object {self!r} cannot be converted to string. Use `repr` " + f"to print the AST, or pass it to the `Print` statement.") + + @final -class Format: +class Format(_FormatLike): def __init__(self, format, *args, **kwargs): fmt = string.Formatter() chunks = [] @@ -2615,17 +2634,17 @@ def subformat(sub_string): shape = obj.shape() if isinstance(shape, ShapeCastable): fmt = shape.format(obj, format_spec) - if not isinstance(fmt, Format): + if not isinstance(fmt, _FormatLike): raise TypeError(f"`ShapeCastable.format` must return a 'Format' instance, not {fmt!r}") - chunks += fmt._chunks + chunks += fmt._as_format()._chunks else: obj = Value.cast(obj) self._parse_format_spec(format_spec, obj.shape()) chunks.append((obj, format_spec)) - elif isinstance(obj, Format): + elif isinstance(obj, _FormatLike): if format_spec != "": raise ValueError(f"Format specifiers ({format_spec!r}) cannot be used for 'Format' objects") - chunks += obj._chunks + chunks += obj._as_format()._chunks else: chunks.append(fmt.format_field(obj, format_spec)) @@ -2638,6 +2657,9 @@ def subformat(sub_string): self._chunks = self._clean_chunks(chunks) + def _as_format(self): + return self + @classmethod def _from_chunks(cls, chunks): res = object.__new__(cls) @@ -2671,25 +2693,11 @@ def _to_format_string(self): format_string.append("{}") return ("".join(format_string), tuple(args)) - def __add__(self, other): - if not isinstance(other, Format): - return NotImplemented - return Format._from_chunks(self._chunks + other._chunks) - def __repr__(self): format_string, args = self._to_format_string() args = "".join(f" {arg!r}" for arg in args) return f"(format {format_string!r}{args})" - def __format__(self, format_desc): - """Forbidden formatting. - - ``Format`` objects cannot be directly formatted for the same reason as the ``Value``s - they contain. - """ - raise TypeError(f"Format object {self!r} cannot be converted to string. Use `repr` " - f"to print the AST, or pass it to the `Print` statement.") - _FORMAT_SPEC_PATTERN = re.compile(r""" (?: (?P.)? @@ -2760,6 +2768,90 @@ def _rhs_signals(self): return res + class Enum(_FormatLike): + def __init__(self, value, /, variants): + self._value = Value.cast(value) + if isinstance(variants, EnumMeta): + self._variants = {Const.cast(member.value).value: member.name for member in variants} + else: + self._variants = dict(variants) + for val, name in self._variants.items(): + if not isinstance(val, int): + raise TypeError(f"Variant values must be integers, not {val!r}") + if not isinstance(name, str): + raise TypeError(f"Variant names must be strings, not {name!r}") + + def _as_format(self): + def str_val(name): + name = name.encode() + return Const(int.from_bytes(name, "little"), len(name) * 8) + value = SwitchValue(self._value, [ + (val, str_val(name)) + for val, name in self._variants.items() + ] + [(None, str_val("[unknown]"))]) + return Format("{:s}", value) + + def __repr__(self): + variants = "".join( + f" ({val!r} {name!r})" + for val, name in self._variants.items() + ) + return f"(format-enum {self._value!r}{variants})" + + + class Struct(_FormatLike): + def __init__(self, value, /, fields): + self._value = Value.cast(value) + self._fields: dict[str, _FormatLike] = dict(fields) + for name, format in self._fields.items(): + if not isinstance(name, str): + raise TypeError(f"Field names must be strings, not {name!r}") + if not isinstance(format, _FormatLike): + raise TypeError(f"Field format must be a 'Format', not {format!r}") + + def _as_format(self): + chunks = ["{"] + for idx, (name, field) in enumerate(self._fields.items()): + if idx != 0: + chunks.append(", ") + chunks.append(f"{name}=") + chunks += field._as_format()._chunks + chunks.append("}") + return Format._from_chunks(chunks) + + def __repr__(self): + fields = "".join( + f" ({name!r} {field!r})" + for name, field in self._fields.items() + ) + return f"(format-struct {self._value!r}{fields})" + + + class Array(_FormatLike): + def __init__(self, value, /, fields): + self._value = Value.cast(value) + self._fields = list(fields) + for format in self._fields: + if not isinstance(format, (Format, Format.Enum, Format.Struct, Format.Array)): + raise TypeError(f"Field format must be a 'Format', not {format!r}") + + def _as_format(self): + chunks = ["["] + for idx, field in enumerate(self._fields): + if idx != 0: + chunks.append(", ") + chunks += field._as_format()._chunks + chunks.append("]") + return Format._from_chunks(chunks) + + def __repr__(self): + fields = "".join( + f" {field!r}" + for field in self._fields + ) + return f"(format-array {self._value!r}{fields})" + + class _StatementList(list): def __repr__(self): return "({})".format(" ".join(map(repr, self))) @@ -2872,8 +2964,10 @@ def __init__(self, kind, test, message=None, *, src_loc_at=0): self._test = Value.cast(test) if isinstance(message, str): message = Format._from_chunks([message]) - if message is not None and not isinstance(message, Format): - raise TypeError(f"Property message must be None, str, or Format, not {message!r}") + if message is not None: + if not isinstance(message, _FormatLike): + raise TypeError(f"Property message must be None, str, or Format, not {message!r}") + message = message._as_format() self._message = message del self._MustUse__silence diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 2b6c3cb86..d39da7ff7 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -1736,6 +1736,97 @@ def test_format_wrong(self): f"{fmt}" +class FormatEnumTestCase(FHDLTestCase): + def test_construct(self): + a = Signal(3) + fmt = Format.Enum(a, {1: "A", 2: "B", 3: "C"}) + self.assertRepr(fmt, "(format-enum (sig a) (1 'A') (2 'B') (3 'C'))") + self.assertRepr(Format("{}", fmt), """ + (format '{:s}' (switch-value (sig a) + (case 001 (const 8'd65)) + (case 010 (const 8'd66)) + (case 011 (const 8'd67)) + (default (const 72'd1723507152241428428123)) + )) + """) + + class MyEnum(Enum): + A = 0 + B = 3 + C = 4 + + fmt = Format.Enum(a, MyEnum) + self.assertRepr(fmt, "(format-enum (sig a) (0 'A') (3 'B') (4 'C'))") + self.assertRepr(Format("{}", fmt), """ + (format '{:s}' (switch-value (sig a) + (case 000 (const 8'd65)) + (case 011 (const 8'd66)) + (case 100 (const 8'd67)) + (default (const 72'd1723507152241428428123)) + )) + """) + + def test_construct_wrong(self): + a = Signal(3) + with self.assertRaisesRegex(TypeError, + r"^Variant values must be integers, not 'a'$"): + Format.Enum(a, {"a": "B"}) + with self.assertRaisesRegex(TypeError, + r"^Variant names must be strings, not 123$"): + Format.Enum(a, {1: 123}) + + +class FormatStructTestCase(FHDLTestCase): + def test_construct(self): + sig = Signal(3) + fmt = Format.Struct(sig, {"a": Format("{}", sig[0]), "b": Format("{}", sig[1:3])}) + self.assertRepr(fmt, """ + (format-struct (sig sig) + ('a' (format '{}' (slice (sig sig) 0:1))) + ('b' (format '{}' (slice (sig sig) 1:3))) + ) + """) + self.assertRepr(Format("{}", fmt), """ + (format '{{a={}, b={}}}' + (slice (sig sig) 0:1) + (slice (sig sig) 1:3) + ) + """) + + def test_construct_wrong(self): + sig = Signal(3) + with self.assertRaisesRegex(TypeError, + r"^Field names must be strings, not 1$"): + Format.Struct(sig, {1: Format("{}", sig[1:3])}) + with self.assertRaisesRegex(TypeError, + r"^Field format must be a 'Format', not \(slice \(sig sig\) 1:3\)$"): + Format.Struct(sig, {"a": sig[1:3]}) + + +class FormatArrayTestCase(FHDLTestCase): + def test_construct(self): + sig = Signal(4) + fmt = Format.Array(sig, [Format("{}", sig[0:2]), Format("{}", sig[2:4])]) + self.assertRepr(fmt, """ + (format-array (sig sig) + (format '{}' (slice (sig sig) 0:2)) + (format '{}' (slice (sig sig) 2:4)) + ) + """) + self.assertRepr(Format("{}", fmt), """ + (format '[{}, {}]' + (slice (sig sig) 0:2) + (slice (sig sig) 2:4) + ) + """) + + def test_construct_wrong(self): + sig = Signal(3) + with self.assertRaisesRegex(TypeError, + r"^Field format must be a 'Format', not \(slice \(sig sig\) 1:3\)$"): + Format.Array(sig, [sig[1:3]]) + + class PrintTestCase(FHDLTestCase): def test_construct(self): a = Signal()