Skip to content

Implement RFC 36 #1344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ omit =

[report]
exclude_lines =
:nocov:
:nocov:
partial_branches =
:nobr:
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ __pycache__/
# coverage
/.coverage
/htmlcov
/coverage.xml

# tests
/tests/spec_*/
Expand Down
14 changes: 14 additions & 0 deletions amaranth/hdl/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, *, src_loc=None):
self.src_loc = src_loc
self.origins = None
self.domains_propagated_up = {}
self.domain_renames = {}

def add_domains(self, *domains):
for domain in flatten(domains):
Expand Down Expand Up @@ -655,6 +656,19 @@ def _assign_names(self, fragment: Fragment, hierarchy: "tuple[str]"):
subfragment_name = _add_name(frag_info.assigned_names, subfragment_name)
self._assign_names(subfragment, hierarchy=(*hierarchy, subfragment_name))

def lookup_domain(self, domain, context):
if domain == "comb":
raise KeyError("comb")
if context is not None:
try:
fragment = self.elaboratables[context]
except KeyError:
raise ValueError(f"Elaboratable {context!r} is not a part of the design")
else:
fragment = self.fragment
domain = fragment.domain_renames.get(domain, domain)
return fragment.domains[domain]


############################################################################################### >:3

Expand Down
13 changes: 13 additions & 0 deletions amaranth/hdl/_xfrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def map_statements(self, fragment, new_fragment):
for domain, statements in fragment.statements.items():
new_fragment.add_statements(domain, statements)

def map_domain_renames(self, fragment, new_fragment):
new_fragment.domain_renames = dict(fragment.domain_renames)

def map_memory_ports(self, fragment, new_fragment):
if hasattr(self, "on_value"):
for port in new_fragment._read_ports:
Expand Down Expand Up @@ -318,6 +321,7 @@ def on_fragment(self, fragment):
self.map_subfragments(fragment, new_fragment)
self.map_domains(fragment, new_fragment)
self.map_statements(fragment, new_fragment)
self.map_domain_renames(fragment, new_fragment)
return new_fragment

def __call__(self, value, *, src_loc_at=0):
Expand Down Expand Up @@ -513,6 +517,15 @@ def map_statements(self, fragment, new_fragment):
map(self.on_statement, statements)
)

def map_domain_renames(self, fragment, new_fragment):
new_fragment.domain_renames = {
src: self.domain_map.get(dst, dst)
for src, dst in fragment.domain_renames.items()
}
for src, dst in self.domain_map.items():
if src not in new_fragment.domain_renames:
new_fragment.domain_renames[src] = dst

def map_memory_ports(self, fragment, new_fragment):
super().map_memory_ports(fragment, new_fragment)
for port in new_fragment._read_ports:
Expand Down
6 changes: 5 additions & 1 deletion amaranth/sim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .core import *


__all__ = ["Settle", "Delay", "Tick", "Passive", "Active", "Simulator"]
__all__ = [
"DomainReset", "BrokenTrigger", "Simulator",
# deprecated
"Settle", "Delay", "Tick", "Passive", "Active",
]
293 changes: 293 additions & 0 deletions amaranth/sim/_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
import typing
import operator
from contextlib import contextmanager

from ..hdl import *
from ..hdl._ast import Slice
from ._base import BaseProcess, BaseEngine


__all__ = [
"DomainReset", "BrokenTrigger",
"SampleTrigger", "ChangedTrigger", "EdgeTrigger", "DelayTrigger",
"TriggerCombination", "TickTrigger",
"SimulatorContext", "ProcessContext", "TestbenchContext", "AsyncProcess",
]


class DomainReset(Exception):
"""Exception raised when the domain of a a tick trigger that is repeatedly awaited has its
reset asserted."""


class BrokenTrigger(Exception):
"""Exception raised when a trigger that is repeatedly awaited in an `async for` loop has
a matching event occur while the body of the `async for` loop is executing."""


class SampleTrigger:
def __init__(self, value):
self.value = Value.cast(value)
if isinstance(value, ValueCastable):
self.shape = value.shape()
else:
self.shape = self.value.shape()


