Skip to content

Commit dce576d

Browse files
committed
Implement RFC 31: Enumeration type safety.
1 parent abd74ea commit dce576d

File tree

2 files changed

+260
-8
lines changed

2 files changed

+260
-8
lines changed

amaranth/lib/enum.py

+121-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import enum as py_enum
22
import warnings
3+
import operator
34

4-
from ..hdl.ast import Value, Shape, ShapeCastable, Const
5+
from ..hdl.ast import Value, ValueCastable, Shape, ShapeCastable, Const
56

67

7-
__all__ = py_enum.__all__
8+
__all__ = py_enum.__all__ + ["EnumView", "FlagView"]
89

910

1011
for _member in py_enum.__all__:
@@ -136,7 +137,12 @@ def __call__(cls, value, *args, **kwargs):
136137
# and is backwards-compatible but is limiting in that it does not allow us to e.g. catch
137138
# comparisons with enum members of the wrong type.
138139
if isinstance(value, Value):
139-
return value
140+
if issubclass(cls, (IntEnum, IntFlag)):
141+
return value
142+
elif issubclass(cls, Flag):
143+
return FlagView(cls, value)
144+
else:
145+
return EnumView(cls, value)
140146
return super().__call__(value, *args, **kwargs)
141147

142148
def const(cls, init):
@@ -148,7 +154,7 @@ def const(cls, init):
148154
member = cls(0)
149155
else:
150156
member = cls(init)
151-
return Const(member.value, cls.as_shape())
157+
return cls(Const(member.value, cls.as_shape()))
152158

153159

154160
class Enum(py_enum.Enum):
@@ -176,3 +182,114 @@ class IntFlag(py_enum.IntFlag):
176182
IntEnum.__class__ = EnumMeta
177183
Flag.__class__ = EnumMeta
178184
IntFlag.__class__ = EnumMeta
185+
186+
187+
class EnumView(ValueCastable):
188+
def __init__(self, enum, target):
189+
if not isinstance(enum, EnumMeta) or not hasattr(enum, "_amaranth_shape_"):
190+
raise TypeError(f"EnumView enum must be an enum, not {enum!r}")
191+
try:
192+
cast_target = Value.cast(target)
193+
except TypeError as e:
194+
raise TypeError("EnumView target must be a value-castable object, not {!r}"
195+
.format(target)) from e
196+
if cast_target.shape() != enum.as_shape():
197+
raise TypeError("EnumView target must have the same shape as the enum")
198+
self.enum = enum
199+
self.target = cast_target
200+
201+
def shape(self):
202+
return self.enum
203+
204+
@ValueCastable.lowermethod
205+
def as_value(self):
206+
return self.target
207+
208+
def eq(self, other):
209+
"""Assign to the underlying value.
210+
211+
Returns
212+
-------
213+
:class:`Assign`
214+
``self.as_value().eq(other)``
215+
"""
216+
return self.as_value().eq(other)
217+
218+
def __add__(self, other):
219+
raise TypeError("cannot perform arithmetic operations on non-IntEnum enum")
220+
221+
__radd__ = __add__
222+
__sub__ = __add__
223+
__rsub__ = __add__
224+
__mul__ = __add__
225+
__rmul__ = __add__
226+
__floordiv__ = __add__
227+
__rfloordiv__ = __add__
228+
__mod__ = __add__
229+
__rmod__ = __add__
230+
__lshift__ = __add__
231+
__rlshift__ = __add__
232+
__rshift__ = __add__
233+
__rrshift__ = __add__
234+
__lt__ = __add__
235+
__le__ = __add__
236+
__gt__ = __add__
237+
__ge__ = __add__
238+
239+
def __and__(self, other):
240+
raise TypeError("cannot perform bitwise operations on non-IntEnum non-Flag enum")
241+
242+
__rand__ = __and__
243+
__or__ = __and__
244+
__ror__ = __and__
245+
__xor__ = __and__
246+
__rxor__ = __and__
247+
248+
def __eq__(self, other):
249+
if isinstance(other, self.enum):
250+
other = self.enum(Value.cast(other))
251+
if not isinstance(other, EnumView) or other.enum is not self.enum:
252+
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
253+
return self.target == other.target
254+
255+
def __ne__(self, other):
256+
if isinstance(other, self.enum):
257+
other = self.enum(Value.cast(other))
258+
if not isinstance(other, EnumView) or other.enum is not self.enum:
259+
raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type")
260+
return self.target != other.target
261+
262+
def __repr__(self):
263+
return f"({type(self).__name__} {self.enum.__name__} {self.target!r})"
264+
265+
266+
class FlagView(EnumView):
267+
def __invert__(self):
268+
if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP):
269+
return FlagView(self.enum, ~self.target)
270+
else:
271+
singles_mask = 0
272+
for flag in self.enum:
273+
if (flag.value & (flag.value - 1)) == 0:
274+
singles_mask |= flag.value
275+
return FlagView(self.enum, ~self.target & singles_mask)
276+
277+
def __bitop(self, other, op):
278+
if isinstance(other, self.enum):
279+
other = self.enum(Value.cast(other))
280+
if not isinstance(other, FlagView) or other.enum is not self.enum:
281+
raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type")
282+
return FlagView(self.enum, op(self.target, other.target))
283+
284+
def __and__(self, other):
285+
return self.__bitop(other, operator.__and__)
286+
287+
def __or__(self, other):
288+
return self.__bitop(other, operator.__or__)
289+
290+
def __xor__(self, other):
291+
return self.__bitop(other, operator.__xor__)
292+
293+
__rand__ = __and__
294+
__ror__ = __or__
295+
__rxor__ = __xor__

