Skip to content

Commit 229a4a8

Browse files
committed
Implement RFC 31: Enumeration type safety.
1 parent b0b193f commit 229a4a8

File tree

2 files changed

+306
-11
lines changed

2 files changed

+306
-11
lines changed

amaranth/lib/enum.py

+132-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import enum as py_enum
22
import warnings
3+
import operator
34

4-
from ..hdl.ast import Value, Shape, ShapeCastable, Const
5+
from ..hdl.ast import Value, ValueCastable, Shape, ShapeCastable, Const
56
from ..hdl._repr import *
67

78

8-
__all__ = py_enum.__all__
9+
__all__ = py_enum.__all__ + ["EnumView", "FlagView"]
910

1011

1112
for _member in py_enum.__all__:
@@ -27,10 +28,10 @@ class EnumMeta(ShapeCastable, py_enum.EnumMeta):
2728

2829
# TODO: remove this shim once py3.8 support is dropped
2930
@classmethod
30-
def __prepare__(metacls, name, bases, shape=None, **kwargs):
31+
def __prepare__(metacls, name, bases, shape=None, view_class=None, **kwargs):
3132
return super().__prepare__(name, bases, **kwargs)
3233

33-
def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
34+
def __new__(metacls, name, bases, namespace, shape=None, view_class=None, **kwargs):
3435
if shape is not None:
3536
shape = Shape.cast(shape)
3637
# Prepare enumeration members for instantiation. This logic is unfortunately very
@@ -89,6 +90,8 @@ def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
8990
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that
9091
# the values of every member can be cast to the provided shape without truncation.
9192
cls._amaranth_shape_ = shape
93+
if view_class is not None:
94+
cls._amaranth_view_class_ = view_class
9295
else:
9396
# Shape is not provided explicitly. Behave the same as a standard enumeration;
9497
# the lack of `_amaranth_shape_` attribute is used to emit a warning when such
@@ -136,8 +139,12 @@ def __call__(cls, value, *args, **kwargs):
136139
# At the moment however, for historical reasons, this is just the value itself. This works
137140
# and is backwards-compatible but is limiting in that it does not allow us to e.g. catch
138141
# comparisons with enum members of the wrong type.
139-
if isinstance(value, Value):
140-
return value
142+
if isinstance(value, (Value, ValueCastable)):
143+
value = Value.cast(value)
144+
if cls._amaranth_view_class_ is None:
145+
return value
146+
else:
147+
return cls._amaranth_view_class_(cls, value)
141148
return super().__call__(value, *args, **kwargs)
142149

143150
def const(cls, init):
@@ -149,7 +156,7 @@ def const(cls, init):
149156
member = cls(0)
150157
else:
151158
member = cls(init)
152-
return Const(member.value, cls.as_shape())
159+
return cls(Const(member.value, cls.as_shape()))
153160

154161
def _value_repr(cls, value):
155162
yield Repr(FormatEnum(cls), value)
@@ -174,9 +181,127 @@ class IntFlag(py_enum.IntFlag):
174181
"""Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as
175182
its metaclass."""
176183

