Skip to content

hdl._ast: add Format.Enum, Format.Struct, Format.Array. #1316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 115 additions & 21 deletions amaranth/hdl/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2559,8 +2559,27 @@ def __repr__(self):
return "(initial)"


class _FormatLike:
def _as_format(self) -> "Format":
raise NotImplementedError # :nocov:

def __add__(self, other):
if not isinstance(other, _FormatLike):
return NotImplemented
return Format._from_chunks(self._as_format()._chunks + other._as_format()._chunks)

def __format__(self, format_desc):
"""Forbidden formatting.

``Format`` objects cannot be directly formatted for the same reason as the ``Value``s
they contain.
"""
raise TypeError(f"Format object {self!r} cannot be converted to string. Use `repr` "
f"to print the AST, or pass it to the `Print` statement.")


@final
class Format:
class Format(_FormatLike):
def __init__(self, format, *args, **kwargs):
fmt = string.Formatter()
chunks = []
Expand Down Expand Up @@ -2615,17 +2634,17 @@ def subformat(sub_string):
shape = obj.shape()
if isinstance(shape, ShapeCastable):
fmt = shape.format(obj, format_spec)
if not isinstance(fmt, Format):
if not isinstance(fmt, _FormatLike):
raise TypeError(f"`ShapeCastable.format` must return a 'Format' instance, not {fmt!r}")
chunks += fmt._chunks
chunks += fmt._as_format()._chunks
else:
obj = Value.cast(obj)
self._parse_format_spec(format_spec, obj.shape())
chunks.append((obj, format_spec))
elif isinstance(obj, Format):
elif isinstance(obj, _FormatLike):
if format_spec != "":
raise ValueError(f"Format specifiers ({format_spec!r}) cannot be used for 'Format' objects")
chunks += obj._chunks
chunks += obj._as_format()._chunks
else:
chunks.append(fmt.format_field(obj, format_spec))

Expand All @@ -2638,6 +2657,9 @@ def subformat(sub_string):

self._chunks = self._clean_chunks(chunks)

def _as_format(self):
return self

@classmethod
def _from_chunks(cls, chunks):
res = object.__new__(cls)
Expand Down Expand Up @@ -2671,25 +2693,11 @@ def _to_format_string(self):
format_string.append("{}")
return ("".join(format_string), tuple(args))

def __add__(self, other):
if not isinstance(other, Format):
return NotImplemented
return Format._from_chunks(self._chunks + other._chunks)

def __repr__(self):
format_string, args = self._to_format_string()
args = "".join(f" {arg!r}" for arg in args)
return f"(format {format_string!r}{args})"

def __format__(self, format_desc):
"""Forbidden formatting.

``Format`` objects cannot be directly formatted for the same reason as the ``Value``s
they contain.
"""
raise TypeError(f"Format object {self!r} cannot be converted to string. Use `repr` "
f"to print the AST, or pass it to the `Print` statement.")