class ChangedTrigger:
def __init__(self, signal):
cast_signal = Value.cast(signal)
if not isinstance(cast_signal, Signal):
raise TypeError(f"Change trigger can only be used with a signal, not {signal!r}")
self.shape = signal.shape()
self.signal = cast_signal

@property
def value(self):
return self.signal


class EdgeTrigger:
def __init__(self, signal, polarity):
cast_signal = Value.cast(signal)
if isinstance(cast_signal, Signal) and len(cast_signal) == 1:
self.signal, self.bit = cast_signal, 0
elif (isinstance(cast_signal, Slice) and
len(cast_signal) == 1 and
isinstance(cast_signal.value, Signal)):
self.signal, self.bit = cast_signal.value, cast_signal.start
else:
raise TypeError(f"Edge trigger can only be used with a single-bit signal or "
f"a single-bit slice of a signal, not {signal!r}")
if polarity not in (0, 1):
raise ValueError(f"Edge trigger polarity must be 0 or 1, not {polarity!r}")
self.polarity = polarity


class DelayTrigger:
def __init__(self, interval):
self.interval_fs = round(float(interval) * 1e15)


class TriggerCombination:
def __init__(self, engine: BaseEngine, process: BaseProcess, *,
triggers: 'tuple[DelayTrigger|ChangedTrigger|SampleTrigger|EdgeTrigger, ...]' = ()):
self._engine = engine
self._process = process # private but used by engines
self._triggers = triggers # private but used by engines

def sample(self, *values) -> 'TriggerCombination':
return TriggerCombination(self._engine, self._process, triggers=self._triggers +
tuple(SampleTrigger(value) for value in values))

def changed(self, *signals) -> 'TriggerCombination':
return TriggerCombination(self._engine, self._process, triggers=self._triggers +
tuple(ChangedTrigger(signal) for signal in signals))

def edge(self, signal, polarity) -> 'TriggerCombination':
return TriggerCombination(self._engine, self._process, triggers=self._triggers +
(EdgeTrigger(signal, polarity),))

def posedge(self, signal) -> 'TriggerCombination':
return self.edge(signal, 1)

def negedge(self, signal) -> 'TriggerCombination':
return self.edge(signal, 0)

def delay(self, interval) -> 'TriggerCombination':
return TriggerCombination(self._engine, self._process, triggers=self._triggers +
(DelayTrigger(interval),))

def __await__(self):
trigger = self._engine.add_trigger_combination(self, oneshot=True)
return trigger.__await__()

async def __aiter__(self):
trigger = self._engine.add_trigger_combination(self, oneshot=False)
while True:
yield await trigger


class TickTrigger:
def __init__(self, engine: BaseEngine, process: BaseProcess, *,
domain: ClockDomain, sampled: 'tuple[ValueLike]' = ()):
self._engine = engine
self._process = process
self._domain = domain
self._sampled = sampled

def sample(self, *values: ValueLike) -> 'TickTrigger':
return TickTrigger(self._engine, self._process,
domain=self._domain, sampled=(*self._sampled, *values))

async def until(self, condition: ValueLike):
if not isinstance(condition, ValueLike):
raise TypeError(f"Condition must be a value-like object, not {condition!r}")
tick = self.sample(condition).__aiter__()
done = False
while not done:
clk, rst, *values, done = await tick.__anext__()
if rst:
raise DomainReset
return tuple(values)

async def repeat(self, count: int):
count = operator.index(count)
if count <= 0:
raise ValueError(f"Repeat count must be a positive integer, not {count!r}")
tick = self.__aiter__()
for _ in range(count):
clk, rst, *values = await tick.__anext__()
if rst:
raise DomainReset
return tuple(values)

def _collect_trigger(self):
clk_polarity = (1 if self._domain.clk_edge == "pos" else 0)
if self._domain.async_reset and self._domain.rst is not None:
return (TriggerCombination(self._engine, self._process)
.edge(self._domain.clk, clk_polarity)
.edge(self._domain.rst, 1)
.sample(self._domain.rst)
.sample(*self._sampled))
else:
return (TriggerCombination(self._engine, self._process)
.edge(self._domain.clk, clk_polarity)
.sample(Const(0))
.sample(Const(0) if self._domain.rst is None else self._domain.rst)
.sample(*self._sampled))

