Skip to content

hdl._ast: Make AST nodes immutable. #1165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 186 additions & 64 deletions amaranth/hdl/_ast.py
Original file line number Diff line number Diff line change
@@ -1462,26 +1462,38 @@ def cast(obj):

def __init__(self, value, shape=None, *, src_loc_at=0):
# We deliberately do not call Value.__init__ here.
self.value = int(operator.index(value))
value = int(operator.index(value))
if shape is None:
shape = Shape(bits_for(self.value), signed=self.value < 0)
shape = Shape(bits_for(value), signed=value < 0)
elif isinstance(shape, int):
shape = Shape(shape, signed=self.value < 0)
shape = Shape(shape, signed=value < 0)
else:
if isinstance(shape, range) and self.value == shape.stop:
if isinstance(shape, range) and value == shape.stop:
warnings.warn(
message="Value {!r} equals the non-inclusive end of the constant "
"shape {!r}; this is likely an off-by-one error"
.format(self.value, shape),
message=f"Value {value!r} equals the non-inclusive end of the constant "
f"shape {shape!r}; this is likely an off-by-one error",
category=SyntaxWarning,
stacklevel=3)
shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
self.width = shape.width
self.signed = shape.signed
if self.signed and self.value >> (self.width - 1) & 1:
self.value |= -(1 << self.width)
self._width = shape.width
self._signed = shape.signed
if shape.signed and value >> (shape.width - 1) & 1:
value |= -(1 << shape.width)
else:
self.value &= (1 << self.width) - 1
value &= (1 << shape.width) - 1
self._value = value

@property
def value(self):
return self._value

@property
def width(self):
return self._width

@property
def signed(self):
return self._signed

def shape(self):
return Shape(self.width, self.signed)
@@ -1500,8 +1512,16 @@ def __repr__(self):
class Operator(Value):
def __init__(self, operator, operands, *, src_loc_at=0):
super().__init__(src_loc_at=1 + src_loc_at)
self.operator = operator
self.operands = [Value.cast(op) for op in operands]
self._operator = operator
self._operands = tuple(Value.cast(op) for op in operands)

@property
def operator(self):
return self._operator

@property
def operands(self):
return self._operands

def shape(self):
def _bitwise_binary_shape(a_shape, b_shape):
@@ -1614,9 +1634,21 @@ def __init__(self, value, start, stop, *, src_loc_at=0):
raise IndexError(f"Slice start {start} must be less than slice stop {stop}")

super().__init__(src_loc_at=src_loc_at)
self.value = value
self.start = int(operator.index(start))
self.stop = int(operator.index(stop))
self._value = value
self._start = int(operator.index(start))
self._stop = int(operator.index(stop))

@property
def value(self):
return self._value

@property
def start(self):
return self._start

@property
def stop(self):
return self._stop

def shape(self):
return Shape(self.stop - self.start)
@@ -1645,10 +1677,26 @@ def __init__(self, value, offset, width, stride=1, *, src_loc_at=0):
raise TypeError("Part offset must be unsigned")

super().__init__(src_loc_at=src_loc_at)
self.value = value
self.offset = offset
self.width = width
self.stride = stride
self._value = value
self._offset = offset
self._width = width
self._stride = stride

@property
def value(self):
return self._value

@property
def offset(self):
return self._offset

@property
def width(self):
return self._width

@property
def stride(self):
return self._stride

