Skip to content

Commit f49fee2

Browse files
committed
Implement RFC 31: Enumeration type safety.
1 parent b0b193f commit f49fee2

File tree

4 files changed

+456
-22
lines changed

4 files changed

+456
-22
lines changed

amaranth/lib/enum.py

+221-18
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import enum as py_enum
22
import warnings
3+
import operator
34

4-
from ..hdl.ast import Value, Shape, ShapeCastable, Const
5+
from ..hdl.ast import Value, ValueCastable, Shape, ShapeCastable, Const
56
from ..hdl._repr import *
67

78

8-
__all__ = py_enum.__all__
9+
__all__ = py_enum.__all__ + ["EnumView", "FlagView"]
910

1011

1112
for _member in py_enum.__all__:
@@ -23,14 +24,18 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
2324
:class:`enum.EnumMeta` class; if the ``shape=`` argument is not specified and
2425
:meth:`as_shape` is never called, it places no restrictions on the enumeration class
2526
or the values of its members.
27+
28+
When a :ref:`value-castable <lang-valuecasting>` is cast to an enum type that is an instance
29+
of this metaclass, it can be automatically wrapped in a view class. A custom view class
30+
can be specified by passing the ``view_class=`` keyword argument when creating the enum class.
2631
"""
2732

2833
# TODO: remove this shim once py3.8 support is dropped
2934
@classmethod
30-
def __prepare__(metacls, name, bases, shape=None, **kwargs):
35+
def __prepare__(metacls, name, bases, shape=None, view_class=None, **kwargs):
3136
return super().__prepare__(name, bases, **kwargs)
3237

33-
def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
38+
def __new__(metacls, name, bases, namespace, shape=None, view_class=None, **kwargs):
3439
if shape is not None:
3540
shape = Shape.cast(shape)
3641
# Prepare enumeration members for instantiation. This logic is unfortunately very
@@ -89,6 +94,8 @@ def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
8994
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that
9095
# the values of every member can be cast to the provided shape without truncation.
9196
cls._amaranth_shape_ = shape
97+
if view_class is not None:
98+
cls._amaranth_view_class_ = view_class
9299
else:
93100
# Shape is not provided explicitly. Behave the same as a standard enumeration;
94101
# the lack of `_amaranth_shape_` attribute is used to emit a warning when such
@@ -127,17 +134,32 @@ def as_shape(cls):
127134
return Shape._cast_plain_enum(cls)
128135

129136
def __call__(cls, value, *args, **kwargs):
130-
# :class:`py_enum.Enum` uses ``__call__()`` for type casting: ``E(x)`` returns
131-
# the enumeration member whose value equals ``x``. In this case, ``x`` must be a concrete
132-
# value.
133-
# Amaranth extends this to indefinite values, but conceptually the operation is the same:
134-
# :class:`View` calls :meth:`Enum.__call__` to go from a :class:`Value` to something
135-
# representing this enumeration with that value.
136-
# At the moment however, for historical reasons, this is just the value itself. This works
137-
# and is backwards-compatible but is limiting in that it does not allow us to e.g. catch
138-
# comparisons with enum members of the wrong type.
139-
if isinstance(value, Value):
140-
return value
137+
"""Cast the value to this enum type.
138+
139+
When given an integer constant, it returns the corresponding enum value, like a standard
140+
Python enumeration.
141+
142+
When given a :ref:`value-castable <lang-valuecasting>`, it is cast to a value, then wrapped
143+
in the ``view_class`` specified for this enum type (:class:`EnumView` for :class:`Enum`,
144+
:class:`FlagView` for :class:`Flag`, or a custom user-defined class). If the type has no
145+
``view_class`` (like :class:`IntEnum` or :class:`IntFlag`), a plain
146+
:class:`Value` is returned.
147+
148+
Returns
149+
-------
150+
instance of itself
151+
For integer values, or instances of itself.
152+
:class:`EnumView` or its subclass
153+
For value-castables, as defined by the ``view_class`` keyword argument.
154+
:class:`Value`
155+
For value-castables, when a view class is not specified for this enum.
156+
"""
157+
if isinstance(value, (Value, ValueCastable)):
158+
value = Value.cast(value)
159+
if cls._amaranth_view_class_ is None:
160+
return value
161+
else:
162+
return cls._amaranth_view_class_(cls, value)
141163
return super().__call__(value, *args, **kwargs)
142164

143165
def const(cls, init):
@@ -149,15 +171,15 @@ def const(cls, init):
149171
member = cls(0)
150172
else:
151173
member = cls(init)
152-
return Const(member.value, cls.as_shape())
174+
return cls(Const(member.value, cls.as_shape()))
153175

154176
def _value_repr(cls, value):
155177
yield Repr(FormatEnum(cls), value)
156178

157179

158180
class Enum(py_enum.Enum):
159181
"""Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as
160-
its metaclass."""
182+
its metaclass and :class:`EnumView` as its view class."""
161183