def __await__(self):
trigger = self._engine.add_trigger_combination(self._collect_trigger(), oneshot=True)
clk_edge, rst_edge, rst_sample, *values = yield from trigger.__await__()
return (clk_edge, bool(rst_edge or rst_sample), *values)

async def __aiter__(self):
trigger = self._engine.add_trigger_combination(self._collect_trigger(), oneshot=False)
while True:
clk_edge, rst_edge, rst_sample, *values = await trigger
yield (clk_edge, bool(rst_edge or rst_sample), *values)


class SimulatorContext:
def __init__(self, design, engine: BaseEngine, process: BaseProcess):
self._design = design
self._engine = engine
self._process = process

def delay(self, interval) -> TriggerCombination:
return TriggerCombination(self._engine, self._process).delay(interval)

def changed(self, *signals) -> TriggerCombination:
return TriggerCombination(self._engine, self._process).changed(*signals)

def edge(self, signal, polarity) -> TriggerCombination:
return TriggerCombination(self._engine, self._process).edge(signal, polarity)

def posedge(self, signal) -> TriggerCombination:
return TriggerCombination(self._engine, self._process).posedge(signal)

def negedge(self, signal) -> TriggerCombination:
return TriggerCombination(self._engine, self._process).negedge(signal)

@typing.overload
def tick(self, domain: str, *, context: Elaboratable = None) -> TickTrigger: ... # :nocov:

@typing.overload
def tick(self, domain: ClockDomain) -> TickTrigger: ... # :nocov:

def tick(self, domain="sync", *, context=None):
if domain == "comb":
raise ValueError("Combinational domain does not have a clock")
if isinstance(domain, ClockDomain):
if context is not None:
raise ValueError("Context cannot be provided if a clock domain is specified "
"directly")
else:
domain = self._design.lookup_domain(domain, context)
return TickTrigger(self._engine, self._process, domain=domain)

@contextmanager
def critical(self):
try:
old_critical, self._process.critical = self._process.critical, True
yield
finally:
self._process.critical = old_critical


class ProcessContext(SimulatorContext):
def get(self, expr: ValueLike) -> 'typing.Never':
raise TypeError("`.get()` cannot be used to sample values in simulator processes; use "
"`.sample()` on a trigger object instead")

@typing.overload
def set(self, expr: Value, value: int) -> None: ... # :nocov:

@typing.overload
def set(self, expr: ValueCastable, value: typing.Any) -> None: ... # :nocov:

def set(self, expr, value):
if isinstance(expr, ValueCastable):
shape = expr.shape()
if isinstance(shape, ShapeCastable):
value = shape.const(value)
value = Const.cast(value).value
self._engine.set_value(expr, value)


class TestbenchContext(SimulatorContext):
@typing.overload
def get(self, expr: Value) -> int: ... # :nocov:

@typing.overload
def get(self, expr: ValueCastable) -> typing.Any: ... # :nocov:

def get(self, expr):
value = self._engine.get_value(expr)
if isinstance(expr, ValueCastable):
shape = expr.shape()
if isinstance(shape, ShapeCastable):
return shape.from_bits(value)
return value

@typing.overload
def set(self, expr: Value, value: int) -> None: ... # :nocov:

@typing.overload
def set(self, expr: ValueCastable, value: typing.Any) -> None: ... # :nocov:

def set(self, expr, value):
if isinstance(expr, ValueCastable):
shape = expr.shape()
if isinstance(shape, ShapeCastable):
value = shape.const(value)
value = Const.cast(value).value
self._engine.set_value(expr, value)
self._engine.step_design()


class AsyncProcess(BaseProcess):
def __init__(self, design, engine, constructor, *, testbench, background):
self.constructor = constructor
if testbench:
self.context = TestbenchContext(design, engine, self)
else:
self.context = ProcessContext(design, engine, self)
self.background = background

self.reset()

def reset(self):
self.runnable = True
self.critical = not self.background
self.waits_on = None
self.coroutine = self.constructor(self.context)

def run(self):
try:
self.waits_on = self.coroutine.send(None)
except StopIteration:
self.critical = False
self.waits_on = None
self.coroutine = None
Loading