def shape(self):
return Shape(self.width)
@@ -1691,7 +1739,7 @@ class Cat(Value):
"""
def __init__(self, *args, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.parts = []
parts = []
for index, arg in enumerate(flatten(args)):
if isinstance(arg, Enum) and (not isinstance(type(arg), ShapeCastable) or
not hasattr(arg, "_amaranth_shape_")):
@@ -1706,7 +1754,12 @@ def __init__(self, *args, src_loc_at=0):
"context; specify the width explicitly using C({}, {})"
.format(index + 1, arg, arg, bits_for(arg)),
SyntaxWarning, stacklevel=2 + src_loc_at)
self.parts.append(Value.cast(arg))
parts.append(Value.cast(arg))
self._parts = tuple(parts)

@property
def parts(self):
return self._parts

def shape(self):
return Shape(sum(len(part) for part in self.parts))
@@ -1784,8 +1837,8 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F
shape = unsigned(1)
else:
shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
self.width = shape.width
self.signed = shape.signed
self._width = shape.width
self._signed = shape.signed

# TODO(amaranth-0.7): remove
if reset is not None:
@@ -1831,8 +1884,8 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F
.format(orig_init, shape),
category=SyntaxWarning,
stacklevel=2)
self.init = init.value
self.reset_less = bool(reset_less)
self._init = init.value
self._reset_less = bool(reset_less)

if isinstance(orig_shape, range) and orig_init is not None and orig_init not in orig_shape:
if orig_init == orig_shape.stop:
@@ -1843,21 +1896,21 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F
raise SyntaxError(
f"Initial value {orig_init!r} is not within the signal shape {orig_shape!r}")

self.attrs = OrderedDict(() if attrs is None else attrs)
self._attrs = OrderedDict(() if attrs is None else attrs)

if decoder is not None:
# The value representation is specified explicitly. Since we do not expose `hdl._repr`,
# this is the only way to add a custom filter to the signal right now. The setter sets
# `self._value_repr` as well as the compatibility `self.decoder`.
self.decoder = decoder
pass
else:
# If it's an enum, expose it via `self.decoder` for compatibility, whether it's a Python
# enum or an Amaranth enum. This also sets the value representation, even for custom
# shape-castables that implement their own `_value_repr`.
if isinstance(orig_shape, type) and issubclass(orig_shape, Enum):
self.decoder = orig_shape
decoder = orig_shape
else:
self.decoder = None
decoder = None
# The value representation is specified implicitly in the shape of the signal.
if isinstance(orig_shape, ShapeCastable):
# A custom shape-castable always has a `_value_repr`, at least the default one.
@@ -1869,24 +1922,6 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F
# Any other case is formatted as a plain integer.
self._value_repr = (Repr(FormatInt(), self),)

@property
def reset(self):
warnings.warn("`Signal.reset` is deprecated, use `Signal.init` instead",
DeprecationWarning, stacklevel=2)
return self.init

@reset.setter
def reset(self, value):
warnings.warn("`Signal.reset` is deprecated, use `Signal.init` instead",
DeprecationWarning, stacklevel=2)
self.init = value

@property
def decoder(self):
return self._decoder

@decoder.setter
def decoder(self, decoder):
# Compute the value representation that will be used by Amaranth.
if decoder is None:
self._value_repr = (Repr(FormatInt(), self),)
@@ -1903,6 +1938,37 @@ def enum_decoder(value):
return str(value)
self._decoder = enum_decoder

@property
def width(self):
return self._width

@property
def signed(self):
return self._signed

@property
def init(self):
return self._init

@property
def reset(self):
warnings.warn("`Signal.reset` is deprecated, use `Signal.init` instead",
DeprecationWarning, stacklevel=2)
return self._init

@property
def reset_less(self):
return self._reset_less

@property
def attrs(self):
# Would ideally be frozendict...
return self._attrs

@property
def decoder(self):
return self._decoder

@classmethod
def like(cls, other, *, name=None, name_suffix=None, init=None, reset=None, src_loc_at=0, **kwargs):
"""Create Signal based on another.
@@ -1970,7 +2036,11 @@ def __init__(self, domain="sync", *, src_loc_at=0):
raise TypeError(f"Clock domain name must be a string, not {domain!r}")
if domain == "comb":
raise ValueError(f"Domain '{domain}' does not have a clock")
self.domain = domain
self._domain = domain

@property
def domain(self):
return self._domain

def shape(self):
return Shape(1)
@@ -2006,8 +2076,16 @@ def __init__(self, domain="sync", allow_reset_less=False, *, src_loc_at=0):
raise TypeError(f"Clock domain name must be a string, not {domain!r}")
if domain == "comb":
raise ValueError(f"Domain '{domain}' does not have a reset")
self.domain = domain
self.allow_reset_less = allow_reset_less
self._domain = domain
self._allow_reset_less = allow_reset_less

@property
def domain(self):
return self._domain

@property
def allow_reset_less(self):
return self._allow_reset_less