184+
177185
# Fix up the metaclass after the fact: the metaclass __new__ requires these classes
178186
# to already be present, and also would not install itself on them due to lack of shape.
179187
Enum.__class__ = EnumMeta
180188
IntEnum.__class__ = EnumMeta
181189
Flag.__class__ = EnumMeta
182190
IntFlag.__class__ = EnumMeta
191+
192+
193+
class EnumView(ValueCastable):
194+
def __init__(self, enum, target):
195+
if not isinstance(enum, EnumMeta) or not hasattr(enum, "_amaranth_shape_"):
196+
raise TypeError(f"EnumView type must be an enum with shape, not {enum!r}")
197+
try:
198+
cast_target = Value.cast(target)
199+
except TypeError as e:
200+
raise TypeError("EnumView target must be a value-castable object, not {!r}"
201+
.format(target)) from e
202+
if cast_target.shape() != enum.as_shape():
203+
raise TypeError("EnumView target must have the same shape as the enum")
204+
self.enum = enum
205+
self.target = cast_target
206+
207+
def shape(self):
208+
return self.enum
209+
210+
@ValueCastable.lowermethod
211+
def as_value(self):
212+
return self.target
213+
214+
def eq(self, other):
215+
"""Assign to the underlying value.
216+
217+
Returns
218+
-------
219+
:class:`Assign`
220+
``self.as_value().eq(other)``
221+
"""
222+
return self.as_value().eq(other)
223+
224+
def __add__(self, other):
225+
raise TypeError("cannot perform arithmetic operations on non-IntEnum enum")
226+
227+
__radd__ = __add__
228+
__sub__ = __add__
229+
__rsub__ = __add__
230+
__mul__ = __add__
231+
__rmul__ = __add__
232+
__floordiv__ = __add__
233+
__rfloordiv__ = __add__
234+
__mod__ = __add__
235+
__rmod__ = __add__
236+
__lshift__ = __add__
237+
__rlshift__ = __add__
238+
__rshift__ = __add__
239+
__rrshift__ = __add__
240+
__lt__ = __add__
241+
__le__ = __add__
242+
__gt__ = __add__
243+
__ge__ = __add__
244+
245+
def __and__(self, other):
246+
raise TypeError("cannot perform bitwise operations on non-IntEnum non-Flag enum")
247+
248+
__rand__ = __and__
249+
__or__ = __and__
250+
__ror__ = __and__
251+
__xor__ = __and__
252+
__rxor__ = __and__
253+
254+
def __eq__(self, other):
255+
if isinstance(other, self.enum):
256+
other = self.enum(Value.cast(other))
257+
if not isinstance(other, EnumView) or other.enum is not self.enum:
258+
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
259+
return self.target == other.target
260+
261+
def __ne__(self, other):
262+
if isinstance(other, self.enum):
263+
other = self.enum(Value.cast(other))
264+
if not isinstance(other, EnumView) or other.enum is not self.enum:
265+
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
266+
return self.target != other.target
267+
268+
def __repr__(self):
269+
return f"{type(self).__name__}({self.enum.__name__}, {self.target!r})"
270+
271+
272+
class FlagView(EnumView):
273+
def __invert__(self):
274+
if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP):
275+
return self.enum._amaranth_view_class_(self.enum, ~self.target)
276+
else:
277+
singles_mask = 0
278+
for flag in self.enum:
279+
if (flag.value & (flag.value - 1)) == 0:
280+
singles_mask |= flag.value
281+
return self.enum._amaranth_view_class_(self.enum, ~self.target & singles_mask)
282+
283+
def __bitop(self, other, op):
284+
if isinstance(other, self.enum):
285+
other = self.enum(Value.cast(other))
286+
if not isinstance(other, FlagView) or other.enum is not self.enum:
287+
raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type")
288+
return self.enum._amaranth_view_class_(self.enum, op(self.target, other.target))
289+
290+
def __and__(self, other):
291+
return self.__bitop(other, operator.__and__)
292+
293+
def __or__(self, other):
294+
return self.__bitop(other, operator.__or__)
295+
296+
def __xor__(self, other):
297+
return self.__bitop(other, operator.__xor__)
298+
299+
__rand__ = __and__
300+
__ror__ = __or__
301+
__rxor__ = __xor__
302+
303+
304+
Enum._amaranth_view_class_ = EnumView
305+
IntEnum._amaranth_view_class_ = None
306+
Flag._amaranth_view_class_ = FlagView
307+
IntFlag._amaranth_view_class_ = None

tests/test_lib_enum.py

+174-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import enum as py_enum
2+
import operator
3+
import sys
24

35
from amaranth import *
4-
from amaranth.lib.enum import Enum, EnumMeta
6+
from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView
57

68
from .utils import *
79