_FORMAT_SPEC_PATTERN = re.compile(r"""
(?:
(?P<fill>.)?
Expand Down Expand Up @@ -2760,6 +2768,90 @@ def _rhs_signals(self):
return res


class Enum(_FormatLike):
def __init__(self, value, /, variants):
self._value = Value.cast(value)
if isinstance(variants, EnumMeta):
self._variants = {Const.cast(member.value).value: member.name for member in variants}
else:
self._variants = dict(variants)
for val, name in self._variants.items():
if not isinstance(val, int):
raise TypeError(f"Variant values must be integers, not {val!r}")
if not isinstance(name, str):
raise TypeError(f"Variant names must be strings, not {name!r}")

def _as_format(self):
def str_val(name):
name = name.encode()
return Const(int.from_bytes(name, "little"), len(name) * 8)
value = SwitchValue(self._value, [
(val, str_val(name))
for val, name in self._variants.items()
] + [(None, str_val("[unknown]"))])
return Format("{:s}", value)

def __repr__(self):
variants = "".join(
f" ({val!r} {name!r})"
for val, name in self._variants.items()
)
return f"(format-enum {self._value!r}{variants})"


class Struct(_FormatLike):
def __init__(self, value, /, fields):
self._value = Value.cast(value)
self._fields: dict[str, _FormatLike] = dict(fields)
for name, format in self._fields.items():
if not isinstance(name, str):
raise TypeError(f"Field names must be strings, not {name!r}")
if not isinstance(format, _FormatLike):
raise TypeError(f"Field format must be a 'Format', not {format!r}")

def _as_format(self):
chunks = ["{"]
for idx, (name, field) in enumerate(self._fields.items()):
if idx != 0:
chunks.append(", ")
chunks.append(f"{name}=")
chunks += field._as_format()._chunks
chunks.append("}")
return Format._from_chunks(chunks)

def __repr__(self):
fields = "".join(
f" ({name!r} {field!r})"
for name, field in self._fields.items()
)
return f"(format-struct {self._value!r}{fields})"


class Array(_FormatLike):
def __init__(self, value, /, fields):
self._value = Value.cast(value)
self._fields = list(fields)
for format in self._fields:
if not isinstance(format, (Format, Format.Enum, Format.Struct, Format.Array)):
raise TypeError(f"Field format must be a 'Format', not {format!r}")

def _as_format(self):
chunks = ["["]
for idx, field in enumerate(self._fields):
if idx != 0:
chunks.append(", ")
chunks += field._as_format()._chunks
chunks.append("]")
return Format._from_chunks(chunks)

def __repr__(self):
fields = "".join(
f" {field!r}"
for field in self._fields
)
return f"(format-array {self._value!r}{fields})"


class _StatementList(list):
def __repr__(self):
return "({})".format(" ".join(map(repr, self)))
Expand Down Expand Up @@ -2872,8 +2964,10 @@ def __init__(self, kind, test, message=None, *, src_loc_at=0):
self._test = Value.cast(test)
if isinstance(message, str):
message = Format._from_chunks([message])
if message is not None and not isinstance(message, Format):
raise TypeError(f"Property message must be None, str, or Format, not {message!r}")
if message is not None:
if not isinstance(message, _FormatLike):
raise TypeError(f"Property message must be None, str, or Format, not {message!r}")
message = message._as_format()
self._message = message
del self._MustUse__silence

Expand Down
91 changes: 91 additions & 0 deletions tests/test_hdl_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,97 @@ def test_format_wrong(self):
f"{fmt}"


class FormatEnumTestCase(FHDLTestCase):
def test_construct(self):
a = Signal(3)
fmt = Format.Enum(a, {1: "A", 2: "B", 3: "C"})
self.assertRepr(fmt, "(format-enum (sig a) (1 'A') (2 'B') (3 'C'))")
self.assertRepr(Format("{}", fmt), """
(format '{:s}' (switch-value (sig a)
(case 001 (const 8'd65))
(case 010 (const 8'd66))
(case 011 (const 8'd67))
(default (const 72'd1723507152241428428123))
))
""")

class MyEnum(Enum):
A = 0
B = 3
C = 4

fmt = Format.Enum(a, MyEnum)
self.assertRepr(fmt, "(format-enum (sig a) (0 'A') (3 'B') (4 'C'))")
self.assertRepr(Format("{}", fmt), """
(format '{:s}' (switch-value (sig a)
(case 000 (const 8'd65))
(case 011 (const 8'd66))
(case 100 (const 8'd67))
(default (const 72'd1723507152241428428123))
))
""")

def test_construct_wrong(self):
a = Signal(3)
with self.assertRaisesRegex(TypeError,
r"^Variant values must be integers, not 'a'$"):
Format.Enum(a, {"a": "B"})
with self.assertRaisesRegex(TypeError,
r"^Variant names must be strings, not 123$"):
Format.Enum(a, {1: 123})


class FormatStructTestCase(FHDLTestCase):
def test_construct(self):
sig = Signal(3)
fmt = Format.Struct(sig, {"a": Format("{}", sig[0]), "b": Format("{}", sig[1:3])})
self.assertRepr(fmt, """
(format-struct (sig sig)
('a' (format '{}' (slice (sig sig) 0:1)))
('b' (format '{}' (slice (sig sig) 1:3)))
)
""")
self.assertRepr(Format("{}", fmt), """
(format '{{a={}, b={}}}'
(slice (sig sig) 0:1)
(slice (sig sig) 1:3)
)
""")

def test_construct_wrong(self):
sig = Signal(3)
with self.assertRaisesRegex(TypeError,
r"^Field names must be strings, not 1$"):
Format.Struct(sig, {1: Format("{}", sig[1:3])})
with self.assertRaisesRegex(TypeError,
r"^Field format must be a 'Format', not \(slice \(sig sig\) 1:3\)$"):
Format.Struct(sig, {"a": sig[1:3]})


class FormatArrayTestCase(FHDLTestCase):
def test_construct(self):
sig = Signal(4)
fmt = Format.Array(sig, [Format("{}", sig[0:2]), Format("{}", sig[2:4])])
self.assertRepr(fmt, """
(format-array (sig sig)
(format '{}' (slice (sig sig) 0:2))
(format '{}' (slice (sig sig) 2:4))
)
""")
self.assertRepr(Format("{}", fmt), """
(format '[{}, {}]'
(slice (sig sig) 0:2)
(slice (sig sig) 2:4)
)
""")

def test_construct_wrong(self):
sig = Signal(3)
with self.assertRaisesRegex(TypeError,
r"^Field format must be a 'Format', not \(slice \(sig sig\) 1:3\)$"):
Format.Array(sig, [sig[1:3]])


class PrintTestCase(FHDLTestCase):
def test_construct(self):
a = Signal()
Expand Down