tests/test_lib_enum.py

+139-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import enum as py_enum
2+
import operator
23

34
from amaranth import *
4-
from amaranth.lib.enum import Enum, EnumMeta
5+
from amaranth.lib.enum import Enum, EnumMeta, Flag, EnumView, FlagView
56

67
from .utils import *
78

@@ -103,9 +104,9 @@ def test_const_shape(self):
103104
class EnumA(Enum, shape=8):
104105
Z = 0
105106
A = 10
106-
self.assertRepr(EnumA.const(None), "(const 8'd0)")
107-
self.assertRepr(EnumA.const(10), "(const 8'd10)")
108-
self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)")
107+
self.assertRepr(EnumA.const(None), "(EnumView EnumA (const 8'd0))")
108+
self.assertRepr(EnumA.const(10), "(EnumView EnumA (const 8'd10))")
109+
self.assertRepr(EnumA.const(EnumA.A), "(EnumView EnumA (const 8'd10))")
109110

110111
def test_shape_implicit_wrong_in_concat(self):
111112
class EnumA(Enum):
@@ -118,3 +119,137 @@ class EnumA(Enum):
118119

119120
def test_functional(self):
120121
Enum("FOO", ["BAR", "BAZ"])
122+
123+
def test_enum_view(self):
124+
class EnumA(Enum, shape=signed(4)):
125+
A = 0
126+
B = -3
127+
class EnumB(Enum, shape=signed(4)):
128+
C = 0
129+
D = 5
130+
a = Signal(EnumA)
131+
b = Signal(EnumB)
132+
c = Signal(EnumA)
133+
d = Signal(4)
134+
self.assertIsInstance(a, EnumView)
135+
for op in [
136+
operator.__add__,
137+
operator.__sub__,
138+
operator.__mul__,
139+
operator.__floordiv__,
140+
operator.__mod__,
141+
operator.__lshift__,
142+
operator.__rshift__,
143+
operator.__and__,
144+
operator.__or__,
145+
operator.__xor__,
146+
operator.__lt__,
147+
operator.__le__,
148+
operator.__gt__,
149+
operator.__ge__,
150+
]:
151+
with self.assertRaises(TypeError):
152+
op(a, a)
153+
with self.assertRaises(TypeError):
154+
op(a, d)
155+
with self.assertRaises(TypeError):
156+
op(d, a)
157+
with self.assertRaises(TypeError):
158+
op(a, 3)
159+
with self.assertRaises(TypeError):
160+
op(a, EnumA.A)
161+
for op in [
162+
operator.__eq__,
163+
operator.__ne__,
164+
]:
165+
with self.assertRaises(TypeError):
166+
op(a, b)
167+
with self.assertRaises(TypeError):
168+
op(a, d)
169+
with self.assertRaises(TypeError):
170+
op(d, a)
171+
with self.assertRaises(TypeError):
172+
op(a, 3)
173+
with self.assertRaises(TypeError):
174+
op(a, EnumB.C)
175+
v = a == c
176+
self.assertEqual(repr(v), "(== (sig a) (sig c))")
177+
v = a != c
178+
self.assertEqual(repr(v), "(!= (sig a) (sig c))")
179+
v = a == EnumA.B
180+
self.assertEqual(repr(v), "(== (sig a) (const 4'sd-3))")
181+
v = EnumA.B == a
182+
self.assertEqual(repr(v), "(== (sig a) (const 4'sd-3))")
183+
v = a != EnumA.B
184+
self.assertEqual(repr(v), "(!= (sig a) (const 4'sd-3))")
185+
186+
def test_flag_view(self):
187+
class FlagA(Flag, shape=unsigned(4)):
188+
A = 1
189+
B = 4
190+
class FlagB(Flag, shape=unsigned(4)):
191+
C = 1
192+
D = 2
193+
a = Signal(FlagA)
194+
b = Signal(FlagB)
195+
c = Signal(FlagA)
196+
d = Signal(4)
197+
self.assertIsInstance(a, FlagView)
198+
for op in [
199+
operator.__add__,
200+
operator.__sub__,
201+
operator.__mul__,
202+
operator.__floordiv__,
203+
operator.__mod__,
204+
operator.__lshift__,
205+
operator.__rshift__,
206+
operator.__lt__,
207+
operator.__le__,
208+
operator.__gt__,
209+
operator.__ge__,
210+
]:
211+
with self.assertRaises(TypeError):
212+
op(a, a)
213+
with self.assertRaises(TypeError):
214+
op(a, d)
215+
with self.assertRaises(TypeError):
216+
op(d, a)
217+
with self.assertRaises(TypeError):
218+
op(a, 3)
219+
with self.assertRaises(TypeError):
220+
op(a, FlagA.A)
221+
for op in [
222+
operator.__eq__,
223+
operator.__ne__,
224+
operator.__and__,
225+
operator.__or__,
226+
operator.__xor__,
227+
]:
228+
with self.assertRaises(TypeError):
229+
op(a, b)
230+
with self.assertRaises(TypeError):
231+
op(a, d)
232+
with self.assertRaises(TypeError):
233+
op(d, a)
234+
with self.assertRaises(TypeError):
235+
op(a, 3)
236+
with self.assertRaises(TypeError):
237+
op(a, FlagB.C)
238+
v = a == c
239+
self.assertEqual(repr(v), "(== (sig a) (sig c))")
240+
v = a != c
241+
self.assertEqual(repr(v), "(!= (sig a) (sig c))")
242+
v = a == FlagA.B
243+
self.assertEqual(repr(v), "(== (sig a) (const 4'd4))")
244+
v = FlagA.B == a
245+
self.assertEqual(repr(v), "(== (sig a) (const 4'd4))")
246+
v = a != FlagA.B
247+
self.assertEqual(repr(v), "(!= (sig a) (const 4'd4))")
248+
v = a | c
249+
self.assertEqual(repr(v), "(FlagView FlagA (| (sig a) (sig c)))")
250+
v = a & c
251+
self.assertEqual(repr(v), "(FlagView FlagA (& (sig a) (sig c)))")
252+
v = a ^ c
253+
self.assertEqual(repr(v), "(FlagView FlagA (^ (sig a) (sig c)))")
254+
v = ~a
255+
self.assertEqual(repr(v), "(FlagView FlagA (& (~ (sig a)) (const 3'd5)))")

0 commit comments

Comments
 (0)