@@ -103,9 +105,9 @@ def test_const_shape(self):
103105
class EnumA(Enum, shape=8):
104106
Z = 0
105107
A = 10
106-
self.assertRepr(EnumA.const(None), "(const 8'd0)")
107-
self.assertRepr(EnumA.const(10), "(const 8'd10)")
108-
self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)")
108+
self.assertRepr(EnumA.const(None), "EnumView(EnumA, (const 8'd0))")
109+
self.assertRepr(EnumA.const(10), "EnumView(EnumA, (const 8'd10))")
110+
self.assertRepr(EnumA.const(EnumA.A), "EnumView(EnumA, (const 8'd10))")
109111

110112
def test_shape_implicit_wrong_in_concat(self):
111113
class EnumA(Enum):
@@ -118,3 +120,171 @@ class EnumA(Enum):
118120

119121
def test_functional(self):
120122
Enum("FOO", ["BAR", "BAZ"])
123+
124+
def test_int_enum(self):
125+
class EnumA(IntEnum, shape=signed(4)):
126+
A = 0
127+
B = -3
128+
a = Signal(EnumA)
129+
self.assertRepr(a, "(sig a)")
130+
131+
def test_enum_view(self):
132+
class EnumA(Enum, shape=signed(4)):
133+
A = 0
134+
B = -3
135+
class EnumB(Enum, shape=signed(4)):
136+
C = 0
137+
D = 5
138+
a = Signal(EnumA)
139+
b = Signal(EnumB)
140+
c = Signal(EnumA)
141+
d = Signal(4)
142+
self.assertIsInstance(a, EnumView)
143+
self.assertIs(a.shape(), EnumA)
144+
self.assertRepr(a, "EnumView(EnumA, (sig a))")
145+
self.assertRepr(a.as_value(), "(sig a)")
146+
self.assertRepr(a.eq(c), "(eq (sig a) (sig c))")
147+
for op in [
148+
operator.__add__,
149+
operator.__sub__,
150+
operator.__mul__,
151+
operator.__floordiv__,
152+
operator.__mod__,
153+
operator.__lshift__,
154+
operator.__rshift__,
155+
operator.__and__,
156+
operator.__or__,
157+
operator.__xor__,
158+
operator.__lt__,
159+
operator.__le__,
160+
operator.__gt__,
161+
operator.__ge__,
162+
]:
163+
with self.assertRaises(TypeError):
164+
op(a, a)
165+
with self.assertRaises(TypeError):
166+
op(a, d)
167+
with self.assertRaises(TypeError):
168+
op(d, a)
169+
with self.assertRaises(TypeError):
170+
op(a, 3)
171+
with self.assertRaises(TypeError):
172+
op(a, EnumA.A)
173+
for op in [
174+
operator.__eq__,
175+
operator.__ne__,
176+
]:
177+
with self.assertRaises(TypeError):
178+
op(a, b)
179+
with self.assertRaises(TypeError):
180+
op(a, d)
181+
with self.assertRaises(TypeError):
182+
op(d, a)
183+
with self.assertRaises(TypeError):
184+
op(a, 3)
185+
with self.assertRaises(TypeError):
186+
op(a, EnumB.C)
187+
self.assertRepr(a == c, "(== (sig a) (sig c))")
188+
self.assertRepr(a != c, "(!= (sig a) (sig c))")
189+
self.assertRepr(a == EnumA.B, "(== (sig a) (const 4'sd-3))")
190+
self.assertRepr(EnumA.B == a, "(== (sig a) (const 4'sd-3))")
191+
self.assertRepr(a != EnumA.B, "(!= (sig a) (const 4'sd-3))")
192+
193+
def test_flag_view(self):
194+
class FlagA(Flag, shape=unsigned(4)):
195+
A = 1
196+
B = 4
197+
class FlagB(Flag, shape=unsigned(4)):
198+
C = 1
199+
D = 2
200+
a = Signal(FlagA)
201+
b = Signal(FlagB)
202+
c = Signal(FlagA)
203+
d = Signal(4)
204+
self.assertIsInstance(a, FlagView)
205+
self.assertRepr(a, "FlagView(FlagA, (sig a))")
206+
for op in [
207+
operator.__add__,
208+
operator.__sub__,
209+
operator.__mul__,
210+
operator.__floordiv__,
211+
operator.__mod__,
212+
operator.__lshift__,
213+
operator.__rshift__,
214+
operator.__lt__,
215+
operator.__le__,
216+
operator.__gt__,
217+
operator.__ge__,
218+
]:
219+
with self.assertRaises(TypeError):
220+
op(a, a)
221+
with self.assertRaises(TypeError):
222+
op(a, d)
223+
with self.assertRaises(TypeError):
224+
op(d, a)
225+
with self.assertRaises(TypeError):
226+
op(a, 3)
227+
with self.assertRaises(TypeError):
228+
op(a, FlagA.A)
229+
for op in [
230+
operator.__eq__,
231+
operator.__ne__,
232+
operator.__and__,
233+
operator.__or__,
234+
operator.__xor__,
235+
]:
236+
with self.assertRaises(TypeError):
237+
op(a, b)
238+
with self.assertRaises(TypeError):
239+
op(a, d)
240+
with self.assertRaises(TypeError):
241+
op(d, a)
242+
with self.assertRaises(TypeError):
243+
op(a, 3)
244+
with self.assertRaises(TypeError):
245+
op(a, FlagB.C)
246+
self.assertRepr(a == c, "(== (sig a) (sig c))")
247+
self.assertRepr(a != c, "(!= (sig a) (sig c))")
248+
self.assertRepr(a == FlagA.B, "(== (sig a) (const 4'd4))")
249+
self.assertRepr(FlagA.B == a, "(== (sig a) (const 4'd4))")
250+
self.assertRepr(a != FlagA.B, "(!= (sig a) (const 4'd4))")
251+
self.assertRepr(a | c, "FlagView(FlagA, (| (sig a) (sig c)))")
252+
self.assertRepr(a & c, "FlagView(FlagA, (& (sig a) (sig c)))")
253+
self.assertRepr(a ^ c, "FlagView(FlagA, (^ (sig a) (sig c)))")
254+
self.assertRepr(~a, "FlagView(FlagA, (& (~ (sig a)) (const 3'd5)))")
255+
self.assertRepr(a | FlagA.B, "FlagView(FlagA, (| (sig a) (const 4'd4)))")
256+
if sys.version_info >= (3, 11):
257+
class FlagC(Flag, shape=unsigned(4), boundary=py_enum.KEEP):
258+
A = 1
259+
B = 4
260+
e = Signal(FlagC)
261+
self.assertRepr(~e, "FlagView(FlagC, (~ (sig e)))")
262+
263+
def test_enum_view_wrong(self):
264+
class EnumA(Enum, shape=signed(4)):
265+
A = 0
266+
B = -3
267+
268+
a = Signal(2)
269+
with self.assertRaisesRegex(TypeError,
270+
r'^EnumView target must have the same shape as the enum$'):
271+
EnumA(a)
272+
with self.assertRaisesRegex(TypeError,
273+
r'^EnumView target must be a value-castable object, not .*$'):
274+
EnumView(EnumA, "a")
275+
276+
class EnumB(Enum):
277+
C = 0
278+
D = 1
279+
with self.assertRaisesRegex(TypeError,
280+
r'^EnumView type must be an enum with shape, not .*$'):
281+
EnumView(EnumB, 3)
282+
283+
def test_enum_view_custom(self):
284+
class CustomView(EnumView):
285+
pass
286+
class EnumA(Enum, view_class=CustomView, shape=unsigned(2)):
287+
A = 0
288+
B = 1
289+
a = Signal(EnumA)
290+
assert isinstance(a, CustomView)

0 commit comments

Comments
 (0)