def shape(self):
return Shape(1)
@@ -2032,8 +2110,16 @@ def __init__(self, kind, shape, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.kind = self.Kind(kind)
shape = Shape.cast(shape, src_loc_at=1 + src_loc_at)
self.width = shape.width
self.signed = shape.signed
self._width = shape.width
self._signed = shape.signed

@property
def width(self):
return self._width

@property
def signed(self):
return self._signed

def shape(self):
return Shape(self.width, self.signed)
@@ -2147,8 +2233,16 @@ def __repr__(self):
class ArrayProxy(Value):
def __init__(self, elems, index, *, src_loc_at=0):
super().__init__(src_loc_at=1 + src_loc_at)
self.elems = elems
self.index = Value.cast(index)
self._elems = elems
self._index = Value.cast(index)

@property
def elems(self):
return self._elems

@property
def index(self):
return self._index

def __getattr__(self, attr):
return ArrayProxy([getattr(elem, attr) for elem in self.elems], self.index)
@@ -2245,8 +2339,16 @@ def cast(obj):
class Assign(Statement):
def __init__(self, lhs, rhs, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.lhs = Value.cast(lhs)
self.rhs = Value.cast(rhs)
self._lhs = Value.cast(lhs)
self._rhs = Value.cast(rhs)

@property
def lhs(self):
return self._lhs

@property
def rhs(self):
return self._rhs

def _lhs_signals(self):
return self.lhs._lhs_signals()
@@ -2273,13 +2375,25 @@ class Kind(Enum):

def __init__(self, kind, test, *, name=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
self.kind = self.Kind(kind)
self.test = Value.cast(test)
self.name = name
self._kind = self.Kind(kind)
self._test = Value.cast(test)
self._name = name
if not isinstance(self.name, str) and self.name is not None:
raise TypeError("Property name must be a string or None, not {!r}"
.format(self.name))

@property
def kind(self):
return self._kind

@property
def test(self):
return self._test

@property
def name(self):
return self._name

def _lhs_signals(self):
return set()

@@ -2322,8 +2436,8 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={})
# be automatically traced, so whatever constructs a Switch may optionally provide it.
self.case_src_locs = {}

self.test = Value.cast(test)
self.cases = OrderedDict()
self._test = Value.cast(test)
self._cases = OrderedDict()
for orig_keys, stmts in cases.items():
# Map: None -> (); key -> (key,); (key...) -> (key...)
keys = orig_keys
@@ -2354,10 +2468,18 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={})
new_keys = (*new_keys, key)
if not isinstance(stmts, Iterable):
stmts = [stmts]
self.cases[new_keys] = Statement.cast(stmts)
self._cases[new_keys] = Statement.cast(stmts)
if orig_keys in case_src_locs:
self.case_src_locs[new_keys] = case_src_locs[orig_keys]

@property
def test(self):
return self._test

@property
def cases(self):
return self._cases

def _lhs_signals(self):
return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())

4 changes: 0 additions & 4 deletions tests/test_hdl_ast.py
Original file line number Diff line number Diff line change
@@ -1237,10 +1237,6 @@ def test_reset(self):
with self.assertWarnsRegex(DeprecationWarning,
r"^`Signal.reset` is deprecated, use `Signal.init` instead$"):
self.assertEqual(s1.reset, 0b111)
with self.assertWarnsRegex(DeprecationWarning,
r"^`Signal.reset` is deprecated, use `Signal.init` instead$"):
s1.reset = 0b010
self.assertEqual(s1.init, 0b010)
with self.assertWarnsRegex(DeprecationWarning,
r"^`reset=` is deprecated, use `init=` instead$"):
s2 = Signal.like(s1, reset=3)
14 changes: 0 additions & 14 deletions tests/test_hdl_rec.py
Original file line number Diff line number Diff line change
@@ -194,20 +194,6 @@ def test_like(self):
r4 = Record.like(r1, name_suffix="foo")
self.assertEqual(r4.name, "r1foo")

def test_like_modifications(self):
r1 = Record([("a", 1), ("b", [("s", 1)])])
self.assertEqual(r1.a.name, "r1__a")
self.assertEqual(r1.b.name, "r1__b")
self.assertEqual(r1.b.s.name, "r1__b__s")
r1.a.init = 1
r1.b.s.init = 1
r2 = Record.like(r1)
self.assertEqual(r2.a.init, 1)
self.assertEqual(r2.b.s.init, 1)
self.assertEqual(r2.a.name, "r2__a")
self.assertEqual(r2.b.name, "r2__b")
self.assertEqual(r2.b.s.name, "r2__b__s")

def test_slice_tuple(self):
r1 = Record([("a", 1), ("b", 2), ("c", 3)])
r2 = r1["a", "c"]