162184

163185
class IntEnum(py_enum.IntEnum):
@@ -167,16 +189,197 @@ class IntEnum(py_enum.IntEnum):
167189

168190
class Flag(py_enum.Flag):
169191
"""Subclass of the standard :class:`enum.Flag` that has :class:`EnumMeta` as
170-
its metaclass."""
192+
its metaclass and :class:`FlagView` as its view class."""
171193

172194

173195
class IntFlag(py_enum.IntFlag):
174196
"""Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as
175197
its metaclass."""
176198

199+
177200
# Fix up the metaclass after the fact: the metaclass __new__ requires these classes
178201
# to already be present, and also would not install itself on them due to lack of shape.
179202
Enum.__class__ = EnumMeta
180203
IntEnum.__class__ = EnumMeta
181204
Flag.__class__ = EnumMeta
182205
IntFlag.__class__ = EnumMeta
206+
207+
208+
class EnumView(ValueCastable):
209+
"""The view class used for :class:`Enum`.
210+
211+
Wraps a :class:`Value` and only allows type-safe operations. The only operators allowed are
212+
equality comparisons (``==`` and ``!=``) with another :class:`EnumView` of the same enum type.
213+
"""
214+
215+
def __init__(self, enum, target):
216+
"""Constructs a view with the given enum type and target
217+
(a :ref:`value-castable <lang-valuecasting>`).
218+
"""
219+
if not isinstance(enum, EnumMeta) or not hasattr(enum, "_amaranth_shape_"):
220+
raise TypeError(f"EnumView type must be an enum with shape, not {enum!r}")
221+
try:
222+
cast_target = Value.cast(target)
223+
except TypeError as e:
224+
raise TypeError("EnumView target must be a value-castable object, not {!r}"
225+
.format(target)) from e
226+
if cast_target.shape() != enum.as_shape():
227+
raise TypeError("EnumView target must have the same shape as the enum")
228+
self.enum = enum
229+
self.target = cast_target
230+
231+
def shape(self):
232+
"""Returns the underlying enum type."""
233+
return self.enum
234+
235+
@ValueCastable.lowermethod
236+
def as_value(self):
237+
"""Returns the underlying value."""
238+
return self.target
239+
240+
def eq(self, other):
241+
"""Assign to the underlying value.
242+
243+
Returns
244+
-------
245+
:class:`Assign`
246+
``self.as_value().eq(other)``
247+
"""
248+
return self.as_value().eq(other)
249+
250+
def __add__(self, other):
251+
raise TypeError("cannot perform arithmetic operations on non-IntEnum enum")
252+
253+
__radd__ = __add__
254+
__sub__ = __add__
255+
__rsub__ = __add__
256+
__mul__ = __add__
257+
__rmul__ = __add__
258+
__floordiv__ = __add__
259+
__rfloordiv__ = __add__
260+
__mod__ = __add__
261+
__rmod__ = __add__
262+
__lshift__ = __add__
263+
__rlshift__ = __add__
264+
__rshift__ = __add__
265+
__rrshift__ = __add__
266+
__lt__ = __add__
267+
__le__ = __add__
268+
__gt__ = __add__
269+
__ge__ = __add__
270+
271+
def __and__(self, other):
272+
raise TypeError("cannot perform bitwise operations on non-IntEnum non-Flag enum")
273+
274+
__rand__ = __and__
275+
__or__ = __and__
276+
__ror__ = __and__
277+
__xor__ = __and__
278+
__rxor__ = __and__
279+
280+
def __eq__(self, other):
281+
"""Compares the underlying value for equality.
282+
283+
The other operand has to be either another :class:`EnumView` with the same enum type, or
284+
a plain value of the underlying enum.
285+
286+
Returns
287+
-------
288+
:class:`Value`
289+
The result of the equality comparison, as a single-bit value.
290+
"""
291+
if isinstance(other, self.enum):
292+
other = self.enum(Value.cast(other))
293+
if not isinstance(other, EnumView) or other.enum is not self.enum:
294+
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
295+
return self.target == other.target
296+
297+
def __ne__(self, other):
298+
if isinstance(other, self.enum):
299+
other = self.enum(Value.cast(other))
300+
if not isinstance(other, EnumView) or other.enum is not self.enum:
301+
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
302+
return self.target != other.target
303+
304+
def __repr__(self):
305+
return f"{type(self).__name__}({self.enum.__name__}, {self.target!r})"
306+
307+
308+
class FlagView(EnumView):
309+
"""The view class used for :class:`Flag`.
310+
311+
In addition to the operations allowed by :class:`EnumView`, it allows bitwise operations among
312+
values of the same enum type."""
313+
314+
def __invert__(self):
315+
"""Inverts all flags in this value and returns another :ref:`FlagView`.
316+
317+
Note that this is not equivalent to applying bitwise negation to the underlying value:
318+
just like the Python :class:`enum.Flag` class, only bits corresponding to flags actually
319+
defined in the enumeration are included in the result.
320+
321+
Returns
322+
-------
323+
:class:`FlagView`
324+
"""
325+
if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP):
326+
return self.enum._amaranth_view_class_(self.enum, ~self.target)
327+
else:
328+
singles_mask = 0
329+
for flag in self.enum:
330+
if (flag.value & (flag.value - 1)) == 0:
331+
singles_mask |= flag.value
332+
return self.enum._amaranth_view_class_(self.enum, ~self.target & singles_mask)
333+
334+
def __bitop(self, other, op):
335+
if isinstance(other, self.enum):
336+
other = self.enum(Value.cast(other))
337+
if not isinstance(other, FlagView) or other.enum is not self.enum:
338+
raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type")
339+
return self.enum._amaranth_view_class_(self.enum, op(self.target, other.target))
340+
341+
def __and__(self, other):
342+
"""Performs a bitwise AND and returns another :class:`FlagView`.
343+
344+
The other operand has to be either another :class:`FlagView` of the same enum type, or
345+
a plain value of the underlying enum type.
346+
347+
Returns
348+
-------
349+
:class:`FlagView`
350+
"""
351+
return self.__bitop(other, operator.__and__)
352+
353+
def __or__(self, other):
354+
"""Performs a bitwise OR and returns another :class:`FlagView`.
355+
356+
The other operand has to be either another :class:`FlagView` of the same enum type, or
357+
a plain value of the underlying enum type.
358+
359+
Returns
360+
-------
361+
:class:`FlagView`
362+
"""
363+
return self.__bitop(other, operator.__or__)
364+
365+
def __xor__(self, other):
366+
"""Performs a bitwise XOR and returns another :class:`FlagView`.
367+
368+
The other operand has to be either another :class:`FlagView` of the same enum type, or
369+
a plain value of the underlying enum type.
370+
371+
Returns
372+
-------
373+
:class:`FlagView`
374+
"""
375+
return self.__bitop(other, operator.__xor__)
376+
377+
__rand__ = __and__
378+
__ror__ = __or__
379+
__rxor__ = __xor__
380+
381+
382+
Enum._amaranth_view_class_ = EnumView
383+
IntEnum._amaranth_view_class_ = None
384+
Flag._amaranth_view_class_ = FlagView
385+
IntFlag._amaranth_view_class_ = None

docs/changes.rst

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Implemented RFCs
6060
.. _RFC 20: https://amaranth-lang.org/rfcs/0020-deprecate-non-fwft-fifos.html
6161
.. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html
6262
.. _RFC 28: https://amaranth-lang.org/rfcs/0028-override-value-operators.html
63+
.. _RFC 31: https://amaranth-lang.org/rfcs/0031-enumeration-type-safety.html
6364

6465

6566
* `RFC 1`_: Aggregate data structure library
@@ -77,6 +78,7 @@ Implemented RFCs
7778
* `RFC 20`_: Deprecate non-FWFT FIFOs
7879
* `RFC 22`_: Define ``ValueCastable.shape()``
7980
* `RFC 28`_: Allow overriding ``Value`` operators
81+
* `RFC 31`_: Enumeration type safety
8082

8183

8284
Language changes

docs/stdlib/enum.rst

+59
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ A shape can be specified for an enumeration with the ``shape=`` keyword argument
2424

2525
>>> Shape.cast(Funct)
2626
unsigned(4)
27+
>>> Value.cast(Funct.ADD)
28+
(const 4'd0)
2729

2830
Any :ref:`constant-castable <lang-constcasting>` expression can be used as the value of a member:
2931

@@ -57,6 +59,57 @@ The ``shape=`` argument is optional. If not specified, classes from this module
5759

5860
In this way, this module is a drop-in replacement for the standard :mod:`enum` module, and in an Amaranth project, all ``import enum`` statements may be replaced with ``from amaranth.lib import enum``.
5961

62+
Signals with :class:`Enum` or :class:`Flag` based shape are automatically wrapped in the :class:`EnumView` or :class:`FlagView` value-castable wrappers, which ensure type safety. Any :ref:`value-castable <lang-valuecasting>` can also be explicitly wrapped in a view class by casting it to the enum type:
63+
64+
.. doctest::
65+
66+
>>> a = Signal(Funct)
67+
>>> b = Signal(Op)
68+
>>> type(a)
69+
<class 'amaranth.lib.enum.EnumView'>
70+
>>> a == b
71+
Traceback (most recent call last):
72+
File "<stdin>", line 1, in <module>
73+
TypeError: an EnumView can only be compared to value or other EnumView of the same enum type
74+
>>> c = Signal(4)
75+
>>> type(Funct(c))
76+
<class 'amaranth.lib.enum.EnumView'>
77+
78+
Like the standard Python :class:`enum.IntEnum` and :class:`enum.IntFlag` classes, the Amaranth :class:`IntEnum` and :class:`IntFlag` classes are loosely typed and will not be subject to wrapping in view classes:
79+
80+
.. testcode::
81+
82+
class TransparentEnum(enum.IntEnum, shape=unsigned(4)):
83+
FOO = 0
84+
BAR = 1
85+
86+
.. doctest::
87+
88+
>>> a = Signal(TransparentEnum)
89+
>>> type(a)
90+
<class 'amaranth.hdl.ast.Signal'>
91+
92+
It is also possible to define a custom view class for a given enum:
93+
94+
.. testcode::
95+
96+
class InstrView(enum.EnumView):
97+
def has_immediate(self):
98+
return (self == Instr.ADDI) | (self == Instr.SUBI)
99+
100+
class Instr(enum.Enum, shape=5, view_class=InstrView):
101+
ADD = Cat(Funct.ADD, Op.REG)
102+
ADDI = Cat(Funct.ADD, Op.IMM)
103+
SUB = Cat(Funct.SUB, Op.REG)
104+
SUBI = Cat(Funct.SUB, Op.IMM)
105+
106+
.. doctest::
107+
108+
>>> a = Signal(Instr)
109+
>>> type(a)
110+
<class 'InstrView'>
111+
>>> a.has_immediate()
112+
(| (== (sig a) (const 5'd16)) (== (sig a) (const 5'd17)))
60113

61114
Metaclass
62115
=========
@@ -71,3 +124,9 @@ Base classes
71124
.. autoclass:: IntEnum()
72125
.. autoclass:: Flag()
73126
.. autoclass:: IntFlag()
127+
128+
View classes
129+
============
130+
131+
.. autoclass:: EnumView()
132+
.. autoclass:: FlagView()

0 commit comments

Comments
 (0)